xref: /llvm-project/mlir/test/python/dialects/vector.py (revision 2ee5586ac7d8424b51790b143dbc6e2105bf99bc)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.builtin as builtin
5import mlir.dialects.func as func
6import mlir.dialects.vector as vector
7
8
9def run(f):
10    print("\nTEST:", f.__name__)
11    with Context(), Location.unknown():
12        f()
13    return f
14
15
16# CHECK-LABEL: TEST: testPrintOp
17@run
18def testPrintOp():
19    module = Module.create()
20    with InsertionPoint(module.body):
21
22        @func.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
23        def print_vector(arg):
24            return vector.PrintOp(source=arg)
25
26    # CHECK-LABEL: func @print_vector(
27    # CHECK-SAME:                     %[[ARG:.*]]: vector<12x5xf32>) {
28    #       CHECK:   vector.print %[[ARG]] : vector<12x5xf32>
29    #       CHECK:   return
30    #       CHECK: }
31    print(module)
32
33
34# CHECK-LABEL: TEST: testTransferReadOp
35@run
36def testTransferReadOp():
37    module = Module.create()
38    with InsertionPoint(module.body):
39        vector_type = VectorType.get([2, 3], F32Type.get())
40        memref_type = MemRefType.get(
41            [ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()],
42            F32Type.get(),
43        )
44        index_type = IndexType.get()
45        mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
46        identity_map = AffineMap.get_identity(vector_type.rank)
47        identity_map_attr = AffineMapAttr.get(identity_map)
48        f = func.FuncOp(
49            "transfer_read", ([memref_type, index_type, F32Type.get(), mask_type], [])
50        )
51        with InsertionPoint(f.add_entry_block()):
52            A, zero, padding, mask = f.arguments
53            vector.TransferReadOp(
54                vector_type,
55                A,
56                [zero, zero],
57                identity_map_attr,
58                padding,
59                [False, False],
60                mask=mask,
61            )
62            vector.TransferReadOp(
63                vector_type, A, [zero, zero], identity_map_attr, padding, [False, False]
64            )
65            func.ReturnOp([])
66
67    # CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
68    # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>)
69    # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]]
70    # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
71    # CHECK-NOT: %[[MASK]]
72    print(module)
73
74
75# CHECK-LABEL: TEST: testBitEnumCombiningKind
76@run
77def testBitEnumCombiningKind():
78    module = Module.create()
79    with InsertionPoint(module.body):
80        f32 = F32Type.get()
81        vector_type = VectorType.get([16], f32)
82
83        @func.FuncOp.from_py_func(vector_type)
84        def reduction(arg):
85            v = vector.ReductionOp(f32, vector.CombiningKind.ADD, arg)
86            return v
87
88    # CHECK: func.func @reduction(%[[VEC:.*]]: vector<16xf32>) -> f32 {
89    # CHECK: %0 = vector.reduction <add>, %[[VEC]] : vector<16xf32> into f32
90    print(module)
91