1# RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false 2 3import gc 4from mlir.ir import * 5from mlir.dialects import func 6 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 gc.collect() 12 assert Context._get_live_count() == 0 13 return f 14 15 16# CHECK-LABEL: TEST: testCapsuleConversions 17@run 18def testCapsuleConversions(): 19 ctx = Context() 20 ctx.allow_unregistered_dialects = True 21 with Location.unknown(ctx): 22 i32 = IntegerType.get_signless(32) 23 value = Operation.create("custom.op1", results=[i32]).result 24 value_capsule = value._CAPIPtr 25 assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule) 26 value2 = Value._CAPICreate(value_capsule) 27 assert value2 == value 28 29 30# CHECK-LABEL: TEST: testOpResultOwner 31@run 32def testOpResultOwner(): 33 ctx = Context() 34 ctx.allow_unregistered_dialects = True 35 with Location.unknown(ctx): 36 i32 = IntegerType.get_signless(32) 37 op = Operation.create("custom.op1", results=[i32]) 38 assert op.result.owner == op 39 40 41# CHECK-LABEL: TEST: testBlockArgOwner 42@run 43def testBlockArgOwner(): 44 ctx = Context() 45 ctx.allow_unregistered_dialects = True 46 module = Module.parse( 47 r""" 48 func.func @foo(%arg0: f32) { 49 return 50 }""", 51 ctx, 52 ) 53 func = module.body.operations[0] 54 block = func.regions[0].blocks[0] 55 assert block.arguments[0].owner == block 56 57 58# CHECK-LABEL: TEST: testValueIsInstance 59@run 60def testValueIsInstance(): 61 ctx = Context() 62 ctx.allow_unregistered_dialects = True 63 module = Module.parse( 64 r""" 65 func.func @foo(%arg0: f32) { 66 %0 = "some_dialect.some_op"() : () -> f64 67 return 68 }""", 69 ctx, 70 ) 71 func = module.body.operations[0] 72 assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0]) 73 assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0]) 74 75 op = func.regions[0].blocks[0].operations[0] 76 assert not BlockArgument.isinstance(op.results[0]) 77 assert OpResult.isinstance(op.results[0]) 78 79 80# CHECK-LABEL: TEST: testValueHash 81@run 82def testValueHash(): 83 ctx = Context() 84 ctx.allow_unregistered_dialects = True 85 module = Module.parse( 86 r""" 87 func.func @foo(%arg0: f32) -> f32 { 88 %0 = "some_dialect.some_op"(%arg0) : (f32) -> f32 89 return %0 : f32 90 }""", 91 ctx, 92 ) 93 94 [func] = module.body.operations 95 block = func.entry_block 96 op, ret = block.operations 97 assert hash(block.arguments[0]) == hash(op.operands[0]) 98 assert hash(op.result) == hash(ret.operands[0]) 99 100 101# CHECK-LABEL: TEST: testValueUses 102@run 103def testValueUses(): 104 ctx = Context() 105 ctx.allow_unregistered_dialects = True 106 with Location.unknown(ctx): 107 i32 = IntegerType.get_signless(32) 108 module = Module.create() 109 with InsertionPoint(module.body): 110 value = Operation.create("custom.op1", results=[i32]).results[0] 111 op1 = Operation.create("custom.op2", operands=[value]) 112 op2 = Operation.create("custom.op2", operands=[value]) 113 114 # CHECK: Use owner: "custom.op2" 115 # CHECK: Use operand_number: 0 116 # CHECK: Use owner: "custom.op2" 117 # CHECK: Use operand_number: 0 118 for use in value.uses: 119 assert use.owner in [op1, op2] 120 print(f"Use owner: {use.owner}") 121 print(f"Use operand_number: {use.operand_number}") 122 123 124# CHECK-LABEL: TEST: testValueReplaceAllUsesWith 125@run 126def testValueReplaceAllUsesWith(): 127 ctx = Context() 128 ctx.allow_unregistered_dialects = True 129 with Location.unknown(ctx): 130 i32 = IntegerType.get_signless(32) 131 module = Module.create() 132 with InsertionPoint(module.body): 133 value = Operation.create("custom.op1", results=[i32]).results[0] 134 op1 = Operation.create("custom.op2", operands=[value]) 135 op2 = Operation.create("custom.op2", operands=[value]) 136 value2 = Operation.create("custom.op3", results=[i32]).results[0] 137 value.replace_all_uses_with(value2) 138 139 assert len(list(value.uses)) == 0 140 141 # CHECK: Use owner: "custom.op2" 142 # CHECK: Use operand_number: 0 143 # CHECK: Use owner: "custom.op2" 144 # CHECK: Use operand_number: 0 145 for use in value2.uses: 146 assert use.owner in [op1, op2] 147 print(f"Use owner: {use.owner}") 148 print(f"Use operand_number: {use.operand_number}") 149 150 151# CHECK-LABEL: TEST: testValueReplaceAllUsesWithExcept 152@run 153def testValueReplaceAllUsesWithExcept(): 154 ctx = Context() 155 ctx.allow_unregistered_dialects = True 156 with Location.unknown(ctx): 157 i32 = IntegerType.get_signless(32) 158 module = Module.create() 159 with InsertionPoint(module.body): 160 value = Operation.create("custom.op1", results=[i32]).results[0] 161 op1 = Operation.create("custom.op1", operands=[value]) 162 op2 = Operation.create("custom.op2", operands=[value]) 163 value2 = Operation.create("custom.op3", results=[i32]).results[0] 164 value.replace_all_uses_except(value2, op1) 165 166 assert len(list(value.uses)) == 1 167 168 # CHECK: Use owner: "custom.op2" 169 # CHECK: Use operand_number: 0 170 for use in value2.uses: 171 assert use.owner in [op2] 172 print(f"Use owner: {use.owner}") 173 print(f"Use operand_number: {use.operand_number}") 174 175 # CHECK: Use owner: "custom.op1" 176 # CHECK: Use operand_number: 0 177 for use in value.uses: 178 assert use.owner in [op1] 179 print(f"Use owner: {use.owner}") 180 print(f"Use operand_number: {use.operand_number}") 181 182 183# CHECK-LABEL: TEST: testValueReplaceAllUsesWithMultipleExceptions 184@run 185def testValueReplaceAllUsesWithMultipleExceptions(): 186 ctx = Context() 187 ctx.allow_unregistered_dialects = True 188 with Location.unknown(ctx): 189 i32 = IntegerType.get_signless(32) 190 module = Module.create() 191 with InsertionPoint(module.body): 192 value = Operation.create("custom.op1", results=[i32]).results[0] 193 op1 = Operation.create("custom.op1", operands=[value]) 194 op2 = Operation.create("custom.op2", operands=[value]) 195 op3 = Operation.create("custom.op3", operands=[value]) 196 value2 = Operation.create("custom.op4", results=[i32]).results[0] 197 198 # Replace all uses of `value` with `value2`, except for `op1` and `op2`. 199 value.replace_all_uses_except(value2, [op1, op2]) 200 201 # After replacement, only `op3` should use `value2`, while `op1` and `op2` should still use `value`. 202 assert len(list(value.uses)) == 2 203 assert len(list(value2.uses)) == 1 204 205 # CHECK: Use owner: "custom.op3" 206 # CHECK: Use operand_number: 0 207 for use in value2.uses: 208 assert use.owner in [op3] 209 print(f"Use owner: {use.owner}") 210 print(f"Use operand_number: {use.operand_number}") 211 212 # CHECK: Use owner: "custom.op2" 213 # CHECK: Use operand_number: 0 214 # CHECK: Use owner: "custom.op1" 215 # CHECK: Use operand_number: 0 216 for use in value.uses: 217 assert use.owner in [op1, op2] 218 print(f"Use owner: {use.owner}") 219 print(f"Use operand_number: {use.operand_number}") 220 221 222# CHECK-LABEL: TEST: testValuePrintAsOperand 223@run 224def testValuePrintAsOperand(): 225 ctx = Context() 226 ctx.allow_unregistered_dialects = True 227 with Location.unknown(ctx): 228 i32 = IntegerType.get_signless(32) 229 module = Module.create() 230 with InsertionPoint(module.body): 231 value = Operation.create("custom.op1", results=[i32]).results[0] 232 # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32) 233 print(value) 234 235 value2 = Operation.create("custom.op2", results=[i32]).results[0] 236 # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32) 237 print(value2) 238 239 topFn = func.FuncOp("test", ([i32, i32], [])) 240 entry_block = Block.create_at_start(topFn.operation.regions[0], [i32, i32]) 241 242 with InsertionPoint(entry_block): 243 value3 = Operation.create("custom.op3", results=[i32]).results[0] 244 # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32) 245 print(value3) 246 value4 = Operation.create("custom.op4", results=[i32]).results[0] 247 # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32) 248 print(value4) 249 func.ReturnOp([]) 250 251 # CHECK: %[[VAL1]] 252 print(value.get_name()) 253 # CHECK: %[[VAL2]] 254 print(value2.get_name()) 255 # CHECK: %[[VAL3]] 256 print(value3.get_name()) 257 # CHECK: %[[VAL4]] 258 print(value4.get_name()) 259 260 print("With AsmState") 261 # CHECK-LABEL: With AsmState 262 state = AsmState(topFn.operation, use_local_scope=True) 263 # CHECK: %0 264 print(value3.get_name(state=state)) 265 # CHECK: %1 266 print(value4.get_name(state=state)) 267 268 print("With use_local_scope") 269 # CHECK-LABEL: With use_local_scope 270 # CHECK: %0 271 print(value3.get_name(use_local_scope=True)) 272 # CHECK: %1 273 print(value4.get_name(use_local_scope=True)) 274 275 # CHECK: %[[ARG0:.*]] 276 print(entry_block.arguments[0].get_name()) 277 # CHECK: %[[ARG1:.*]] 278 print(entry_block.arguments[1].get_name()) 279 280 # CHECK: module { 281 # CHECK: %[[VAL1]] = "custom.op1"() : () -> i32 282 # CHECK: %[[VAL2]] = "custom.op2"() : () -> i32 283 # CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) { 284 # CHECK: %[[VAL3]] = "custom.op3"() : () -> i32 285 # CHECK: %[[VAL4]] = "custom.op4"() : () -> i32 286 # CHECK: return 287 # CHECK: } 288 # CHECK: } 289 print(module) 290 291 value2.owner.detach_from_parent() 292 # CHECK: %0 293 print(value2.get_name()) 294 295 296# CHECK-LABEL: TEST: testValueSetType 297@run 298def testValueSetType(): 299 ctx = Context() 300 ctx.allow_unregistered_dialects = True 301 with Location.unknown(ctx): 302 i32 = IntegerType.get_signless(32) 303 i64 = IntegerType.get_signless(64) 304 module = Module.create() 305 with InsertionPoint(module.body): 306 value = Operation.create("custom.op1", results=[i32]).results[0] 307 # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32) 308 print(value) 309 310 value.set_type(i64) 311 # CHECK: Value(%[[VAL1]] = "custom.op1"() : () -> i64) 312 print(value) 313 314 # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64 315 print(value.owner) 316 317 318# CHECK-LABEL: TEST: testValueCasters 319@run 320def testValueCasters(): 321 class NOPResult(OpResult): 322 def __init__(self, v): 323 super().__init__(v) 324 325 def __str__(self): 326 return super().__str__().replace(Value.__name__, NOPResult.__name__) 327 328 class NOPValue(Value): 329 def __init__(self, v): 330 super().__init__(v) 331 332 def __str__(self): 333 return super().__str__().replace(Value.__name__, NOPValue.__name__) 334 335 class NOPBlockArg(BlockArgument): 336 def __init__(self, v): 337 super().__init__(v) 338 339 def __str__(self): 340 return super().__str__().replace(Value.__name__, NOPBlockArg.__name__) 341 342 @register_value_caster(IntegerType.static_typeid) 343 def cast_int(v) -> Value: 344 print("in caster", v.__class__.__name__) 345 if isinstance(v, OpResult): 346 return NOPResult(v) 347 if isinstance(v, BlockArgument): 348 return NOPBlockArg(v) 349 elif isinstance(v, Value): 350 return NOPValue(v) 351 352 ctx = Context() 353 ctx.allow_unregistered_dialects = True 354 with Location.unknown(ctx): 355 i32 = IntegerType.get_signless(32) 356 module = Module.create() 357 with InsertionPoint(module.body): 358 values = Operation.create("custom.op1", results=[i32, i32]).results 359 # CHECK: in caster OpResult 360 # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) 361 print("result", values[0].result_number, values[0]) 362 # CHECK: in caster OpResult 363 # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) 364 print("result", values[1].result_number, values[1]) 365 366 # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) 367 print("results slice", values[:1][0].result_number, values[:1][0]) 368 369 value0, value1 = values 370 # CHECK: in caster OpResult 371 # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) 372 print("result", value0.result_number, values[0]) 373 # CHECK: in caster OpResult 374 # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) 375 print("result", value1.result_number, values[1]) 376 377 op1 = Operation.create("custom.op2", operands=[value0, value1]) 378 # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> () 379 print(op1) 380 381 # CHECK: in caster Value 382 # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32)) 383 print("operand 0", op1.operands[0]) 384 # CHECK: in caster Value 385 # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32)) 386 print("operand 1", op1.operands[1]) 387 388 # CHECK: in caster BlockArgument 389 # CHECK: in caster BlockArgument 390 @func.FuncOp.from_py_func(i32, i32) 391 def reduction(arg0, arg1): 392 # CHECK: as func arg 0 NOPBlockArg 393 print("as func arg", arg0.arg_number, arg0.__class__.__name__) 394 # CHECK: as func arg 1 NOPBlockArg 395 print("as func arg", arg1.arg_number, arg1.__class__.__name__) 396 397 # CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0) 398 print( 399 "args slice", 400 reduction.func_op.arguments[:1][0].arg_number, 401 reduction.func_op.arguments[:1][0], 402 ) 403 404 try: 405 406 @register_value_caster(IntegerType.static_typeid) 407 def dont_cast_int_shouldnt_register(v): 408 ... 409 410 except RuntimeError as e: 411 # CHECK: Value caster is already registered: {{.*}}cast_int 412 print(e) 413 414 @register_value_caster(IntegerType.static_typeid, replace=True) 415 def dont_cast_int(v) -> OpResult: 416 assert isinstance(v, OpResult) 417 print("don't cast", v.result_number, v) 418 return v 419 420 with Location.unknown(ctx): 421 i32 = IntegerType.get_signless(32) 422 module = Module.create() 423 with InsertionPoint(module.body): 424 # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32) 425 new_value = Operation.create("custom.op1", results=[i32]).result 426 # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32) 427 print("result", new_value.result_number, new_value) 428 429 # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32) 430 new_value = Operation.create("custom.op2", results=[i32]).results[0] 431 # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32) 432 print("result", new_value.result_number, new_value) 433