1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.dialects import arith, func, linalg, tensor, memref 4from mlir.dialects.linalg.opdsl.lang import * 5from mlir.ir import * 6 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 return f 12 13 14# CHECK-LABEL: TEST: testFill 15@run 16def testFill(): 17 with Context() as ctx, Location.unknown(): 18 module = Module.create() 19 f32 = F32Type.get() 20 with InsertionPoint(module.body): 21 # CHECK-LABEL: func @fill_tensor 22 # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32> 23 # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 24 # CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32> 25 # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32> 26 @func.FuncOp.from_py_func( 27 RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32) 28 ) 29 def fill_tensor(out): 30 zero = arith.ConstantOp( 31 value=FloatAttr.get(f32, 0.0), result=f32 32 ).result 33 return linalg.fill(zero, outs=[out]) 34 35 # CHECK-LABEL: func @fill_buffer 36 # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32> 37 # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 38 # CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>) 39 # CHECK-NEXT: return 40 @func.FuncOp.from_py_func( 41 MemRefType.get((12, ShapedType.get_dynamic_size()), f32) 42 ) 43 def fill_buffer(out): 44 zero = arith.ConstantOp( 45 value=FloatAttr.get(f32, 0.0), result=f32 46 ).result 47 linalg.fill(zero, outs=[out]) 48 49 print(module) 50 51 52# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm 53@run 54def testNamedStructuredOpCustomForm(): 55 with Context() as ctx, Location.unknown(): 56 module = Module.create() 57 f32 = F32Type.get() 58 with InsertionPoint(module.body): 59 60 @func.FuncOp.from_py_func( 61 RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32) 62 ) 63 def named_form(lhs, rhs): 64 init_result = tensor.EmptyOp([4, 8], f32) 65 # Check for the named form with custom format 66 # CHECK: linalg.elemwise_unary 67 # CHECK-SAME: cast = #linalg.type_fn<cast_signed> 68 # CHECK-SAME: fun = #linalg.unary_fn<exp> 69 # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) 70 unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result]) 71 # CHECK: linalg.elemwise_binary 72 # CHECK-SAME: cast = #linalg.type_fn<cast_unsigned> 73 # CHECK-SAME: fun = #linalg.binary_fn<mul> 74 # CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) 75 # CHECK: return 76 binary_result = linalg.elemwise_binary( 77 lhs, 78 rhs, 79 outs=[init_result.result], 80 fun=BinaryFn.mul, 81 cast=TypeFn.cast_unsigned, 82 ) 83 return unary_result, binary_result 84 85 print(module) 86 87 88# CHECK-LABEL: TEST: testIdentityRegionOps 89@run 90def testIdentityRegionOps(): 91 with Context(), Location.unknown(): 92 module = Module.create() 93 f32 = F32Type.get() 94 with InsertionPoint(module.body): 95 # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32> 96 # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32> 97 op1 = tensor.EmptyOp([1, 13], f32) 98 op2 = tensor.EmptyOp([13, 1], f32) 99 # CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0] 100 op3 = linalg.TransposeOp( 101 result=[RankedTensorType.get((13, 1), f32)], 102 input=op1, 103 init=op2, 104 permutation=[1, 0], 105 ) 106 linalg.fill_builtin_region(op3.operation) 107 108 # CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0] 109 op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0]) 110 111 # CHECK: func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>) 112 @func.FuncOp.from_py_func( 113 MemRefType.get((1, 13), f32), 114 MemRefType.get((13, 1), f32), 115 ) 116 def transpose_op(op1, op2): 117 # CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0] 118 op3 = linalg.TransposeOp( 119 result=[], 120 input=op1, 121 init=op2, 122 permutation=[1, 0], 123 ) 124 linalg.fill_builtin_region(op3.operation) 125 # CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0] 126 op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0]) 127 128 # CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32> 129 # CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32> 130 op1 = tensor.EmptyOp([16], f32) 131 op2 = tensor.EmptyOp([16, 64], f32) 132 # CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1] 133 op3 = linalg.BroadcastOp( 134 result=[RankedTensorType.get((16, 64), f32)], 135 input=op1, 136 init=op2, 137 dimensions=[1], 138 ) 139 linalg.fill_builtin_region(op3.operation) 140 141 # CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32> 142 op4 = tensor.EmptyOp([64], f32) 143 # CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0] 144 op5 = linalg.broadcast(op4, outs=[op2], dimensions=[0]) 145 146 # CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>) 147 @func.FuncOp.from_py_func( 148 MemRefType.get((16,), f32), 149 MemRefType.get((16, 64), f32), 150 MemRefType.get((64,), f32), 151 ) 152 def broadcast_op(op1, op2, op3): 153 # CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1] 154 op4 = linalg.BroadcastOp( 155 result=[], 156 input=op1, 157 init=op2, 158 dimensions=[1], 159 ) 160 linalg.fill_builtin_region(op4.operation) 161 # CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0] 162 op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0]) 163 164 print(module) 165 166 167# CHECK-LABEL: TEST: testGenericOp 168@run 169def testGenericOp(): 170 with Context(), Location.unknown(): 171 module = Module.create() 172 f32 = F32Type.get() 173 memref_t = MemRefType.get([10, 10], f32) 174 with InsertionPoint(module.body): 175 id_map_1 = AffineMap.get_identity(2) 176 # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32> 177 # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32> 178 x = tensor.empty((16, 16), f32) 179 y = tensor.empty((16, 16), f32) 180 181 # CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) { 182 # CHECK: ^bb0(%in: f32, %out: f32): 183 # CHECK: linalg.yield %in : f32 184 # CHECK: } -> tensor<16x16xf32> 185 @linalg.generic( 186 [x], 187 [y], 188 [id_map_1, id_map_1], 189 [linalg.IteratorType.parallel, linalg.IteratorType.parallel], 190 ) 191 def f(a, b): 192 assert isinstance(a, Value) 193 assert isinstance(a.type, F32Type) 194 assert isinstance(b, Value) 195 assert isinstance(b.type, F32Type) 196 return a 197 198 assert isinstance(f, Value) 199 assert isinstance(f.type, RankedTensorType) 200 201 # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32> 202 z = tensor.empty((16, 16, 16), f32) 203 204 minor_id = AffineMap.get_minor_identity(3, 2) 205 id_map_2 = AffineMap.get_identity(3) 206 207 # CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) { 208 # CHECK: ^bb0(%in: f32, %out: f32, %out_1: f32): 209 # CHECK: linalg.yield %in, %out : f32, f32 210 # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>) 211 @linalg.generic( 212 [x], 213 [z, z], 214 [minor_id, id_map_2, id_map_2], 215 [ 216 linalg.IteratorType.parallel, 217 linalg.IteratorType.parallel, 218 linalg.IteratorType.parallel, 219 ], 220 ) 221 def g(a, b, c): 222 assert isinstance(a, Value) 223 assert isinstance(a.type, F32Type) 224 assert isinstance(b, Value) 225 assert isinstance(b.type, F32Type) 226 assert isinstance(c, Value) 227 assert isinstance(c.type, F32Type) 228 return a, b 229 230 assert isinstance(g, OpResultList) 231 assert len(g) == 2 232 assert isinstance(g[0].type, RankedTensorType) 233 assert isinstance(g[1].type, RankedTensorType) 234 235 # CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<10x10xf32> 236 # CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<10x10xf32> 237 xx = memref.alloc(memref_t, [], []) 238 yy = memref.alloc(memref_t, [], []) 239 240 # CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_5]] : memref<10x10xf32>) outs(%[[VAL_6]] : memref<10x10xf32>) { 241 # CHECK: ^bb0(%in: f32, %out: f32): 242 # CHECK: linalg.yield %in : f32 243 # CHECK: } 244 @linalg.generic( 245 [xx], 246 [yy], 247 [id_map_1, id_map_1], 248 [linalg.IteratorType.parallel, linalg.IteratorType.parallel], 249 ) 250 def f(a, b): 251 assert isinstance(a, Value) 252 assert isinstance(a.type, F32Type) 253 assert isinstance(b, Value) 254 assert isinstance(b.type, F32Type) 255 return a 256 257 module.operation.verify() 258 print(module) 259