xref: /llvm-project/mlir/test/python/dialects/ml_program.py (revision 4eee9ef9768b1335800878b8f0b7aa3e549e41dc)
18b7e85f4SStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
28b7e85f4SStella Laurenzo# This is just a smoke test that the dialect is functional.
38b7e85f4SStella Laurenzo
48b7e85f4SStella Laurenzofrom mlir.ir import *
5*4eee9ef9Smaxfrom mlir.dialects import ml_program, arith, builtin
68b7e85f4SStella Laurenzo
78b7e85f4SStella Laurenzo
88b7e85f4SStella Laurenzodef constructAndPrintInModule(f):
98b7e85f4SStella Laurenzo    print("\nTEST:", f.__name__)
108b7e85f4SStella Laurenzo    with Context(), Location.unknown():
118b7e85f4SStella Laurenzo        module = Module.create()
128b7e85f4SStella Laurenzo        with InsertionPoint(module.body):
138b7e85f4SStella Laurenzo            f()
148b7e85f4SStella Laurenzo        print(module)
158b7e85f4SStella Laurenzo    return f
168b7e85f4SStella Laurenzo
178b7e85f4SStella Laurenzo
188b7e85f4SStella Laurenzo# CHECK-LABEL: testFuncOp
198b7e85f4SStella Laurenzo@constructAndPrintInModule
208b7e85f4SStella Laurenzodef testFuncOp():
218b7e85f4SStella Laurenzo    # CHECK: ml_program.func @foobar(%arg0: si32) -> si32
228b7e85f4SStella Laurenzo    f = ml_program.FuncOp(
23f9008e63STobias Hieta        name="foobar", type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)])
24f9008e63STobias Hieta    )
258b7e85f4SStella Laurenzo    block = f.add_entry_block()
268b7e85f4SStella Laurenzo    with InsertionPoint(block):
278b7e85f4SStella Laurenzo        # CHECK: ml_program.return
288b7e85f4SStella Laurenzo        ml_program.ReturnOp([block.arguments[0]])
29*4eee9ef9Smax
30*4eee9ef9Smax
31*4eee9ef9Smax# CHECK-LABEL: testGlobalStoreOp
32*4eee9ef9Smax@constructAndPrintInModule
33*4eee9ef9Smaxdef testGlobalStoreOp():
34*4eee9ef9Smax    # CHECK: %cst = arith.constant 4.242000e+01 : f32
35*4eee9ef9Smax    cst = arith.ConstantOp(value=42.42, result=F32Type.get())
36*4eee9ef9Smax
37*4eee9ef9Smax    m = builtin.ModuleOp()
38*4eee9ef9Smax    m.sym_name = StringAttr.get("symbol1")
39*4eee9ef9Smax    m.sym_visibility = StringAttr.get("public")
40*4eee9ef9Smax    # CHECK: module @symbol1 attributes {sym_visibility = "public"} {
41*4eee9ef9Smax    # CHECK:   ml_program.global public mutable @symbol2 : f32
42*4eee9ef9Smax    # CHECK: }
43*4eee9ef9Smax    with InsertionPoint(m.body):
44*4eee9ef9Smax        ml_program.GlobalOp("symbol2", F32Type.get(), is_mutable=True)
45*4eee9ef9Smax    # CHECK: ml_program.global_store @symbol1::@symbol2 = %cst : f32
46*4eee9ef9Smax    ml_program.GlobalStoreOp(["symbol1", "symbol2"], cst)
47