xref: /llvm-project/mlir/test/python/dialects/shape.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import numpy as np
5import mlir.dialects.func as func
6import mlir.dialects.shape as shape
7
8
9def run(f):
10    print("\nTEST:", f.__name__)
11    f()
12    return f
13
14
15# CHECK-LABEL: TEST: testConstShape
16@run
17def testConstShape():
18    with Context() as ctx, Location.unknown():
19        module = Module.create()
20        f32 = F32Type.get()
21        with InsertionPoint(module.body):
22
23            @func.FuncOp.from_py_func(
24                RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)
25            )
26            def const_shape_tensor(arg):
27                shape.ConstWitnessOp(False)
28                shape.ConstSizeOp(30)
29                shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
30                x = shape.ConstShapeOp([1, 2])
31                shape.MeetOp(x, x, error="impossible")
32                return shape.ConstShapeOp(
33                    DenseElementsAttr.get(
34                        np.array([3, 4], dtype=np.int64), type=IndexType.get()
35                    )
36                )
37
38        # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
39        # CHECK-DAG: shape.const_witness false
40        # CHECK-DAG: shape.const_size 30
41        # CHECK-DAG: shape.const_size 40
42        # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
43        # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
44        print(module)
45