1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import mlir.dialects.func as func 5import mlir.dialects.python_test as test 6import mlir.dialects.tensor as tensor 7import mlir.dialects.arith as arith 8 9 10def run(f): 11 print("\nTEST:", f.__name__) 12 f() 13 return f 14 15 16# CHECK-LABEL: TEST: testAttributes 17@run 18def testAttributes(): 19 with Context() as ctx, Location.unknown(): 20 ctx.allow_unregistered_dialects = True 21 22 # 23 # Check op construction with attributes. 24 # 25 26 i32 = IntegerType.get_signless(32) 27 one = IntegerAttr.get(i32, 1) 28 two = IntegerAttr.get(i32, 2) 29 unit = UnitAttr.get() 30 31 # CHECK: "python_test.attributed_op"() { 32 # CHECK-DAG: mandatory_i32 = 1 : i32 33 # CHECK-DAG: optional_i32 = 2 : i32 34 # CHECK-DAG: unit 35 # CHECK: } 36 op = test.AttributedOp(one, optional_i32=two, unit=unit) 37 print(f"{op}") 38 39 # CHECK: "python_test.attributed_op"() { 40 # CHECK: mandatory_i32 = 2 : i32 41 # CHECK: } 42 op2 = test.AttributedOp(two) 43 print(f"{op2}") 44 45 # 46 # Check generic "attributes" access and mutation. 47 # 48 49 assert "additional" not in op.attributes 50 51 # CHECK: "python_test.attributed_op"() { 52 # CHECK-DAG: additional = 1 : i32 53 # CHECK-DAG: mandatory_i32 = 2 : i32 54 # CHECK: } 55 op2.attributes["additional"] = one 56 print(f"{op2}") 57 58 # CHECK: "python_test.attributed_op"() { 59 # CHECK-DAG: additional = 2 : i32 60 # CHECK-DAG: mandatory_i32 = 2 : i32 61 # CHECK: } 62 op2.attributes["additional"] = two 63 print(f"{op2}") 64 65 # CHECK: "python_test.attributed_op"() { 66 # CHECK-NOT: additional = 2 : i32 67 # CHECK: mandatory_i32 = 2 : i32 68 # CHECK: } 69 del op2.attributes["additional"] 70 print(f"{op2}") 71 72 try: 73 print(op.attributes["additional"]) 74 except KeyError: 75 pass 76 else: 77 assert False, "expected KeyError on unknown attribute key" 78 79 # 80 # Check accessors to defined attributes. 81 # 82 83 # CHECK: Mandatory: 1 84 # CHECK: Optional: 2 85 # CHECK: Unit: True 86 print(f"Mandatory: {op.mandatory_i32.value}") 87 print(f"Optional: {op.optional_i32.value}") 88 print(f"Unit: {op.unit}") 89 90 # CHECK: Mandatory: 2 91 # CHECK: Optional: None 92 # CHECK: Unit: False 93 print(f"Mandatory: {op2.mandatory_i32.value}") 94 print(f"Optional: {op2.optional_i32}") 95 print(f"Unit: {op2.unit}") 96 97 # CHECK: Mandatory: 2 98 # CHECK: Optional: None 99 # CHECK: Unit: False 100 op.mandatory_i32 = two 101 op.optional_i32 = None 102 op.unit = False 103 print(f"Mandatory: {op.mandatory_i32.value}") 104 print(f"Optional: {op.optional_i32}") 105 print(f"Unit: {op.unit}") 106 assert "optional_i32" not in op.attributes 107 assert "unit" not in op.attributes 108 109 try: 110 op.mandatory_i32 = None 111 except ValueError: 112 pass 113 else: 114 assert False, "expected ValueError on setting a mandatory attribute to None" 115 116 # CHECK: Optional: 2 117 op.optional_i32 = two 118 print(f"Optional: {op.optional_i32.value}") 119 120 # CHECK: Optional: None 121 del op.optional_i32 122 print(f"Optional: {op.optional_i32}") 123 124 # CHECK: Unit: False 125 op.unit = None 126 print(f"Unit: {op.unit}") 127 assert "unit" not in op.attributes 128 129 # CHECK: Unit: True 130 op.unit = True 131 print(f"Unit: {op.unit}") 132 133 # CHECK: Unit: False 134 del op.unit 135 print(f"Unit: {op.unit}") 136 137 138# CHECK-LABEL: TEST: attrBuilder 139@run 140def attrBuilder(): 141 with Context() as ctx, Location.unknown(): 142 ctx.allow_unregistered_dialects = True 143 op = test.AttributesOp( 144 x_bool=True, 145 x_i16=1, 146 x_i32=2, 147 x_i64=3, 148 x_si16=-1, 149 x_si32=-2, 150 x_f32=1.5, 151 x_f64=2.5, 152 x_str="x_str", 153 x_i32_array=[1, 2, 3], 154 x_i64_array=[4, 5, 6], 155 x_f32_array=[1.5, -2.5, 3.5], 156 x_f64_array=[4.5, 5.5, -6.5], 157 x_i64_dense=[1, 2, 3, 4, 5, 6], 158 ) 159 print(op) 160 161 162# CHECK-LABEL: TEST: inferReturnTypes 163@run 164def inferReturnTypes(): 165 with Context() as ctx, Location.unknown(ctx): 166 test.register_python_test_dialect(ctx) 167 module = Module.create() 168 with InsertionPoint(module.body): 169 op = test.InferResultsOp() 170 dummy = test.DummyOp() 171 172 # CHECK: [Type(i32), Type(i64)] 173 iface = InferTypeOpInterface(op) 174 print(iface.inferReturnTypes()) 175 176 # CHECK: [Type(i32), Type(i64)] 177 iface_static = InferTypeOpInterface(test.InferResultsOp) 178 print(iface.inferReturnTypes()) 179 180 assert isinstance(iface.opview, test.InferResultsOp) 181 assert iface.opview == iface.operation.opview 182 183 try: 184 iface_static.opview 185 except TypeError: 186 pass 187 else: 188 assert False, ( 189 "not expected to be able to obtain an opview from a static" " interface" 190 ) 191 192 try: 193 InferTypeOpInterface(dummy) 194 except ValueError: 195 pass 196 else: 197 assert False, "not expected dummy op to implement the interface" 198 199 try: 200 InferTypeOpInterface(test.DummyOp) 201 except ValueError: 202 pass 203 else: 204 assert False, "not expected dummy op class to implement the interface" 205 206 207# CHECK-LABEL: TEST: resultTypesDefinedByTraits 208@run 209def resultTypesDefinedByTraits(): 210 with Context() as ctx, Location.unknown(ctx): 211 test.register_python_test_dialect(ctx) 212 module = Module.create() 213 with InsertionPoint(module.body): 214 inferred = test.InferResultsOp() 215 same = test.SameOperandAndResultTypeOp([inferred.results[0]]) 216 # CHECK-COUNT-2: i32 217 print(same.one.type) 218 print(same.two.type) 219 220 first_type_attr = test.FirstAttrDeriveTypeAttrOp( 221 inferred.results[1], TypeAttr.get(IndexType.get()) 222 ) 223 # CHECK-COUNT-2: index 224 print(first_type_attr.one.type) 225 print(first_type_attr.two.type) 226 227 first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14)) 228 # CHECK-COUNT-3: f32 229 print(first_attr.one.type) 230 print(first_attr.two.type) 231 print(first_attr.three.type) 232 233 implied = test.InferResultsImpliedOp() 234 # CHECK: i32 235 print(implied.integer.type) 236 # CHECK: f64 237 print(implied.flt.type) 238 # CHECK: index 239 print(implied.index.type) 240 241 242# CHECK-LABEL: TEST: testOptionalOperandOp 243@run 244def testOptionalOperandOp(): 245 with Context() as ctx, Location.unknown(): 246 test.register_python_test_dialect(ctx) 247 248 module = Module.create() 249 with InsertionPoint(module.body): 250 251 op1 = test.OptionalOperandOp() 252 # CHECK: op1.input is None: True 253 print(f"op1.input is None: {op1.input is None}") 254 255 op2 = test.OptionalOperandOp(input=op1) 256 # CHECK: op2.input is None: False 257 print(f"op2.input is None: {op2.input is None}") 258 259 260# CHECK-LABEL: TEST: testCustomAttribute 261@run 262def testCustomAttribute(): 263 with Context() as ctx: 264 test.register_python_test_dialect(ctx) 265 a = test.TestAttr.get() 266 # CHECK: #python_test.test_attr 267 print(a) 268 269 # The following cast must not assert. 270 b = test.TestAttr(a) 271 272 unit = UnitAttr.get() 273 try: 274 test.TestAttr(unit) 275 except ValueError as e: 276 assert "Cannot cast attribute to TestAttr" in str(e) 277 else: 278 raise 279 280 # The following must trigger a TypeError from our adaptors and must not 281 # crash. 282 try: 283 test.TestAttr(42) 284 except TypeError as e: 285 assert "Expected an MLIR object" in str(e) 286 else: 287 raise 288 289 # The following must trigger a TypeError from pybind (therefore, not 290 # checking its message) and must not crash. 291 try: 292 test.TestAttr(42, 56) 293 except TypeError: 294 pass 295 else: 296 raise 297 298 299@run 300def testCustomType(): 301 with Context() as ctx: 302 test.register_python_test_dialect(ctx) 303 a = test.TestType.get() 304 # CHECK: !python_test.test_type 305 print(a) 306 307 # The following cast must not assert. 308 b = test.TestType(a) 309 # Instance custom types should have typeids 310 assert isinstance(b.typeid, TypeID) 311 # Subclasses of ir.Type should not have a static_typeid 312 # CHECK: 'TestType' object has no attribute 'static_typeid' 313 try: 314 b.static_typeid 315 except AttributeError as e: 316 print(e) 317 318 i8 = IntegerType.get_signless(8) 319 try: 320 test.TestType(i8) 321 except ValueError as e: 322 assert "Cannot cast type to TestType" in str(e) 323 else: 324 raise 325 326 # The following must trigger a TypeError from our adaptors and must not 327 # crash. 328 try: 329 test.TestType(42) 330 except TypeError as e: 331 assert "Expected an MLIR object" in str(e) 332 else: 333 raise 334 335 # The following must trigger a TypeError from pybind (therefore, not 336 # checking its message) and must not crash. 337 try: 338 test.TestType(42, 56) 339 except TypeError: 340 pass 341 else: 342 raise 343 344 345@run 346# CHECK-LABEL: TEST: testTensorValue 347def testTensorValue(): 348 with Context() as ctx, Location.unknown(): 349 test.register_python_test_dialect(ctx) 350 351 i8 = IntegerType.get_signless(8) 352 353 class Tensor(test.TestTensorValue): 354 def __str__(self): 355 return super().__str__().replace("Value", "Tensor") 356 357 module = Module.create() 358 with InsertionPoint(module.body): 359 t = tensor.EmptyOp([10, 10], i8).result 360 361 # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>) 362 print(Value(t)) 363 364 tt = Tensor(t) 365 # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>) 366 print(tt) 367 368 # CHECK: False 369 print(tt.is_null()) 370 371 # Classes of custom types that inherit from concrete types should have 372 # static_typeid 373 assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID) 374 # And it should be equal to the in-tree concrete type 375 assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid 376 377 378# CHECK-LABEL: TEST: inferReturnTypeComponents 379@run 380def inferReturnTypeComponents(): 381 with Context() as ctx, Location.unknown(ctx): 382 test.register_python_test_dialect(ctx) 383 module = Module.create() 384 i32 = IntegerType.get_signless(32) 385 with InsertionPoint(module.body): 386 resultType = UnrankedTensorType.get(i32) 387 operandTypes = [ 388 RankedTensorType.get([1, 3, 10, 10], i32), 389 UnrankedTensorType.get(i32), 390 ] 391 f = func.FuncOp( 392 "test_inferReturnTypeComponents", (operandTypes, [resultType]) 393 ) 394 entry_block = Block.create_at_start(f.operation.regions[0], operandTypes) 395 with InsertionPoint(entry_block): 396 ranked_op = test.InferShapedTypeComponentsOp( 397 resultType, entry_block.arguments[0] 398 ) 399 unranked_op = test.InferShapedTypeComponentsOp( 400 resultType, entry_block.arguments[1] 401 ) 402 403 # CHECK: has rank: True 404 # CHECK: rank: 4 405 # CHECK: element type: i32 406 # CHECK: shape: [1, 3, 10, 10] 407 iface = InferShapedTypeOpInterface(ranked_op) 408 shaped_type_components = iface.inferReturnTypeComponents( 409 operands=[ranked_op.operand] 410 )[0] 411 print("has rank:", shaped_type_components.has_rank) 412 print("rank:", shaped_type_components.rank) 413 print("element type:", shaped_type_components.element_type) 414 print("shape:", shaped_type_components.shape) 415 416 # CHECK: has rank: False 417 # CHECK: rank: None 418 # CHECK: element type: i32 419 # CHECK: shape: None 420 iface = InferShapedTypeOpInterface(unranked_op) 421 shaped_type_components = iface.inferReturnTypeComponents( 422 operands=[unranked_op.operand] 423 )[0] 424 print("has rank:", shaped_type_components.has_rank) 425 print("rank:", shaped_type_components.rank) 426 print("element type:", shaped_type_components.element_type) 427 print("shape:", shaped_type_components.shape) 428 429 430# CHECK-LABEL: TEST: testCustomTypeTypeCaster 431@run 432def testCustomTypeTypeCaster(): 433 with Context() as ctx, Location.unknown(): 434 test.register_python_test_dialect(ctx) 435 436 a = test.TestType.get() 437 assert a.typeid is not None 438 439 b = Type.parse("!python_test.test_type") 440 # CHECK: !python_test.test_type 441 print(b) 442 # CHECK: TestType(!python_test.test_type) 443 print(repr(b)) 444 445 c = test.TestIntegerRankedTensorType.get([10, 10], 5) 446 # CHECK: tensor<10x10xi5> 447 print(c) 448 # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>) 449 print(repr(c)) 450 451 # CHECK: Type caster is already registered 452 try: 453 454 def type_caster(pytype): 455 return test.TestIntegerRankedTensorType(pytype) 456 457 register_type_caster(c.typeid, type_caster) 458 except RuntimeError as e: 459 print(e) 460 461 def type_caster(pytype): 462 return test.TestIntegerRankedTensorType(pytype) 463 464 register_type_caster(c.typeid, type_caster, replace=True) 465 466 d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result 467 # CHECK: tensor<10x10xi5> 468 print(d.type) 469 # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>) 470 print(repr(d.type)) 471 472 473# CHECK-LABEL: TEST: testInferTypeOpInterface 474@run 475def testInferTypeOpInterface(): 476 with Context() as ctx, Location.unknown(ctx): 477 test.register_python_test_dialect(ctx) 478 module = Module.create() 479 with InsertionPoint(module.body): 480 i64 = IntegerType.get_signless(64) 481 zero = arith.ConstantOp(i64, 0) 482 483 one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None) 484 # CHECK: i32 485 print(one_operand.result.type) 486 487 two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero) 488 # CHECK: f32 489 print(two_operands.result.type) 490