1// RUN: mlir-opt %s --transform-interpreter --split-input-file -verify-diagnostics | FileCheck %s 2 3module attributes {transform.with_named_sequence} { 4 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 5 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 6 %1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op 7 transform.yield 8 } 9} 10 11func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 12 13// CHECK: #[[$ADD_42_MAP:.+]] = affine_map<()[s0] -> (s0 + 42)> 14 15// CHECK-LABEL: @one_d_static 16// CHECK-SAME: %[[IN:.+]]: tensor<100xf32>, %[[OUT:.+]]: tensor<100xf32> 17func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 18 // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [42] [1] : tensor<100xf32> to tensor<42xf32> 19 // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [42] [1] : tensor<100xf32> to tensor<42xf32> 20 // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic 21 // CHECK: ins(%[[IN_SLICE_LOW]] 22 // CHECK: outs(%[[OUT_SLICE_LOW]] 23 // CHECK: linalg.index 0 24 // CHECK: func.call @elem 25 // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [42] [1] 26 // 27 // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][42] [58] [1] : tensor<100xf32> to tensor<58xf32> 28 // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][42] [58] [1] : tensor<100xf32> to tensor<58xf32> 29 // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic 30 // CHECK: ins(%[[IN_SLICE_HIGH]] 31 // CHECK: outs(%[[OUT_SLICE_HIGH]] 32 // CHECK: %[[IDX:.+]] = linalg.index 0 33 // CHECK: affine.apply #[[$ADD_42_MAP]]()[%[[IDX]]] 34 // CHECK: func.call @elem 35 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][42] [58] [1] 36 %0 = linalg.generic { 37 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 38 iterator_types = ["parallel"] 39 } 40 ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { 41 ^bb0(%0: f32, %1: f32): 42 %i = linalg.index 0 : index 43 %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32 44 linalg.yield %call_res : f32 45 } -> tensor<100xf32> 46 47 // CHECK: return %[[RES]] 48 return %0 : tensor<100xf32> 49} 50 51// ----- 52 53module attributes {transform.with_named_sequence} { 54 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 55 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 56 %1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op 57 transform.yield 58 } 59} 60 61func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 62 63// CHECK-LABEL: @one_d_static_overflow 64// CHECK-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32> 65func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { 66 // Folding is sufficiently powerful to detect the static overflow and avoid 67 // the splitting altogether. 68 // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic 69 // CHECK: ins(%[[IN]] 70 // CHECK: outs(%[[OUT]] 71 // CHECK: linalg.index 0 72 // CHECK: func.call @elem 73 %0 = linalg.generic { 74 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 75 iterator_types = ["parallel"] 76 } 77 ins(%arg0: tensor<10xf32>) outs(%arg1: tensor<10xf32>) { 78 ^bb0(%0: f32, %1: f32): 79 %i = linalg.index 0 : index 80 %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32 81 linalg.yield %call_res : f32 82 } -> tensor<10xf32> 83 return %0 : tensor<10xf32> 84} 85 86// ----- 87 88module attributes {transform.with_named_sequence} { 89 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 90 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 91 %1 = transform.structured.match ops{["func.call"]} in %arg1 : (!transform.any_op) -> !transform.any_op 92 transform.structured.split %0 after %1 { dimension = 0 } : !transform.any_op, !transform.any_op 93 transform.yield 94 } 95} 96 97func.func private @get_size() -> index 98 99// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<()[s0] -> (s0, 100)> 100// CHECK: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)> 101 102// CHECK-LABEL: @dynamic 103func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 104 // CHECK: %[[SPLIT:.+]] = call @get_size 105 // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]]()[%[[SPLIT]] 106 // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] 107 // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32> 108 // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32> 109 // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic 110 // CHECK: ins(%[[IN_SLICE_LOW]] 111 // CHECK: outs(%[[OUT_SLICE_LOW]] 112 // CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1] 113 // 114 // CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] 115 // CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] 116 // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor<?xf32> 117 // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[PARTIAL:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] : tensor<100xf32> to tensor<?xf32> 118 // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic 119 // CHECK: ins(%[[IN_SLICE_HIGH]] 120 // CHECK: outs(%[[OUT_SLICE_HIGH]] 121 // CHECK: %[[SPLIT_HIGH_4:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] 122 // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_4]]] [1] 123 %0 = func.call @get_size() : () -> index 124 %1 = linalg.generic { 125 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 126 iterator_types = ["parallel"] 127 } 128 ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { 129 ^bb0(%3: f32, %4: f32): 130 %5 = arith.addf %3, %4 : f32 131 linalg.yield %5 : f32 132 } -> tensor<100xf32> 133 return %1 : tensor<100xf32> 134} 135 136// ----- 137 138module attributes {transform.with_named_sequence} { 139 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 140 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 141 %t = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op 142 %1:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 143 %2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op 144 transform.yield 145 } 146} 147 148func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 149 150// CHECK-LABEL: @two_d 151func.func @two_d(%arg0: tensor<10x34xf32>, 152 %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> { 153 // Check the overall structure: split along the dimension 0, and then split 154 // the second half only along the dimension 1. 155 // CHECK: %[[IN_1:.+]] = tensor.extract_slice %[[IN:.+]][0, 0] 156 // CHECK: %[[OUT_1:.+]] = tensor.extract_slice %[[OUT:.+]][0, 0] 157 // CHECK: %[[RES_1:.+]] = linalg.generic 158 // CHECK-SAME: ins(%[[IN_1]] : tensor<4x34xf32>) 159 // CHECK-SAME: outs(%[[OUT_1]] : tensor<4x34xf32>) 160 // CHECK: %[[PARTIAL_1:.+]] = tensor.insert_slice %[[RES_1]] into %[[OUT]] 161 // 162 // CHECK: %[[IN_2:.+]] = tensor.extract_slice %[[IN]] 163 // CHECK: %[[OUT_2:.+]] = tensor.extract_slice %[[PARTIAL_1]] 164 // Note that `extract_slice` taking a slice from another `extract_slice` result 165 // is folded to use the operand of the first `extract_slice`. 166 // CHECK: %[[IN_21:.+]] = tensor.extract_slice %[[IN_2]] 167 // CHECK: %[[OUT_21:.+]] = tensor.extract_slice %[[OUT_2]] 168 // CHECK: %[[RES_21:.+]] = linalg.generic 169 // CHECK-SAME: ins(%[[IN_21]] : tensor<6x16xf32>) 170 // CHECK-SAME: outs(%[[OUT_21]] : tensor<6x16xf32>) 171 // CHECK: %[[PARTIAL_21:.+]] = tensor.insert_slice %[[RES_21]] into %[[OUT_2]] 172 // 173 // CHECK: %[[IN_22:.+]] = tensor.extract_slice %[[IN_2]] 174 // CHECK: %[[OUT_22:.+]] = tensor.extract_slice %[[PARTIAL_21]] 175 // CHECK: %[[RES_22:.+]] = linalg.generic 176 // CHECK-SAME: ins(%[[IN_22]] : tensor<6x18xf32>) 177 // CHECK-SAME: outs(%[[OUT_22]] : tensor<6x18xf32>) 178 // CHECK: %[[PARTIAL_22:.+]] = tensor.insert_slice %[[RES_22]] into %[[PARTIAL_21]] 179 // CHECK: %[[PARTIAL_2:.+]] = tensor.insert_slice %[[PARTIAL_22]] into %[[PARTIAL_1]] 180 %0 = linalg.generic { 181 indexing_maps = [affine_map<(i, j) -> (i, j)>, 182 affine_map<(i, j) -> (i, j)>], 183 iterator_types = ["parallel", "parallel"] 184 } 185 ins(%arg0: tensor<10x34xf32>) 186 outs(%arg1: tensor<10x34xf32>) { 187 ^bb0(%0: f32, %1: f32): 188 %i = linalg.index 0 : index 189 %j = linalg.index 1 : index 190 %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32 191 linalg.yield %call_res : f32 192 } -> tensor<10x34xf32> 193 return %0 : tensor<10x34xf32> 194} 195 196// ----- 197 198module attributes {transform.with_named_sequence} { 199 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) { 200 // expected-error @below {{expects either a dynamic or a static split point to be provided}} 201 %0 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op) 202 transform.yield 203 } 204} 205 206// ----- 207 208module attributes {transform.with_named_sequence} { 209 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 210 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 211 %1 = transform.structured.match ops{["func.call"]} in %arg1 : (!transform.any_op) -> !transform.any_op 212 // expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}} 213 transform.structured.split %0 after %1 { dimension = 0 } : !transform.any_op, !transform.any_op 214 transform.yield 215 } 216} 217 218func.func private @get_size() -> i64 219 220func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 221 // expected-note @below {{dynamic split point}} 222 %0 = func.call @get_size() : () -> i64 223 %1 = linalg.generic { 224 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 225 iterator_types = ["parallel"] 226 } 227 ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { 228 ^bb0(%3: f32, %4: f32): 229 linalg.yield %3 : f32 230 } -> tensor<100xf32> 231 return %1 : tensor<100xf32> 232} 233 234// ----- 235 236module attributes {transform.with_named_sequence} { 237 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 238 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 239 %1 = transform.structured.match ops{["func.call"]} in %arg1 : (!transform.any_op) -> !transform.any_op 240 // expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}} 241 transform.structured.split %0 after %1 { dimension = 0 } : !transform.any_op, !transform.any_op 242 transform.yield 243 } 244} 245 246func.func private @get_size() -> i64 247 248func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 249 %1 = linalg.generic { 250 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 251 iterator_types = ["parallel"] 252 } 253 ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { 254 ^bb0(%3: f32, %4: f32): 255 linalg.yield %3 : f32 256 } -> tensor<100xf32> 257 return %1 : tensor<100xf32> 258} 259 260// ----- 261 262module attributes {transform.with_named_sequence} { 263 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 264 %0 = transform.structured.match ops{["func.return"]} in %arg1 : (!transform.any_op) -> !transform.any_op 265 // expected-error @below {{only applies to structured ops}} 266 transform.structured.split %0 after 16 { dimension = 1 } : !transform.any_op 267 transform.yield 268 } 269} 270 271func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 272 // expected-note @below {{target op}} 273 return %arg0 : tensor<100xf32> 274} 275 276// ----- 277 278module attributes {transform.with_named_sequence} { 279 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 280 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 281 // expected-error @below {{dimension 1 does not exist in target op}} 282 transform.structured.split %0 after 16 { dimension = 1 } : !transform.any_op 283 transform.yield 284 } 285} 286 287func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { 288 // expected-note @below {{target op}} 289 %0 = linalg.generic { 290 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 291 iterator_types = ["parallel"] 292 } 293 ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { 294 ^bb0(%0: f32, %1: f32): 295 linalg.yield %0 : f32 296 } -> tensor<100xf32> 297 return %0 : tensor<100xf32> 298} 299 300// ----- 301 302module attributes {transform.with_named_sequence} { 303 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 304 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 305 // expected-error @below {{splitting does not produce the second part for a subset of targets}} 306 // expected-note @below {{expected splitting to produce the second part of all or none of the targets}} 307 %1 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op 308 transform.yield 309 } 310} 311 312func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 313 314func.func @split_one_but_not_other( 315 %arg0: tensor<100xf32>, %arg1: tensor<100xf32>, 316 %arg2: tensor<200xf32>, %arg3: tensor<200xf32>) 317 -> (tensor<100xf32>, tensor<200xf32>) { 318 // expected-note @below {{first target with no second part}} 319 %0 = linalg.generic { 320 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 321 iterator_types = ["parallel"] 322 } 323 ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { 324 ^bb0(%arg4: f32, %arg5: f32): 325 %i = linalg.index 0 : index 326 %call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32 327 linalg.yield %call_res : f32 328 } -> tensor<100xf32> 329 330 %1 = linalg.generic { 331 indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], 332 iterator_types = ["parallel"] 333 } 334 ins(%arg2: tensor<200xf32>) outs(%arg3: tensor<200xf32>) { 335 ^bb0(%arg4: f32, %arg5: f32): 336 %i = linalg.index 0 : index 337 %call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32 338 linalg.yield %call_res : f32 339 } -> tensor<200xf32> 340 341 return %0, %1 : tensor<100xf32>, tensor<200xf32> 342} 343 344