1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import mlir.dialects.func as func 5import mlir.dialects.python_test as test 6import mlir.dialects.tensor as tensor 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 return f 12 13# CHECK-LABEL: TEST: testAttributes 14@run 15def testAttributes(): 16 with Context() as ctx, Location.unknown(): 17 ctx.allow_unregistered_dialects = True 18 19 # 20 # Check op construction with attributes. 21 # 22 23 i32 = IntegerType.get_signless(32) 24 one = IntegerAttr.get(i32, 1) 25 two = IntegerAttr.get(i32, 2) 26 unit = UnitAttr.get() 27 28 # CHECK: "python_test.attributed_op"() { 29 # CHECK-DAG: mandatory_i32 = 1 : i32 30 # CHECK-DAG: optional_i32 = 2 : i32 31 # CHECK-DAG: unit 32 # CHECK: } 33 op = test.AttributedOp(one, optional_i32=two, unit=unit) 34 print(f"{op}") 35 36 # CHECK: "python_test.attributed_op"() { 37 # CHECK: mandatory_i32 = 2 : i32 38 # CHECK: } 39 op2 = test.AttributedOp(two) 40 print(f"{op2}") 41 42 # 43 # Check generic "attributes" access and mutation. 44 # 45 46 assert "additional" not in op.attributes 47 48 # CHECK: "python_test.attributed_op"() { 49 # CHECK-DAG: additional = 1 : i32 50 # CHECK-DAG: mandatory_i32 = 2 : i32 51 # CHECK: } 52 op2.attributes["additional"] = one 53 print(f"{op2}") 54 55 # CHECK: "python_test.attributed_op"() { 56 # CHECK-DAG: additional = 2 : i32 57 # CHECK-DAG: mandatory_i32 = 2 : i32 58 # CHECK: } 59 op2.attributes["additional"] = two 60 print(f"{op2}") 61 62 # CHECK: "python_test.attributed_op"() { 63 # CHECK-NOT: additional = 2 : i32 64 # CHECK: mandatory_i32 = 2 : i32 65 # CHECK: } 66 del op2.attributes["additional"] 67 print(f"{op2}") 68 69 try: 70 print(op.attributes["additional"]) 71 except KeyError: 72 pass 73 else: 74 assert False, "expected KeyError on unknown attribute key" 75 76 # 77 # Check accessors to defined attributes. 78 # 79 80 # CHECK: Mandatory: 1 81 # CHECK: Optional: 2 82 # CHECK: Unit: True 83 print(f"Mandatory: {op.mandatory_i32.value}") 84 print(f"Optional: {op.optional_i32.value}") 85 print(f"Unit: {op.unit}") 86 87 # CHECK: Mandatory: 2 88 # CHECK: Optional: None 89 # CHECK: Unit: False 90 print(f"Mandatory: {op2.mandatory_i32.value}") 91 print(f"Optional: {op2.optional_i32}") 92 print(f"Unit: {op2.unit}") 93 94 # CHECK: Mandatory: 2 95 # CHECK: Optional: None 96 # CHECK: Unit: False 97 op.mandatory_i32 = two 98 op.optional_i32 = None 99 op.unit = False 100 print(f"Mandatory: {op.mandatory_i32.value}") 101 print(f"Optional: {op.optional_i32}") 102 print(f"Unit: {op.unit}") 103 assert "optional_i32" not in op.attributes 104 assert "unit" not in op.attributes 105 106 try: 107 op.mandatory_i32 = None 108 except ValueError: 109 pass 110 else: 111 assert False, "expected ValueError on setting a mandatory attribute to None" 112 113 # CHECK: Optional: 2 114 op.optional_i32 = two 115 print(f"Optional: {op.optional_i32.value}") 116 117 # CHECK: Optional: None 118 del op.optional_i32 119 print(f"Optional: {op.optional_i32}") 120 121 # CHECK: Unit: False 122 op.unit = None 123 print(f"Unit: {op.unit}") 124 assert "unit" not in op.attributes 125 126 # CHECK: Unit: True 127 op.unit = True 128 print(f"Unit: {op.unit}") 129 130 # CHECK: Unit: False 131 del op.unit 132 print(f"Unit: {op.unit}") 133 134# CHECK-LABEL: TEST: attrBuilder 135@run 136def attrBuilder(): 137 with Context() as ctx, Location.unknown(): 138 ctx.allow_unregistered_dialects = True 139 op = test.AttributesOp(x_bool=True, 140 x_i16=1, 141 x_i32=2, 142 x_i64=3, 143 x_si16=-1, 144 x_si32=-2, 145 x_f32=1.5, 146 x_f64=2.5, 147 x_str='x_str', 148 x_i32_array=[1, 2, 3], 149 x_i64_array=[4, 5, 6], 150 x_f32_array=[1.5, -2.5, 3.5], 151 x_f64_array=[4.5, 5.5, -6.5], 152 x_i64_dense=[1, 2, 3, 4, 5, 6]) 153 print(op) 154 155 156# CHECK-LABEL: TEST: inferReturnTypes 157@run 158def inferReturnTypes(): 159 with Context() as ctx, Location.unknown(ctx): 160 test.register_python_test_dialect(ctx) 161 module = Module.create() 162 with InsertionPoint(module.body): 163 op = test.InferResultsOp() 164 dummy = test.DummyOp() 165 166 # CHECK: [Type(i32), Type(i64)] 167 iface = InferTypeOpInterface(op) 168 print(iface.inferReturnTypes()) 169 170 # CHECK: [Type(i32), Type(i64)] 171 iface_static = InferTypeOpInterface(test.InferResultsOp) 172 print(iface.inferReturnTypes()) 173 174 assert isinstance(iface.opview, test.InferResultsOp) 175 assert iface.opview == iface.operation.opview 176 177 try: 178 iface_static.opview 179 except TypeError: 180 pass 181 else: 182 assert False, ("not expected to be able to obtain an opview from a static" 183 " interface") 184 185 try: 186 InferTypeOpInterface(dummy) 187 except ValueError: 188 pass 189 else: 190 assert False, "not expected dummy op to implement the interface" 191 192 try: 193 InferTypeOpInterface(test.DummyOp) 194 except ValueError: 195 pass 196 else: 197 assert False, "not expected dummy op class to implement the interface" 198 199 200# CHECK-LABEL: TEST: resultTypesDefinedByTraits 201@run 202def resultTypesDefinedByTraits(): 203 with Context() as ctx, Location.unknown(ctx): 204 test.register_python_test_dialect(ctx) 205 module = Module.create() 206 with InsertionPoint(module.body): 207 inferred = test.InferResultsOp() 208 same = test.SameOperandAndResultTypeOp([inferred.results[0]]) 209 # CHECK-COUNT-2: i32 210 print(same.one.type) 211 print(same.two.type) 212 213 first_type_attr = test.FirstAttrDeriveTypeAttrOp( 214 inferred.results[1], TypeAttr.get(IndexType.get())) 215 # CHECK-COUNT-2: index 216 print(first_type_attr.one.type) 217 print(first_type_attr.two.type) 218 219 first_attr = test.FirstAttrDeriveAttrOp( 220 FloatAttr.get(F32Type.get(), 3.14)) 221 # CHECK-COUNT-3: f32 222 print(first_attr.one.type) 223 print(first_attr.two.type) 224 print(first_attr.three.type) 225 226 implied = test.InferResultsImpliedOp() 227 # CHECK: i32 228 print(implied.integer.type) 229 # CHECK: f64 230 print(implied.flt.type) 231 # CHECK: index 232 print(implied.index.type) 233 234 235# CHECK-LABEL: TEST: testOptionalOperandOp 236@run 237def testOptionalOperandOp(): 238 with Context() as ctx, Location.unknown(): 239 test.register_python_test_dialect(ctx) 240 241 module = Module.create() 242 with InsertionPoint(module.body): 243 244 op1 = test.OptionalOperandOp() 245 # CHECK: op1.input is None: True 246 print(f"op1.input is None: {op1.input is None}") 247 248 op2 = test.OptionalOperandOp(input=op1) 249 # CHECK: op2.input is None: False 250 print(f"op2.input is None: {op2.input is None}") 251 252 253# CHECK-LABEL: TEST: testCustomAttribute 254@run 255def testCustomAttribute(): 256 with Context() as ctx: 257 test.register_python_test_dialect(ctx) 258 a = test.TestAttr.get() 259 # CHECK: #python_test.test_attr 260 print(a) 261 262 # The following cast must not assert. 263 b = test.TestAttr(a) 264 265 unit = UnitAttr.get() 266 try: 267 test.TestAttr(unit) 268 except ValueError as e: 269 assert "Cannot cast attribute to TestAttr" in str(e) 270 else: 271 raise 272 273 # The following must trigger a TypeError from our adaptors and must not 274 # crash. 275 try: 276 test.TestAttr(42) 277 except TypeError as e: 278 assert "Expected an MLIR object" in str(e) 279 else: 280 raise 281 282 # The following must trigger a TypeError from pybind (therefore, not 283 # checking its message) and must not crash. 284 try: 285 test.TestAttr(42, 56) 286 except TypeError: 287 pass 288 else: 289 raise 290 291 292@run 293def testCustomType(): 294 with Context() as ctx: 295 test.register_python_test_dialect(ctx) 296 a = test.TestType.get() 297 # CHECK: !python_test.test_type 298 print(a) 299 300 # The following cast must not assert. 301 b = test.TestType(a) 302 303 i8 = IntegerType.get_signless(8) 304 try: 305 test.TestType(i8) 306 except ValueError as e: 307 assert "Cannot cast type to TestType" in str(e) 308 else: 309 raise 310 311 # The following must trigger a TypeError from our adaptors and must not 312 # crash. 313 try: 314 test.TestType(42) 315 except TypeError as e: 316 assert "Expected an MLIR object" in str(e) 317 else: 318 raise 319 320 # The following must trigger a TypeError from pybind (therefore, not 321 # checking its message) and must not crash. 322 try: 323 test.TestType(42, 56) 324 except TypeError: 325 pass 326 else: 327 raise 328 329 330@run 331# CHECK-LABEL: TEST: testTensorValue 332def testTensorValue(): 333 with Context() as ctx, Location.unknown(): 334 test.register_python_test_dialect(ctx) 335 336 i8 = IntegerType.get_signless(8) 337 338 class Tensor(test.TestTensorValue): 339 def __str__(self): 340 return super().__str__().replace("Value", "Tensor") 341 342 module = Module.create() 343 with InsertionPoint(module.body): 344 t = tensor.EmptyOp([10, 10], i8).result 345 346 # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>) 347 print(Value(t)) 348 349 tt = Tensor(t) 350 # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>) 351 print(tt) 352 353 # CHECK: False 354 print(tt.is_null()) 355 356 357# CHECK-LABEL: TEST: inferReturnTypeComponents 358@run 359def inferReturnTypeComponents(): 360 with Context() as ctx, Location.unknown(ctx): 361 test.register_python_test_dialect(ctx) 362 module = Module.create() 363 i32 = IntegerType.get_signless(32) 364 with InsertionPoint(module.body): 365 resultType = UnrankedTensorType.get(i32) 366 operandTypes = [ 367 RankedTensorType.get([1, 3, 10, 10], i32), 368 UnrankedTensorType.get(i32), 369 ] 370 f = func.FuncOp( 371 "test_inferReturnTypeComponents", (operandTypes, [resultType]) 372 ) 373 entry_block = Block.create_at_start(f.operation.regions[0], operandTypes) 374 with InsertionPoint(entry_block): 375 ranked_op = test.InferShapedTypeComponentsOp( 376 resultType, entry_block.arguments[0] 377 ) 378 unranked_op = test.InferShapedTypeComponentsOp( 379 resultType, entry_block.arguments[1] 380 ) 381 382 # CHECK: has rank: True 383 # CHECK: rank: 4 384 # CHECK: element type: i32 385 # CHECK: shape: [1, 3, 10, 10] 386 iface = InferShapedTypeOpInterface(ranked_op) 387 shaped_type_components = iface.inferReturnTypeComponents( 388 operands=[ranked_op.operand] 389 )[0] 390 print("has rank:", shaped_type_components.has_rank) 391 print("rank:", shaped_type_components.rank) 392 print("element type:", shaped_type_components.element_type) 393 print("shape:", shaped_type_components.shape) 394 395 # CHECK: has rank: False 396 # CHECK: rank: None 397 # CHECK: element type: i32 398 # CHECK: shape: None 399 iface = InferShapedTypeOpInterface(unranked_op) 400 shaped_type_components = iface.inferReturnTypeComponents( 401 operands=[unranked_op.operand] 402 )[0] 403 print("has rank:", shaped_type_components.has_rank) 404 print("rank:", shaped_type_components.rank) 405 print("element type:", shaped_type_components.element_type) 406 print("shape:", shaped_type_components.shape) 407