xref: /llvm-project/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp (revision b719ab4eef634f24605ca7ccd4874338c34e05bd)
1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
23 #include <optional>
24 
25 using namespace mlir;
26 
27 #define DEBUG_TYPE "spirv-deserialization"
28 
29 //===----------------------------------------------------------------------===//
30 // Utility Functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Extracts the opcode from the given first word of a SPIR-V instruction.
34 static inline spirv::Opcode extractOpcode(uint32_t word) {
35   return static_cast<spirv::Opcode>(word & 0xffff);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Instruction
40 //===----------------------------------------------------------------------===//
41 
42 Value spirv::Deserializer::getValue(uint32_t id) {
43   if (auto constInfo = getConstant(id)) {
44     // Materialize a `spirv.Constant` op at every use site.
45     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
46                                                constInfo->first);
47   }
48   if (auto varOp = getGlobalVariable(id)) {
49     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
50         unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
51     return addressOfOp.getPointer();
52   }
53   if (auto constOp = getSpecConstant(id)) {
54     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
55         unknownLoc, constOp.getDefaultValue().getType(),
56         SymbolRefAttr::get(constOp.getOperation()));
57     return referenceOfOp.getReference();
58   }
59   if (auto constCompositeOp = getSpecConstantComposite(id)) {
60     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
61         unknownLoc, constCompositeOp.getType(),
62         SymbolRefAttr::get(constCompositeOp.getOperation()));
63     return referenceOfOp.getReference();
64   }
65   if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
66     return materializeSpecConstantOperation(
67         id, specConstOperationInfo->enclodesOpcode,
68         specConstOperationInfo->resultTypeID,
69         specConstOperationInfo->enclosedOpOperands);
70   }
71   if (auto undef = getUndefType(id)) {
72     return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
73   }
74   return valueMap.lookup(id);
75 }
76 
77 LogicalResult spirv::Deserializer::sliceInstruction(
78     spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
79     std::optional<spirv::Opcode> expectedOpcode) {
80   auto binarySize = binary.size();
81   if (curOffset >= binarySize) {
82     return emitError(unknownLoc, "expected ")
83            << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
84                               : "more")
85            << " instruction";
86   }
87 
88   // For each instruction, get its word count from the first word to slice it
89   // from the stream properly, and then dispatch to the instruction handler.
90 
91   uint32_t wordCount = binary[curOffset] >> 16;
92 
93   if (wordCount == 0)
94     return emitError(unknownLoc, "word count cannot be zero");
95 
96   uint32_t nextOffset = curOffset + wordCount;
97   if (nextOffset > binarySize)
98     return emitError(unknownLoc, "insufficient words for the last instruction");
99 
100   opcode = extractOpcode(binary[curOffset]);
101   operands = binary.slice(curOffset + 1, wordCount - 1);
102   curOffset = nextOffset;
103   return success();
104 }
105 
106 LogicalResult spirv::Deserializer::processInstruction(
107     spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
108   LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
109                                 << spirv::stringifyOpcode(opcode) << "\n");
110 
111   // First dispatch all the instructions whose opcode does not correspond to
112   // those that have a direct mirror in the SPIR-V dialect
113   switch (opcode) {
114   case spirv::Opcode::OpCapability:
115     return processCapability(operands);
116   case spirv::Opcode::OpExtension:
117     return processExtension(operands);
118   case spirv::Opcode::OpExtInst:
119     return processExtInst(operands);
120   case spirv::Opcode::OpExtInstImport:
121     return processExtInstImport(operands);
122   case spirv::Opcode::OpMemberName:
123     return processMemberName(operands);
124   case spirv::Opcode::OpMemoryModel:
125     return processMemoryModel(operands);
126   case spirv::Opcode::OpEntryPoint:
127   case spirv::Opcode::OpExecutionMode:
128     if (deferInstructions) {
129       deferredInstructions.emplace_back(opcode, operands);
130       return success();
131     }
132     break;
133   case spirv::Opcode::OpVariable:
134     if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
135       return processGlobalVariable(operands);
136     }
137     break;
138   case spirv::Opcode::OpLine:
139     return processDebugLine(operands);
140   case spirv::Opcode::OpNoLine:
141     clearDebugLine();
142     return success();
143   case spirv::Opcode::OpName:
144     return processName(operands);
145   case spirv::Opcode::OpString:
146     return processDebugString(operands);
147   case spirv::Opcode::OpModuleProcessed:
148   case spirv::Opcode::OpSource:
149   case spirv::Opcode::OpSourceContinued:
150   case spirv::Opcode::OpSourceExtension:
151     // TODO: This is debug information embedded in the binary which should be
152     // translated into the spirv.module.
153     return success();
154   case spirv::Opcode::OpTypeVoid:
155   case spirv::Opcode::OpTypeBool:
156   case spirv::Opcode::OpTypeInt:
157   case spirv::Opcode::OpTypeFloat:
158   case spirv::Opcode::OpTypeVector:
159   case spirv::Opcode::OpTypeMatrix:
160   case spirv::Opcode::OpTypeArray:
161   case spirv::Opcode::OpTypeFunction:
162   case spirv::Opcode::OpTypeImage:
163   case spirv::Opcode::OpTypeSampledImage:
164   case spirv::Opcode::OpTypeRuntimeArray:
165   case spirv::Opcode::OpTypeStruct:
166   case spirv::Opcode::OpTypePointer:
167   case spirv::Opcode::OpTypeCooperativeMatrixKHR:
168     return processType(opcode, operands);
169   case spirv::Opcode::OpTypeForwardPointer:
170     return processTypeForwardPointer(operands);
171   case spirv::Opcode::OpConstant:
172     return processConstant(operands, /*isSpec=*/false);
173   case spirv::Opcode::OpSpecConstant:
174     return processConstant(operands, /*isSpec=*/true);
175   case spirv::Opcode::OpConstantComposite:
176     return processConstantComposite(operands);
177   case spirv::Opcode::OpSpecConstantComposite:
178     return processSpecConstantComposite(operands);
179   case spirv::Opcode::OpSpecConstantOp:
180     return processSpecConstantOperation(operands);
181   case spirv::Opcode::OpConstantTrue:
182     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
183   case spirv::Opcode::OpSpecConstantTrue:
184     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
185   case spirv::Opcode::OpConstantFalse:
186     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
187   case spirv::Opcode::OpSpecConstantFalse:
188     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
189   case spirv::Opcode::OpConstantNull:
190     return processConstantNull(operands);
191   case spirv::Opcode::OpDecorate:
192     return processDecoration(operands);
193   case spirv::Opcode::OpMemberDecorate:
194     return processMemberDecoration(operands);
195   case spirv::Opcode::OpFunction:
196     return processFunction(operands);
197   case spirv::Opcode::OpLabel:
198     return processLabel(operands);
199   case spirv::Opcode::OpBranch:
200     return processBranch(operands);
201   case spirv::Opcode::OpBranchConditional:
202     return processBranchConditional(operands);
203   case spirv::Opcode::OpSelectionMerge:
204     return processSelectionMerge(operands);
205   case spirv::Opcode::OpLoopMerge:
206     return processLoopMerge(operands);
207   case spirv::Opcode::OpPhi:
208     return processPhi(operands);
209   case spirv::Opcode::OpUndef:
210     return processUndef(operands);
211   default:
212     break;
213   }
214   return dispatchToAutogenDeserialization(opcode, operands);
215 }
216 
217 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
218     ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
219     unsigned numOperands) {
220   SmallVector<Type, 1> resultTypes;
221   uint32_t valueID = 0;
222 
223   size_t wordIndex = 0;
224   if (hasResult) {
225     if (wordIndex >= words.size())
226       return emitError(unknownLoc,
227                        "expected result type <id> while deserializing for ")
228              << opName;
229 
230     // Decode the type <id>
231     auto type = getType(words[wordIndex]);
232     if (!type)
233       return emitError(unknownLoc, "unknown type result <id>: ")
234              << words[wordIndex];
235     resultTypes.push_back(type);
236     ++wordIndex;
237 
238     // Decode the result <id>
239     if (wordIndex >= words.size())
240       return emitError(unknownLoc,
241                        "expected result <id> while deserializing for ")
242              << opName;
243     valueID = words[wordIndex];
244     ++wordIndex;
245   }
246 
247   SmallVector<Value, 4> operands;
248   SmallVector<NamedAttribute, 4> attributes;
249 
250   // Decode operands
251   size_t operandIndex = 0;
252   for (; operandIndex < numOperands && wordIndex < words.size();
253        ++operandIndex, ++wordIndex) {
254     auto arg = getValue(words[wordIndex]);
255     if (!arg)
256       return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
257     operands.push_back(arg);
258   }
259   if (operandIndex != numOperands) {
260     return emitError(
261                unknownLoc,
262                "found less operands than expected when deserializing for ")
263            << opName << "; only " << operandIndex << " of " << numOperands
264            << " processed";
265   }
266   if (wordIndex != words.size()) {
267     return emitError(
268                unknownLoc,
269                "found more operands than expected when deserializing for ")
270            << opName << "; only " << wordIndex << " of " << words.size()
271            << " processed";
272   }
273 
274   // Attach attributes from decorations
275   if (decorations.count(valueID)) {
276     auto attrs = decorations[valueID].getAttrs();
277     attributes.append(attrs.begin(), attrs.end());
278   }
279 
280   // Create the op and update bookkeeping maps
281   Location loc = createFileLineColLoc(opBuilder);
282   OperationState opState(loc, opName);
283   opState.addOperands(operands);
284   if (hasResult)
285     opState.addTypes(resultTypes);
286   opState.addAttributes(attributes);
287   Operation *op = opBuilder.create(opState);
288   if (hasResult)
289     valueMap[valueID] = op->getResult(0);
290 
291   if (op->hasTrait<OpTrait::IsTerminator>())
292     clearDebugLine();
293 
294   return success();
295 }
296 
297 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
298   if (operands.size() != 2) {
299     return emitError(unknownLoc, "OpUndef instruction must have two operands");
300   }
301   auto type = getType(operands[0]);
302   if (!type) {
303     return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
304   }
305   undefMap[operands[1]] = type;
306   return success();
307 }
308 
309 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
310   if (operands.size() < 4) {
311     return emitError(unknownLoc,
312                      "OpExtInst must have at least 4 operands, result type "
313                      "<id>, result <id>, set <id> and instruction opcode");
314   }
315   if (!extendedInstSets.count(operands[2])) {
316     return emitError(unknownLoc, "undefined set <id> in OpExtInst");
317   }
318   SmallVector<uint32_t, 4> slicedOperands;
319   slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
320   slicedOperands.append(std::next(operands.begin(), 4), operands.end());
321   return dispatchToExtensionSetAutogenDeserialization(
322       extendedInstSets[operands[2]], operands[3], slicedOperands);
323 }
324 
325 namespace mlir {
326 namespace spirv {
327 
328 template <>
329 LogicalResult
330 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
331   unsigned wordIndex = 0;
332   if (wordIndex >= words.size()) {
333     return emitError(unknownLoc,
334                      "missing Execution Model specification in OpEntryPoint");
335   }
336   auto execModel = spirv::ExecutionModelAttr::get(
337       context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
338   if (wordIndex >= words.size()) {
339     return emitError(unknownLoc, "missing <id> in OpEntryPoint");
340   }
341   // Get the function <id>
342   auto fnID = words[wordIndex++];
343   // Get the function name
344   auto fnName = decodeStringLiteral(words, wordIndex);
345   // Verify that the function <id> matches the fnName
346   auto parsedFunc = getFunction(fnID);
347   if (!parsedFunc) {
348     return emitError(unknownLoc, "no function matching <id> ") << fnID;
349   }
350   if (parsedFunc.getName() != fnName) {
351     // The deserializer uses "spirv_fn_<id>" as the function name if the input
352     // SPIR-V blob does not contain a name for it. We should use a more clear
353     // indication for such case rather than relying on naming details.
354     if (!parsedFunc.getName().starts_with("spirv_fn_"))
355       return emitError(unknownLoc,
356                        "function name mismatch between OpEntryPoint "
357                        "and OpFunction with <id> ")
358              << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
359     parsedFunc.setName(fnName);
360   }
361   SmallVector<Attribute, 4> interface;
362   while (wordIndex < words.size()) {
363     auto arg = getGlobalVariable(words[wordIndex]);
364     if (!arg) {
365       return emitError(unknownLoc, "undefined result <id> ")
366              << words[wordIndex] << " while decoding OpEntryPoint";
367     }
368     interface.push_back(SymbolRefAttr::get(arg.getOperation()));
369     wordIndex++;
370   }
371   opBuilder.create<spirv::EntryPointOp>(
372       unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
373       opBuilder.getArrayAttr(interface));
374   return success();
375 }
376 
377 template <>
378 LogicalResult
379 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
380   unsigned wordIndex = 0;
381   if (wordIndex >= words.size()) {
382     return emitError(unknownLoc,
383                      "missing function result <id> in OpExecutionMode");
384   }
385   // Get the function <id> to get the name of the function
386   auto fnID = words[wordIndex++];
387   auto fn = getFunction(fnID);
388   if (!fn) {
389     return emitError(unknownLoc, "no function matching <id> ") << fnID;
390   }
391   // Get the Execution mode
392   if (wordIndex >= words.size()) {
393     return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
394   }
395   auto execMode = spirv::ExecutionModeAttr::get(
396       context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
397 
398   // Get the values
399   SmallVector<Attribute, 4> attrListElems;
400   while (wordIndex < words.size()) {
401     attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
402   }
403   auto values = opBuilder.getArrayAttr(attrListElems);
404   opBuilder.create<spirv::ExecutionModeOp>(
405       unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
406       execMode, values);
407   return success();
408 }
409 
410 template <>
411 LogicalResult
412 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
413   if (operands.size() < 3) {
414     return emitError(unknownLoc,
415                      "OpFunctionCall must have at least 3 operands");
416   }
417 
418   Type resultType = getType(operands[0]);
419   if (!resultType) {
420     return emitError(unknownLoc, "undefined result type from <id> ")
421            << operands[0];
422   }
423 
424   // Use null type to mean no result type.
425   if (isVoidType(resultType))
426     resultType = nullptr;
427 
428   auto resultID = operands[1];
429   auto functionID = operands[2];
430 
431   auto functionName = getFunctionSymbol(functionID);
432 
433   SmallVector<Value, 4> arguments;
434   for (auto operand : llvm::drop_begin(operands, 3)) {
435     auto value = getValue(operand);
436     if (!value) {
437       return emitError(unknownLoc, "unknown <id> ")
438              << operand << " used by OpFunctionCall";
439     }
440     arguments.push_back(value);
441   }
442 
443   auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
444       unknownLoc, resultType,
445       SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
446 
447   if (resultType)
448     valueMap[resultID] = opFunctionCall.getResult(0);
449   return success();
450 }
451 
452 template <>
453 LogicalResult
454 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
455   SmallVector<Type, 1> resultTypes;
456   size_t wordIndex = 0;
457   SmallVector<Value, 4> operands;
458   SmallVector<NamedAttribute, 4> attributes;
459 
460   if (wordIndex < words.size()) {
461     auto arg = getValue(words[wordIndex]);
462 
463     if (!arg) {
464       return emitError(unknownLoc, "unknown result <id> : ")
465              << words[wordIndex];
466     }
467 
468     operands.push_back(arg);
469     wordIndex++;
470   }
471 
472   if (wordIndex < words.size()) {
473     auto arg = getValue(words[wordIndex]);
474 
475     if (!arg) {
476       return emitError(unknownLoc, "unknown result <id> : ")
477              << words[wordIndex];
478     }
479 
480     operands.push_back(arg);
481     wordIndex++;
482   }
483 
484   bool isAlignedAttr = false;
485 
486   if (wordIndex < words.size()) {
487     auto attrValue = words[wordIndex++];
488     auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
489         static_cast<spirv::MemoryAccess>(attrValue));
490     attributes.push_back(
491         opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
492     isAlignedAttr = (attrValue == 2);
493   }
494 
495   if (isAlignedAttr && wordIndex < words.size()) {
496     attributes.push_back(opBuilder.getNamedAttr(
497         "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
498   }
499 
500   if (wordIndex < words.size()) {
501     auto attrValue = words[wordIndex++];
502     auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
503         static_cast<spirv::MemoryAccess>(attrValue));
504     attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr));
505   }
506 
507   if (wordIndex < words.size()) {
508     attributes.push_back(opBuilder.getNamedAttr(
509         "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
510   }
511 
512   if (wordIndex != words.size()) {
513     return emitError(unknownLoc,
514                      "found more operands than expected when deserializing "
515                      "spirv::CopyMemoryOp, only ")
516            << wordIndex << " of " << words.size() << " processed";
517   }
518 
519   Location loc = createFileLineColLoc(opBuilder);
520   opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
521 
522   return success();
523 }
524 
525 template <>
526 LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>(
527     ArrayRef<uint32_t> words) {
528   if (words.size() != 4) {
529     return emitError(unknownLoc,
530                      "expected 4 words in GenericCastToPtrExplicitOp"
531                      " but got : ")
532            << words.size();
533   }
534   SmallVector<Type, 1> resultTypes;
535   SmallVector<Value, 4> operands;
536   uint32_t valueID = 0;
537   auto type = getType(words[0]);
538 
539   if (!type)
540     return emitError(unknownLoc, "unknown type result <id> : ") << words[0];
541   resultTypes.push_back(type);
542 
543   valueID = words[1];
544 
545   auto arg = getValue(words[2]);
546   if (!arg)
547     return emitError(unknownLoc, "unknown result <id> : ") << words[2];
548   operands.push_back(arg);
549 
550   Location loc = createFileLineColLoc(opBuilder);
551   Operation *op = opBuilder.create<spirv::GenericCastToPtrExplicitOp>(
552       loc, resultTypes, operands);
553   valueMap[valueID] = op->getResult(0);
554   return success();
555 }
556 
557 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
558 // various Deserializer::processOp<...>() specializations.
559 #define GET_DESERIALIZATION_FNS
560 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
561 
562 } // namespace spirv
563 } // namespace mlir
564