19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s 29f3f6d7bSStella Laurenzo 3*1bc5fe66SMaksim Leventalfrom mlir.dialects import arith, func, linalg, tensor, memref 4e8e718faSAlex Zinenkofrom mlir.dialects.linalg.opdsl.lang import * 558a47508SJeff Niufrom mlir.ir import * 651fdd802Sgysit 79f3f6d7bSStella Laurenzo 89f3f6d7bSStella Laurenzodef run(f): 99f3f6d7bSStella Laurenzo print("\nTEST:", f.__name__) 109f3f6d7bSStella Laurenzo f() 119f3f6d7bSStella Laurenzo return f 129f3f6d7bSStella Laurenzo 139f3f6d7bSStella Laurenzo 149f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testFill 159f3f6d7bSStella Laurenzo@run 169f3f6d7bSStella Laurenzodef testFill(): 179f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown(): 189f3f6d7bSStella Laurenzo module = Module.create() 199f3f6d7bSStella Laurenzo f32 = F32Type.get() 209f3f6d7bSStella Laurenzo with InsertionPoint(module.body): 219f3f6d7bSStella Laurenzo # CHECK-LABEL: func @fill_tensor 229f3f6d7bSStella Laurenzo # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32> 23a54f4eaeSMogball # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 247294be2bSgysit # CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32> 259f3f6d7bSStella Laurenzo # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32> 26fb4cedccSAliia Khasanova @func.FuncOp.from_py_func( 27f9008e63STobias Hieta RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32) 28f9008e63STobias Hieta ) 299f3f6d7bSStella Laurenzo def fill_tensor(out): 30f9008e63STobias Hieta zero = arith.ConstantOp( 31f9008e63STobias Hieta value=FloatAttr.get(f32, 0.0), result=f32 32f9008e63STobias Hieta ).result 337294be2bSgysit return linalg.fill(zero, outs=[out]) 349f3f6d7bSStella Laurenzo 359f3f6d7bSStella Laurenzo # CHECK-LABEL: func @fill_buffer 369f3f6d7bSStella Laurenzo # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32> 37a54f4eaeSMogball # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 387294be2bSgysit # CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>) 399f3f6d7bSStella Laurenzo # CHECK-NEXT: return 40fb4cedccSAliia Khasanova @func.FuncOp.from_py_func( 41f9008e63STobias Hieta MemRefType.get((12, ShapedType.get_dynamic_size()), f32) 42f9008e63STobias Hieta ) 439f3f6d7bSStella Laurenzo def fill_buffer(out): 44f9008e63STobias Hieta zero = arith.ConstantOp( 45f9008e63STobias Hieta value=FloatAttr.get(f32, 0.0), result=f32 46f9008e63STobias Hieta ).result 477294be2bSgysit linalg.fill(zero, outs=[out]) 489f3f6d7bSStella Laurenzo 499f3f6d7bSStella Laurenzo print(module) 509f3f6d7bSStella Laurenzo 519f3f6d7bSStella Laurenzo 529f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm 539f3f6d7bSStella Laurenzo@run 549f3f6d7bSStella Laurenzodef testNamedStructuredOpCustomForm(): 559f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown(): 569f3f6d7bSStella Laurenzo module = Module.create() 579f3f6d7bSStella Laurenzo f32 = F32Type.get() 589f3f6d7bSStella Laurenzo with InsertionPoint(module.body): 59a54f4eaeSMogball 6036550692SRiver Riddle @func.FuncOp.from_py_func( 61f9008e63STobias Hieta RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32) 62f9008e63STobias Hieta ) 639f3f6d7bSStella Laurenzo def named_form(lhs, rhs): 6481ca5aa4SMatthias Springer init_result = tensor.EmptyOp([4, 8], f32) 6524357fecSgysit # Check for the named form with custom format 6624357fecSgysit # CHECK: linalg.elemwise_unary 67e9085d0dSgysit # CHECK-SAME: cast = #linalg.type_fn<cast_signed> 6824357fecSgysit # CHECK-SAME: fun = #linalg.unary_fn<exp> 6924357fecSgysit # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) 7024357fecSgysit unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result]) 7124357fecSgysit # CHECK: linalg.elemwise_binary 7224357fecSgysit # CHECK-SAME: cast = #linalg.type_fn<cast_unsigned> 7324357fecSgysit # CHECK-SAME: fun = #linalg.binary_fn<mul> 7424357fecSgysit # CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) 7524357fecSgysit # CHECK: return 7624357fecSgysit binary_result = linalg.elemwise_binary( 7724357fecSgysit lhs, 7824357fecSgysit rhs, 7924357fecSgysit outs=[init_result.result], 8024357fecSgysit fun=BinaryFn.mul, 81f9008e63STobias Hieta cast=TypeFn.cast_unsigned, 82f9008e63STobias Hieta ) 8324357fecSgysit return unary_result, binary_result 849f3f6d7bSStella Laurenzo 859f3f6d7bSStella Laurenzo print(module) 869f3f6d7bSStella Laurenzo 87*1bc5fe66SMaksim Levental 88a9694043SMaksim Levental# CHECK-LABEL: TEST: testIdentityRegionOps 89a9694043SMaksim Levental@run 90a9694043SMaksim Leventaldef testIdentityRegionOps(): 91a9694043SMaksim Levental with Context(), Location.unknown(): 92a9694043SMaksim Levental module = Module.create() 93a9694043SMaksim Levental f32 = F32Type.get() 94a9694043SMaksim Levental with InsertionPoint(module.body): 95a9694043SMaksim Levental # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32> 96a9694043SMaksim Levental # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32> 97a9694043SMaksim Levental op1 = tensor.EmptyOp([1, 13], f32) 98a9694043SMaksim Levental op2 = tensor.EmptyOp([13, 1], f32) 99a9694043SMaksim Levental # CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0] 100a9694043SMaksim Levental op3 = linalg.TransposeOp( 101a9694043SMaksim Levental result=[RankedTensorType.get((13, 1), f32)], 102a9694043SMaksim Levental input=op1, 103a9694043SMaksim Levental init=op2, 104a9694043SMaksim Levental permutation=[1, 0], 105a9694043SMaksim Levental ) 106a9694043SMaksim Levental linalg.fill_builtin_region(op3.operation) 107a9694043SMaksim Levental 108a9694043SMaksim Levental # CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0] 109a9694043SMaksim Levental op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0]) 110a9694043SMaksim Levental 111a9694043SMaksim Levental # CHECK: func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>) 112a9694043SMaksim Levental @func.FuncOp.from_py_func( 113a9694043SMaksim Levental MemRefType.get((1, 13), f32), 114a9694043SMaksim Levental MemRefType.get((13, 1), f32), 115a9694043SMaksim Levental ) 116a9694043SMaksim Levental def transpose_op(op1, op2): 117a9694043SMaksim Levental # CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0] 118a9694043SMaksim Levental op3 = linalg.TransposeOp( 119a9694043SMaksim Levental result=[], 120a9694043SMaksim Levental input=op1, 121a9694043SMaksim Levental init=op2, 122a9694043SMaksim Levental permutation=[1, 0], 123a9694043SMaksim Levental ) 124a9694043SMaksim Levental linalg.fill_builtin_region(op3.operation) 125a9694043SMaksim Levental # CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0] 126a9694043SMaksim Levental op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0]) 127a9694043SMaksim Levental 128a9694043SMaksim Levental # CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32> 129a9694043SMaksim Levental # CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32> 130a9694043SMaksim Levental op1 = tensor.EmptyOp([16], f32) 131a9694043SMaksim Levental op2 = tensor.EmptyOp([16, 64], f32) 132a9694043SMaksim Levental # CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1] 133a9694043SMaksim Levental op3 = linalg.BroadcastOp( 134a9694043SMaksim Levental result=[RankedTensorType.get((16, 64), f32)], 135a9694043SMaksim Levental input=op1, 136a9694043SMaksim Levental init=op2, 137a9694043SMaksim Levental dimensions=[1], 138a9694043SMaksim Levental ) 139a9694043SMaksim Levental linalg.fill_builtin_region(op3.operation) 140a9694043SMaksim Levental 141a9694043SMaksim Levental # CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32> 142a9694043SMaksim Levental op4 = tensor.EmptyOp([64], f32) 143a9694043SMaksim Levental # CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0] 144a9694043SMaksim Levental op5 = linalg.broadcast(op4, outs=[op2], dimensions=[0]) 145a9694043SMaksim Levental 146a9694043SMaksim Levental # CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>) 147a9694043SMaksim Levental @func.FuncOp.from_py_func( 148a9694043SMaksim Levental MemRefType.get((16,), f32), 149a9694043SMaksim Levental MemRefType.get((16, 64), f32), 150a9694043SMaksim Levental MemRefType.get((64,), f32), 151a9694043SMaksim Levental ) 152a9694043SMaksim Levental def broadcast_op(op1, op2, op3): 153a9694043SMaksim Levental # CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1] 154a9694043SMaksim Levental op4 = linalg.BroadcastOp( 155a9694043SMaksim Levental result=[], 156a9694043SMaksim Levental input=op1, 157a9694043SMaksim Levental init=op2, 158a9694043SMaksim Levental dimensions=[1], 159a9694043SMaksim Levental ) 160a9694043SMaksim Levental linalg.fill_builtin_region(op4.operation) 161a9694043SMaksim Levental # CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0] 162a9694043SMaksim Levental op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0]) 163a9694043SMaksim Levental 164a9694043SMaksim Levental print(module) 165*1bc5fe66SMaksim Levental 166*1bc5fe66SMaksim Levental 167*1bc5fe66SMaksim Levental# CHECK-LABEL: TEST: testGenericOp 168*1bc5fe66SMaksim Levental@run 169*1bc5fe66SMaksim Leventaldef testGenericOp(): 170*1bc5fe66SMaksim Levental with Context(), Location.unknown(): 171*1bc5fe66SMaksim Levental module = Module.create() 172*1bc5fe66SMaksim Levental f32 = F32Type.get() 173*1bc5fe66SMaksim Levental memref_t = MemRefType.get([10, 10], f32) 174*1bc5fe66SMaksim Levental with InsertionPoint(module.body): 175*1bc5fe66SMaksim Levental id_map_1 = AffineMap.get_identity(2) 176*1bc5fe66SMaksim Levental # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32> 177*1bc5fe66SMaksim Levental # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32> 178*1bc5fe66SMaksim Levental x = tensor.empty((16, 16), f32) 179*1bc5fe66SMaksim Levental y = tensor.empty((16, 16), f32) 180*1bc5fe66SMaksim Levental 181*1bc5fe66SMaksim Levental # CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) { 182*1bc5fe66SMaksim Levental # CHECK: ^bb0(%in: f32, %out: f32): 183*1bc5fe66SMaksim Levental # CHECK: linalg.yield %in : f32 184*1bc5fe66SMaksim Levental # CHECK: } -> tensor<16x16xf32> 185*1bc5fe66SMaksim Levental @linalg.generic( 186*1bc5fe66SMaksim Levental [x], 187*1bc5fe66SMaksim Levental [y], 188*1bc5fe66SMaksim Levental [id_map_1, id_map_1], 189*1bc5fe66SMaksim Levental [linalg.IteratorType.parallel, linalg.IteratorType.parallel], 190*1bc5fe66SMaksim Levental ) 191*1bc5fe66SMaksim Levental def f(a, b): 192*1bc5fe66SMaksim Levental assert isinstance(a, Value) 193*1bc5fe66SMaksim Levental assert isinstance(a.type, F32Type) 194*1bc5fe66SMaksim Levental assert isinstance(b, Value) 195*1bc5fe66SMaksim Levental assert isinstance(b.type, F32Type) 196*1bc5fe66SMaksim Levental return a 197*1bc5fe66SMaksim Levental 198*1bc5fe66SMaksim Levental assert isinstance(f, Value) 199*1bc5fe66SMaksim Levental assert isinstance(f.type, RankedTensorType) 200*1bc5fe66SMaksim Levental 201*1bc5fe66SMaksim Levental # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32> 202*1bc5fe66SMaksim Levental z = tensor.empty((16, 16, 16), f32) 203*1bc5fe66SMaksim Levental 204*1bc5fe66SMaksim Levental minor_id = AffineMap.get_minor_identity(3, 2) 205*1bc5fe66SMaksim Levental id_map_2 = AffineMap.get_identity(3) 206*1bc5fe66SMaksim Levental 207*1bc5fe66SMaksim Levental # CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) { 208*1bc5fe66SMaksim Levental # CHECK: ^bb0(%in: f32, %out: f32, %out_1: f32): 209*1bc5fe66SMaksim Levental # CHECK: linalg.yield %in, %out : f32, f32 210*1bc5fe66SMaksim Levental # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>) 211*1bc5fe66SMaksim Levental @linalg.generic( 212*1bc5fe66SMaksim Levental [x], 213*1bc5fe66SMaksim Levental [z, z], 214*1bc5fe66SMaksim Levental [minor_id, id_map_2, id_map_2], 215*1bc5fe66SMaksim Levental [ 216*1bc5fe66SMaksim Levental linalg.IteratorType.parallel, 217*1bc5fe66SMaksim Levental linalg.IteratorType.parallel, 218*1bc5fe66SMaksim Levental linalg.IteratorType.parallel, 219*1bc5fe66SMaksim Levental ], 220*1bc5fe66SMaksim Levental ) 221*1bc5fe66SMaksim Levental def g(a, b, c): 222*1bc5fe66SMaksim Levental assert isinstance(a, Value) 223*1bc5fe66SMaksim Levental assert isinstance(a.type, F32Type) 224*1bc5fe66SMaksim Levental assert isinstance(b, Value) 225*1bc5fe66SMaksim Levental assert isinstance(b.type, F32Type) 226*1bc5fe66SMaksim Levental assert isinstance(c, Value) 227*1bc5fe66SMaksim Levental assert isinstance(c.type, F32Type) 228*1bc5fe66SMaksim Levental return a, b 229*1bc5fe66SMaksim Levental 230*1bc5fe66SMaksim Levental assert isinstance(g, OpResultList) 231*1bc5fe66SMaksim Levental assert len(g) == 2 232*1bc5fe66SMaksim Levental assert isinstance(g[0].type, RankedTensorType) 233*1bc5fe66SMaksim Levental assert isinstance(g[1].type, RankedTensorType) 234*1bc5fe66SMaksim Levental 235*1bc5fe66SMaksim Levental # CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<10x10xf32> 236*1bc5fe66SMaksim Levental # CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<10x10xf32> 237*1bc5fe66SMaksim Levental xx = memref.alloc(memref_t, [], []) 238*1bc5fe66SMaksim Levental yy = memref.alloc(memref_t, [], []) 239*1bc5fe66SMaksim Levental 240*1bc5fe66SMaksim Levental # CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_5]] : memref<10x10xf32>) outs(%[[VAL_6]] : memref<10x10xf32>) { 241*1bc5fe66SMaksim Levental # CHECK: ^bb0(%in: f32, %out: f32): 242*1bc5fe66SMaksim Levental # CHECK: linalg.yield %in : f32 243*1bc5fe66SMaksim Levental # CHECK: } 244*1bc5fe66SMaksim Levental @linalg.generic( 245*1bc5fe66SMaksim Levental [xx], 246*1bc5fe66SMaksim Levental [yy], 247*1bc5fe66SMaksim Levental [id_map_1, id_map_1], 248*1bc5fe66SMaksim Levental [linalg.IteratorType.parallel, linalg.IteratorType.parallel], 249*1bc5fe66SMaksim Levental ) 250*1bc5fe66SMaksim Levental def f(a, b): 251*1bc5fe66SMaksim Levental assert isinstance(a, Value) 252*1bc5fe66SMaksim Levental assert isinstance(a.type, F32Type) 253*1bc5fe66SMaksim Levental assert isinstance(b, Value) 254*1bc5fe66SMaksim Levental assert isinstance(b.type, F32Type) 255*1bc5fe66SMaksim Levental return a 256*1bc5fe66SMaksim Levental 257*1bc5fe66SMaksim Levental module.operation.verify() 258*1bc5fe66SMaksim Levental print(module) 259