1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import mlir.dialects.func as func 5import mlir.dialects.python_test as test 6import mlir.dialects.tensor as tensor 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 return f 13 14 15# CHECK-LABEL: TEST: testAttributes 16@run 17def testAttributes(): 18 with Context() as ctx, Location.unknown(): 19 ctx.allow_unregistered_dialects = True 20 21 # 22 # Check op construction with attributes. 23 # 24 25 i32 = IntegerType.get_signless(32) 26 one = IntegerAttr.get(i32, 1) 27 two = IntegerAttr.get(i32, 2) 28 unit = UnitAttr.get() 29 30 # CHECK: "python_test.attributed_op"() { 31 # CHECK-DAG: mandatory_i32 = 1 : i32 32 # CHECK-DAG: optional_i32 = 2 : i32 33 # CHECK-DAG: unit 34 # CHECK: } 35 op = test.AttributedOp(one, optional_i32=two, unit=unit) 36 print(f"{op}") 37 38 # CHECK: "python_test.attributed_op"() { 39 # CHECK: mandatory_i32 = 2 : i32 40 # CHECK: } 41 op2 = test.AttributedOp(two) 42 print(f"{op2}") 43 44 # 45 # Check generic "attributes" access and mutation. 46 # 47 48 assert "additional" not in op.attributes 49 50 # CHECK: "python_test.attributed_op"() { 51 # CHECK-DAG: additional = 1 : i32 52 # CHECK-DAG: mandatory_i32 = 2 : i32 53 # CHECK: } 54 op2.attributes["additional"] = one 55 print(f"{op2}") 56 57 # CHECK: "python_test.attributed_op"() { 58 # CHECK-DAG: additional = 2 : i32 59 # CHECK-DAG: mandatory_i32 = 2 : i32 60 # CHECK: } 61 op2.attributes["additional"] = two 62 print(f"{op2}") 63 64 # CHECK: "python_test.attributed_op"() { 65 # CHECK-NOT: additional = 2 : i32 66 # CHECK: mandatory_i32 = 2 : i32 67 # CHECK: } 68 del op2.attributes["additional"] 69 print(f"{op2}") 70 71 try: 72 print(op.attributes["additional"]) 73 except KeyError: 74 pass 75 else: 76 assert False, "expected KeyError on unknown attribute key" 77 78 # 79 # Check accessors to defined attributes. 80 # 81 82 # CHECK: Mandatory: 1 83 # CHECK: Optional: 2 84 # CHECK: Unit: True 85 print(f"Mandatory: {op.mandatory_i32.value}") 86 print(f"Optional: {op.optional_i32.value}") 87 print(f"Unit: {op.unit}") 88 89 # CHECK: Mandatory: 2 90 # CHECK: Optional: None 91 # CHECK: Unit: False 92 print(f"Mandatory: {op2.mandatory_i32.value}") 93 print(f"Optional: {op2.optional_i32}") 94 print(f"Unit: {op2.unit}") 95 96 # CHECK: Mandatory: 2 97 # CHECK: Optional: None 98 # CHECK: Unit: False 99 op.mandatory_i32 = two 100 op.optional_i32 = None 101 op.unit = False 102 print(f"Mandatory: {op.mandatory_i32.value}") 103 print(f"Optional: {op.optional_i32}") 104 print(f"Unit: {op.unit}") 105 assert "optional_i32" not in op.attributes 106 assert "unit" not in op.attributes 107 108 try: 109 op.mandatory_i32 = None 110 except ValueError: 111 pass 112 else: 113 assert False, "expected ValueError on setting a mandatory attribute to None" 114 115 # CHECK: Optional: 2 116 op.optional_i32 = two 117 print(f"Optional: {op.optional_i32.value}") 118 119 # CHECK: Optional: None 120 del op.optional_i32 121 print(f"Optional: {op.optional_i32}") 122 123 # CHECK: Unit: False 124 op.unit = None 125 print(f"Unit: {op.unit}") 126 assert "unit" not in op.attributes 127 128 # CHECK: Unit: True 129 op.unit = True 130 print(f"Unit: {op.unit}") 131 132 # CHECK: Unit: False 133 del op.unit 134 print(f"Unit: {op.unit}") 135 136 137# CHECK-LABEL: TEST: attrBuilder 138@run 139def attrBuilder(): 140 with Context() as ctx, Location.unknown(): 141 ctx.allow_unregistered_dialects = True 142 op = test.AttributesOp( 143 x_bool=True, 144 x_i16=1, 145 x_i32=2, 146 x_i64=3, 147 x_si16=-1, 148 x_si32=-2, 149 x_f32=1.5, 150 x_f64=2.5, 151 x_str="x_str", 152 x_i32_array=[1, 2, 3], 153 x_i64_array=[4, 5, 6], 154 x_f32_array=[1.5, -2.5, 3.5], 155 x_f64_array=[4.5, 5.5, -6.5], 156 x_i64_dense=[1, 2, 3, 4, 5, 6], 157 ) 158 print(op) 159 160 161# CHECK-LABEL: TEST: inferReturnTypes 162@run 163def inferReturnTypes(): 164 with Context() as ctx, Location.unknown(ctx): 165 test.register_python_test_dialect(ctx) 166 module = Module.create() 167 with InsertionPoint(module.body): 168 op = test.InferResultsOp() 169 dummy = test.DummyOp() 170 171 # CHECK: [Type(i32), Type(i64)] 172 iface = InferTypeOpInterface(op) 173 print(iface.inferReturnTypes()) 174 175 # CHECK: [Type(i32), Type(i64)] 176 iface_static = InferTypeOpInterface(test.InferResultsOp) 177 print(iface.inferReturnTypes()) 178 179 assert isinstance(iface.opview, test.InferResultsOp) 180 assert iface.opview == iface.operation.opview 181 182 try: 183 iface_static.opview 184 except TypeError: 185 pass 186 else: 187 assert False, ( 188 "not expected to be able to obtain an opview from a static" " interface" 189 ) 190 191 try: 192 InferTypeOpInterface(dummy) 193 except ValueError: 194 pass 195 else: 196 assert False, "not expected dummy op to implement the interface" 197 198 try: 199 InferTypeOpInterface(test.DummyOp) 200 except ValueError: 201 pass 202 else: 203 assert False, "not expected dummy op class to implement the interface" 204 205 206# CHECK-LABEL: TEST: resultTypesDefinedByTraits 207@run 208def resultTypesDefinedByTraits(): 209 with Context() as ctx, Location.unknown(ctx): 210 test.register_python_test_dialect(ctx) 211 module = Module.create() 212 with InsertionPoint(module.body): 213 inferred = test.InferResultsOp() 214 same = test.SameOperandAndResultTypeOp([inferred.results[0]]) 215 # CHECK-COUNT-2: i32 216 print(same.one.type) 217 print(same.two.type) 218 219 first_type_attr = test.FirstAttrDeriveTypeAttrOp( 220 inferred.results[1], TypeAttr.get(IndexType.get()) 221 ) 222 # CHECK-COUNT-2: index 223 print(first_type_attr.one.type) 224 print(first_type_attr.two.type) 225 226 first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14)) 227 # CHECK-COUNT-3: f32 228 print(first_attr.one.type) 229 print(first_attr.two.type) 230 print(first_attr.three.type) 231 232 implied = test.InferResultsImpliedOp() 233 # CHECK: i32 234 print(implied.integer.type) 235 # CHECK: f64 236 print(implied.flt.type) 237 # CHECK: index 238 print(implied.index.type) 239 240 241# CHECK-LABEL: TEST: testOptionalOperandOp 242@run 243def testOptionalOperandOp(): 244 with Context() as ctx, Location.unknown(): 245 test.register_python_test_dialect(ctx) 246 247 module = Module.create() 248 with InsertionPoint(module.body): 249 250 op1 = test.OptionalOperandOp() 251 # CHECK: op1.input is None: True 252 print(f"op1.input is None: {op1.input is None}") 253 254 op2 = test.OptionalOperandOp(input=op1) 255 # CHECK: op2.input is None: False 256 print(f"op2.input is None: {op2.input is None}") 257 258 259# CHECK-LABEL: TEST: testCustomAttribute 260@run 261def testCustomAttribute(): 262 with Context() as ctx: 263 test.register_python_test_dialect(ctx) 264 a = test.TestAttr.get() 265 # CHECK: #python_test.test_attr 266 print(a) 267 268 # The following cast must not assert. 269 b = test.TestAttr(a) 270 271 unit = UnitAttr.get() 272 try: 273 test.TestAttr(unit) 274 except ValueError as e: 275 assert "Cannot cast attribute to TestAttr" in str(e) 276 else: 277 raise 278 279 # The following must trigger a TypeError from our adaptors and must not 280 # crash. 281 try: 282 test.TestAttr(42) 283 except TypeError as e: 284 assert "Expected an MLIR object" in str(e) 285 else: 286 raise 287 288 # The following must trigger a TypeError from pybind (therefore, not 289 # checking its message) and must not crash. 290 try: 291 test.TestAttr(42, 56) 292 except TypeError: 293 pass 294 else: 295 raise 296 297 298@run 299def testCustomType(): 300 with Context() as ctx: 301 test.register_python_test_dialect(ctx) 302 a = test.TestType.get() 303 # CHECK: !python_test.test_type 304 print(a) 305 306 # The following cast must not assert. 307 b = test.TestType(a) 308 # Instance custom types should have typeids 309 assert isinstance(b.typeid, TypeID) 310 # Subclasses of ir.Type should not have a static_typeid 311 # CHECK: 'TestType' object has no attribute 'static_typeid' 312 try: 313 b.static_typeid 314 except AttributeError as e: 315 print(e) 316 317 i8 = IntegerType.get_signless(8) 318 try: 319 test.TestType(i8) 320 except ValueError as e: 321 assert "Cannot cast type to TestType" in str(e) 322 else: 323 raise 324 325 # The following must trigger a TypeError from our adaptors and must not 326 # crash. 327 try: 328 test.TestType(42) 329 except TypeError as e: 330 assert "Expected an MLIR object" in str(e) 331 else: 332 raise 333 334 # The following must trigger a TypeError from pybind (therefore, not 335 # checking its message) and must not crash. 336 try: 337 test.TestType(42, 56) 338 except TypeError: 339 pass 340 else: 341 raise 342 343 344@run 345# CHECK-LABEL: TEST: testTensorValue 346def testTensorValue(): 347 with Context() as ctx, Location.unknown(): 348 test.register_python_test_dialect(ctx) 349 350 i8 = IntegerType.get_signless(8) 351 352 class Tensor(test.TestTensorValue): 353 def __str__(self): 354 return super().__str__().replace("Value", "Tensor") 355 356 module = Module.create() 357 with InsertionPoint(module.body): 358 t = tensor.EmptyOp([10, 10], i8).result 359 360 # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>) 361 print(Value(t)) 362 363 tt = Tensor(t) 364 # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>) 365 print(tt) 366 367 # CHECK: False 368 print(tt.is_null()) 369 370 # Classes of custom types that inherit from concrete types should have 371 # static_typeid 372 assert isinstance(test.TestTensorType.static_typeid, TypeID) 373 # And it should be equal to the in-tree concrete type 374 assert test.TestTensorType.static_typeid == t.type.typeid 375 376 377# CHECK-LABEL: TEST: inferReturnTypeComponents 378@run 379def inferReturnTypeComponents(): 380 with Context() as ctx, Location.unknown(ctx): 381 test.register_python_test_dialect(ctx) 382 module = Module.create() 383 i32 = IntegerType.get_signless(32) 384 with InsertionPoint(module.body): 385 resultType = UnrankedTensorType.get(i32) 386 operandTypes = [ 387 RankedTensorType.get([1, 3, 10, 10], i32), 388 UnrankedTensorType.get(i32), 389 ] 390 f = func.FuncOp( 391 "test_inferReturnTypeComponents", (operandTypes, [resultType]) 392 ) 393 entry_block = Block.create_at_start(f.operation.regions[0], operandTypes) 394 with InsertionPoint(entry_block): 395 ranked_op = test.InferShapedTypeComponentsOp( 396 resultType, entry_block.arguments[0] 397 ) 398 unranked_op = test.InferShapedTypeComponentsOp( 399 resultType, entry_block.arguments[1] 400 ) 401 402 # CHECK: has rank: True 403 # CHECK: rank: 4 404 # CHECK: element type: i32 405 # CHECK: shape: [1, 3, 10, 10] 406 iface = InferShapedTypeOpInterface(ranked_op) 407 shaped_type_components = iface.inferReturnTypeComponents( 408 operands=[ranked_op.operand] 409 )[0] 410 print("has rank:", shaped_type_components.has_rank) 411 print("rank:", shaped_type_components.rank) 412 print("element type:", shaped_type_components.element_type) 413 print("shape:", shaped_type_components.shape) 414 415 # CHECK: has rank: False 416 # CHECK: rank: None 417 # CHECK: element type: i32 418 # CHECK: shape: None 419 iface = InferShapedTypeOpInterface(unranked_op) 420 shaped_type_components = iface.inferReturnTypeComponents( 421 operands=[unranked_op.operand] 422 )[0] 423 print("has rank:", shaped_type_components.has_rank) 424 print("rank:", shaped_type_components.rank) 425 print("element type:", shaped_type_components.element_type) 426 print("shape:", shaped_type_components.shape) 427