xref: /llvm-project/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py (revision 8a6e54c9b3080f1b8e1a925bf1a39730223b99f9)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import builtin
5from mlir.dialects import func
6from mlir.dialects import linalg
7
8from mlir.dialects.linalg.opdsl.lang import *
9
10T1 = TV.T1
11T2 = TV.T2
12
13
14@linalg_structured_op
15def pooling_poly(
16    I=TensorDef(T1, S.N, S.H, S.W, S.C),
17    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
18    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
19    reduce=BinaryFnAttrDef(default=BinaryFn.max_signed),
20    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
21    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
22    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
23):
24    domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
25    O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
26        cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
27    )
28
29
30with Context() as ctx, Location.unknown():
31    module = Module.create()
32    f32 = F32Type.get()
33    i32 = IntegerType.get_signless(32)
34    with InsertionPoint(module.body):
35
36        # Pooling indexing maps.
37        # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
38        # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
39        # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
40
41        # CHECK-LABEL: @test_f32i32_max_pooling
42        # CHECK: linalg.generic
43        # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
44        # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
45        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
46        # CHECK-NEXT:   %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
47        # CHECK-NEXT:   %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
48        # CHECK-NEXT:   linalg.yield %[[MAX]] : i32
49        # CHECK-NEXT: -> tensor<1x2x4x1xi32>
50        @func.FuncOp.from_py_func(
51            RankedTensorType.get((1, 4, 16, 1), f32),
52            RankedTensorType.get((2, 2), f32),
53            RankedTensorType.get((1, 2, 4, 1), i32),
54        )
55        def test_f32i32_max_pooling(input, shape, init_result):
56            return pooling_poly(
57                input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]
58            )
59
60        # CHECK-LABEL: @test_f32i32_max_unsigned_pooling
61        # CHECK:   = arith.fptoui
62        # CHECK:   = arith.maxui
63        @func.FuncOp.from_py_func(
64            RankedTensorType.get((1, 4, 16, 1), f32),
65            RankedTensorType.get((2, 2), f32),
66            RankedTensorType.get((1, 2, 4, 1), i32),
67        )
68        def test_f32i32_max_unsigned_pooling(input, shape, init_result):
69            return pooling_poly(
70                input,
71                shape,
72                outs=[init_result],
73                reduce=BinaryFn.max_unsigned,
74                cast=TypeFn.cast_unsigned,
75                strides=[2, 4],
76                dilations=[1, 2],
77            )
78
79        # CHECK-LABEL: @test_f32f32_max_pooling
80        # CHECK: linalg.generic
81        # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
82        # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
83        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
84        # CHECK-NEXT:   %[[MAX:.+]] = arith.maximumf %[[OUT]], %[[IN:.+]] : f32
85        # CHECK-NEXT:   linalg.yield %[[MAX]] : f32
86        # CHECK-NEXT: -> tensor<1x2x4x1xf32>
87        @func.FuncOp.from_py_func(
88            RankedTensorType.get((1, 4, 16, 1), f32),
89            RankedTensorType.get((2, 2), f32),
90            RankedTensorType.get((1, 2, 4, 1), f32),
91        )
92        def test_f32f32_max_pooling(input, shape, init_result):
93            return pooling_poly(
94                input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]
95            )
96
97        # CHECK-LABEL: @test_f32i32_min_pooling
98        # CHECK:   = arith.fptosi
99        # CHECK:   = arith.minsi
100        @func.FuncOp.from_py_func(
101            RankedTensorType.get((1, 4, 16, 1), f32),
102            RankedTensorType.get((2, 2), f32),
103            RankedTensorType.get((1, 2, 4, 1), i32),
104        )
105        def test_f32i32_min_pooling(input, shape, init_result):
106            return pooling_poly(
107                input,
108                shape,
109                outs=[init_result],
110                reduce=BinaryFn.min_signed,
111                strides=[2, 4],
112                dilations=[1, 2],
113            )
114
115        # CHECK-LABEL: @test_f32i32_min_unsigned_pooling
116        # CHECK:   = arith.fptoui
117        # CHECK:   = arith.minui
118        @func.FuncOp.from_py_func(
119            RankedTensorType.get((1, 4, 16, 1), f32),
120            RankedTensorType.get((2, 2), f32),
121            RankedTensorType.get((1, 2, 4, 1), i32),
122        )
123        def test_f32i32_min_unsigned_pooling(input, shape, init_result):
124            return pooling_poly(
125                input,
126                shape,
127                outs=[init_result],
128                reduce=BinaryFn.min_unsigned,
129                cast=TypeFn.cast_unsigned,
130                strides=[2, 4],
131                dilations=[1, 2],
132            )
133
134        # CHECK-LABEL: @test_f32f32_min_pooling
135        # CHECK:   = arith.minimumf
136        @func.FuncOp.from_py_func(
137            RankedTensorType.get((1, 4, 16, 1), f32),
138            RankedTensorType.get((2, 2), f32),
139            RankedTensorType.get((1, 2, 4, 1), f32),
140        )
141        def test_f32f32_min_pooling(input, shape, init_result):
142            return pooling_poly(
143                input,
144                shape,
145                outs=[init_result],
146                reduce=BinaryFn.min_signed,
147                strides=[2, 4],
148                dilations=[1, 2],
149            )
150
151
152print(module)
153