1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import mlir.dialects.arith as arith 5import mlir.dialects.func as func 6import mlir.dialects.tensor as tensor 7from mlir.extras import types as T 8 9 10def run(f): 11 print("\nTEST:", f.__name__) 12 f() 13 return f 14 15 16# CHECK-LABEL: TEST: testDimOp 17@run 18def testDimOp(): 19 with Context() as ctx, Location.unknown(): 20 module = Module.create() 21 f32Type = F32Type.get() 22 indexType = IndexType.get() 23 with InsertionPoint(module.body): 24 25 @func.FuncOp.from_py_func( 26 RankedTensorType.get( 27 (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()), 28 f32Type, 29 ) 30 ) 31 # CHECK: func @tensor_static_dim 32 # CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> 33 # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 34 # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 35 # CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] 36 # CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] 37 # CHECK: return %[[D0]], %[[D1]] 38 def tensor_static_dim(t): 39 c0 = arith.ConstantOp(indexType, 0) 40 c1 = arith.ConstantOp(indexType, 1) 41 d0 = tensor.DimOp(t, c0) 42 d1 = tensor.DimOp(t, c1) 43 return [d0.result, d1.result] 44 45 print(module) 46 47 48# CHECK-LABEL: TEST: testEmptyOp 49@run 50def testEmptyOp(): 51 with Context() as ctx, Location.unknown(): 52 module = Module.create() 53 f32 = F32Type.get() 54 with InsertionPoint(module.body): 55 # CHECK-LABEL: func @static_sizes 56 # CHECK: %0 = tensor.empty() : tensor<3x4xf32> 57 @func.FuncOp.from_py_func() 58 def static_sizes(): 59 return tensor.EmptyOp([3, 4], f32) 60 61 # CHECK-LABEL: func @dynamic_sizes 62 # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32> 63 @func.FuncOp.from_py_func(IndexType.get(), IndexType.get()) 64 def dynamic_sizes(d0, d1): 65 return tensor.EmptyOp([d0, d1], f32) 66 67 # CHECK-LABEL: func @mixed_static_dynamic_sizes 68 # CHECK: %0 = tensor.empty(%arg0) : tensor<?x4xf32> 69 @func.FuncOp.from_py_func(IndexType.get()) 70 def mixed_static_dynamic_sizes(d0): 71 return tensor.EmptyOp([d0, 4], f32) 72 73 # CHECK-LABEL: func @zero_d 74 # CHECK: %0 = tensor.empty() : tensor<f32> 75 @func.FuncOp.from_py_func() 76 def zero_d(): 77 return tensor.EmptyOp([], f32) 78 79 print(module) 80 81 82# CHECK-LABEL: TEST: testInferTypesInsertSlice 83@run 84def testInferTypesInsertSlice(): 85 with Context() as ctx, Location.unknown(): 86 module = Module.create() 87 f32Type = F32Type.get() 88 with InsertionPoint(module.body): 89 90 @func.FuncOp.from_py_func( 91 RankedTensorType.get((1, 1), f32Type), 92 RankedTensorType.get((1, 1), f32Type), 93 ) 94 # CHECK: func @f 95 # CHECK: tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] : 96 # CHECK-SAME: tensor<1x1xf32> into tensor<1x1xf32> 97 def f(source, dest): 98 d0 = tensor.InsertSliceOp( 99 source, 100 dest, 101 [], 102 [], 103 [], 104 DenseI64ArrayAttr.get([0, 0]), 105 DenseI64ArrayAttr.get([1, 1]), 106 DenseI64ArrayAttr.get([0, 0]), 107 ) 108 return [d0.result] 109 110 print(module) 111 112 113# CHECK-LABEL: TEST: testFromElementsOp 114@run 115def testFromElementsOp(): 116 with Context() as ctx, Location.unknown(): 117 module = Module.create() 118 f32 = F32Type.get() 119 with InsertionPoint(module.body): 120 121 @func.FuncOp.from_py_func() 122 def default_builder(): 123 c0 = arith.ConstantOp(f32, 0.0) 124 # CHECK: %[[C0:.*]] = "arith.constant 125 # CHECK-SAME: value = 0.000000e+00 : f32 126 print(c0) 127 c1 = arith.ConstantOp(f32, 1.0) 128 # CHECK: %[[C1:.*]] = "arith.constant 129 # CHECK-SAME: value = 1.000000e+00 : f32 130 print(c1) 131 132 t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1]) 133 # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32> 134 print(t) 135 136 t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1]) 137 # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32> 138 print(t) 139 140 t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1]) 141 # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32> 142 print(t) 143 144 145# CHECK-LABEL: TEST: testGenerateRegionOp 146@run 147def testGenerateRegionOp(): 148 S = ShapedType.get_dynamic_size() 149 with Context(), Location.unknown(): 150 module = Module.create() 151 with InsertionPoint(module.body): 152 # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index 153 # CHECK: %[[VAL_1:.*]] = arith.constant 2 : index 154 one = arith.constant(T.index(), 1) 155 two = arith.constant(T.index(), 2) 156 157 @tensor.generate(T.tensor(S, 3, S, T.index()), dynamic_extents=[one, two]) 158 def generate_one(i: T.index(), j: T.index(), k: T.index()): 159 ij = arith.addi(i, j) 160 ijk = arith.addi(ij, k) 161 return ijk 162 163 assert ( 164 isinstance(generate_one, Value) 165 and generate_one.owner.name == "tensor.generate" 166 ) 167 168 # CHECK: %[[GENERATED:.*]] = tensor.generate 169 # CHECK-SAME: %[[VAL_0]], 170 # CHECK-SAME: %[[VAL_1]] { 171 # CHECK: ^bb0(%[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index): 172 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index 173 # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_3]] : index 174 # CHECK: tensor.yield %[[VAL_5]] : index 175 # CHECK: } : tensor<?x3x?xindex> 176 print(module) 177