1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import builtin 5from mlir.dialects import func 6from mlir.dialects import linalg 7 8from mlir.dialects.linalg.opdsl.lang import * 9 10# This tests miscellaneous features of the emitter that are not tested by the 11# fill, matmul, convolution, or pooling tests. The features include: 12# - constant defined in the body 13# - fix/predefined types 14# - some math/arith functions, including abs, ceil, exp, floor, log, and negf 15# - custom op names. 16 17 18@linalg_structured_op 19def test_const(O=TensorDef(F32, S.M, S.N, output=True)): 20 O[D.m, D.n] = TypeFn.cast_unsigned(F32, const(42)) + TypeFn.cast_unsigned( 21 F32, const(2.3283064e-10) 22 ) 23 24 25@linalg_structured_op 26def test_index(O=TensorDef(I32, S.M, S.N, output=True)): 27 O[D.m, D.n] = TypeFn.cast_signed(I32, index(D.m)) + TypeFn.cast_signed( 28 I32, index(D.n) 29 ) 30 31 32@linalg_structured_op 33def elemwise_unary_poly( 34 I=TensorDef(T), 35 O=TensorDef(U, output=True), 36 fun=UnaryFnAttrDef(default=UnaryFn.exp), 37 cast=TypeFnAttrDef(default=TypeFn.cast_signed), 38): 39 O[None] = fun(cast(U, I[None])) 40 41 42@linalg_structured_op(op_name="custom_op_name") 43def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)): 44 O[D.n] = I[D.n] 45 46 47with Context() as ctx, Location.unknown(): 48 module = Module.create() 49 f32 = F32Type.get() 50 c32 = ComplexType.get(f32) 51 i32 = IntegerType.get_signless(32) 52 with InsertionPoint(module.body): 53 54 # CHECK-LABEL: @test_f32_const 55 # CHECK-DAG: %[[CST0:.+]] = arith.constant 42 : i64 56 # CHECK-DAG: %[[CST0_CAST:.+]] = arith.uitofp %[[CST0]] : i64 to f32 57 # CHECK-DAG: %[[CST1:.+]] = arith.constant 2.3283063999999999E-10 : f64 58 # CHECK-DAG: %[[CST1_CAST:.+]] = arith.truncf %[[CST1]] : f64 to f32 59 # CHECK-DAG: %[[SUM:.+]] = arith.addf %[[CST0_CAST]], %[[CST1_CAST]] : f32 60 # CHECK-NEXT: linalg.yield %[[SUM]] : f32 61 @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32)) 62 def test_f32_const(init_result): 63 return test_const(outs=[init_result]) 64 65 # CHECK-LABEL: @test_i32_index 66 # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index 67 # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index 68 # CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 69 # CHECK-DAG: %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32 70 # CHECK-DAG: %[[SUM:.+]] = arith.addi %[[IDX0_CAST]], %[[IDX1_CAST]] : i32 71 # CHECK-NEXT: linalg.yield %[[SUM]] : i32 72 @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32)) 73 def test_i32_index(init_result): 74 return test_index(outs=[init_result]) 75 76 # CHECK-LABEL: @test_f32_elemwise_exp 77 # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 78 # CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32 79 # CHECK-NEXT: linalg.yield %[[EXP]] : f32 80 # CHECK-NEXT: -> tensor<4x16xf32> 81 @func.FuncOp.from_py_func( 82 RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 83 ) 84 def test_f32_elemwise_exp(input, init_result): 85 return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp) 86 87 # CHECK-LABEL: @test_f32_elemwise_log 88 # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 89 # CHECK-NEXT: %[[LOG:.+]] = math.log %[[IN]] : f32 90 # CHECK-NEXT: linalg.yield %[[LOG]] : f32 91 # CHECK-NEXT: -> tensor<4x16xf32> 92 @func.FuncOp.from_py_func( 93 RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 94 ) 95 def test_f32_elemwise_log(input, init_result): 96 return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log) 97 98 # CHECK-LABEL: @test_f32_elemwise_abs 99 # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 100 # CHECK-NEXT: %[[EXP:.+]] = math.absf %[[IN]] : f32 101 # CHECK-NEXT: linalg.yield %[[EXP]] : f32 102 # CHECK-NEXT: -> tensor<4x16xf32> 103 @func.FuncOp.from_py_func( 104 RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 105 ) 106 def test_f32_elemwise_abs(input, init_result): 107 return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs) 108 109 # CHECK-LABEL: @test_f32_elemwise_ceil 110 # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 111 # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32 112 # CHECK-NEXT: linalg.yield %[[EXP]] : f32 113 # CHECK-NEXT: -> tensor<4x16xf32> 114 @func.FuncOp.from_py_func( 115 RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 116 ) 117 def test_f32_elemwise_ceil(input, init_result): 118 return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil) 119 120 # CHECK-LABEL: @test_f32_elemwise_floor 121 # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 122 # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32 123 # CHECK-NEXT: linalg.yield %[[EXP]] : f32 124 # CHECK-NEXT: -> tensor<4x16xf32> 125 @func.FuncOp.from_py_func( 126 RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 127 ) 128 def test_f32_elemwise_floor(input, init_result): 129 return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor) 130 131 # CHECK-LABEL: @test_f32_elemwise_neg 132 # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 133 # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32 134 # CHECK-NEXT: linalg.yield %[[EXP]] : f32 135 # CHECK-NEXT: -> tensor<4x16xf32> 136 @func.FuncOp.from_py_func( 137 RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 138 ) 139 def test_f32_elemwise_neg(input, init_result): 140 return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) 141 142 # CHECK-LABEL: @test_c32_elemwise_neg 143 # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>) 144 # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32> 145 # CHECK-NEXT: linalg.yield %[[EXP]] : complex<f32> 146 # CHECK-NEXT: -> tensor<4x16xcomplex<f32>> 147 @func.FuncOp.from_py_func( 148 RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32) 149 ) 150 def test_c32_elemwise_neg(input, init_result): 151 return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) 152 153 # Just check that we don't assert out on name mismatch. 154 # CHECK-LABEL: @test_non_default_op_name 155 @func.FuncOp.from_py_func( 156 RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32) 157 ) 158 def test_non_default_op_name(input, init_result): 159 return non_default_op_name(input, outs=[init_result]) 160 161 162print(module) 163