xref: /llvm-project/mlir/test/python/dialects/tensor.py (revision 537b2aa264c5a9879a80289c8d123b39e520eb15)
1d115a48eSMaheshRavishankar# RUN: %PYTHON %s | FileCheck %s
2d115a48eSMaheshRavishankar
3d115a48eSMaheshRavishankarfrom mlir.ir import *
4d115a48eSMaheshRavishankarimport mlir.dialects.arith as arith
536550692SRiver Riddleimport mlir.dialects.func as func
6d115a48eSMaheshRavishankarimport mlir.dialects.tensor as tensor
7*537b2aa2SMaksim Leventalfrom mlir.extras import types as T
8d115a48eSMaheshRavishankar
9d115a48eSMaheshRavishankar
10d115a48eSMaheshRavishankardef run(f):
11d115a48eSMaheshRavishankar    print("\nTEST:", f.__name__)
12d115a48eSMaheshRavishankar    f()
13d115a48eSMaheshRavishankar    return f
14d115a48eSMaheshRavishankar
15d115a48eSMaheshRavishankar
16d115a48eSMaheshRavishankar# CHECK-LABEL: TEST: testDimOp
17d115a48eSMaheshRavishankar@run
18d115a48eSMaheshRavishankardef testDimOp():
19d115a48eSMaheshRavishankar    with Context() as ctx, Location.unknown():
20d115a48eSMaheshRavishankar        module = Module.create()
21d115a48eSMaheshRavishankar        f32Type = F32Type.get()
22d115a48eSMaheshRavishankar        indexType = IndexType.get()
23d115a48eSMaheshRavishankar        with InsertionPoint(module.body):
24d115a48eSMaheshRavishankar
25fb4cedccSAliia Khasanova            @func.FuncOp.from_py_func(
26fb4cedccSAliia Khasanova                RankedTensorType.get(
27fb4cedccSAliia Khasanova                    (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()),
28f9008e63STobias Hieta                    f32Type,
29f9008e63STobias Hieta                )
30f9008e63STobias Hieta            )
31d115a48eSMaheshRavishankar            #      CHECK: func @tensor_static_dim
32d115a48eSMaheshRavishankar            # CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
33d115a48eSMaheshRavishankar            #  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
34d115a48eSMaheshRavishankar            #  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
35d115a48eSMaheshRavishankar            #      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
36d115a48eSMaheshRavishankar            #      CHECK:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
37d115a48eSMaheshRavishankar            #      CHECK:   return %[[D0]], %[[D1]]
38d115a48eSMaheshRavishankar            def tensor_static_dim(t):
39d115a48eSMaheshRavishankar                c0 = arith.ConstantOp(indexType, 0)
40d115a48eSMaheshRavishankar                c1 = arith.ConstantOp(indexType, 1)
41d115a48eSMaheshRavishankar                d0 = tensor.DimOp(t, c0)
42d115a48eSMaheshRavishankar                d1 = tensor.DimOp(t, c1)
43d115a48eSMaheshRavishankar                return [d0.result, d1.result]
44d115a48eSMaheshRavishankar
45d115a48eSMaheshRavishankar        print(module)
4681ca5aa4SMatthias Springer
4781ca5aa4SMatthias Springer
4881ca5aa4SMatthias Springer# CHECK-LABEL: TEST: testEmptyOp
4981ca5aa4SMatthias Springer@run
5081ca5aa4SMatthias Springerdef testEmptyOp():
5181ca5aa4SMatthias Springer    with Context() as ctx, Location.unknown():
5281ca5aa4SMatthias Springer        module = Module.create()
5381ca5aa4SMatthias Springer        f32 = F32Type.get()
5481ca5aa4SMatthias Springer        with InsertionPoint(module.body):
5581ca5aa4SMatthias Springer            # CHECK-LABEL: func @static_sizes
5681ca5aa4SMatthias Springer            # CHECK: %0 = tensor.empty() : tensor<3x4xf32>
5781ca5aa4SMatthias Springer            @func.FuncOp.from_py_func()
5881ca5aa4SMatthias Springer            def static_sizes():
5981ca5aa4SMatthias Springer                return tensor.EmptyOp([3, 4], f32)
6081ca5aa4SMatthias Springer
6181ca5aa4SMatthias Springer            # CHECK-LABEL: func @dynamic_sizes
6281ca5aa4SMatthias Springer            # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
6381ca5aa4SMatthias Springer            @func.FuncOp.from_py_func(IndexType.get(), IndexType.get())
6481ca5aa4SMatthias Springer            def dynamic_sizes(d0, d1):
6581ca5aa4SMatthias Springer                return tensor.EmptyOp([d0, d1], f32)
6681ca5aa4SMatthias Springer
6781ca5aa4SMatthias Springer            # CHECK-LABEL: func @mixed_static_dynamic_sizes
6881ca5aa4SMatthias Springer            # CHECK: %0 = tensor.empty(%arg0) : tensor<?x4xf32>
6981ca5aa4SMatthias Springer            @func.FuncOp.from_py_func(IndexType.get())
7081ca5aa4SMatthias Springer            def mixed_static_dynamic_sizes(d0):
7181ca5aa4SMatthias Springer                return tensor.EmptyOp([d0, 4], f32)
7281ca5aa4SMatthias Springer
7381ca5aa4SMatthias Springer            # CHECK-LABEL: func @zero_d
7481ca5aa4SMatthias Springer            # CHECK: %0 = tensor.empty() : tensor<f32>
7581ca5aa4SMatthias Springer            @func.FuncOp.from_py_func()
7681ca5aa4SMatthias Springer            def zero_d():
7781ca5aa4SMatthias Springer                return tensor.EmptyOp([], f32)
7881ca5aa4SMatthias Springer
7981ca5aa4SMatthias Springer    print(module)
80ee308c99SJacques Pienaar
81ee308c99SJacques Pienaar
82ee308c99SJacques Pienaar# CHECK-LABEL: TEST: testInferTypesInsertSlice
83ee308c99SJacques Pienaar@run
84ee308c99SJacques Pienaardef testInferTypesInsertSlice():
85ee308c99SJacques Pienaar    with Context() as ctx, Location.unknown():
86ee308c99SJacques Pienaar        module = Module.create()
87ee308c99SJacques Pienaar        f32Type = F32Type.get()
88ee308c99SJacques Pienaar        with InsertionPoint(module.body):
89ee308c99SJacques Pienaar
90ee308c99SJacques Pienaar            @func.FuncOp.from_py_func(
91ee308c99SJacques Pienaar                RankedTensorType.get((1, 1), f32Type),
92f9008e63STobias Hieta                RankedTensorType.get((1, 1), f32Type),
93f9008e63STobias Hieta            )
94ee308c99SJacques Pienaar            # CHECK: func @f
95ee308c99SJacques Pienaar            # CHECK:      tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
96ee308c99SJacques Pienaar            # CHECK-SAME:   tensor<1x1xf32> into tensor<1x1xf32>
97ee308c99SJacques Pienaar            def f(source, dest):
98f9008e63STobias Hieta                d0 = tensor.InsertSliceOp(
99f9008e63STobias Hieta                    source,
100f9008e63STobias Hieta                    dest,
101f9008e63STobias Hieta                    [],
102f9008e63STobias Hieta                    [],
103f9008e63STobias Hieta                    [],
104ee308c99SJacques Pienaar                    DenseI64ArrayAttr.get([0, 0]),
105ee308c99SJacques Pienaar                    DenseI64ArrayAttr.get([1, 1]),
106f9008e63STobias Hieta                    DenseI64ArrayAttr.get([0, 0]),
107f9008e63STobias Hieta                )
108ee308c99SJacques Pienaar                return [d0.result]
109ee308c99SJacques Pienaar
110ee308c99SJacques Pienaar    print(module)
1110a02f76dSmax
1120a02f76dSmax
1130a02f76dSmax# CHECK-LABEL: TEST: testFromElementsOp
1140a02f76dSmax@run
1150a02f76dSmaxdef testFromElementsOp():
1160a02f76dSmax    with Context() as ctx, Location.unknown():
1170a02f76dSmax        module = Module.create()
1180a02f76dSmax        f32 = F32Type.get()
1190a02f76dSmax        with InsertionPoint(module.body):
120f9008e63STobias Hieta
1210a02f76dSmax            @func.FuncOp.from_py_func()
1220a02f76dSmax            def default_builder():
1230a02f76dSmax                c0 = arith.ConstantOp(f32, 0.0)
12461d0f803SMehdi Amini                # CHECK: %[[C0:.*]] = "arith.constant
12561d0f803SMehdi Amini                # CHECK-SAME: value = 0.000000e+00 : f32
1260a02f76dSmax                print(c0)
1270a02f76dSmax                c1 = arith.ConstantOp(f32, 1.0)
12861d0f803SMehdi Amini                # CHECK: %[[C1:.*]] = "arith.constant
12961d0f803SMehdi Amini                # CHECK-SAME: value = 1.000000e+00 : f32
1300a02f76dSmax                print(c1)
1310a02f76dSmax
1320a02f76dSmax                t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1])
1330a02f76dSmax                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32>
1340a02f76dSmax                print(t)
1350a02f76dSmax
1360a02f76dSmax                t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1])
1370a02f76dSmax                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32>
1380a02f76dSmax                print(t)
1390a02f76dSmax
1400a02f76dSmax                t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1])
1410a02f76dSmax                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32>
1420a02f76dSmax                print(t)
143*537b2aa2SMaksim Levental
144*537b2aa2SMaksim Levental
145*537b2aa2SMaksim Levental# CHECK-LABEL: TEST: testGenerateRegionOp
146*537b2aa2SMaksim Levental@run
147*537b2aa2SMaksim Leventaldef testGenerateRegionOp():
148*537b2aa2SMaksim Levental    S = ShapedType.get_dynamic_size()
149*537b2aa2SMaksim Levental    with Context(), Location.unknown():
150*537b2aa2SMaksim Levental        module = Module.create()
151*537b2aa2SMaksim Levental        with InsertionPoint(module.body):
152*537b2aa2SMaksim Levental            # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
153*537b2aa2SMaksim Levental            # CHECK: %[[VAL_1:.*]] = arith.constant 2 : index
154*537b2aa2SMaksim Levental            one = arith.constant(T.index(), 1)
155*537b2aa2SMaksim Levental            two = arith.constant(T.index(), 2)
156*537b2aa2SMaksim Levental
157*537b2aa2SMaksim Levental            @tensor.generate(T.tensor(S, 3, S, T.index()), dynamic_extents=[one, two])
158*537b2aa2SMaksim Levental            def generate_one(i: T.index(), j: T.index(), k: T.index()):
159*537b2aa2SMaksim Levental                ij = arith.addi(i, j)
160*537b2aa2SMaksim Levental                ijk = arith.addi(ij, k)
161*537b2aa2SMaksim Levental                return ijk
162*537b2aa2SMaksim Levental
163*537b2aa2SMaksim Levental            assert (
164*537b2aa2SMaksim Levental                isinstance(generate_one, Value)
165*537b2aa2SMaksim Levental                and generate_one.owner.name == "tensor.generate"
166*537b2aa2SMaksim Levental            )
167*537b2aa2SMaksim Levental
168*537b2aa2SMaksim Levental        # CHECK:         %[[GENERATED:.*]] = tensor.generate
169*537b2aa2SMaksim Levental        # CHECK-SAME:    %[[VAL_0]],
170*537b2aa2SMaksim Levental        # CHECK-SAME:    %[[VAL_1]] {
171*537b2aa2SMaksim Levental        # CHECK:         ^bb0(%[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index):
172*537b2aa2SMaksim Levental        # CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
173*537b2aa2SMaksim Levental        # CHECK:           %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_3]] : index
174*537b2aa2SMaksim Levental        # CHECK:           tensor.yield %[[VAL_5]] : index
175*537b2aa2SMaksim Levental        # CHECK:         } : tensor<?x3x?xindex>
176*537b2aa2SMaksim Levental        print(module)
177