xref: /llvm-project/mlir/test/python/dialects/vector.py (revision 2ee5586ac7d8424b51790b143dbc6e2105bf99bc)
14cd1b66dSMatthias Springer# RUN: %PYTHON %s | FileCheck %s
24cd1b66dSMatthias Springer
34cd1b66dSMatthias Springerfrom mlir.ir import *
44cd1b66dSMatthias Springerimport mlir.dialects.builtin as builtin
523aa5a74SRiver Riddleimport mlir.dialects.func as func
64cd1b66dSMatthias Springerimport mlir.dialects.vector as vector
74cd1b66dSMatthias Springer
8f9008e63STobias Hieta
94cd1b66dSMatthias Springerdef run(f):
104cd1b66dSMatthias Springer    print("\nTEST:", f.__name__)
116981e5ecSAlex Zinenko    with Context(), Location.unknown():
124cd1b66dSMatthias Springer        f()
136981e5ecSAlex Zinenko    return f
144cd1b66dSMatthias Springer
15f9008e63STobias Hieta
164cd1b66dSMatthias Springer# CHECK-LABEL: TEST: testPrintOp
174cd1b66dSMatthias Springer@run
184cd1b66dSMatthias Springerdef testPrintOp():
194cd1b66dSMatthias Springer    module = Module.create()
204cd1b66dSMatthias Springer    with InsertionPoint(module.body):
216981e5ecSAlex Zinenko
2236550692SRiver Riddle        @func.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
234cd1b66dSMatthias Springer        def print_vector(arg):
24f36e909dSBenjamin Maxwell            return vector.PrintOp(source=arg)
254cd1b66dSMatthias Springer
264cd1b66dSMatthias Springer    # CHECK-LABEL: func @print_vector(
274cd1b66dSMatthias Springer    # CHECK-SAME:                     %[[ARG:.*]]: vector<12x5xf32>) {
284cd1b66dSMatthias Springer    #       CHECK:   vector.print %[[ARG]] : vector<12x5xf32>
294cd1b66dSMatthias Springer    #       CHECK:   return
304cd1b66dSMatthias Springer    #       CHECK: }
314cd1b66dSMatthias Springer    print(module)
326981e5ecSAlex Zinenko
336981e5ecSAlex Zinenko
346981e5ecSAlex Zinenko# CHECK-LABEL: TEST: testTransferReadOp
356981e5ecSAlex Zinenko@run
366981e5ecSAlex Zinenkodef testTransferReadOp():
376981e5ecSAlex Zinenko    module = Module.create()
386981e5ecSAlex Zinenko    with InsertionPoint(module.body):
396981e5ecSAlex Zinenko        vector_type = VectorType.get([2, 3], F32Type.get())
40fb4cedccSAliia Khasanova        memref_type = MemRefType.get(
41f9008e63STobias Hieta            [ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()],
42f9008e63STobias Hieta            F32Type.get(),
43f9008e63STobias Hieta        )
446981e5ecSAlex Zinenko        index_type = IndexType.get()
456981e5ecSAlex Zinenko        mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
466981e5ecSAlex Zinenko        identity_map = AffineMap.get_identity(vector_type.rank)
476981e5ecSAlex Zinenko        identity_map_attr = AffineMapAttr.get(identity_map)
48f9008e63STobias Hieta        f = func.FuncOp(
49f9008e63STobias Hieta            "transfer_read", ([memref_type, index_type, F32Type.get(), mask_type], [])
50f9008e63STobias Hieta        )
5123aa5a74SRiver Riddle        with InsertionPoint(f.add_entry_block()):
5223aa5a74SRiver Riddle            A, zero, padding, mask = f.arguments
53f9008e63STobias Hieta            vector.TransferReadOp(
54*2ee5586aSAndrzej Warzyński                vector_type,
55*2ee5586aSAndrzej Warzyński                A,
56*2ee5586aSAndrzej Warzyński                [zero, zero],
57*2ee5586aSAndrzej Warzyński                identity_map_attr,
58*2ee5586aSAndrzej Warzyński                padding,
59*2ee5586aSAndrzej Warzyński                [False, False],
60*2ee5586aSAndrzej Warzyński                mask=mask,
61f9008e63STobias Hieta            )
62f9008e63STobias Hieta            vector.TransferReadOp(
63*2ee5586aSAndrzej Warzyński                vector_type, A, [zero, zero], identity_map_attr, padding, [False, False]
64f9008e63STobias Hieta            )
6523aa5a74SRiver Riddle            func.ReturnOp([])
666981e5ecSAlex Zinenko
676981e5ecSAlex Zinenko    # CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
686981e5ecSAlex Zinenko    # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>)
696981e5ecSAlex Zinenko    # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]]
706981e5ecSAlex Zinenko    # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
716981e5ecSAlex Zinenko    # CHECK-NOT: %[[MASK]]
726981e5ecSAlex Zinenko    print(module)
7392233062Smax
7492233062Smax
7592233062Smax# CHECK-LABEL: TEST: testBitEnumCombiningKind
7692233062Smax@run
7792233062Smaxdef testBitEnumCombiningKind():
7892233062Smax    module = Module.create()
7992233062Smax    with InsertionPoint(module.body):
8092233062Smax        f32 = F32Type.get()
8192233062Smax        vector_type = VectorType.get([16], f32)
8292233062Smax
8392233062Smax        @func.FuncOp.from_py_func(vector_type)
8492233062Smax        def reduction(arg):
8592233062Smax            v = vector.ReductionOp(f32, vector.CombiningKind.ADD, arg)
8692233062Smax            return v
8792233062Smax
8892233062Smax    # CHECK: func.func @reduction(%[[VEC:.*]]: vector<16xf32>) -> f32 {
8992233062Smax    # CHECK: %0 = vector.reduction <add>, %[[VEC]] : vector<16xf32> into f32
9092233062Smax    print(module)
91