1// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s 2 3// CHECK-LABEL: func.func @fuse_unary 4func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { 5 6 // CHECK: %[[RES:.*]] = scf.for 7 // CHECK: scf.for 8 // CHECK: linalg.elemwise_unary 9 // CHECK: linalg.elemwise_binary 10 // CHECK: return %[[RES]] 11 %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>) 12 outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32> 13 %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>) 14 outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32> 15 return %1 : tensor<?x?xf32> 16} 17 18module attributes {transform.with_named_sequence} { 19 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 20 %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op 21 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} 22 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 23 transform.yield 24 } 25} 26 27// ----- 28 29// CHECK-LABEL: func.func @fuse_unary 30func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { 31 32 // CHECK: %[[PARTIAL_RES:.*]] = scf.for 33 // CHECK: scf.for 34 // CHECK: linalg.elemwise_unary 35 // CHECK: linalg.elemwise_binary 36 // CHECK: %[[RES:.*]] = scf.for {{.*}}%[[PARTIAL_RES]] 37 // CHECK: scf.for 38 // CHECK: linalg.elemwise_unary 39 // CHECK: linalg.elemwise_binary 40 // CHECK: return %[[RES]] 41 %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>) 42 outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32> 43 %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>) 44 outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32> 45 return %1 : tensor<?x?xf32> 46} 47 48module attributes {transform.with_named_sequence} { 49 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 50 %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op 51 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} 52 : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) 53 transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op) 54 transform.yield 55 } 56} 57 58// ----- 59 60// CHECK-LABEL: func.func @interchange_reduction 61// CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>) 62func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> { 63 %five = arith.constant 5.0 : f32 64 %init = tensor.empty() : tensor<12x25xf32> 65 66// CHECK-DAG: %[[INIT:.+]] = tensor.empty() 67// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index 68// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index 69// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index 70// CHECK: %[[RES:.*]] = scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[FOR_ARG0:.+]] = %[[INIT]]) 71// CHECK: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C7]] iter_args(%[[FOR_ARG1:.+]] = %[[FOR_ARG0]]) 72// CHECK: %[[OUT_SLICE0:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], 0, %[[IV1]]] 73// CHECK: %[[OUT_SLICE1:.+]] = tensor.extract_slice %[[FOR_ARG1]][%[[IV0]], %[[IV1]]] 74// CHECK: %[[FILL:.+]] = linalg.fill {{.+}} outs(%[[OUT_SLICE1]] : tensor<?x?xf32>) 75// CHECK: scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %[[C4]] iter_args(%[[FOR_ARG2:.+]] = %[[FILL]]) 76// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[OUT_SLICE0]] 77// CHECK: %[[OUT_SLICE2:.+]] = tensor.extract_slice %[[FOR_ARG2]][0, 0] 78// CHECK: linalg.generic {{.+}} ins(%[[IN_SLICE]] : tensor<?x?x?xf32>) outs(%[[OUT_SLICE2]] : tensor<?x?xf32>) 79// CHECK: return %[[RES]] 80 81 %fill = linalg.fill ins(%five : f32) outs(%init : tensor<12x25xf32>) -> tensor<12x25xf32> 82 %0 = linalg.generic { 83 indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>], 84 iterator_types = ["parallel", "reduction", "parallel"] 85 } ins(%input : tensor<12x7x25xf32>) outs(%fill : tensor<12x25xf32>) { 86 ^bb0(%arg0: f32, %arg1: f32): 87 %2 = arith.addf %arg0, %arg1 : f32 88 linalg.yield %2 : f32 89 } -> tensor<12x25xf32> 90 func.return %0 : tensor<12x25xf32> 91} 92 93module attributes {transform.with_named_sequence} { 94 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 95 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 96 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]} 97 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 98 %2, %loops_2 = transform.structured.tile_using_for %1 tile_sizes [0, 4] 99 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 100 transform.yield 101 } 102} 103 104// ----- 105 106// CHECK-LABEL: func.func @unpack_elemwise 107// CHECK: %[[RES:.*]] = scf.for 108// CHECK: scf.for 109// CHECK: tensor.unpack 110// CHECK: linalg.elemwise_unary 111// CHECK: return %[[RES]] 112func.func @unpack_elemwise(%arg0: tensor<16x48x8x8xf32>, %arg1: tensor<128x384xf32>) -> tensor<128x384xf32> { 113 %0 = tensor.empty() : tensor<128x384xf32> 114 %1 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0 115 : tensor<16x48x8x8xf32> -> tensor<128x384xf32> 116 %2 = linalg.elemwise_unary ins(%1: tensor<128x384xf32>) 117 outs(%arg1: tensor<128x384xf32>) -> tensor<128x384xf32> 118 return %2 : tensor<128x384xf32> 119} 120 121module attributes {transform.with_named_sequence} { 122 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 123 %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op 124 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [16, 32], tile_interchange = [0, 1]} 125 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 126 transform.yield 127 } 128} 129 130// ----- 131 132// CHECK-LABEL: func.func @pack_elemwise 133// CHECK: %[[RES:.*]] = scf.for 134// CHECK: scf.for 135// CHECK: tensor.pack 136// CHECK: linalg.elemwise_unary 137// CHECK: return %[[RES]] 138func.func @pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> { 139 %0 = tensor.empty() : tensor<16x48x8x8xf32> 140 %1 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0 141 : tensor<128x384xf32> -> tensor<16x48x8x8xf32> 142 %2 = linalg.elemwise_unary ins(%1: tensor<16x48x8x8xf32>) 143 outs(%arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> 144 return %2 : tensor<16x48x8x8xf32> 145} 146 147module attributes {transform.with_named_sequence} { 148 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 149 %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op 150 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [3, 5, 0, 0]} 151 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 152 transform.yield 153 } 154} 155 156// ----- 157 158// CHECK-LABEL: func.func @nofuse_pack_elemwise 159// CHECK: tensor.pack 160// CHECK: %[[RES:.*]] = scf.for 161// CHECK: scf.for 162// CHECK: linalg.elemwise_unary 163// CHECK: return %[[RES]] 164func.func @nofuse_pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> { 165 %0 = tensor.empty() : tensor<16x48x8x8xf32> 166 %1 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0 167 : tensor<128x384xf32> -> tensor<16x48x8x8xf32> 168 %2 = linalg.elemwise_unary ins(%1: tensor<16x48x8x8xf32>) 169 outs(%arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> 170 return %2 : tensor<16x48x8x8xf32> 171} 172 173module attributes {transform.with_named_sequence} { 174 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 175 %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op 176 %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [3, 5, 2, 0]} 177 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) 178 transform.yield 179 } 180} 181 182// ----- 183 184// CHECK-LABEL: func.func @fuse_through_slice 185func.func @fuse_through_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { 186 187 // CHECK: %[[RES:.*]] = scf.for 188 // CHECK: scf.for 189 // CHECK: linalg.elemwise_unary 190 // CHECK: linalg.elemwise_binary 191 // CHECK: return %[[RES]] 192 %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>) 193 outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> 194 %c0 = arith.constant 0 : index 195 %c1 = arith.constant 1 : index 196 %dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32> 197 %dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32> 198 %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> 199 %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 200 outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32> 201 return %2 : tensor<?x?xf32> 202} 203 204module attributes {transform.with_named_sequence} { 205 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 206 %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op 207 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} 208 : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) 209 transform.yield 210 } 211} 212 213// ----- 214 215// CHECK-LABEL: func.func @fuse_through_slice_and_cast_chain 216func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { 217 218 // CHECK: %[[RES:.*]] = scf.for 219 // CHECK: scf.for 220 // CHECK: linalg.elemwise_unary 221 // CHECK: linalg.elemwise_binary 222 // CHECK: return %[[RES]] 223 %0 = linalg.elemwise_unary ins(%arg0 : tensor<100x100xf32>) 224 outs(%arg0: tensor<100x100xf32>) -> tensor<100x100xf32> 225 %1 = tensor.cast %0 : tensor<100x100xf32> to tensor<100x?xf32> 226 %2 = tensor.extract_slice %1 [1, 1] [98, 98] [1, 1] : tensor<100x?xf32> to tensor<98x98xf32> 227 %3 = tensor.cast %2 : tensor<98x98xf32> to tensor<?x?xf32> 228 %c0 = arith.constant 0 : index 229 %c1 = arith.constant 1 : index 230 %dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32> 231 %dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32> 232 %4 = tensor.extract_slice %3 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> 233 %5 = linalg.elemwise_binary ins(%4, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 234 outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32> 235 return %5 : tensor<?x?xf32> 236} 237 238module attributes {transform.with_named_sequence} { 239 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 240 %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op 241 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} 242 : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) 243 transform.yield 244 } 245} 246 247// ----- 248 249// CHECK-LABEL: func.func @fuse_unrelated_slice 250func.func @fuse_unrelated_slices(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<10x10xf32>) { 251 252 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice 253 // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[SLICE1]] 254 // CHECK: %[[RES:.*]] = scf.for 255 // CHECK: scf.for 256 // CHECK: linalg.elemwise_unary 257 // CHECK: linalg.elemwise_binary 258 // CHECK: return %[[RES]], %[[SLICE2]] 259 %c0 = arith.constant 0 : index 260 %c1 = arith.constant 1 : index 261 %dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32> 262 %dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32> 263 %slice1 = tensor.extract_slice %arg0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> 264 %slice2 = tensor.extract_slice %slice1 [1, 1] [10, 10] [1, 1] : tensor<?x?xf32> to tensor<10x10xf32> 265 %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>) 266 outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> 267 %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> 268 %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 269 outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32> 270 return %2, %slice2 : tensor<?x?xf32>, tensor<10x10xf32> 271} 272 273module attributes {transform.with_named_sequence} { 274 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 275 %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op 276 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} 277 : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) 278 transform.yield 279 } 280} 281