xref: /llvm-project/llvm/utils/spirv-sim/instructions.py (revision 5914566474de29309b0b4815ecd406805793de1f)
1from typing import Optional, List
2
3
4# Base class for an instruction. To implement a basic instruction that doesn't
5# impact the control-flow, create a new class inheriting from this.
6class Instruction:
7    # Contains the name of the output register, if any.
8    _result: Optional[str]
9    # Contains the instruction opcode.
10    _opcode: str
11    # Contains all the instruction operands, except result and opcode.
12    _operands: List[str]
13
14    def __init__(self, line: str):
15        self.line = line
16        tokens = line.split()
17        if len(tokens) > 1 and tokens[1] == "=":
18            self._result = tokens[0]
19            self._opcode = tokens[2]
20            self._operands = tokens[3:] if len(tokens) > 2 else []
21        else:
22            self._result = None
23            self._opcode = tokens[0]
24            self._operands = tokens[1:] if len(tokens) > 1 else []
25
26    def __str__(self):
27        if self._result is None:
28            return f"      {self._opcode} {self._operands}"
29        return f"{self._result:3} = {self._opcode} {self._operands}"
30
31    # Returns the instruction opcode.
32    def opcode(self) -> str:
33        return self._opcode
34
35    # Returns the instruction operands.
36    def operands(self) -> List[str]:
37        return self._operands
38
39    # Returns the instruction output register. Calling this function is
40    # only allowed if has_output_register() is true.
41    def output_register(self) -> str:
42        assert self._result is not None
43        return self._result
44
45    # Returns true if this function has an output register. False otherwise.
46    def has_output_register(self) -> bool:
47        return self._result is not None
48
49    # This function is used to initialize state related to this instruction
50    # before module execution begins. For example, global Input variables
51    # can use this to store the lane ID into the register.
52    def static_execution(self, lane):
53        pass
54
55    # This function is called everytime this instruction is executed by a
56    # tangle. This function should not be directly overriden, instead see
57    # _impl and _advance_ip.
58    def runtime_execution(self, module, lane):
59        self._impl(module, lane)
60        self._advance_ip(module, lane)
61
62    # This function needs to be overriden if your instruction can be executed.
63    # It implements the logic of the instruction.
64    # 'Static' instructions like OpConstant should not override this since
65    # they are not supposed to be executed at runtime.
66    def _impl(self, module, lane):
67        raise RuntimeError(f"Unimplemented instruction {self}")
68
69    # By default, IP is incremented to point to the next instruction.
70    # If the instruction modifies IP (like OpBranch), this must be overridden.
71    def _advance_ip(self, module, lane):
72        lane.set_ip(lane.ip() + 1)
73
74
75# Those are parsed, but never executed.
76class OpEntryPoint(Instruction):
77    pass
78
79
80class OpFunction(Instruction):
81    pass
82
83
84class OpFunctionEnd(Instruction):
85    pass
86
87
88class OpLabel(Instruction):
89    pass
90
91
92class OpVariable(Instruction):
93    pass
94
95
96class OpName(Instruction):
97    def name(self) -> str:
98        return self._operands[1][1:-1]
99
100    def decoratedRegister(self) -> str:
101        return self._operands[0]
102
103
104# The only decoration we use if the BuiltIn one to initialize the values.
105class OpDecorate(Instruction):
106    def static_execution(self, lane):
107        if self._operands[1] == "LinkageAttributes":
108            return
109
110        assert (
111            self._operands[1] == "BuiltIn"
112            and self._operands[2] == "SubgroupLocalInvocationId"
113        )
114        lane.set_register(self._operands[0], lane.tid())
115
116
117# Constants
118class OpConstant(Instruction):
119    def static_execution(self, lane):
120        lane.set_register(self._result, int(self._operands[1]))
121
122
123class OpConstantTrue(OpConstant):
124    def static_execution(self, lane):
125        lane.set_register(self._result, True)
126
127
128class OpConstantFalse(OpConstant):
129    def static_execution(self, lane):
130        lane.set_register(self._result, False)
131
132
133class OpConstantComposite(OpConstant):
134    def static_execution(self, lane):
135        result = []
136        for op in self._operands[1:]:
137            result.append(lane.get_register(op))
138        lane.set_register(self._result, result)
139
140
141# Control flow instructions
142class OpFunctionCall(Instruction):
143    def _impl(self, module, lane):
144        pass
145
146    def _advance_ip(self, module, lane):
147        entry = module.get_function_entry(self._operands[1])
148        lane.do_call(entry, self._result)
149
150
151class OpReturn(Instruction):
152    def _impl(self, module, lane):
153        pass
154
155    def _advance_ip(self, module, lane):
156        lane.do_return(None)
157
158
159class OpReturnValue(Instruction):
160    def _impl(self, module, lane):
161        pass
162
163    def _advance_ip(self, module, lane):
164        lane.do_return(lane.get_register(self._operands[0]))
165
166
167class OpBranch(Instruction):
168    def _impl(self, module, lane):
169        pass
170
171    def _advance_ip(self, module, lane):
172        lane.set_ip(module.get_bb_entry(self._operands[0]))
173        pass
174
175
176class OpBranchConditional(Instruction):
177    def _impl(self, module, lane):
178        pass
179
180    def _advance_ip(self, module, lane):
181        condition = lane.get_register(self._operands[0])
182        if condition:
183            lane.set_ip(module.get_bb_entry(self._operands[1]))
184        else:
185            lane.set_ip(module.get_bb_entry(self._operands[2]))
186
187
188class OpSwitch(Instruction):
189    def _impl(self, module, lane):
190        pass
191
192    def _advance_ip(self, module, lane):
193        value = lane.get_register(self._operands[0])
194        default_label = self._operands[1]
195        i = 2
196        while i < len(self._operands):
197            imm = int(self._operands[i])
198            label = self._operands[i + 1]
199            if value == imm:
200                lane.set_ip(module.get_bb_entry(label))
201                return
202            i += 2
203        lane.set_ip(module.get_bb_entry(default_label))
204
205
206class OpUnreachable(Instruction):
207    def _impl(self, module, lane):
208        raise RuntimeError("This instruction should never be executed.")
209
210
211# Convergence instructions
212class MergeInstruction(Instruction):
213    def merge_location(self):
214        return self._operands[0]
215
216    def continue_location(self):
217        return None if len(self._operands) < 3 else self._operands[1]
218
219    def _impl(self, module, lane):
220        lane.handle_convergence_header(self)
221
222
223class OpLoopMerge(MergeInstruction):
224    pass
225
226
227class OpSelectionMerge(MergeInstruction):
228    pass
229
230
231# Other instructions
232class OpBitcast(Instruction):
233    def _impl(self, module, lane):
234        # TODO: find out the type from the defining instruction.
235        # This can only work for DXC.
236        if self._operands[0] == "%int":
237            lane.set_register(self._result, int(lane.get_register(self._operands[1])))
238        else:
239            raise RuntimeError("Unsupported OpBitcast operand")
240
241
242class OpAccessChain(Instruction):
243    def _impl(self, module, lane):
244        # Python dynamic types allows me to simplify. As long as the SPIR-V
245        # is legal, this should be fine.
246        # Note: SPIR-V structs are stored as tuples
247        value = lane.get_register(self._operands[1])
248        for operand in self._operands[2:]:
249            value = value[lane.get_register(operand)]
250        lane.set_register(self._result, value)
251
252
253class OpCompositeConstruct(Instruction):
254    def _impl(self, module, lane):
255        output = []
256        for op in self._operands[1:]:
257            output.append(lane.get_register(op))
258        lane.set_register(self._result, output)
259
260
261class OpCompositeExtract(Instruction):
262    def _impl(self, module, lane):
263        value = lane.get_register(self._operands[1])
264        output = value
265        for op in self._operands[2:]:
266            output = output[int(op)]
267        lane.set_register(self._result, output)
268
269
270class OpStore(Instruction):
271    def _impl(self, module, lane):
272        lane.set_register(self._operands[0], lane.get_register(self._operands[1]))
273
274
275class OpLoad(Instruction):
276    def _impl(self, module, lane):
277        lane.set_register(self._result, lane.get_register(self._operands[1]))
278
279
280class OpIAdd(Instruction):
281    def _impl(self, module, lane):
282        LHS = lane.get_register(self._operands[1])
283        RHS = lane.get_register(self._operands[2])
284        lane.set_register(self._result, LHS + RHS)
285
286
287class OpISub(Instruction):
288    def _impl(self, module, lane):
289        LHS = lane.get_register(self._operands[1])
290        RHS = lane.get_register(self._operands[2])
291        lane.set_register(self._result, LHS - RHS)
292
293
294class OpIMul(Instruction):
295    def _impl(self, module, lane):
296        LHS = lane.get_register(self._operands[1])
297        RHS = lane.get_register(self._operands[2])
298        lane.set_register(self._result, LHS * RHS)
299
300
301class OpLogicalNot(Instruction):
302    def _impl(self, module, lane):
303        LHS = lane.get_register(self._operands[1])
304        lane.set_register(self._result, not LHS)
305
306
307class _LessThan(Instruction):
308    def _impl(self, module, lane):
309        LHS = lane.get_register(self._operands[1])
310        RHS = lane.get_register(self._operands[2])
311        lane.set_register(self._result, LHS < RHS)
312
313
314class _GreaterThan(Instruction):
315    def _impl(self, module, lane):
316        LHS = lane.get_register(self._operands[1])
317        RHS = lane.get_register(self._operands[2])
318        lane.set_register(self._result, LHS > RHS)
319
320
321class OpSLessThan(_LessThan):
322    pass
323
324
325class OpULessThan(_LessThan):
326    pass
327
328
329class OpSGreaterThan(_GreaterThan):
330    pass
331
332
333class OpUGreaterThan(_GreaterThan):
334    pass
335
336
337class OpIEqual(Instruction):
338    def _impl(self, module, lane):
339        LHS = lane.get_register(self._operands[1])
340        RHS = lane.get_register(self._operands[2])
341        lane.set_register(self._result, LHS == RHS)
342
343
344class OpINotEqual(Instruction):
345    def _impl(self, module, lane):
346        LHS = lane.get_register(self._operands[1])
347        RHS = lane.get_register(self._operands[2])
348        lane.set_register(self._result, LHS != RHS)
349
350
351class OpPhi(Instruction):
352    def _impl(self, module, lane):
353        previousBBName = lane.get_previous_bb_name()
354        i = 1
355        while i < len(self._operands):
356            label = self._operands[i + 1]
357            if label == previousBBName:
358                lane.set_register(self._result, lane.get_register(self._operands[i]))
359                return
360            i += 2
361        raise RuntimeError("previousBB not in the OpPhi _operands")
362
363
364class OpSelect(Instruction):
365    def _impl(self, module, lane):
366        condition = lane.get_register(self._operands[1])
367        value = lane.get_register(self._operands[2 if condition else 3])
368        lane.set_register(self._result, value)
369
370
371# Wave intrinsics
372class OpGroupNonUniformBroadcastFirst(Instruction):
373    def _impl(self, module, lane):
374        assert lane.get_register(self._operands[1]) == 3
375        if lane.is_first_active_lane():
376            lane.broadcast_register(self._result, lane.get_register(self._operands[2]))
377
378
379class OpGroupNonUniformElect(Instruction):
380    def _impl(self, module, lane):
381        lane.set_register(self._result, lane.is_first_active_lane())
382