1// RUN: mlir-opt --transform-interpreter --scf-for-loop-canonicalization --canonicalize --split-input-file %s | FileCheck %s 2// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s --check-prefix=NOCANON 3 4// This implements a 2D multisize tiling with target sizes [3, 10]. 5module attributes {transform.with_named_sequence} { 6 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 7 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 8 %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op 9 %split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op 10 %2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 11 %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 12 %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 13 %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op 14 transform.foreach %5 : !transform.any_op { 15 ^bb0(%inner_linalg: !transform.any_op): 16 %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op 17 %split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op 18 %inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 19 transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 20 transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 21 } 22 transform.yield 23 } 24} 25 26func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 27 28// Without canonicalization, tile sizes are computed dynamically as affine maps. 29// NOCANON-LABEL: @two_d 30// NOCANON-COUNT-8: affine.apply 31// NOCANON: scf.for 32 33// CHECK-LABEL: @two_d 34// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32> 35func.func @two_d(%arg0: tensor<10x34xf32>, 36 %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> { 37 %0 = linalg.generic { 38 indexing_maps = [affine_map<(i, j) -> (i, j)>, 39 affine_map<(i, j) -> (i, j)>], 40 iterator_types = ["parallel", "parallel"] 41 } 42 ins(%arg0: tensor<10x34xf32>) 43 outs(%arg1: tensor<10x34xf32>) { 44 ^bb0(%0: f32, %1: f32): 45 %i = linalg.index 0 : index 46 %j = linalg.index 1 : index 47 %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32 48 linalg.yield %call_res : f32 49 } -> tensor<10x34xf32> 50 51 // 2D multi-size tiling should produce for quadrants with sizes 52 // (2, 8), (2, 9), (3, 8), (3, 9) 53 // respectively, and in this order. 54 // Check the full code for the first quadrant, the data flow for the second 55 // quadrant and only the overall code structure for the remaining quadrants. 56 // The canonicalizer is able to recover static shapes of for linalg.generic 57 // instances, use those to differentiate the quadrants. 58 59 // CHECK: %[[SLICE_1_IN:.+]] = tensor.extract_slice %[[IN]][0, 0] [4, 34] [1, 1] 60 // CHECK: %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1] 61 // CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]]) 62 // CHECK: %[[OUTSLICE_1_IN:.+]] = tensor.extract_slice %[[SLICE_1_IN]][%[[I1]], 0] [2, 34] [1, 1] 63 // CHECK: %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1] 64 65 // CHECK: %[[SLICE_2_IN:.+]] = tensor.extract_slice %[[OUTSLICE_1_IN]][0, 0] [2, 16] [1, 1] 66 // CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1] 67 // CHECK: %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]]) 68 // CHECK: %[[INSLICE_2:.+]] = tensor.extract_slice %[[SLICE_2_IN]][0, %[[I2]]] [2, 8] [1, 1] 69 // CHECK: %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1] 70 // CHECK: %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>) 71 // CHECK: %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]] 72 // CHECK: scf.yield %[[RESPARTIAL]] 73 74 // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1] 75 // CHECK: %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1] 76 // CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]]) 77 // CHECK-COUNT-2: tensor.extract_slice 78 // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<2x9xf32>) 79 // CHECK: tensor.insert_slice 80 // CHECK: scf.yield 81 // CHECK: %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]] 82 // CHECK: %[[INSERTED_3:.+]] = tensor.insert_slice %[[INSERTED_2]] into %[[ITERARG_1]] 83 // CHECK: scf.yield %[[INSERTED_3]] 84 85 // CHECK: tensor.insert_slice 86 // CHECK: tensor.extract_slice 87 // CHECK: scf.for 88 // CHECK-COUNT-2: tensor.extract_slice 89 // CHECK: scf.for 90 // CHECK-COUNT-2: tensor.extract_slice 91 // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>) 92 // CHECK: tensor.insert_slice 93 // CHECK: scf.yield 94 // CHECK: tensor.insert_slice 95 // CHECK: tensor.extract_slice 96 // CHECK: scf.for 97 // CHECK-COUNT-2: tensor.extract_slice 98 // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x9xf32>) 99 // CHECK: tensor.insert_slice 100 // CHECK: scf.yield 101 // CHECK-COUNT-2: tensor.insert_slice 102 // CHECK: scf.yield 103 // CHECK: %[[RESULT:.+]] = tensor.insert_slice 104 // CHECK: return %[[RESULT]] 105 106 return %0 : tensor<10x34xf32> 107} 108 109// ----- 110 111module attributes {transform.with_named_sequence} { 112 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 113 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 114 %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.param<i64> 115 %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.param<i64> 116 %split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param<i64> 117 %2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 118 %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) 119 %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) 120 %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op 121 %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> 122 transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> { 123 ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>): 124 %split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64> 125 %inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 126 transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) 127 transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) 128 } 129 transform.yield 130 } 131} 132 133func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 134 135// Even without canonicalization, tile sizes can be computed statically thanks 136// to parameters. 137// NOCANON-LABEL: @two_d 138// NOCANON-NOT: affine.apply 139// NOCANON: scf.for 140 141// CHECK-LABEL: @two_d_param 142// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32> 143func.func @two_d_param(%arg0: tensor<10x34xf32>, 144 %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> { 145 %0 = linalg.generic { 146 indexing_maps = [affine_map<(i, j) -> (i, j)>, 147 affine_map<(i, j) -> (i, j)>], 148 iterator_types = ["parallel", "parallel"] 149 } 150 ins(%arg0: tensor<10x34xf32>) 151 outs(%arg1: tensor<10x34xf32>) { 152 ^bb0(%0: f32, %1: f32): 153 %i = linalg.index 0 : index 154 %j = linalg.index 1 : index 155 %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32 156 linalg.yield %call_res : f32 157 } -> tensor<10x34xf32> 158 159 // CHECK: %[[SLICE_1_IN:.+]] = tensor.extract_slice %[[IN]][0, 0] [4, 34] [1, 1] 160 // CHECK: %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1] 161 // CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]]) 162 // CHECK: %[[OUTSLICE_1_IN:.+]] = tensor.extract_slice %[[SLICE_1_IN]][%[[I1]], 0] [2, 34] [1, 1] 163 // CHECK: %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1] 164 165 // CHECK: %[[SLICE_2_IN:.+]] = tensor.extract_slice %[[OUTSLICE_1_IN]][0, 0] [2, 16] [1, 1] 166 // CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1] 167 // CHECK: %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]]) 168 // CHECK: %[[INSLICE_2:.+]] = tensor.extract_slice %[[SLICE_2_IN]][0, %[[I2]]] [2, 8] [1, 1] 169 // CHECK: %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1] 170 // CHECK: %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>) 171 // CHECK: %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]] 172 // CHECK: scf.yield %[[RESPARTIAL]] 173 174 // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1] 175 // CHECK: %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1] 176 // CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]]) 177 // CHECK-COUNT-2: tensor.extract_slice 178 // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<2x9xf32>) 179 // CHECK: tensor.insert_slice 180 // CHECK: scf.yield 181 // CHECK: %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]] 182 // CHECK: %[[INSERTED_3:.+]] = tensor.insert_slice %[[INSERTED_2]] into %[[ITERARG_1]] 183 // CHECK: scf.yield %[[INSERTED_3]] 184 185 // CHECK: tensor.insert_slice 186 // CHECK: tensor.extract_slice 187 // CHECK: scf.for 188 // CHECK-COUNT-2: tensor.extract_slice 189 // CHECK: scf.for 190 // CHECK-COUNT-2: tensor.extract_slice 191 // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>) 192 // CHECK: tensor.insert_slice 193 // CHECK: scf.yield 194 // CHECK: tensor.insert_slice 195 // CHECK: tensor.extract_slice 196 // CHECK: scf.for 197 // CHECK-COUNT-2: tensor.extract_slice 198 // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x9xf32>) 199 // CHECK: tensor.insert_slice 200 // CHECK: scf.yield 201 // CHECK-COUNT-2: tensor.insert_slice 202 // CHECK: scf.yield 203 // CHECK: %[[RESULT:.+]] = tensor.insert_slice 204 // CHECK: return %[[RESULT]] 205 206 return %0 : tensor<10x34xf32> 207} 208