1d115a48eSMaheshRavishankar# RUN: %PYTHON %s | FileCheck %s 2d115a48eSMaheshRavishankar 3d115a48eSMaheshRavishankarfrom mlir.ir import * 4d115a48eSMaheshRavishankarimport mlir.dialects.arith as arith 536550692SRiver Riddleimport mlir.dialects.func as func 6d115a48eSMaheshRavishankarimport mlir.dialects.tensor as tensor 7*537b2aa2SMaksim Leventalfrom mlir.extras import types as T 8d115a48eSMaheshRavishankar 9d115a48eSMaheshRavishankar 10d115a48eSMaheshRavishankardef run(f): 11d115a48eSMaheshRavishankar print("\nTEST:", f.__name__) 12d115a48eSMaheshRavishankar f() 13d115a48eSMaheshRavishankar return f 14d115a48eSMaheshRavishankar 15d115a48eSMaheshRavishankar 16d115a48eSMaheshRavishankar# CHECK-LABEL: TEST: testDimOp 17d115a48eSMaheshRavishankar@run 18d115a48eSMaheshRavishankardef testDimOp(): 19d115a48eSMaheshRavishankar with Context() as ctx, Location.unknown(): 20d115a48eSMaheshRavishankar module = Module.create() 21d115a48eSMaheshRavishankar f32Type = F32Type.get() 22d115a48eSMaheshRavishankar indexType = IndexType.get() 23d115a48eSMaheshRavishankar with InsertionPoint(module.body): 24d115a48eSMaheshRavishankar 25fb4cedccSAliia Khasanova @func.FuncOp.from_py_func( 26fb4cedccSAliia Khasanova RankedTensorType.get( 27fb4cedccSAliia Khasanova (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()), 28f9008e63STobias Hieta f32Type, 29f9008e63STobias Hieta ) 30f9008e63STobias Hieta ) 31d115a48eSMaheshRavishankar # CHECK: func @tensor_static_dim 32d115a48eSMaheshRavishankar # CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> 33d115a48eSMaheshRavishankar # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 34d115a48eSMaheshRavishankar # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 35d115a48eSMaheshRavishankar # CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] 36d115a48eSMaheshRavishankar # CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] 37d115a48eSMaheshRavishankar # CHECK: return %[[D0]], %[[D1]] 38d115a48eSMaheshRavishankar def tensor_static_dim(t): 39d115a48eSMaheshRavishankar c0 = arith.ConstantOp(indexType, 0) 40d115a48eSMaheshRavishankar c1 = arith.ConstantOp(indexType, 1) 41d115a48eSMaheshRavishankar d0 = tensor.DimOp(t, c0) 42d115a48eSMaheshRavishankar d1 = tensor.DimOp(t, c1) 43d115a48eSMaheshRavishankar return [d0.result, d1.result] 44d115a48eSMaheshRavishankar 45d115a48eSMaheshRavishankar print(module) 4681ca5aa4SMatthias Springer 4781ca5aa4SMatthias Springer 4881ca5aa4SMatthias Springer# CHECK-LABEL: TEST: testEmptyOp 4981ca5aa4SMatthias Springer@run 5081ca5aa4SMatthias Springerdef testEmptyOp(): 5181ca5aa4SMatthias Springer with Context() as ctx, Location.unknown(): 5281ca5aa4SMatthias Springer module = Module.create() 5381ca5aa4SMatthias Springer f32 = F32Type.get() 5481ca5aa4SMatthias Springer with InsertionPoint(module.body): 5581ca5aa4SMatthias Springer # CHECK-LABEL: func @static_sizes 5681ca5aa4SMatthias Springer # CHECK: %0 = tensor.empty() : tensor<3x4xf32> 5781ca5aa4SMatthias Springer @func.FuncOp.from_py_func() 5881ca5aa4SMatthias Springer def static_sizes(): 5981ca5aa4SMatthias Springer return tensor.EmptyOp([3, 4], f32) 6081ca5aa4SMatthias Springer 6181ca5aa4SMatthias Springer # CHECK-LABEL: func @dynamic_sizes 6281ca5aa4SMatthias Springer # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32> 6381ca5aa4SMatthias Springer @func.FuncOp.from_py_func(IndexType.get(), IndexType.get()) 6481ca5aa4SMatthias Springer def dynamic_sizes(d0, d1): 6581ca5aa4SMatthias Springer return tensor.EmptyOp([d0, d1], f32) 6681ca5aa4SMatthias Springer 6781ca5aa4SMatthias Springer # CHECK-LABEL: func @mixed_static_dynamic_sizes 6881ca5aa4SMatthias Springer # CHECK: %0 = tensor.empty(%arg0) : tensor<?x4xf32> 6981ca5aa4SMatthias Springer @func.FuncOp.from_py_func(IndexType.get()) 7081ca5aa4SMatthias Springer def mixed_static_dynamic_sizes(d0): 7181ca5aa4SMatthias Springer return tensor.EmptyOp([d0, 4], f32) 7281ca5aa4SMatthias Springer 7381ca5aa4SMatthias Springer # CHECK-LABEL: func @zero_d 7481ca5aa4SMatthias Springer # CHECK: %0 = tensor.empty() : tensor<f32> 7581ca5aa4SMatthias Springer @func.FuncOp.from_py_func() 7681ca5aa4SMatthias Springer def zero_d(): 7781ca5aa4SMatthias Springer return tensor.EmptyOp([], f32) 7881ca5aa4SMatthias Springer 7981ca5aa4SMatthias Springer print(module) 80ee308c99SJacques Pienaar 81ee308c99SJacques Pienaar 82ee308c99SJacques Pienaar# CHECK-LABEL: TEST: testInferTypesInsertSlice 83ee308c99SJacques Pienaar@run 84ee308c99SJacques Pienaardef testInferTypesInsertSlice(): 85ee308c99SJacques Pienaar with Context() as ctx, Location.unknown(): 86ee308c99SJacques Pienaar module = Module.create() 87ee308c99SJacques Pienaar f32Type = F32Type.get() 88ee308c99SJacques Pienaar with InsertionPoint(module.body): 89ee308c99SJacques Pienaar 90ee308c99SJacques Pienaar @func.FuncOp.from_py_func( 91ee308c99SJacques Pienaar RankedTensorType.get((1, 1), f32Type), 92f9008e63STobias Hieta RankedTensorType.get((1, 1), f32Type), 93f9008e63STobias Hieta ) 94ee308c99SJacques Pienaar # CHECK: func @f 95ee308c99SJacques Pienaar # CHECK: tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] : 96ee308c99SJacques Pienaar # CHECK-SAME: tensor<1x1xf32> into tensor<1x1xf32> 97ee308c99SJacques Pienaar def f(source, dest): 98f9008e63STobias Hieta d0 = tensor.InsertSliceOp( 99f9008e63STobias Hieta source, 100f9008e63STobias Hieta dest, 101f9008e63STobias Hieta [], 102f9008e63STobias Hieta [], 103f9008e63STobias Hieta [], 104ee308c99SJacques Pienaar DenseI64ArrayAttr.get([0, 0]), 105ee308c99SJacques Pienaar DenseI64ArrayAttr.get([1, 1]), 106f9008e63STobias Hieta DenseI64ArrayAttr.get([0, 0]), 107f9008e63STobias Hieta ) 108ee308c99SJacques Pienaar return [d0.result] 109ee308c99SJacques Pienaar 110ee308c99SJacques Pienaar print(module) 1110a02f76dSmax 1120a02f76dSmax 1130a02f76dSmax# CHECK-LABEL: TEST: testFromElementsOp 1140a02f76dSmax@run 1150a02f76dSmaxdef testFromElementsOp(): 1160a02f76dSmax with Context() as ctx, Location.unknown(): 1170a02f76dSmax module = Module.create() 1180a02f76dSmax f32 = F32Type.get() 1190a02f76dSmax with InsertionPoint(module.body): 120f9008e63STobias Hieta 1210a02f76dSmax @func.FuncOp.from_py_func() 1220a02f76dSmax def default_builder(): 1230a02f76dSmax c0 = arith.ConstantOp(f32, 0.0) 12461d0f803SMehdi Amini # CHECK: %[[C0:.*]] = "arith.constant 12561d0f803SMehdi Amini # CHECK-SAME: value = 0.000000e+00 : f32 1260a02f76dSmax print(c0) 1270a02f76dSmax c1 = arith.ConstantOp(f32, 1.0) 12861d0f803SMehdi Amini # CHECK: %[[C1:.*]] = "arith.constant 12961d0f803SMehdi Amini # CHECK-SAME: value = 1.000000e+00 : f32 1300a02f76dSmax print(c1) 1310a02f76dSmax 1320a02f76dSmax t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1]) 1330a02f76dSmax # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32> 1340a02f76dSmax print(t) 1350a02f76dSmax 1360a02f76dSmax t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1]) 1370a02f76dSmax # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32> 1380a02f76dSmax print(t) 1390a02f76dSmax 1400a02f76dSmax t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1]) 1410a02f76dSmax # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32> 1420a02f76dSmax print(t) 143*537b2aa2SMaksim Levental 144*537b2aa2SMaksim Levental 145*537b2aa2SMaksim Levental# CHECK-LABEL: TEST: testGenerateRegionOp 146*537b2aa2SMaksim Levental@run 147*537b2aa2SMaksim Leventaldef testGenerateRegionOp(): 148*537b2aa2SMaksim Levental S = ShapedType.get_dynamic_size() 149*537b2aa2SMaksim Levental with Context(), Location.unknown(): 150*537b2aa2SMaksim Levental module = Module.create() 151*537b2aa2SMaksim Levental with InsertionPoint(module.body): 152*537b2aa2SMaksim Levental # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index 153*537b2aa2SMaksim Levental # CHECK: %[[VAL_1:.*]] = arith.constant 2 : index 154*537b2aa2SMaksim Levental one = arith.constant(T.index(), 1) 155*537b2aa2SMaksim Levental two = arith.constant(T.index(), 2) 156*537b2aa2SMaksim Levental 157*537b2aa2SMaksim Levental @tensor.generate(T.tensor(S, 3, S, T.index()), dynamic_extents=[one, two]) 158*537b2aa2SMaksim Levental def generate_one(i: T.index(), j: T.index(), k: T.index()): 159*537b2aa2SMaksim Levental ij = arith.addi(i, j) 160*537b2aa2SMaksim Levental ijk = arith.addi(ij, k) 161*537b2aa2SMaksim Levental return ijk 162*537b2aa2SMaksim Levental 163*537b2aa2SMaksim Levental assert ( 164*537b2aa2SMaksim Levental isinstance(generate_one, Value) 165*537b2aa2SMaksim Levental and generate_one.owner.name == "tensor.generate" 166*537b2aa2SMaksim Levental ) 167*537b2aa2SMaksim Levental 168*537b2aa2SMaksim Levental # CHECK: %[[GENERATED:.*]] = tensor.generate 169*537b2aa2SMaksim Levental # CHECK-SAME: %[[VAL_0]], 170*537b2aa2SMaksim Levental # CHECK-SAME: %[[VAL_1]] { 171*537b2aa2SMaksim Levental # CHECK: ^bb0(%[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index): 172*537b2aa2SMaksim Levental # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index 173*537b2aa2SMaksim Levental # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_3]] : index 174*537b2aa2SMaksim Levental # CHECK: tensor.yield %[[VAL_5]] : index 175*537b2aa2SMaksim Levental # CHECK: } : tensor<?x3x?xindex> 176*537b2aa2SMaksim Levental print(module) 177