1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4 5from mlir.ir import * 6from mlir.dialects._ods_common import equally_sized_accessor 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 gc.collect() 13 assert Context._get_live_count() == 0 14 15 16def add_dummy_value(): 17 return Operation.create( 18 "custom.value", results=[IntegerType.get_signless(32)] 19 ).result 20 21 22def testOdsBuildDefaultImplicitRegions(): 23 class TestFixedRegionsOp(OpView): 24 OPERATION_NAME = "custom.test_op" 25 _ODS_REGIONS = (2, True) 26 27 class TestVariadicRegionsOp(OpView): 28 OPERATION_NAME = "custom.test_any_regions_op" 29 _ODS_REGIONS = (2, False) 30 31 with Context() as ctx, Location.unknown(): 32 ctx.allow_unregistered_dialects = True 33 m = Module.create() 34 with InsertionPoint(m.body): 35 op = TestFixedRegionsOp.build_generic(results=[], operands=[]) 36 # CHECK: NUM_REGIONS: 2 37 print(f"NUM_REGIONS: {len(op.regions)}") 38 # Including a regions= that matches should be fine. 39 op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2) 40 print(f"NUM_REGIONS: {len(op.regions)}") 41 # Reject greater than. 42 try: 43 op = TestFixedRegionsOp.build_generic( 44 results=[], operands=[], regions=3 45 ) 46 except ValueError as e: 47 # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3 48 print(f"ERROR:{e}") 49 # Reject less than. 50 try: 51 op = TestFixedRegionsOp.build_generic( 52 results=[], operands=[], regions=1 53 ) 54 except ValueError as e: 55 # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1 56 print(f"ERROR:{e}") 57 58 # If no regions specified for a variadic region op, build the minimum. 59 op = TestVariadicRegionsOp.build_generic(results=[], operands=[]) 60 # CHECK: DEFAULT_NUM_REGIONS: 2 61 print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}") 62 # Should also accept an explicit regions= that matches the minimum. 63 op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=2) 64 # CHECK: EQ_NUM_REGIONS: 2 65 print(f"EQ_NUM_REGIONS: {len(op.regions)}") 66 # And accept greater than minimum. 67 # Should also accept an explicit regions= that matches the minimum. 68 op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=3) 69 # CHECK: GT_NUM_REGIONS: 3 70 print(f"GT_NUM_REGIONS: {len(op.regions)}") 71 # Should reject less than minimum. 72 try: 73 op = TestVariadicRegionsOp.build_generic( 74 results=[], operands=[], regions=1 75 ) 76 except ValueError as e: 77 # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1 78 print(f"ERROR:{e}") 79 80 81run(testOdsBuildDefaultImplicitRegions) 82 83 84def testOdsBuildDefaultNonVariadic(): 85 class TestOp(OpView): 86 OPERATION_NAME = "custom.test_op" 87 88 with Context() as ctx, Location.unknown(): 89 ctx.allow_unregistered_dialects = True 90 m = Module.create() 91 with InsertionPoint(m.body): 92 v0 = add_dummy_value() 93 v1 = add_dummy_value() 94 t0 = IntegerType.get_signless(8) 95 t1 = IntegerType.get_signless(16) 96 op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1]) 97 # CHECK: %[[V0:.+]] = "custom.value" 98 # CHECK: %[[V1:.+]] = "custom.value" 99 # CHECK: "custom.test_op"(%[[V0]], %[[V1]]) 100 # CHECK-NOT: operandSegmentSizes 101 # CHECK-NOT: resultSegmentSizes 102 # CHECK-SAME: : (i32, i32) -> (i8, i16) 103 print(m) 104 105 106run(testOdsBuildDefaultNonVariadic) 107 108 109def testOdsBuildDefaultSizedVariadic(): 110 class TestOp(OpView): 111 OPERATION_NAME = "custom.test_op" 112 _ODS_OPERAND_SEGMENTS = [1, -1, 0] 113 _ODS_RESULT_SEGMENTS = [-1, 0, 1] 114 115 with Context() as ctx, Location.unknown(): 116 ctx.allow_unregistered_dialects = True 117 m = Module.create() 118 with InsertionPoint(m.body): 119 v0 = add_dummy_value() 120 v1 = add_dummy_value() 121 v2 = add_dummy_value() 122 v3 = add_dummy_value() 123 t0 = IntegerType.get_signless(8) 124 t1 = IntegerType.get_signless(16) 125 t2 = IntegerType.get_signless(32) 126 t3 = IntegerType.get_signless(64) 127 # CHECK: %[[V0:.+]] = "custom.value" 128 # CHECK: %[[V1:.+]] = "custom.value" 129 # CHECK: %[[V2:.+]] = "custom.value" 130 # CHECK: %[[V3:.+]] = "custom.value" 131 # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) 132 # CHECK-SAME: operandSegmentSizes = array<i32: 1, 2, 1> 133 # CHECK-SAME: resultSegmentSizes = array<i32: 2, 1, 1> 134 # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64) 135 op = TestOp.build_generic( 136 results=[[t0, t1], t2, t3], operands=[v0, [v1, v2], v3] 137 ) 138 139 # Now test with optional omitted. 140 # CHECK: "custom.test_op"(%[[V0]]) 141 # CHECK-SAME: operandSegmentSizes = array<i32: 1, 0, 0> 142 # CHECK-SAME: resultSegmentSizes = array<i32: 0, 0, 1> 143 # CHECK-SAME: (i32) -> i64 144 op = TestOp.build_generic( 145 results=[None, None, t3], operands=[v0, None, None] 146 ) 147 print(m) 148 149 # And verify that errors are raised for None in a required operand. 150 try: 151 op = TestOp.build_generic( 152 results=[None, None, t3], operands=[None, None, None] 153 ) 154 except ValueError as e: 155 # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional) 156 print(f"OPERAND_CAST_ERROR:{e}") 157 158 # And verify that errors are raised for None in a required result. 159 try: 160 op = TestOp.build_generic( 161 results=[None, None, None], operands=[v0, None, None] 162 ) 163 except ValueError as e: 164 # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional) 165 print(f"RESULT_CAST_ERROR:{e}") 166 167 # Variadic lists with None elements should reject. 168 try: 169 op = TestOp.build_generic( 170 results=[None, None, t3], operands=[v0, [None], None] 171 ) 172 except ValueError as e: 173 # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item) 174 print(f"OPERAND_LIST_CAST_ERROR:{e}") 175 try: 176 op = TestOp.build_generic( 177 results=[[None], None, t3], operands=[v0, None, None] 178 ) 179 except ValueError as e: 180 # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item) 181 print(f"RESULT_LIST_CAST_ERROR:{e}") 182 183 184run(testOdsBuildDefaultSizedVariadic) 185 186 187def testOdsBuildDefaultCastError(): 188 class TestOp(OpView): 189 OPERATION_NAME = "custom.test_op" 190 191 with Context() as ctx, Location.unknown(): 192 ctx.allow_unregistered_dialects = True 193 m = Module.create() 194 with InsertionPoint(m.body): 195 v0 = add_dummy_value() 196 v1 = add_dummy_value() 197 t0 = IntegerType.get_signless(8) 198 t1 = IntegerType.get_signless(16) 199 try: 200 op = TestOp.build_generic(results=[t0, t1], operands=[None, v1]) 201 except ValueError as e: 202 # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value 203 print(f"ERROR: {e}") 204 try: 205 op = TestOp.build_generic(results=[t0, None], operands=[v0, v1]) 206 except ValueError as e: 207 # CHECK: Result 1 of operation "custom.test_op" must be a Type 208 print(f"ERROR: {e}") 209 210 211run(testOdsBuildDefaultCastError) 212 213 214def testOdsEquallySizedAccessor(): 215 class TestOpMultiResultSegments(OpView): 216 OPERATION_NAME = "custom.test_op" 217 _ODS_REGIONS = (1, True) 218 219 with Context() as ctx, Location.unknown(): 220 ctx.allow_unregistered_dialects = True 221 m = Module.create() 222 with InsertionPoint(m.body): 223 v = add_dummy_value() 224 ts = [IntegerType.get_signless(i * 8) for i in range(4)] 225 226 op = TestOpMultiResultSegments.build_generic( 227 results=[ts[0], ts[1], ts[2], ts[3]], operands=[v] 228 ) 229 start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 0) 230 # CHECK: start: 1, elements_per_group: 1 231 print(f"start: {start}, elements_per_group: {elements_per_group}") 232 # CHECK: i8 233 print(op.results[start].type) 234 235 start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 1) 236 # CHECK: start: 2, elements_per_group: 1 237 print(f"start: {start}, elements_per_group: {elements_per_group}") 238 # CHECK: i16 239 print(op.results[start].type) 240 241 242run(testOdsEquallySizedAccessor) 243 244 245def testOdsEquallySizedAccessorMultipleSegments(): 246 class TestOpMultiResultSegments(OpView): 247 OPERATION_NAME = "custom.test_op" 248 _ODS_REGIONS = (1, True) 249 _ODS_RESULT_SEGMENTS = [0, -1, -1] 250 251 def types(lst): 252 return [e.type for e in lst] 253 254 with Context() as ctx, Location.unknown(): 255 ctx.allow_unregistered_dialects = True 256 m = Module.create() 257 with InsertionPoint(m.body): 258 v = add_dummy_value() 259 ts = [IntegerType.get_signless(i * 8) for i in range(7)] 260 261 op = TestOpMultiResultSegments.build_generic( 262 results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]], 263 operands=[v], 264 ) 265 start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 0) 266 # CHECK: start: 1, elements_per_group: 3 267 print(f"start: {start}, elements_per_group: {elements_per_group}") 268 # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)] 269 print(types(op.results[start : start + elements_per_group])) 270 271 start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 1) 272 # CHECK: start: 4, elements_per_group: 3 273 print(f"start: {start}, elements_per_group: {elements_per_group}") 274 # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)] 275 print(types(op.results[start : start + elements_per_group])) 276 277 278run(testOdsEquallySizedAccessorMultipleSegments) 279