19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s 29f3f6d7bSStella Laurenzo 39f3f6d7bSStella Laurenzoimport gc 458a47508SJeff Niu 59f3f6d7bSStella Laurenzofrom mlir.ir import * 6*3766ba44SKasper Nielsenfrom mlir.dialects._ods_common import equally_sized_accessor 79f3f6d7bSStella Laurenzo 858a47508SJeff Niu 99f3f6d7bSStella Laurenzodef run(f): 109f3f6d7bSStella Laurenzo print("\nTEST:", f.__name__) 119f3f6d7bSStella Laurenzo f() 129f3f6d7bSStella Laurenzo gc.collect() 139f3f6d7bSStella Laurenzo assert Context._get_live_count() == 0 149f3f6d7bSStella Laurenzo 159f3f6d7bSStella Laurenzo 169f3f6d7bSStella Laurenzodef add_dummy_value(): 179f3f6d7bSStella Laurenzo return Operation.create( 18f9008e63STobias Hieta "custom.value", results=[IntegerType.get_signless(32)] 19f9008e63STobias Hieta ).result 209f3f6d7bSStella Laurenzo 219f3f6d7bSStella Laurenzo 229f3f6d7bSStella Laurenzodef testOdsBuildDefaultImplicitRegions(): 239f3f6d7bSStella Laurenzo class TestFixedRegionsOp(OpView): 249f3f6d7bSStella Laurenzo OPERATION_NAME = "custom.test_op" 259f3f6d7bSStella Laurenzo _ODS_REGIONS = (2, True) 269f3f6d7bSStella Laurenzo 279f3f6d7bSStella Laurenzo class TestVariadicRegionsOp(OpView): 289f3f6d7bSStella Laurenzo OPERATION_NAME = "custom.test_any_regions_op" 299f3f6d7bSStella Laurenzo _ODS_REGIONS = (2, False) 309f3f6d7bSStella Laurenzo 319f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown(): 329f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 339f3f6d7bSStella Laurenzo m = Module.create() 349f3f6d7bSStella Laurenzo with InsertionPoint(m.body): 359f3f6d7bSStella Laurenzo op = TestFixedRegionsOp.build_generic(results=[], operands=[]) 369f3f6d7bSStella Laurenzo # CHECK: NUM_REGIONS: 2 379f3f6d7bSStella Laurenzo print(f"NUM_REGIONS: {len(op.regions)}") 389f3f6d7bSStella Laurenzo # Including a regions= that matches should be fine. 399f3f6d7bSStella Laurenzo op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2) 409f3f6d7bSStella Laurenzo print(f"NUM_REGIONS: {len(op.regions)}") 419f3f6d7bSStella Laurenzo # Reject greater than. 429f3f6d7bSStella Laurenzo try: 43f9008e63STobias Hieta op = TestFixedRegionsOp.build_generic( 44f9008e63STobias Hieta results=[], operands=[], regions=3 45f9008e63STobias Hieta ) 469f3f6d7bSStella Laurenzo except ValueError as e: 479f3f6d7bSStella Laurenzo # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3 489f3f6d7bSStella Laurenzo print(f"ERROR:{e}") 499f3f6d7bSStella Laurenzo # Reject less than. 509f3f6d7bSStella Laurenzo try: 51f9008e63STobias Hieta op = TestFixedRegionsOp.build_generic( 52f9008e63STobias Hieta results=[], operands=[], regions=1 53f9008e63STobias Hieta ) 549f3f6d7bSStella Laurenzo except ValueError as e: 559f3f6d7bSStella Laurenzo # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1 569f3f6d7bSStella Laurenzo print(f"ERROR:{e}") 579f3f6d7bSStella Laurenzo 589f3f6d7bSStella Laurenzo # If no regions specified for a variadic region op, build the minimum. 599f3f6d7bSStella Laurenzo op = TestVariadicRegionsOp.build_generic(results=[], operands=[]) 609f3f6d7bSStella Laurenzo # CHECK: DEFAULT_NUM_REGIONS: 2 619f3f6d7bSStella Laurenzo print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}") 629f3f6d7bSStella Laurenzo # Should also accept an explicit regions= that matches the minimum. 63f9008e63STobias Hieta op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=2) 649f3f6d7bSStella Laurenzo # CHECK: EQ_NUM_REGIONS: 2 659f3f6d7bSStella Laurenzo print(f"EQ_NUM_REGIONS: {len(op.regions)}") 669f3f6d7bSStella Laurenzo # And accept greater than minimum. 679f3f6d7bSStella Laurenzo # Should also accept an explicit regions= that matches the minimum. 68f9008e63STobias Hieta op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=3) 699f3f6d7bSStella Laurenzo # CHECK: GT_NUM_REGIONS: 3 709f3f6d7bSStella Laurenzo print(f"GT_NUM_REGIONS: {len(op.regions)}") 719f3f6d7bSStella Laurenzo # Should reject less than minimum. 729f3f6d7bSStella Laurenzo try: 73f9008e63STobias Hieta op = TestVariadicRegionsOp.build_generic( 74f9008e63STobias Hieta results=[], operands=[], regions=1 75f9008e63STobias Hieta ) 769f3f6d7bSStella Laurenzo except ValueError as e: 779f3f6d7bSStella Laurenzo # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1 789f3f6d7bSStella Laurenzo print(f"ERROR:{e}") 799f3f6d7bSStella Laurenzo 809f3f6d7bSStella Laurenzo 819f3f6d7bSStella Laurenzorun(testOdsBuildDefaultImplicitRegions) 829f3f6d7bSStella Laurenzo 839f3f6d7bSStella Laurenzo 849f3f6d7bSStella Laurenzodef testOdsBuildDefaultNonVariadic(): 859f3f6d7bSStella Laurenzo class TestOp(OpView): 869f3f6d7bSStella Laurenzo OPERATION_NAME = "custom.test_op" 879f3f6d7bSStella Laurenzo 889f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown(): 899f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 909f3f6d7bSStella Laurenzo m = Module.create() 919f3f6d7bSStella Laurenzo with InsertionPoint(m.body): 929f3f6d7bSStella Laurenzo v0 = add_dummy_value() 939f3f6d7bSStella Laurenzo v1 = add_dummy_value() 949f3f6d7bSStella Laurenzo t0 = IntegerType.get_signless(8) 959f3f6d7bSStella Laurenzo t1 = IntegerType.get_signless(16) 969f3f6d7bSStella Laurenzo op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1]) 979f3f6d7bSStella Laurenzo # CHECK: %[[V0:.+]] = "custom.value" 989f3f6d7bSStella Laurenzo # CHECK: %[[V1:.+]] = "custom.value" 999f3f6d7bSStella Laurenzo # CHECK: "custom.test_op"(%[[V0]], %[[V1]]) 100363b6559SMehdi Amini # CHECK-NOT: operandSegmentSizes 101363b6559SMehdi Amini # CHECK-NOT: resultSegmentSizes 1029f3f6d7bSStella Laurenzo # CHECK-SAME: : (i32, i32) -> (i8, i16) 1039f3f6d7bSStella Laurenzo print(m) 1049f3f6d7bSStella Laurenzo 105f9008e63STobias Hieta 1069f3f6d7bSStella Laurenzorun(testOdsBuildDefaultNonVariadic) 1079f3f6d7bSStella Laurenzo 1089f3f6d7bSStella Laurenzo 1099f3f6d7bSStella Laurenzodef testOdsBuildDefaultSizedVariadic(): 1109f3f6d7bSStella Laurenzo class TestOp(OpView): 1119f3f6d7bSStella Laurenzo OPERATION_NAME = "custom.test_op" 1129f3f6d7bSStella Laurenzo _ODS_OPERAND_SEGMENTS = [1, -1, 0] 1139f3f6d7bSStella Laurenzo _ODS_RESULT_SEGMENTS = [-1, 0, 1] 1149f3f6d7bSStella Laurenzo 1159f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown(): 1169f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 1179f3f6d7bSStella Laurenzo m = Module.create() 1189f3f6d7bSStella Laurenzo with InsertionPoint(m.body): 1199f3f6d7bSStella Laurenzo v0 = add_dummy_value() 1209f3f6d7bSStella Laurenzo v1 = add_dummy_value() 1219f3f6d7bSStella Laurenzo v2 = add_dummy_value() 1229f3f6d7bSStella Laurenzo v3 = add_dummy_value() 1239f3f6d7bSStella Laurenzo t0 = IntegerType.get_signless(8) 1249f3f6d7bSStella Laurenzo t1 = IntegerType.get_signless(16) 1259f3f6d7bSStella Laurenzo t2 = IntegerType.get_signless(32) 1269f3f6d7bSStella Laurenzo t3 = IntegerType.get_signless(64) 1279f3f6d7bSStella Laurenzo # CHECK: %[[V0:.+]] = "custom.value" 1289f3f6d7bSStella Laurenzo # CHECK: %[[V1:.+]] = "custom.value" 1299f3f6d7bSStella Laurenzo # CHECK: %[[V2:.+]] = "custom.value" 1309f3f6d7bSStella Laurenzo # CHECK: %[[V3:.+]] = "custom.value" 1319f3f6d7bSStella Laurenzo # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) 132363b6559SMehdi Amini # CHECK-SAME: operandSegmentSizes = array<i32: 1, 2, 1> 133363b6559SMehdi Amini # CHECK-SAME: resultSegmentSizes = array<i32: 2, 1, 1> 1349f3f6d7bSStella Laurenzo # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64) 1359f3f6d7bSStella Laurenzo op = TestOp.build_generic( 136f9008e63STobias Hieta results=[[t0, t1], t2, t3], operands=[v0, [v1, v2], v3] 137f9008e63STobias Hieta ) 1389f3f6d7bSStella Laurenzo 1399f3f6d7bSStella Laurenzo # Now test with optional omitted. 1409f3f6d7bSStella Laurenzo # CHECK: "custom.test_op"(%[[V0]]) 141363b6559SMehdi Amini # CHECK-SAME: operandSegmentSizes = array<i32: 1, 0, 0> 142363b6559SMehdi Amini # CHECK-SAME: resultSegmentSizes = array<i32: 0, 0, 1> 1439f3f6d7bSStella Laurenzo # CHECK-SAME: (i32) -> i64 1449f3f6d7bSStella Laurenzo op = TestOp.build_generic( 145f9008e63STobias Hieta results=[None, None, t3], operands=[v0, None, None] 146f9008e63STobias Hieta ) 1479f3f6d7bSStella Laurenzo print(m) 1489f3f6d7bSStella Laurenzo 1499f3f6d7bSStella Laurenzo # And verify that errors are raised for None in a required operand. 1509f3f6d7bSStella Laurenzo try: 1519f3f6d7bSStella Laurenzo op = TestOp.build_generic( 152f9008e63STobias Hieta results=[None, None, t3], operands=[None, None, None] 153f9008e63STobias Hieta ) 1549f3f6d7bSStella Laurenzo except ValueError as e: 1559f3f6d7bSStella Laurenzo # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional) 1569f3f6d7bSStella Laurenzo print(f"OPERAND_CAST_ERROR:{e}") 1579f3f6d7bSStella Laurenzo 1589f3f6d7bSStella Laurenzo # And verify that errors are raised for None in a required result. 1599f3f6d7bSStella Laurenzo try: 1609f3f6d7bSStella Laurenzo op = TestOp.build_generic( 161f9008e63STobias Hieta results=[None, None, None], operands=[v0, None, None] 162f9008e63STobias Hieta ) 1639f3f6d7bSStella Laurenzo except ValueError as e: 1649f3f6d7bSStella Laurenzo # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional) 1659f3f6d7bSStella Laurenzo print(f"RESULT_CAST_ERROR:{e}") 1669f3f6d7bSStella Laurenzo 1679f3f6d7bSStella Laurenzo # Variadic lists with None elements should reject. 1689f3f6d7bSStella Laurenzo try: 1699f3f6d7bSStella Laurenzo op = TestOp.build_generic( 170f9008e63STobias Hieta results=[None, None, t3], operands=[v0, [None], None] 171f9008e63STobias Hieta ) 1729f3f6d7bSStella Laurenzo except ValueError as e: 1739f3f6d7bSStella Laurenzo # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item) 1749f3f6d7bSStella Laurenzo print(f"OPERAND_LIST_CAST_ERROR:{e}") 1759f3f6d7bSStella Laurenzo try: 1769f3f6d7bSStella Laurenzo op = TestOp.build_generic( 177f9008e63STobias Hieta results=[[None], None, t3], operands=[v0, None, None] 178f9008e63STobias Hieta ) 1799f3f6d7bSStella Laurenzo except ValueError as e: 1809f3f6d7bSStella Laurenzo # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item) 1819f3f6d7bSStella Laurenzo print(f"RESULT_LIST_CAST_ERROR:{e}") 1829f3f6d7bSStella Laurenzo 183f9008e63STobias Hieta 1849f3f6d7bSStella Laurenzorun(testOdsBuildDefaultSizedVariadic) 1859f3f6d7bSStella Laurenzo 1869f3f6d7bSStella Laurenzo 1879f3f6d7bSStella Laurenzodef testOdsBuildDefaultCastError(): 1889f3f6d7bSStella Laurenzo class TestOp(OpView): 1899f3f6d7bSStella Laurenzo OPERATION_NAME = "custom.test_op" 1909f3f6d7bSStella Laurenzo 1919f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown(): 1929f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 1939f3f6d7bSStella Laurenzo m = Module.create() 1949f3f6d7bSStella Laurenzo with InsertionPoint(m.body): 1959f3f6d7bSStella Laurenzo v0 = add_dummy_value() 1969f3f6d7bSStella Laurenzo v1 = add_dummy_value() 1979f3f6d7bSStella Laurenzo t0 = IntegerType.get_signless(8) 1989f3f6d7bSStella Laurenzo t1 = IntegerType.get_signless(16) 1999f3f6d7bSStella Laurenzo try: 200f9008e63STobias Hieta op = TestOp.build_generic(results=[t0, t1], operands=[None, v1]) 2019f3f6d7bSStella Laurenzo except ValueError as e: 2029f3f6d7bSStella Laurenzo # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value 2039f3f6d7bSStella Laurenzo print(f"ERROR: {e}") 2049f3f6d7bSStella Laurenzo try: 205f9008e63STobias Hieta op = TestOp.build_generic(results=[t0, None], operands=[v0, v1]) 2069f3f6d7bSStella Laurenzo except ValueError as e: 2079f3f6d7bSStella Laurenzo # CHECK: Result 1 of operation "custom.test_op" must be a Type 2089f3f6d7bSStella Laurenzo print(f"ERROR: {e}") 2099f3f6d7bSStella Laurenzo 210f9008e63STobias Hieta 2119f3f6d7bSStella Laurenzorun(testOdsBuildDefaultCastError) 212*3766ba44SKasper Nielsen 213*3766ba44SKasper Nielsen 214*3766ba44SKasper Nielsendef testOdsEquallySizedAccessor(): 215*3766ba44SKasper Nielsen class TestOpMultiResultSegments(OpView): 216*3766ba44SKasper Nielsen OPERATION_NAME = "custom.test_op" 217*3766ba44SKasper Nielsen _ODS_REGIONS = (1, True) 218*3766ba44SKasper Nielsen 219*3766ba44SKasper Nielsen with Context() as ctx, Location.unknown(): 220*3766ba44SKasper Nielsen ctx.allow_unregistered_dialects = True 221*3766ba44SKasper Nielsen m = Module.create() 222*3766ba44SKasper Nielsen with InsertionPoint(m.body): 223*3766ba44SKasper Nielsen v = add_dummy_value() 224*3766ba44SKasper Nielsen ts = [IntegerType.get_signless(i * 8) for i in range(4)] 225*3766ba44SKasper Nielsen 226*3766ba44SKasper Nielsen op = TestOpMultiResultSegments.build_generic( 227*3766ba44SKasper Nielsen results=[ts[0], ts[1], ts[2], ts[3]], operands=[v] 228*3766ba44SKasper Nielsen ) 229*3766ba44SKasper Nielsen start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 0) 230*3766ba44SKasper Nielsen # CHECK: start: 1, elements_per_group: 1 231*3766ba44SKasper Nielsen print(f"start: {start}, elements_per_group: {elements_per_group}") 232*3766ba44SKasper Nielsen # CHECK: i8 233*3766ba44SKasper Nielsen print(op.results[start].type) 234*3766ba44SKasper Nielsen 235*3766ba44SKasper Nielsen start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 1) 236*3766ba44SKasper Nielsen # CHECK: start: 2, elements_per_group: 1 237*3766ba44SKasper Nielsen print(f"start: {start}, elements_per_group: {elements_per_group}") 238*3766ba44SKasper Nielsen # CHECK: i16 239*3766ba44SKasper Nielsen print(op.results[start].type) 240*3766ba44SKasper Nielsen 241*3766ba44SKasper Nielsen 242*3766ba44SKasper Nielsenrun(testOdsEquallySizedAccessor) 243*3766ba44SKasper Nielsen 244*3766ba44SKasper Nielsen 245*3766ba44SKasper Nielsendef testOdsEquallySizedAccessorMultipleSegments(): 246*3766ba44SKasper Nielsen class TestOpMultiResultSegments(OpView): 247*3766ba44SKasper Nielsen OPERATION_NAME = "custom.test_op" 248*3766ba44SKasper Nielsen _ODS_REGIONS = (1, True) 249*3766ba44SKasper Nielsen _ODS_RESULT_SEGMENTS = [0, -1, -1] 250*3766ba44SKasper Nielsen 251*3766ba44SKasper Nielsen def types(lst): 252*3766ba44SKasper Nielsen return [e.type for e in lst] 253*3766ba44SKasper Nielsen 254*3766ba44SKasper Nielsen with Context() as ctx, Location.unknown(): 255*3766ba44SKasper Nielsen ctx.allow_unregistered_dialects = True 256*3766ba44SKasper Nielsen m = Module.create() 257*3766ba44SKasper Nielsen with InsertionPoint(m.body): 258*3766ba44SKasper Nielsen v = add_dummy_value() 259*3766ba44SKasper Nielsen ts = [IntegerType.get_signless(i * 8) for i in range(7)] 260*3766ba44SKasper Nielsen 261*3766ba44SKasper Nielsen op = TestOpMultiResultSegments.build_generic( 262*3766ba44SKasper Nielsen results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]], 263*3766ba44SKasper Nielsen operands=[v], 264*3766ba44SKasper Nielsen ) 265*3766ba44SKasper Nielsen start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 0) 266*3766ba44SKasper Nielsen # CHECK: start: 1, elements_per_group: 3 267*3766ba44SKasper Nielsen print(f"start: {start}, elements_per_group: {elements_per_group}") 268*3766ba44SKasper Nielsen # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)] 269*3766ba44SKasper Nielsen print(types(op.results[start : start + elements_per_group])) 270*3766ba44SKasper Nielsen 271*3766ba44SKasper Nielsen start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 1) 272*3766ba44SKasper Nielsen # CHECK: start: 4, elements_per_group: 3 273*3766ba44SKasper Nielsen print(f"start: {start}, elements_per_group: {elements_per_group}") 274*3766ba44SKasper Nielsen # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)] 275*3766ba44SKasper Nielsen print(types(op.results[start : start + elements_per_group])) 276*3766ba44SKasper Nielsen 277*3766ba44SKasper Nielsen 278*3766ba44SKasper Nielsenrun(testOdsEquallySizedAccessorMultipleSegments) 279