1# RUN: %PYTHON %s | FileCheck %s 2 3 4from mlir.ir import * 5from mlir.dialects import transform 6from mlir.dialects.transform import memref 7 8 9def run(f): 10 with Context(), Location.unknown(): 11 module = Module.create() 12 with InsertionPoint(module.body): 13 print("\nTEST:", f.__name__) 14 f() 15 print(module) 16 return f 17 18 19@run 20def testMemRefAllocaToAllocOpCompact(): 21 sequence = transform.SequenceOp( 22 transform.FailurePropagationMode.Propagate, 23 [], 24 transform.OperationType.get("memref.alloca"), 25 ) 26 with InsertionPoint(sequence.body): 27 memref.MemRefAllocaToGlobalOp(sequence.bodyTarget) 28 transform.YieldOp() 29 # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact 30 # CHECK: = transform.memref.alloca_to_global 31 # CHECK-SAME: (!transform.op<"memref.alloca">) 32 # CHECK-SAME: -> (!transform.any_op, !transform.any_op) 33 34 35@run 36def testMemRefAllocaToAllocOpTyped(): 37 sequence = transform.SequenceOp( 38 transform.FailurePropagationMode.Propagate, 39 [], 40 transform.OperationType.get("memref.alloca"), 41 ) 42 with InsertionPoint(sequence.body): 43 memref.MemRefAllocaToGlobalOp( 44 transform.OperationType.get("memref.get_global"), 45 transform.OperationType.get("memref.global"), 46 sequence.bodyTarget, 47 ) 48 transform.YieldOp() 49 # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped 50 # CHECK: = transform.memref.alloca_to_global 51 # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">) 52 53 54@run 55def testMemRefMultiBufferOpCompact(): 56 sequence = transform.SequenceOp( 57 transform.FailurePropagationMode.Propagate, 58 [], 59 transform.OperationType.get("memref.alloc"), 60 ) 61 with InsertionPoint(sequence.body): 62 memref.MemRefMultiBufferOp(sequence.bodyTarget, 4) 63 transform.YieldOp() 64 # CHECK-LABEL: TEST: testMemRefMultiBufferOpCompact 65 # CHECK: = transform.memref.multibuffer 66 # CHECK-SAME: factor = 4 : i64 67 # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op 68 69 70@run 71def testMemRefMultiBufferOpTyped(): 72 sequence = transform.SequenceOp( 73 transform.FailurePropagationMode.Propagate, 74 [], 75 transform.OperationType.get("memref.alloc"), 76 ) 77 with InsertionPoint(sequence.body): 78 memref.MemRefMultiBufferOp( 79 transform.OperationType.get("memref.alloc"), sequence.bodyTarget, 4 80 ) 81 transform.YieldOp() 82 # CHECK-LABEL: TEST: testMemRefMultiBufferOpTyped 83 # CHECK: = transform.memref.multibuffer 84 # CHECK-SAME: factor = 4 : i64 85 # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.op<"memref.alloc"> 86 87 88@run 89def testMemRefMultiBufferOpAttributes(): 90 sequence = transform.SequenceOp( 91 transform.FailurePropagationMode.Propagate, 92 [], 93 transform.OperationType.get("memref.alloc"), 94 ) 95 with InsertionPoint(sequence.body): 96 memref.MemRefMultiBufferOp(sequence.bodyTarget, 4, skip_analysis=True) 97 transform.YieldOp() 98 # CHECK-LABEL: TEST: testMemRefMultiBufferOpAttributes 99 # CHECK: = transform.memref.multibuffer 100 # CHECK-SAME: factor = 4 : i64 101 # CHECK-SAME: skip_analysis 102 # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op 103