xref: /llvm-project/mlir/test/python/dialects/tensor.py (revision 537b2aa264c5a9879a80289c8d123b39e520eb15)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.arith as arith
5import mlir.dialects.func as func
6import mlir.dialects.tensor as tensor
7from mlir.extras import types as T
8
9
10def run(f):
11    print("\nTEST:", f.__name__)
12    f()
13    return f
14
15
16# CHECK-LABEL: TEST: testDimOp
17@run
18def testDimOp():
19    with Context() as ctx, Location.unknown():
20        module = Module.create()
21        f32Type = F32Type.get()
22        indexType = IndexType.get()
23        with InsertionPoint(module.body):
24
25            @func.FuncOp.from_py_func(
26                RankedTensorType.get(
27                    (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()),
28                    f32Type,
29                )
30            )
31            #      CHECK: func @tensor_static_dim
32            # CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
33            #  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
34            #  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
35            #      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
36            #      CHECK:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
37            #      CHECK:   return %[[D0]], %[[D1]]
38            def tensor_static_dim(t):
39                c0 = arith.ConstantOp(indexType, 0)
40                c1 = arith.ConstantOp(indexType, 1)
41                d0 = tensor.DimOp(t, c0)
42                d1 = tensor.DimOp(t, c1)
43                return [d0.result, d1.result]
44
45        print(module)
46
47
48# CHECK-LABEL: TEST: testEmptyOp
49@run
50def testEmptyOp():
51    with Context() as ctx, Location.unknown():
52        module = Module.create()
53        f32 = F32Type.get()
54        with InsertionPoint(module.body):
55            # CHECK-LABEL: func @static_sizes
56            # CHECK: %0 = tensor.empty() : tensor<3x4xf32>
57            @func.FuncOp.from_py_func()
58            def static_sizes():
59                return tensor.EmptyOp([3, 4], f32)
60
61            # CHECK-LABEL: func @dynamic_sizes
62            # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
63            @func.FuncOp.from_py_func(IndexType.get(), IndexType.get())
64            def dynamic_sizes(d0, d1):
65                return tensor.EmptyOp([d0, d1], f32)
66
67            # CHECK-LABEL: func @mixed_static_dynamic_sizes
68            # CHECK: %0 = tensor.empty(%arg0) : tensor<?x4xf32>
69            @func.FuncOp.from_py_func(IndexType.get())
70            def mixed_static_dynamic_sizes(d0):
71                return tensor.EmptyOp([d0, 4], f32)
72
73            # CHECK-LABEL: func @zero_d
74            # CHECK: %0 = tensor.empty() : tensor<f32>
75            @func.FuncOp.from_py_func()
76            def zero_d():
77                return tensor.EmptyOp([], f32)
78
79    print(module)
80
81
82# CHECK-LABEL: TEST: testInferTypesInsertSlice
83@run
84def testInferTypesInsertSlice():
85    with Context() as ctx, Location.unknown():
86        module = Module.create()
87        f32Type = F32Type.get()
88        with InsertionPoint(module.body):
89
90            @func.FuncOp.from_py_func(
91                RankedTensorType.get((1, 1), f32Type),
92                RankedTensorType.get((1, 1), f32Type),
93            )
94            # CHECK: func @f
95            # CHECK:      tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
96            # CHECK-SAME:   tensor<1x1xf32> into tensor<1x1xf32>
97            def f(source, dest):
98                d0 = tensor.InsertSliceOp(
99                    source,
100                    dest,
101                    [],
102                    [],
103                    [],
104                    DenseI64ArrayAttr.get([0, 0]),
105                    DenseI64ArrayAttr.get([1, 1]),
106                    DenseI64ArrayAttr.get([0, 0]),
107                )
108                return [d0.result]
109
110    print(module)
111
112
113# CHECK-LABEL: TEST: testFromElementsOp
114@run
115def testFromElementsOp():
116    with Context() as ctx, Location.unknown():
117        module = Module.create()
118        f32 = F32Type.get()
119        with InsertionPoint(module.body):
120
121            @func.FuncOp.from_py_func()
122            def default_builder():
123                c0 = arith.ConstantOp(f32, 0.0)
124                # CHECK: %[[C0:.*]] = "arith.constant
125                # CHECK-SAME: value = 0.000000e+00 : f32
126                print(c0)
127                c1 = arith.ConstantOp(f32, 1.0)
128                # CHECK: %[[C1:.*]] = "arith.constant
129                # CHECK-SAME: value = 1.000000e+00 : f32
130                print(c1)
131
132                t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1])
133                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32>
134                print(t)
135
136                t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1])
137                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32>
138                print(t)
139
140                t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1])
141                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32>
142                print(t)
143
144
145# CHECK-LABEL: TEST: testGenerateRegionOp
146@run
147def testGenerateRegionOp():
148    S = ShapedType.get_dynamic_size()
149    with Context(), Location.unknown():
150        module = Module.create()
151        with InsertionPoint(module.body):
152            # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
153            # CHECK: %[[VAL_1:.*]] = arith.constant 2 : index
154            one = arith.constant(T.index(), 1)
155            two = arith.constant(T.index(), 2)
156
157            @tensor.generate(T.tensor(S, 3, S, T.index()), dynamic_extents=[one, two])
158            def generate_one(i: T.index(), j: T.index(), k: T.index()):
159                ij = arith.addi(i, j)
160                ijk = arith.addi(ij, k)
161                return ijk
162
163            assert (
164                isinstance(generate_one, Value)
165                and generate_one.owner.name == "tensor.generate"
166            )
167
168        # CHECK:         %[[GENERATED:.*]] = tensor.generate
169        # CHECK-SAME:    %[[VAL_0]],
170        # CHECK-SAME:    %[[VAL_1]] {
171        # CHECK:         ^bb0(%[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index):
172        # CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
173        # CHECK:           %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_3]] : index
174        # CHECK:           tensor.yield %[[VAL_5]] : index
175        # CHECK:         } : tensor<?x3x?xindex>
176        print(module)
177