1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import mlir.dialects.func as func 5import mlir.dialects.python_test as test 6import mlir.dialects.tensor as tensor 7import mlir.dialects.arith as arith 8 9 10def run(f): 11 print("\nTEST:", f.__name__) 12 f() 13 return f 14 15 16# CHECK-LABEL: TEST: testAttributes 17@run 18def testAttributes(): 19 with Context() as ctx, Location.unknown(): 20 ctx.allow_unregistered_dialects = True 21 22 # 23 # Check op construction with attributes. 24 # 25 26 i32 = IntegerType.get_signless(32) 27 one = IntegerAttr.get(i32, 1) 28 two = IntegerAttr.get(i32, 2) 29 unit = UnitAttr.get() 30 31 # CHECK: "python_test.attributed_op"() { 32 # CHECK-DAG: mandatory_i32 = 1 : i32 33 # CHECK-DAG: optional_i32 = 2 : i32 34 # CHECK-DAG: unit 35 # CHECK: } 36 op = test.AttributedOp(one, optional_i32=two, unit=unit) 37 print(f"{op}") 38 39 # CHECK: "python_test.attributed_op"() { 40 # CHECK: mandatory_i32 = 2 : i32 41 # CHECK: } 42 op2 = test.AttributedOp(two) 43 print(f"{op2}") 44 45 # 46 # Check generic "attributes" access and mutation. 47 # 48 49 assert "additional" not in op.attributes 50 51 # CHECK: "python_test.attributed_op"() { 52 # CHECK-DAG: additional = 1 : i32 53 # CHECK-DAG: mandatory_i32 = 2 : i32 54 # CHECK: } 55 op2.attributes["additional"] = one 56 print(f"{op2}") 57 58 # CHECK: "python_test.attributed_op"() { 59 # CHECK-DAG: additional = 2 : i32 60 # CHECK-DAG: mandatory_i32 = 2 : i32 61 # CHECK: } 62 op2.attributes["additional"] = two 63 print(f"{op2}") 64 65 # CHECK: "python_test.attributed_op"() { 66 # CHECK-NOT: additional = 2 : i32 67 # CHECK: mandatory_i32 = 2 : i32 68 # CHECK: } 69 del op2.attributes["additional"] 70 print(f"{op2}") 71 72 try: 73 print(op.attributes["additional"]) 74 except KeyError: 75 pass 76 else: 77 assert False, "expected KeyError on unknown attribute key" 78 79 # 80 # Check accessors to defined attributes. 81 # 82 83 # CHECK: Mandatory: 1 84 # CHECK: Optional: 2 85 # CHECK: Unit: True 86 print(f"Mandatory: {op.mandatory_i32.value}") 87 print(f"Optional: {op.optional_i32.value}") 88 print(f"Unit: {op.unit}") 89 90 # CHECK: Mandatory: 2 91 # CHECK: Optional: None 92 # CHECK: Unit: False 93 print(f"Mandatory: {op2.mandatory_i32.value}") 94 print(f"Optional: {op2.optional_i32}") 95 print(f"Unit: {op2.unit}") 96 97 # CHECK: Mandatory: 2 98 # CHECK: Optional: None 99 # CHECK: Unit: False 100 op.mandatory_i32 = two 101 op.optional_i32 = None 102 op.unit = False 103 print(f"Mandatory: {op.mandatory_i32.value}") 104 print(f"Optional: {op.optional_i32}") 105 print(f"Unit: {op.unit}") 106 assert "optional_i32" not in op.attributes 107 assert "unit" not in op.attributes 108 109 try: 110 op.mandatory_i32 = None 111 except ValueError: 112 pass 113 else: 114 assert False, "expected ValueError on setting a mandatory attribute to None" 115 116 # CHECK: Optional: 2 117 op.optional_i32 = two 118 print(f"Optional: {op.optional_i32.value}") 119 120 # CHECK: Optional: None 121 del op.optional_i32 122 print(f"Optional: {op.optional_i32}") 123 124 # CHECK: Unit: False 125 op.unit = None 126 print(f"Unit: {op.unit}") 127 assert "unit" not in op.attributes 128 129 # CHECK: Unit: True 130 op.unit = True 131 print(f"Unit: {op.unit}") 132 133 # CHECK: Unit: False 134 del op.unit 135 print(f"Unit: {op.unit}") 136 137 138# CHECK-LABEL: TEST: attrBuilder 139@run 140def attrBuilder(): 141 with Context() as ctx, Location.unknown(): 142 ctx.allow_unregistered_dialects = True 143 # CHECK: python_test.attributes_op 144 op = test.AttributesOp( 145 # CHECK-DAG: x_affinemap = affine_map<() -> (2)> 146 x_affinemap=AffineMap.get_constant(2), 147 # CHECK-DAG: x_affinemaparr = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] 148 x_affinemaparr=[AffineMap.get_identity(3)], 149 # CHECK-DAG: x_arr = [true, "x"] 150 x_arr=[BoolAttr.get(True), StringAttr.get("x")], 151 x_boolarr=[False, True], # CHECK-DAG: x_boolarr = [false, true] 152 x_bool=True, # CHECK-DAG: x_bool = true 153 x_dboolarr=[True, False], # CHECK-DAG: x_dboolarr = array<i1: true, false> 154 x_df16arr=[21, 22], # CHECK-DAG: x_df16arr = array<i16: 21, 22> 155 # CHECK-DAG: x_df32arr = array<f32: 2.300000e+01, 2.400000e+01> 156 x_df32arr=[23, 24], 157 # CHECK-DAG: x_df64arr = array<f64: 2.500000e+01, 2.600000e+01> 158 x_df64arr=[25, 26], 159 x_di32arr=[0, 1], # CHECK-DAG: x_di32arr = array<i32: 0, 1> 160 # CHECK-DAG: x_di64arr = array<i64: 1, 2> 161 x_di64arr=[1, 2], 162 x_di8arr=[2, 3], # CHECK-DAG: x_di8arr = array<i8: 2, 3> 163 # CHECK-DAG: x_dictarr = [{a = false}] 164 x_dictarr=[{"a": BoolAttr.get(False)}], 165 x_dict={"b": BoolAttr.get(True)}, # CHECK-DAG: x_dict = {b = true} 166 x_f32=-2.25, # CHECK-DAG: x_f32 = -2.250000e+00 : f32 167 # CHECK-DAG: x_f32arr = [2.000000e+00 : f32, 3.000000e+00 : f32] 168 x_f32arr=[2.0, 3.0], 169 x_f64=4.25, # CHECK-DAG: x_f64 = 4.250000e+00 : f64 170 x_f64arr=[4.0, 8.0], # CHECK-DAG: x_f64arr = [4.000000e+00, 8.000000e+00] 171 # CHECK-DAG: x_f64elems = dense<[3.952530e-323, 7.905050e-323]> : tensor<2xf64> 172 x_f64elems=[8.0, 16.0], 173 # CHECK-DAG: x_flatsymrefarr = [@symbol1, @symbol2] 174 x_flatsymrefarr=["symbol1", "symbol2"], 175 x_flatsymref="symbol3", # CHECK-DAG: x_flatsymref = @symbol3 176 x_i1=0, # CHECK-DAG: x_i1 = false 177 x_i16=42, # CHECK-DAG: x_i16 = 42 : i16 178 x_i32=6, # CHECK-DAG: x_i32 = 6 : i32 179 x_i32arr=[4, 5], # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32] 180 x_i32elems=[5, 6], # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xsi32> 181 x_i64=9, # CHECK-DAG: x_i64 = 9 : i64 182 x_i64arr=[7, 8], # CHECK-DAG: x_i64arr = [7, 8] 183 x_i64elems=[8, 9], # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xsi64> 184 x_i64svecarr=[10, 11], # CHECK-DAG: x_i64svecarr = [10, 11] 185 x_i8=11, # CHECK-DAG: x_i8 = 11 : i8 186 x_idx=10, # CHECK-DAG: x_idx = 10 : index 187 # CHECK-DAG: x_idxelems = dense<[11, 12]> : tensor<2xindex> 188 x_idxelems=[11, 12], 189 # CHECK-DAG: x_idxlistarr = [{{\[}}13], [14, 15]] 190 x_idxlistarr=[[13], [14, 15]], 191 x_si1=-1, # CHECK-DAG: x_si1 = -1 : si1 192 x_si16=-2, # CHECK-DAG: x_si16 = -2 : si16 193 x_si32=-3, # CHECK-DAG: x_si32 = -3 : si32 194 x_si64=-123, # CHECK-DAG: x_si64 = -123 : si64 195 x_si8=-4, # CHECK-DAG: x_si8 = -4 : si8 196 x_strarr=["hello", "world"], # CHECK-DAG: x_strarr = ["hello", "world"] 197 x_str="hello world!", # CHECK-DAG: x_str = "hello world!" 198 # CHECK-DAG: x_symrefarr = [@flatsym, @deep::@sym] 199 x_symrefarr=["flatsym", ["deep", "sym"]], 200 x_symref=["deep", "sym2"], # CHECK-DAG: x_symref = @deep::@sym2 201 x_sym="symbol", # CHECK-DAG: x_sym = "symbol" 202 x_typearr=[F32Type.get()], # CHECK-DAG: x_typearr = [f32] 203 x_type=F64Type.get(), # CHECK-DAG: x_type = f64 204 x_ui1=1, # CHECK-DAG: x_ui1 = 1 : ui1 205 x_ui16=2, # CHECK-DAG: x_ui16 = 2 : ui16 206 x_ui32=3, # CHECK-DAG: x_ui32 = 3 : ui32 207 x_ui64=4, # CHECK-DAG: x_ui64 = 4 : ui64 208 x_ui8=5, # CHECK-DAG: x_ui8 = 5 : ui8 209 x_unit=True, # CHECK-DAG: x_unit 210 ) 211 op.verify() 212 op.print(use_local_scope=True) 213 214 215# CHECK-LABEL: TEST: inferReturnTypes 216@run 217def inferReturnTypes(): 218 with Context() as ctx, Location.unknown(ctx): 219 test.register_python_test_dialect(ctx) 220 module = Module.create() 221 with InsertionPoint(module.body): 222 op = test.InferResultsOp() 223 dummy = test.DummyOp() 224 225 # CHECK: [Type(i32), Type(i64)] 226 iface = InferTypeOpInterface(op) 227 print(iface.inferReturnTypes()) 228 229 # CHECK: [Type(i32), Type(i64)] 230 iface_static = InferTypeOpInterface(test.InferResultsOp) 231 print(iface.inferReturnTypes()) 232 233 assert isinstance(iface.opview, test.InferResultsOp) 234 assert iface.opview == iface.operation.opview 235 236 try: 237 iface_static.opview 238 except TypeError: 239 pass 240 else: 241 assert False, ( 242 "not expected to be able to obtain an opview from a static" " interface" 243 ) 244 245 try: 246 InferTypeOpInterface(dummy) 247 except ValueError: 248 pass 249 else: 250 assert False, "not expected dummy op to implement the interface" 251 252 try: 253 InferTypeOpInterface(test.DummyOp) 254 except ValueError: 255 pass 256 else: 257 assert False, "not expected dummy op class to implement the interface" 258 259 260# CHECK-LABEL: TEST: resultTypesDefinedByTraits 261@run 262def resultTypesDefinedByTraits(): 263 with Context() as ctx, Location.unknown(ctx): 264 test.register_python_test_dialect(ctx) 265 module = Module.create() 266 with InsertionPoint(module.body): 267 inferred = test.InferResultsOp() 268 same = test.SameOperandAndResultTypeOp([inferred.results[0]]) 269 # CHECK-COUNT-2: i32 270 print(same.one.type) 271 print(same.two.type) 272 273 first_type_attr = test.FirstAttrDeriveTypeAttrOp( 274 inferred.results[1], TypeAttr.get(IndexType.get()) 275 ) 276 # CHECK-COUNT-2: index 277 print(first_type_attr.one.type) 278 print(first_type_attr.two.type) 279 280 first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14)) 281 # CHECK-COUNT-3: f32 282 print(first_attr.one.type) 283 print(first_attr.two.type) 284 print(first_attr.three.type) 285 286 implied = test.InferResultsImpliedOp() 287 # CHECK: i32 288 print(implied.integer.type) 289 # CHECK: f64 290 print(implied.flt.type) 291 # CHECK: index 292 print(implied.index.type) 293 294 295# CHECK-LABEL: TEST: testOptionalOperandOp 296@run 297def testOptionalOperandOp(): 298 with Context() as ctx, Location.unknown(): 299 test.register_python_test_dialect(ctx) 300 301 module = Module.create() 302 with InsertionPoint(module.body): 303 op1 = test.OptionalOperandOp() 304 # CHECK: op1.input is None: True 305 print(f"op1.input is None: {op1.input is None}") 306 307 op2 = test.OptionalOperandOp(input=op1) 308 # CHECK: op2.input is None: False 309 print(f"op2.input is None: {op2.input is None}") 310 311 312# CHECK-LABEL: TEST: testCustomAttribute 313@run 314def testCustomAttribute(): 315 with Context() as ctx: 316 test.register_python_test_dialect(ctx) 317 a = test.TestAttr.get() 318 # CHECK: #python_test.test_attr 319 print(a) 320 321 # The following cast must not assert. 322 b = test.TestAttr(a) 323 324 unit = UnitAttr.get() 325 try: 326 test.TestAttr(unit) 327 except ValueError as e: 328 assert "Cannot cast attribute to TestAttr" in str(e) 329 else: 330 raise 331 332 # The following must trigger a TypeError from our adaptors and must not 333 # crash. 334 try: 335 test.TestAttr(42) 336 except TypeError as e: 337 assert "Expected an MLIR object" in str(e) 338 else: 339 raise 340 341 # The following must trigger a TypeError from pybind (therefore, not 342 # checking its message) and must not crash. 343 try: 344 test.TestAttr(42, 56) 345 except TypeError: 346 pass 347 else: 348 raise 349 350 351@run 352def testCustomType(): 353 with Context() as ctx: 354 test.register_python_test_dialect(ctx) 355 a = test.TestType.get() 356 # CHECK: !python_test.test_type 357 print(a) 358 359 # The following cast must not assert. 360 b = test.TestType(a) 361 # Instance custom types should have typeids 362 assert isinstance(b.typeid, TypeID) 363 # Subclasses of ir.Type should not have a static_typeid 364 # CHECK: 'TestType' object has no attribute 'static_typeid' 365 try: 366 b.static_typeid 367 except AttributeError as e: 368 print(e) 369 370 i8 = IntegerType.get_signless(8) 371 try: 372 test.TestType(i8) 373 except ValueError as e: 374 assert "Cannot cast type to TestType" in str(e) 375 else: 376 raise 377 378 # The following must trigger a TypeError from our adaptors and must not 379 # crash. 380 try: 381 test.TestType(42) 382 except TypeError as e: 383 assert "Expected an MLIR object" in str(e) 384 else: 385 raise 386 387 # The following must trigger a TypeError from pybind (therefore, not 388 # checking its message) and must not crash. 389 try: 390 test.TestType(42, 56) 391 except TypeError: 392 pass 393 else: 394 raise 395 396 397@run 398# CHECK-LABEL: TEST: testTensorValue 399def testTensorValue(): 400 with Context() as ctx, Location.unknown(): 401 test.register_python_test_dialect(ctx) 402 403 i8 = IntegerType.get_signless(8) 404 405 class Tensor(test.TestTensorValue): 406 def __str__(self): 407 return super().__str__().replace("Value", "Tensor") 408 409 module = Module.create() 410 with InsertionPoint(module.body): 411 t = tensor.EmptyOp([10, 10], i8).result 412 413 # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>) 414 print(Value(t)) 415 416 tt = Tensor(t) 417 # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>) 418 print(tt) 419 420 # CHECK: False 421 print(tt.is_null()) 422 423 # Classes of custom types that inherit from concrete types should have 424 # static_typeid 425 assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID) 426 # And it should be equal to the in-tree concrete type 427 assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid 428 429 430# CHECK-LABEL: TEST: inferReturnTypeComponents 431@run 432def inferReturnTypeComponents(): 433 with Context() as ctx, Location.unknown(ctx): 434 test.register_python_test_dialect(ctx) 435 module = Module.create() 436 i32 = IntegerType.get_signless(32) 437 with InsertionPoint(module.body): 438 resultType = UnrankedTensorType.get(i32) 439 operandTypes = [ 440 RankedTensorType.get([1, 3, 10, 10], i32), 441 UnrankedTensorType.get(i32), 442 ] 443 f = func.FuncOp( 444 "test_inferReturnTypeComponents", (operandTypes, [resultType]) 445 ) 446 entry_block = Block.create_at_start(f.operation.regions[0], operandTypes) 447 with InsertionPoint(entry_block): 448 ranked_op = test.InferShapedTypeComponentsOp( 449 resultType, entry_block.arguments[0] 450 ) 451 unranked_op = test.InferShapedTypeComponentsOp( 452 resultType, entry_block.arguments[1] 453 ) 454 455 # CHECK: has rank: True 456 # CHECK: rank: 4 457 # CHECK: element type: i32 458 # CHECK: shape: [1, 3, 10, 10] 459 iface = InferShapedTypeOpInterface(ranked_op) 460 shaped_type_components = iface.inferReturnTypeComponents( 461 operands=[ranked_op.operand] 462 )[0] 463 print("has rank:", shaped_type_components.has_rank) 464 print("rank:", shaped_type_components.rank) 465 print("element type:", shaped_type_components.element_type) 466 print("shape:", shaped_type_components.shape) 467 468 # CHECK: has rank: False 469 # CHECK: rank: None 470 # CHECK: element type: i32 471 # CHECK: shape: None 472 iface = InferShapedTypeOpInterface(unranked_op) 473 shaped_type_components = iface.inferReturnTypeComponents( 474 operands=[unranked_op.operand] 475 )[0] 476 print("has rank:", shaped_type_components.has_rank) 477 print("rank:", shaped_type_components.rank) 478 print("element type:", shaped_type_components.element_type) 479 print("shape:", shaped_type_components.shape) 480 481 482# CHECK-LABEL: TEST: testCustomTypeTypeCaster 483@run 484def testCustomTypeTypeCaster(): 485 with Context() as ctx, Location.unknown(): 486 test.register_python_test_dialect(ctx) 487 488 a = test.TestType.get() 489 assert a.typeid is not None 490 491 b = Type.parse("!python_test.test_type") 492 # CHECK: !python_test.test_type 493 print(b) 494 # CHECK: TestType(!python_test.test_type) 495 print(repr(b)) 496 497 c = test.TestIntegerRankedTensorType.get([10, 10], 5) 498 # CHECK: tensor<10x10xi5> 499 print(c) 500 # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>) 501 print(repr(c)) 502 503 # CHECK: Type caster is already registered 504 try: 505 506 def type_caster(pytype): 507 return test.TestIntegerRankedTensorType(pytype) 508 509 register_type_caster(c.typeid, type_caster) 510 except RuntimeError as e: 511 print(e) 512 513 def type_caster(pytype): 514 return RankedTensorType(pytype) 515 516 # python_test dialect registers a caster for RankedTensorType in its extension (pybind) module. 517 # So this one replaces that one (successfully). And then just to be sure we restore the original caster below. 518 register_type_caster(c.typeid, type_caster, replace=True) 519 520 d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result 521 # CHECK: tensor<10x10xi5> 522 print(d.type) 523 # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>) 524 print("ranked tensor type", repr(d.type)) 525 526 def type_caster(pytype): 527 return test.TestIntegerRankedTensorType(pytype) 528 529 register_type_caster(c.typeid, type_caster, replace=True) 530 531 d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result 532 # CHECK: tensor<10x10xi5> 533 print(d.type) 534 # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>) 535 print(repr(d.type)) 536 537 538# CHECK-LABEL: TEST: testInferTypeOpInterface 539@run 540def testInferTypeOpInterface(): 541 with Context() as ctx, Location.unknown(ctx): 542 test.register_python_test_dialect(ctx) 543 module = Module.create() 544 with InsertionPoint(module.body): 545 i64 = IntegerType.get_signless(64) 546 zero = arith.ConstantOp(i64, 0) 547 548 one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None) 549 # CHECK: i32 550 print(one_operand.result.type) 551 552 two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero) 553 # CHECK: f32 554 print(two_operands.result.type) 555