# RUN: %PYTHON %s | FileCheck %s from mlir.ir import * from mlir.dialects import transform from mlir.dialects.transform import memref def run(f): with Context(), Location.unknown(): module = Module.create() with InsertionPoint(module.body): print("\nTEST:", f.__name__) f() print(module) return f @run def testMemRefAllocaToAllocOpCompact(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("memref.alloca"), ) with InsertionPoint(sequence.body): memref.MemRefAllocaToGlobalOp(sequence.bodyTarget) transform.YieldOp() # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact # CHECK: = transform.memref.alloca_to_global # CHECK-SAME: (!transform.op<"memref.alloca">) # CHECK-SAME: -> (!transform.any_op, !transform.any_op) @run def testMemRefAllocaToAllocOpTyped(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("memref.alloca"), ) with InsertionPoint(sequence.body): memref.MemRefAllocaToGlobalOp( transform.OperationType.get("memref.get_global"), transform.OperationType.get("memref.global"), sequence.bodyTarget, ) transform.YieldOp() # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped # CHECK: = transform.memref.alloca_to_global # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">) @run def testMemRefMultiBufferOpCompact(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("memref.alloc"), ) with InsertionPoint(sequence.body): memref.MemRefMultiBufferOp(sequence.bodyTarget, 4) transform.YieldOp() # CHECK-LABEL: TEST: testMemRefMultiBufferOpCompact # CHECK: = transform.memref.multibuffer # CHECK-SAME: factor = 4 : i64 # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op @run def testMemRefMultiBufferOpTyped(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("memref.alloc"), ) with InsertionPoint(sequence.body): memref.MemRefMultiBufferOp( transform.OperationType.get("memref.alloc"), sequence.bodyTarget, 4 ) transform.YieldOp() # CHECK-LABEL: TEST: testMemRefMultiBufferOpTyped # CHECK: = transform.memref.multibuffer # CHECK-SAME: factor = 4 : i64 # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.op<"memref.alloc"> @run def testMemRefMultiBufferOpAttributes(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("memref.alloc"), ) with InsertionPoint(sequence.body): memref.MemRefMultiBufferOp(sequence.bodyTarget, 4, skip_analysis=True) transform.YieldOp() # CHECK-LABEL: TEST: testMemRefMultiBufferOpAttributes # CHECK: = transform.memref.multibuffer # CHECK-SAME: factor = 4 : i64 # CHECK-SAME: skip_analysis # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op