1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import builtin 5from mlir.dialects import func 6from mlir.dialects import linalg 7 8from mlir.dialects.linalg.opdsl.lang import * 9 10T1 = TV.T1 11T2 = TV.T2 12 13 14@linalg_structured_op 15def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)): 16 O[None] = TypeFn.cast_signed(U, value) 17 18 19@linalg_structured_op 20def fill_rank_zero_poly(I=TensorDef(T1), O=TensorDef(U, output=True)): 21 O[None] = TypeFn.cast_signed(U, I[None]) 22 23 24with Context() as ctx, Location.unknown(): 25 module = Module.create() 26 f32 = F32Type.get() 27 with InsertionPoint(module.body): 28 29 # Fill indexing maps. 30 # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()> 31 # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()> 32 # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> 33 # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()> 34 # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 35 36 # CHECK-LABEL: @test_fill_0d 37 # CHECK: linalg.generic 38 # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]] 39 # CHECK-SAME: iterator_types = [] 40 @func.FuncOp.from_py_func(f32, RankedTensorType.get([], f32)) 41 def test_fill_0d(value, init_result): 42 return fill_poly(value, outs=[init_result]) 43 44 # CHECK-LABEL: @test_fill_2d 45 # CHECK: linalg.generic 46 # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]] 47 # CHECK-SAME: iterator_types = ["parallel", "parallel"] 48 @func.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32)) 49 def test_fill_2d(value, init_result): 50 return fill_poly(value, outs=[init_result]) 51 52 # CHECK-LABEL: @test_fill_rank_zero_3d 53 # CHECK: linalg.generic 54 # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]] 55 # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] 56 @func.FuncOp.from_py_func( 57 RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32) 58 ) 59 def test_fill_rank_zero_3d(input, init_result): 60 return fill_rank_zero_poly(input, outs=[init_result]) 61 62 63print(module) 64