1from typing import Optional, List 2 3 4# Base class for an instruction. To implement a basic instruction that doesn't 5# impact the control-flow, create a new class inheriting from this. 6class Instruction: 7 # Contains the name of the output register, if any. 8 _result: Optional[str] 9 # Contains the instruction opcode. 10 _opcode: str 11 # Contains all the instruction operands, except result and opcode. 12 _operands: List[str] 13 14 def __init__(self, line: str): 15 self.line = line 16 tokens = line.split() 17 if len(tokens) > 1 and tokens[1] == "=": 18 self._result = tokens[0] 19 self._opcode = tokens[2] 20 self._operands = tokens[3:] if len(tokens) > 2 else [] 21 else: 22 self._result = None 23 self._opcode = tokens[0] 24 self._operands = tokens[1:] if len(tokens) > 1 else [] 25 26 def __str__(self): 27 if self._result is None: 28 return f" {self._opcode} {self._operands}" 29 return f"{self._result:3} = {self._opcode} {self._operands}" 30 31 # Returns the instruction opcode. 32 def opcode(self) -> str: 33 return self._opcode 34 35 # Returns the instruction operands. 36 def operands(self) -> List[str]: 37 return self._operands 38 39 # Returns the instruction output register. Calling this function is 40 # only allowed if has_output_register() is true. 41 def output_register(self) -> str: 42 assert self._result is not None 43 return self._result 44 45 # Returns true if this function has an output register. False otherwise. 46 def has_output_register(self) -> bool: 47 return self._result is not None 48 49 # This function is used to initialize state related to this instruction 50 # before module execution begins. For example, global Input variables 51 # can use this to store the lane ID into the register. 52 def static_execution(self, lane): 53 pass 54 55 # This function is called everytime this instruction is executed by a 56 # tangle. This function should not be directly overriden, instead see 57 # _impl and _advance_ip. 58 def runtime_execution(self, module, lane): 59 self._impl(module, lane) 60 self._advance_ip(module, lane) 61 62 # This function needs to be overriden if your instruction can be executed. 63 # It implements the logic of the instruction. 64 # 'Static' instructions like OpConstant should not override this since 65 # they are not supposed to be executed at runtime. 66 def _impl(self, module, lane): 67 raise RuntimeError(f"Unimplemented instruction {self}") 68 69 # By default, IP is incremented to point to the next instruction. 70 # If the instruction modifies IP (like OpBranch), this must be overridden. 71 def _advance_ip(self, module, lane): 72 lane.set_ip(lane.ip() + 1) 73 74 75# Those are parsed, but never executed. 76class OpEntryPoint(Instruction): 77 pass 78 79 80class OpFunction(Instruction): 81 pass 82 83 84class OpFunctionEnd(Instruction): 85 pass 86 87 88class OpLabel(Instruction): 89 pass 90 91 92class OpVariable(Instruction): 93 pass 94 95 96class OpName(Instruction): 97 def name(self) -> str: 98 return self._operands[1][1:-1] 99 100 def decoratedRegister(self) -> str: 101 return self._operands[0] 102 103 104# The only decoration we use if the BuiltIn one to initialize the values. 105class OpDecorate(Instruction): 106 def static_execution(self, lane): 107 if self._operands[1] == "LinkageAttributes": 108 return 109 110 assert ( 111 self._operands[1] == "BuiltIn" 112 and self._operands[2] == "SubgroupLocalInvocationId" 113 ) 114 lane.set_register(self._operands[0], lane.tid()) 115 116 117# Constants 118class OpConstant(Instruction): 119 def static_execution(self, lane): 120 lane.set_register(self._result, int(self._operands[1])) 121 122 123class OpConstantTrue(OpConstant): 124 def static_execution(self, lane): 125 lane.set_register(self._result, True) 126 127 128class OpConstantFalse(OpConstant): 129 def static_execution(self, lane): 130 lane.set_register(self._result, False) 131 132 133class OpConstantComposite(OpConstant): 134 def static_execution(self, lane): 135 result = [] 136 for op in self._operands[1:]: 137 result.append(lane.get_register(op)) 138 lane.set_register(self._result, result) 139 140 141# Control flow instructions 142class OpFunctionCall(Instruction): 143 def _impl(self, module, lane): 144 pass 145 146 def _advance_ip(self, module, lane): 147 entry = module.get_function_entry(self._operands[1]) 148 lane.do_call(entry, self._result) 149 150 151class OpReturn(Instruction): 152 def _impl(self, module, lane): 153 pass 154 155 def _advance_ip(self, module, lane): 156 lane.do_return(None) 157 158 159class OpReturnValue(Instruction): 160 def _impl(self, module, lane): 161 pass 162 163 def _advance_ip(self, module, lane): 164 lane.do_return(lane.get_register(self._operands[0])) 165 166 167class OpBranch(Instruction): 168 def _impl(self, module, lane): 169 pass 170 171 def _advance_ip(self, module, lane): 172 lane.set_ip(module.get_bb_entry(self._operands[0])) 173 pass 174 175 176class OpBranchConditional(Instruction): 177 def _impl(self, module, lane): 178 pass 179 180 def _advance_ip(self, module, lane): 181 condition = lane.get_register(self._operands[0]) 182 if condition: 183 lane.set_ip(module.get_bb_entry(self._operands[1])) 184 else: 185 lane.set_ip(module.get_bb_entry(self._operands[2])) 186 187 188class OpSwitch(Instruction): 189 def _impl(self, module, lane): 190 pass 191 192 def _advance_ip(self, module, lane): 193 value = lane.get_register(self._operands[0]) 194 default_label = self._operands[1] 195 i = 2 196 while i < len(self._operands): 197 imm = int(self._operands[i]) 198 label = self._operands[i + 1] 199 if value == imm: 200 lane.set_ip(module.get_bb_entry(label)) 201 return 202 i += 2 203 lane.set_ip(module.get_bb_entry(default_label)) 204 205 206class OpUnreachable(Instruction): 207 def _impl(self, module, lane): 208 raise RuntimeError("This instruction should never be executed.") 209 210 211# Convergence instructions 212class MergeInstruction(Instruction): 213 def merge_location(self): 214 return self._operands[0] 215 216 def continue_location(self): 217 return None if len(self._operands) < 3 else self._operands[1] 218 219 def _impl(self, module, lane): 220 lane.handle_convergence_header(self) 221 222 223class OpLoopMerge(MergeInstruction): 224 pass 225 226 227class OpSelectionMerge(MergeInstruction): 228 pass 229 230 231# Other instructions 232class OpBitcast(Instruction): 233 def _impl(self, module, lane): 234 # TODO: find out the type from the defining instruction. 235 # This can only work for DXC. 236 if self._operands[0] == "%int": 237 lane.set_register(self._result, int(lane.get_register(self._operands[1]))) 238 else: 239 raise RuntimeError("Unsupported OpBitcast operand") 240 241 242class OpAccessChain(Instruction): 243 def _impl(self, module, lane): 244 # Python dynamic types allows me to simplify. As long as the SPIR-V 245 # is legal, this should be fine. 246 # Note: SPIR-V structs are stored as tuples 247 value = lane.get_register(self._operands[1]) 248 for operand in self._operands[2:]: 249 value = value[lane.get_register(operand)] 250 lane.set_register(self._result, value) 251 252 253class OpCompositeConstruct(Instruction): 254 def _impl(self, module, lane): 255 output = [] 256 for op in self._operands[1:]: 257 output.append(lane.get_register(op)) 258 lane.set_register(self._result, output) 259 260 261class OpCompositeExtract(Instruction): 262 def _impl(self, module, lane): 263 value = lane.get_register(self._operands[1]) 264 output = value 265 for op in self._operands[2:]: 266 output = output[int(op)] 267 lane.set_register(self._result, output) 268 269 270class OpStore(Instruction): 271 def _impl(self, module, lane): 272 lane.set_register(self._operands[0], lane.get_register(self._operands[1])) 273 274 275class OpLoad(Instruction): 276 def _impl(self, module, lane): 277 lane.set_register(self._result, lane.get_register(self._operands[1])) 278 279 280class OpIAdd(Instruction): 281 def _impl(self, module, lane): 282 LHS = lane.get_register(self._operands[1]) 283 RHS = lane.get_register(self._operands[2]) 284 lane.set_register(self._result, LHS + RHS) 285 286 287class OpISub(Instruction): 288 def _impl(self, module, lane): 289 LHS = lane.get_register(self._operands[1]) 290 RHS = lane.get_register(self._operands[2]) 291 lane.set_register(self._result, LHS - RHS) 292 293 294class OpIMul(Instruction): 295 def _impl(self, module, lane): 296 LHS = lane.get_register(self._operands[1]) 297 RHS = lane.get_register(self._operands[2]) 298 lane.set_register(self._result, LHS * RHS) 299 300 301class OpLogicalNot(Instruction): 302 def _impl(self, module, lane): 303 LHS = lane.get_register(self._operands[1]) 304 lane.set_register(self._result, not LHS) 305 306 307class _LessThan(Instruction): 308 def _impl(self, module, lane): 309 LHS = lane.get_register(self._operands[1]) 310 RHS = lane.get_register(self._operands[2]) 311 lane.set_register(self._result, LHS < RHS) 312 313 314class _GreaterThan(Instruction): 315 def _impl(self, module, lane): 316 LHS = lane.get_register(self._operands[1]) 317 RHS = lane.get_register(self._operands[2]) 318 lane.set_register(self._result, LHS > RHS) 319 320 321class OpSLessThan(_LessThan): 322 pass 323 324 325class OpULessThan(_LessThan): 326 pass 327 328 329class OpSGreaterThan(_GreaterThan): 330 pass 331 332 333class OpUGreaterThan(_GreaterThan): 334 pass 335 336 337class OpIEqual(Instruction): 338 def _impl(self, module, lane): 339 LHS = lane.get_register(self._operands[1]) 340 RHS = lane.get_register(self._operands[2]) 341 lane.set_register(self._result, LHS == RHS) 342 343 344class OpINotEqual(Instruction): 345 def _impl(self, module, lane): 346 LHS = lane.get_register(self._operands[1]) 347 RHS = lane.get_register(self._operands[2]) 348 lane.set_register(self._result, LHS != RHS) 349 350 351class OpPhi(Instruction): 352 def _impl(self, module, lane): 353 previousBBName = lane.get_previous_bb_name() 354 i = 1 355 while i < len(self._operands): 356 label = self._operands[i + 1] 357 if label == previousBBName: 358 lane.set_register(self._result, lane.get_register(self._operands[i])) 359 return 360 i += 2 361 raise RuntimeError("previousBB not in the OpPhi _operands") 362 363 364class OpSelect(Instruction): 365 def _impl(self, module, lane): 366 condition = lane.get_register(self._operands[1]) 367 value = lane.get_register(self._operands[2 if condition else 3]) 368 lane.set_register(self._result, value) 369 370 371# Wave intrinsics 372class OpGroupNonUniformBroadcastFirst(Instruction): 373 def _impl(self, module, lane): 374 assert lane.get_register(self._operands[1]) == 3 375 if lane.is_first_active_lane(): 376 lane.broadcast_register(self._result, lane.get_register(self._operands[2])) 377 378 379class OpGroupNonUniformElect(Instruction): 380 def _impl(self, module, lane): 381 lane.set_register(self._result, lane.is_first_active_lane()) 382