1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import mlir.dialects.func as func 5import mlir.dialects.python_test as test 6import mlir.dialects.tensor as tensor 7import mlir.dialects.arith as arith 8 9test.register_python_test_dialect(get_dialect_registry()) 10 11 12def run(f): 13 print("\nTEST:", f.__name__) 14 f() 15 return f 16 17 18# CHECK-LABEL: TEST: testAttributes 19@run 20def testAttributes(): 21 with Context() as ctx, Location.unknown(): 22 # 23 # Check op construction with attributes. 24 # 25 26 i32 = IntegerType.get_signless(32) 27 one = IntegerAttr.get(i32, 1) 28 two = IntegerAttr.get(i32, 2) 29 unit = UnitAttr.get() 30 31 # CHECK: python_test.attributed_op { 32 # CHECK-DAG: mandatory_i32 = 1 : i32 33 # CHECK-DAG: optional_i32 = 2 : i32 34 # CHECK-DAG: unit 35 # CHECK: } 36 op = test.AttributedOp(one, optional_i32=two, unit=unit) 37 print(f"{op}") 38 39 # CHECK: python_test.attributed_op { 40 # CHECK: mandatory_i32 = 2 : i32 41 # CHECK: } 42 op2 = test.AttributedOp(two) 43 print(f"{op2}") 44 45 # 46 # Check generic "attributes" access and mutation. 47 # 48 49 assert "additional" not in op.attributes 50 51 # CHECK: python_test.attributed_op { 52 # CHECK-DAG: additional = 1 : i32 53 # CHECK-DAG: mandatory_i32 = 2 : i32 54 # CHECK: } 55 op2.attributes["additional"] = one 56 print(f"{op2}") 57 58 # CHECK: python_test.attributed_op { 59 # CHECK-DAG: additional = 2 : i32 60 # CHECK-DAG: mandatory_i32 = 2 : i32 61 # CHECK: } 62 op2.attributes["additional"] = two 63 print(f"{op2}") 64 65 # CHECK: python_test.attributed_op { 66 # CHECK-NOT: additional = 2 : i32 67 # CHECK: mandatory_i32 = 2 : i32 68 # CHECK: } 69 del op2.attributes["additional"] 70 print(f"{op2}") 71 72 try: 73 print(op.attributes["additional"]) 74 except KeyError: 75 pass 76 else: 77 assert False, "expected KeyError on unknown attribute key" 78 79 # 80 # Check accessors to defined attributes. 81 # 82 83 # CHECK: Mandatory: 1 84 # CHECK: Optional: 2 85 # CHECK: Unit: True 86 print(f"Mandatory: {op.mandatory_i32.value}") 87 print(f"Optional: {op.optional_i32.value}") 88 print(f"Unit: {op.unit}") 89 90 # CHECK: Mandatory: 2 91 # CHECK: Optional: None 92 # CHECK: Unit: False 93 print(f"Mandatory: {op2.mandatory_i32.value}") 94 print(f"Optional: {op2.optional_i32}") 95 print(f"Unit: {op2.unit}") 96 97 # CHECK: Mandatory: 2 98 # CHECK: Optional: None 99 # CHECK: Unit: False 100 op.mandatory_i32 = two 101 op.optional_i32 = None 102 op.unit = False 103 print(f"Mandatory: {op.mandatory_i32.value}") 104 print(f"Optional: {op.optional_i32}") 105 print(f"Unit: {op.unit}") 106 assert "optional_i32" not in op.attributes 107 assert "unit" not in op.attributes 108 109 try: 110 op.mandatory_i32 = None 111 except ValueError: 112 pass 113 else: 114 assert False, "expected ValueError on setting a mandatory attribute to None" 115 116 # CHECK: Optional: 2 117 op.optional_i32 = two 118 print(f"Optional: {op.optional_i32.value}") 119 120 # CHECK: Optional: None 121 del op.optional_i32 122 print(f"Optional: {op.optional_i32}") 123 124 # CHECK: Unit: False 125 op.unit = None 126 print(f"Unit: {op.unit}") 127 assert "unit" not in op.attributes 128 129 # CHECK: Unit: True 130 op.unit = True 131 print(f"Unit: {op.unit}") 132 133 # CHECK: Unit: False 134 del op.unit 135 print(f"Unit: {op.unit}") 136 137 138# CHECK-LABEL: TEST: attrBuilder 139@run 140def attrBuilder(): 141 with Context() as ctx, Location.unknown(): 142 # CHECK: python_test.attributes_op 143 op = test.AttributesOp( 144 # CHECK-DAG: x_affinemap = affine_map<() -> (2)> 145 x_affinemap=AffineMap.get_constant(2), 146 # CHECK-DAG: x_affinemaparr = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] 147 x_affinemaparr=[AffineMap.get_identity(3)], 148 # CHECK-DAG: x_arr = [true, "x"] 149 x_arr=[BoolAttr.get(True), StringAttr.get("x")], 150 x_boolarr=[False, True], # CHECK-DAG: x_boolarr = [false, true] 151 x_bool=True, # CHECK-DAG: x_bool = true 152 x_dboolarr=[True, False], # CHECK-DAG: x_dboolarr = array<i1: true, false> 153 x_df16arr=[21, 22], # CHECK-DAG: x_df16arr = array<i16: 21, 22> 154 # CHECK-DAG: x_df32arr = array<f32: 2.300000e+01, 2.400000e+01> 155 x_df32arr=[23, 24], 156 # CHECK-DAG: x_df64arr = array<f64: 2.500000e+01, 2.600000e+01> 157 x_df64arr=[25, 26], 158 x_di32arr=[0, 1], # CHECK-DAG: x_di32arr = array<i32: 0, 1> 159 # CHECK-DAG: x_di64arr = array<i64: 1, 2> 160 x_di64arr=[1, 2], 161 x_di8arr=[2, 3], # CHECK-DAG: x_di8arr = array<i8: 2, 3> 162 # CHECK-DAG: x_dictarr = [{a = false}] 163 x_dictarr=[{"a": BoolAttr.get(False)}], 164 x_dict={"b": BoolAttr.get(True)}, # CHECK-DAG: x_dict = {b = true} 165 x_f32=-2.25, # CHECK-DAG: x_f32 = -2.250000e+00 : f32 166 # CHECK-DAG: x_f32arr = [2.000000e+00 : f32, 3.000000e+00 : f32] 167 x_f32arr=[2.0, 3.0], 168 x_f64=4.25, # CHECK-DAG: x_f64 = 4.250000e+00 : f64 169 x_f64arr=[4.0, 8.0], # CHECK-DAG: x_f64arr = [4.000000e+00, 8.000000e+00] 170 # CHECK-DAG: x_f64elems = dense<[3.952530e-323, 7.905050e-323]> : tensor<2xf64> 171 x_f64elems=[8.0, 16.0], 172 # CHECK-DAG: x_flatsymrefarr = [@symbol1, @symbol2] 173 x_flatsymrefarr=["symbol1", "symbol2"], 174 x_flatsymref="symbol3", # CHECK-DAG: x_flatsymref = @symbol3 175 x_i1=0, # CHECK-DAG: x_i1 = false 176 x_i16=42, # CHECK-DAG: x_i16 = 42 : i16 177 x_i32=6, # CHECK-DAG: x_i32 = 6 : i32 178 x_i32arr=[4, 5], # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32] 179 x_i32elems=[5, 6], # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xi32> 180 x_i64=9, # CHECK-DAG: x_i64 = 9 : i64 181 x_i64arr=[7, 8], # CHECK-DAG: x_i64arr = [7, 8] 182 x_i64elems=[8, 9], # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xi64> 183 x_i64svecarr=[10, 11], # CHECK-DAG: x_i64svecarr = [10, 11] 184 x_i8=11, # CHECK-DAG: x_i8 = 11 : i8 185 x_idx=10, # CHECK-DAG: x_idx = 10 : index 186 # CHECK-DAG: x_idxelems = dense<[11, 12]> : tensor<2xindex> 187 x_idxelems=[11, 12], 188 # CHECK-DAG: x_idxlistarr = [{{\[}}13], [14, 15]] 189 x_idxlistarr=[[13], [14, 15]], 190 x_si1=-1, # CHECK-DAG: x_si1 = -1 : si1 191 x_si16=-2, # CHECK-DAG: x_si16 = -2 : si16 192 x_si32=-3, # CHECK-DAG: x_si32 = -3 : si32 193 x_si64=-123, # CHECK-DAG: x_si64 = -123 : si64 194 x_si8=-4, # CHECK-DAG: x_si8 = -4 : si8 195 x_strarr=["hello", "world"], # CHECK-DAG: x_strarr = ["hello", "world"] 196 x_str="hello world!", # CHECK-DAG: x_str = "hello world!" 197 # CHECK-DAG: x_symrefarr = [@flatsym, @deep::@sym] 198 x_symrefarr=["flatsym", ["deep", "sym"]], 199 x_symref=["deep", "sym2"], # CHECK-DAG: x_symref = @deep::@sym2 200 x_sym="symbol", # CHECK-DAG: x_sym = "symbol" 201 x_typearr=[F32Type.get()], # CHECK-DAG: x_typearr = [f32] 202 x_type=F64Type.get(), # CHECK-DAG: x_type = f64 203 x_ui1=1, # CHECK-DAG: x_ui1 = 1 : ui1 204 x_ui16=2, # CHECK-DAG: x_ui16 = 2 : ui16 205 x_ui32=3, # CHECK-DAG: x_ui32 = 3 : ui32 206 x_ui64=4, # CHECK-DAG: x_ui64 = 4 : ui64 207 x_ui8=5, # CHECK-DAG: x_ui8 = 5 : ui8 208 x_unit=True, # CHECK-DAG: x_unit 209 ) 210 op.verify() 211 op.print(use_local_scope=True) 212 213 214# CHECK-LABEL: TEST: inferReturnTypes 215@run 216def inferReturnTypes(): 217 with Context() as ctx, Location.unknown(ctx): 218 module = Module.create() 219 with InsertionPoint(module.body): 220 op = test.InferResultsOp() 221 dummy = test.DummyOp() 222 223 # CHECK: [Type(i32), Type(i64)] 224 iface = InferTypeOpInterface(op) 225 print(iface.inferReturnTypes()) 226 227 # CHECK: [Type(i32), Type(i64)] 228 iface_static = InferTypeOpInterface(test.InferResultsOp) 229 print(iface.inferReturnTypes()) 230 231 assert isinstance(iface.opview, test.InferResultsOp) 232 assert iface.opview == iface.operation.opview 233 234 try: 235 iface_static.opview 236 except TypeError: 237 pass 238 else: 239 assert False, ( 240 "not expected to be able to obtain an opview from a static" " interface" 241 ) 242 243 try: 244 InferTypeOpInterface(dummy) 245 except ValueError: 246 pass 247 else: 248 assert False, "not expected dummy op to implement the interface" 249 250 try: 251 InferTypeOpInterface(test.DummyOp) 252 except ValueError: 253 pass 254 else: 255 assert False, "not expected dummy op class to implement the interface" 256 257 258# CHECK-LABEL: TEST: resultTypesDefinedByTraits 259@run 260def resultTypesDefinedByTraits(): 261 with Context() as ctx, Location.unknown(ctx): 262 module = Module.create() 263 with InsertionPoint(module.body): 264 inferred = test.InferResultsOp() 265 same = test.SameOperandAndResultTypeOp([inferred.results[0]]) 266 # CHECK-COUNT-2: i32 267 print(same.one.type) 268 print(same.two.type) 269 270 first_type_attr = test.FirstAttrDeriveTypeAttrOp( 271 inferred.results[1], TypeAttr.get(IndexType.get()) 272 ) 273 # CHECK-COUNT-2: index 274 print(first_type_attr.one.type) 275 print(first_type_attr.two.type) 276 277 first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14)) 278 # CHECK-COUNT-3: f32 279 print(first_attr.one.type) 280 print(first_attr.two.type) 281 print(first_attr.three.type) 282 283 implied = test.InferResultsImpliedOp() 284 # CHECK: i32 285 print(implied.integer.type) 286 # CHECK: f64 287 print(implied.flt.type) 288 # CHECK: index 289 print(implied.index.type) 290 291 292# CHECK-LABEL: TEST: testOptionalOperandOp 293@run 294def testOptionalOperandOp(): 295 with Context() as ctx, Location.unknown(): 296 module = Module.create() 297 with InsertionPoint(module.body): 298 op1 = test.OptionalOperandOp() 299 # CHECK: op1.input is None: True 300 print(f"op1.input is None: {op1.input is None}") 301 302 op2 = test.OptionalOperandOp(input=op1) 303 # CHECK: op2.input is None: False 304 print(f"op2.input is None: {op2.input is None}") 305 306 307# CHECK-LABEL: TEST: testCustomAttribute 308@run 309def testCustomAttribute(): 310 with Context() as ctx: 311 a = test.TestAttr.get() 312 # CHECK: #python_test.test_attr 313 print(a) 314 315 # The following cast must not assert. 316 b = test.TestAttr(a) 317 318 unit = UnitAttr.get() 319 try: 320 test.TestAttr(unit) 321 except ValueError as e: 322 assert "Cannot cast attribute to TestAttr" in str(e) 323 else: 324 raise 325 326 # The following must trigger a TypeError from our adaptors and must not 327 # crash. 328 try: 329 test.TestAttr(42) 330 except TypeError as e: 331 assert "Expected an MLIR object" in str(e) 332 else: 333 raise 334 335 # The following must trigger a TypeError from pybind (therefore, not 336 # checking its message) and must not crash. 337 try: 338 test.TestAttr(42, 56) 339 except TypeError: 340 pass 341 else: 342 raise 343 344 345@run 346def testCustomType(): 347 with Context() as ctx: 348 a = test.TestType.get() 349 # CHECK: !python_test.test_type 350 print(a) 351 352 # The following cast must not assert. 353 b = test.TestType(a) 354 # Instance custom types should have typeids 355 assert isinstance(b.typeid, TypeID) 356 # Subclasses of ir.Type should not have a static_typeid 357 # CHECK: 'TestType' object has no attribute 'static_typeid' 358 try: 359 b.static_typeid 360 except AttributeError as e: 361 print(e) 362 363 i8 = IntegerType.get_signless(8) 364 try: 365 test.TestType(i8) 366 except ValueError as e: 367 assert "Cannot cast type to TestType" in str(e) 368 else: 369 raise 370 371 # The following must trigger a TypeError from our adaptors and must not 372 # crash. 373 try: 374 test.TestType(42) 375 except TypeError as e: 376 assert "Expected an MLIR object" in str(e) 377 else: 378 raise 379 380 # The following must trigger a TypeError from pybind (therefore, not 381 # checking its message) and must not crash. 382 try: 383 test.TestType(42, 56) 384 except TypeError: 385 pass 386 else: 387 raise 388 389 390@run 391# CHECK-LABEL: TEST: testTensorValue 392def testTensorValue(): 393 with Context() as ctx, Location.unknown(): 394 i8 = IntegerType.get_signless(8) 395 396 class Tensor(test.TestTensorValue): 397 def __str__(self): 398 return super().__str__().replace("Value", "Tensor") 399 400 module = Module.create() 401 with InsertionPoint(module.body): 402 t = tensor.EmptyOp([10, 10], i8).result 403 404 # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>) 405 print(Value(t)) 406 407 tt = Tensor(t) 408 # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>) 409 print(tt) 410 411 # CHECK: False 412 print(tt.is_null()) 413 414 # Classes of custom types that inherit from concrete types should have 415 # static_typeid 416 assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID) 417 # And it should be equal to the in-tree concrete type 418 assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid 419 420 d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result 421 # CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>) 422 print(d) 423 # CHECK: TestTensorValue 424 print(repr(d)) 425 426 427# CHECK-LABEL: TEST: inferReturnTypeComponents 428@run 429def inferReturnTypeComponents(): 430 with Context() as ctx, Location.unknown(ctx): 431 module = Module.create() 432 i32 = IntegerType.get_signless(32) 433 with InsertionPoint(module.body): 434 resultType = UnrankedTensorType.get(i32) 435 operandTypes = [ 436 RankedTensorType.get([1, 3, 10, 10], i32), 437 UnrankedTensorType.get(i32), 438 ] 439 f = func.FuncOp( 440 "test_inferReturnTypeComponents", (operandTypes, [resultType]) 441 ) 442 entry_block = Block.create_at_start(f.operation.regions[0], operandTypes) 443 with InsertionPoint(entry_block): 444 ranked_op = test.InferShapedTypeComponentsOp( 445 resultType, entry_block.arguments[0] 446 ) 447 unranked_op = test.InferShapedTypeComponentsOp( 448 resultType, entry_block.arguments[1] 449 ) 450 451 # CHECK: has rank: True 452 # CHECK: rank: 4 453 # CHECK: element type: i32 454 # CHECK: shape: [1, 3, 10, 10] 455 iface = InferShapedTypeOpInterface(ranked_op) 456 shaped_type_components = iface.inferReturnTypeComponents( 457 operands=[ranked_op.operand] 458 )[0] 459 print("has rank:", shaped_type_components.has_rank) 460 print("rank:", shaped_type_components.rank) 461 print("element type:", shaped_type_components.element_type) 462 print("shape:", shaped_type_components.shape) 463 464 # CHECK: has rank: False 465 # CHECK: rank: None 466 # CHECK: element type: i32 467 # CHECK: shape: None 468 iface = InferShapedTypeOpInterface(unranked_op) 469 shaped_type_components = iface.inferReturnTypeComponents( 470 operands=[unranked_op.operand] 471 )[0] 472 print("has rank:", shaped_type_components.has_rank) 473 print("rank:", shaped_type_components.rank) 474 print("element type:", shaped_type_components.element_type) 475 print("shape:", shaped_type_components.shape) 476 477 478# CHECK-LABEL: TEST: testCustomTypeTypeCaster 479@run 480def testCustomTypeTypeCaster(): 481 with Context() as ctx, Location.unknown(): 482 a = test.TestType.get() 483 assert a.typeid is not None 484 485 b = Type.parse("!python_test.test_type") 486 # CHECK: !python_test.test_type 487 print(b) 488 # CHECK: TestType(!python_test.test_type) 489 print(repr(b)) 490 491 c = test.TestIntegerRankedTensorType.get([10, 10], 5) 492 # CHECK: tensor<10x10xi5> 493 print(c) 494 # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>) 495 print(repr(c)) 496 497 # CHECK: Type caster is already registered 498 try: 499 500 @register_type_caster(c.typeid) 501 def type_caster(pytype): 502 return test.TestIntegerRankedTensorType(pytype) 503 504 except RuntimeError as e: 505 print(e) 506 507 # python_test dialect registers a caster for RankedTensorType in its extension (pybind) module. 508 # So this one replaces that one (successfully). And then just to be sure we restore the original caster below. 509 @register_type_caster(c.typeid, replace=True) 510 def type_caster(pytype): 511 return RankedTensorType(pytype) 512 513 d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result 514 # CHECK: tensor<10x10xi5> 515 print(d.type) 516 # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>) 517 print("ranked tensor type", repr(d.type)) 518 519 @register_type_caster(c.typeid, replace=True) 520 def type_caster(pytype): 521 return test.TestIntegerRankedTensorType(pytype) 522 523 d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result 524 # CHECK: tensor<10x10xi5> 525 print(d.type) 526 # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>) 527 print(repr(d.type)) 528 529 530# CHECK-LABEL: TEST: testInferTypeOpInterface 531@run 532def testInferTypeOpInterface(): 533 with Context() as ctx, Location.unknown(ctx): 534 module = Module.create() 535 with InsertionPoint(module.body): 536 i64 = IntegerType.get_signless(64) 537 zero = arith.ConstantOp(i64, 0) 538 539 one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None) 540 # CHECK: i32 541 print(one_operand.result.type) 542 543 two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero) 544 # CHECK: f32 545 print(two_operands.result.type) 546