xref: /llvm-project/mlir/test/python/dialects/ml_program.py (revision 4eee9ef9768b1335800878b8f0b7aa3e549e41dc)
1# RUN: %PYTHON %s | FileCheck %s
2# This is just a smoke test that the dialect is functional.
3
4from mlir.ir import *
5from mlir.dialects import ml_program, arith, builtin
6
7
8def constructAndPrintInModule(f):
9    print("\nTEST:", f.__name__)
10    with Context(), Location.unknown():
11        module = Module.create()
12        with InsertionPoint(module.body):
13            f()
14        print(module)
15    return f
16
17
18# CHECK-LABEL: testFuncOp
19@constructAndPrintInModule
20def testFuncOp():
21    # CHECK: ml_program.func @foobar(%arg0: si32) -> si32
22    f = ml_program.FuncOp(
23        name="foobar", type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)])
24    )
25    block = f.add_entry_block()
26    with InsertionPoint(block):
27        # CHECK: ml_program.return
28        ml_program.ReturnOp([block.arguments[0]])
29
30
31# CHECK-LABEL: testGlobalStoreOp
32@constructAndPrintInModule
33def testGlobalStoreOp():
34    # CHECK: %cst = arith.constant 4.242000e+01 : f32
35    cst = arith.ConstantOp(value=42.42, result=F32Type.get())
36
37    m = builtin.ModuleOp()
38    m.sym_name = StringAttr.get("symbol1")
39    m.sym_visibility = StringAttr.get("public")
40    # CHECK: module @symbol1 attributes {sym_visibility = "public"} {
41    # CHECK:   ml_program.global public mutable @symbol2 : f32
42    # CHECK: }
43    with InsertionPoint(m.body):
44        ml_program.GlobalOp("symbol2", F32Type.get(), is_mutable=True)
45    # CHECK: ml_program.global_store @symbol1::@symbol2 = %cst : f32
46    ml_program.GlobalStoreOp(["symbol1", "symbol2"], cst)
47