xref: /llvm-project/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (revision 4f78f8519056953d26102c7426fbb028caf13bc9)
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43   return block->isEntryBlock() &&
44          isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
51 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52                                   MLIRContext *context)
53     : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
54       module(createModuleOp()), opBuilder(module->getRegion())
55 #ifndef NDEBUG
56       ,
57       logger(llvm::dbgs())
58 #endif
59 {
60 }
61 
62 LogicalResult spirv::Deserializer::deserialize() {
63   LLVM_DEBUG({
64     logger.resetIndent();
65     logger.startLine()
66         << "//+++---------- start deserialization ----------+++//\n";
67   });
68 
69   if (failed(processHeader()))
70     return failure();
71 
72   spirv::Opcode opcode = spirv::Opcode::OpNop;
73   ArrayRef<uint32_t> operands;
74   auto binarySize = binary.size();
75   while (curOffset < binarySize) {
76     // Slice the next instruction out and populate `opcode` and `operands`.
77     // Internally this also updates `curOffset`.
78     if (failed(sliceInstruction(opcode, operands)))
79       return failure();
80 
81     if (failed(processInstruction(opcode, operands)))
82       return failure();
83   }
84 
85   assert(curOffset == binarySize &&
86          "deserializer should never index beyond the binary end");
87 
88   for (auto &deferred : deferredInstructions) {
89     if (failed(processInstruction(deferred.first, deferred.second, false))) {
90       return failure();
91     }
92   }
93 
94   attachVCETriple();
95 
96   LLVM_DEBUG(logger.startLine()
97              << "//+++-------- completed deserialization --------+++//\n");
98   return success();
99 }
100 
101 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
102   return std::move(module);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Module structure
107 //===----------------------------------------------------------------------===//
108 
109 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
110   OpBuilder builder(context);
111   OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
112   spirv::ModuleOp::build(builder, state);
113   return cast<spirv::ModuleOp>(Operation::create(state));
114 }
115 
116 LogicalResult spirv::Deserializer::processHeader() {
117   if (binary.size() < spirv::kHeaderWordCount)
118     return emitError(unknownLoc,
119                      "SPIR-V binary module must have a 5-word header");
120 
121   if (binary[0] != spirv::kMagicNumber)
122     return emitError(unknownLoc, "incorrect magic number");
123 
124   // Version number bytes: 0 | major number | minor number | 0
125   uint32_t majorVersion = (binary[1] << 8) >> 24;
126   uint32_t minorVersion = (binary[1] << 16) >> 24;
127   if (majorVersion == 1) {
128     switch (minorVersion) {
129 #define MIN_VERSION_CASE(v)                                                    \
130   case v:                                                                      \
131     version = spirv::Version::V_1_##v;                                         \
132     break
133 
134       MIN_VERSION_CASE(0);
135       MIN_VERSION_CASE(1);
136       MIN_VERSION_CASE(2);
137       MIN_VERSION_CASE(3);
138       MIN_VERSION_CASE(4);
139       MIN_VERSION_CASE(5);
140 #undef MIN_VERSION_CASE
141     default:
142       return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
143              << minorVersion;
144     }
145   } else {
146     return emitError(unknownLoc, "unsupported SPIR-V major version: ")
147            << majorVersion;
148   }
149 
150   // TODO: generator number, bound, schema
151   curOffset = spirv::kHeaderWordCount;
152   return success();
153 }
154 
155 LogicalResult
156 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
157   if (operands.size() != 1)
158     return emitError(unknownLoc, "OpMemoryModel must have one parameter");
159 
160   auto cap = spirv::symbolizeCapability(operands[0]);
161   if (!cap)
162     return emitError(unknownLoc, "unknown capability: ") << operands[0];
163 
164   capabilities.insert(*cap);
165   return success();
166 }
167 
168 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
169   if (words.empty()) {
170     return emitError(
171         unknownLoc,
172         "OpExtension must have a literal string for the extension name");
173   }
174 
175   unsigned wordIndex = 0;
176   StringRef extName = decodeStringLiteral(words, wordIndex);
177   if (wordIndex != words.size())
178     return emitError(unknownLoc,
179                      "unexpected trailing words in OpExtension instruction");
180   auto ext = spirv::symbolizeExtension(extName);
181   if (!ext)
182     return emitError(unknownLoc, "unknown extension: ") << extName;
183 
184   extensions.insert(*ext);
185   return success();
186 }
187 
188 LogicalResult
189 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
190   if (words.size() < 2) {
191     return emitError(unknownLoc,
192                      "OpExtInstImport must have a result <id> and a literal "
193                      "string for the extended instruction set name");
194   }
195 
196   unsigned wordIndex = 1;
197   extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
198   if (wordIndex != words.size()) {
199     return emitError(unknownLoc,
200                      "unexpected trailing words in OpExtInstImport");
201   }
202   return success();
203 }
204 
205 void spirv::Deserializer::attachVCETriple() {
206   (*module)->setAttr(
207       spirv::ModuleOp::getVCETripleAttrName(),
208       spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
209                                 extensions.getArrayRef(), context));
210 }
211 
212 LogicalResult
213 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
214   if (operands.size() != 2)
215     return emitError(unknownLoc, "OpMemoryModel must have two operands");
216 
217   (*module)->setAttr(
218       module->getAddressingModelAttrName(),
219       opBuilder.getAttr<spirv::AddressingModelAttr>(
220           static_cast<spirv::AddressingModel>(operands.front())));
221 
222   (*module)->setAttr(module->getMemoryModelAttrName(),
223                      opBuilder.getAttr<spirv::MemoryModelAttr>(
224                          static_cast<spirv::MemoryModel>(operands.back())));
225 
226   return success();
227 }
228 
229 template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
230 LogicalResult deserializeCacheControlDecoration(
231     Location loc, OpBuilder &opBuilder,
232     DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words,
233     StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
234   if (words.size() != 4) {
235     return emitError(loc, "OpDecoration with ")
236            << decorationName << "needs a cache control integer literal and a "
237            << cacheControlKind << " cache control literal";
238   }
239   unsigned cacheLevel = words[2];
240   auto cacheControlAttr = static_cast<EnumTy>(words[3]);
241   auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
242   SmallVector<Attribute> attrs;
243   if (auto attrList =
244           llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
245     llvm::append_range(attrs, attrList);
246   attrs.push_back(value);
247   decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
248   return success();
249 }
250 
251 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
252   // TODO: This function should also be auto-generated. For now, since only a
253   // few decorations are processed/handled in a meaningful manner, going with a
254   // manual implementation.
255   if (words.size() < 2) {
256     return emitError(
257         unknownLoc, "OpDecorate must have at least result <id> and Decoration");
258   }
259   auto decorationName =
260       stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
261   if (decorationName.empty()) {
262     return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
263   }
264   auto symbol = getSymbolDecoration(decorationName);
265   switch (static_cast<spirv::Decoration>(words[1])) {
266   case spirv::Decoration::FPFastMathMode:
267     if (words.size() != 3) {
268       return emitError(unknownLoc, "OpDecorate with ")
269              << decorationName << " needs a single integer literal";
270     }
271     decorations[words[0]].set(
272         symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
273                                         static_cast<FPFastMathMode>(words[2])));
274     break;
275   case spirv::Decoration::FPRoundingMode:
276     if (words.size() != 3) {
277       return emitError(unknownLoc, "OpDecorate with ")
278              << decorationName << " needs a single integer literal";
279     }
280     decorations[words[0]].set(
281         symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
282                                         static_cast<FPRoundingMode>(words[2])));
283     break;
284   case spirv::Decoration::DescriptorSet:
285   case spirv::Decoration::Binding:
286     if (words.size() != 3) {
287       return emitError(unknownLoc, "OpDecorate with ")
288              << decorationName << " needs a single integer literal";
289     }
290     decorations[words[0]].set(
291         symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
292     break;
293   case spirv::Decoration::BuiltIn:
294     if (words.size() != 3) {
295       return emitError(unknownLoc, "OpDecorate with ")
296              << decorationName << " needs a single integer literal";
297     }
298     decorations[words[0]].set(
299         symbol, opBuilder.getStringAttr(
300                     stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
301     break;
302   case spirv::Decoration::ArrayStride:
303     if (words.size() != 3) {
304       return emitError(unknownLoc, "OpDecorate with ")
305              << decorationName << " needs a single integer literal";
306     }
307     typeDecorations[words[0]] = words[2];
308     break;
309   case spirv::Decoration::LinkageAttributes: {
310     if (words.size() < 4) {
311       return emitError(unknownLoc, "OpDecorate with ")
312              << decorationName
313              << " needs at least 1 string and 1 integer literal";
314     }
315     // LinkageAttributes has two parameters ["linkageName", linkageType]
316     // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
317     // "linkageName" is a stringliteral encoded as uint32_t,
318     // hence the size of name is variable length which results in words.size()
319     // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
320     // 3 + ceildiv(strlen(name), 4).
321     unsigned wordIndex = 2;
322     auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
323     auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
324         static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
325     auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
326         StringAttr::get(context, linkageName), linkageTypeAttr);
327     decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
328     break;
329   }
330   case spirv::Decoration::Aliased:
331   case spirv::Decoration::AliasedPointer:
332   case spirv::Decoration::Block:
333   case spirv::Decoration::BufferBlock:
334   case spirv::Decoration::Flat:
335   case spirv::Decoration::NonReadable:
336   case spirv::Decoration::NonWritable:
337   case spirv::Decoration::NoPerspective:
338   case spirv::Decoration::NoSignedWrap:
339   case spirv::Decoration::NoUnsignedWrap:
340   case spirv::Decoration::RelaxedPrecision:
341   case spirv::Decoration::Restrict:
342   case spirv::Decoration::RestrictPointer:
343   case spirv::Decoration::NoContraction:
344   case spirv::Decoration::Constant:
345     if (words.size() != 2) {
346       return emitError(unknownLoc, "OpDecoration with ")
347              << decorationName << "needs a single target <id>";
348     }
349     // Block decoration does not affect spirv.struct type, but is still stored
350     // for verification.
351     // TODO: Update StructType to contain this information since
352     // it is needed for many validation rules.
353     decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
354     break;
355   case spirv::Decoration::Location:
356   case spirv::Decoration::SpecId:
357     if (words.size() != 3) {
358       return emitError(unknownLoc, "OpDecoration with ")
359              << decorationName << "needs a single integer literal";
360     }
361     decorations[words[0]].set(
362         symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
363     break;
364   case spirv::Decoration::CacheControlLoadINTEL: {
365     LogicalResult res = deserializeCacheControlDecoration<
366         CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
367         unknownLoc, opBuilder, decorations, words, symbol, decorationName,
368         "load");
369     if (failed(res))
370       return res;
371     break;
372   }
373   case spirv::Decoration::CacheControlStoreINTEL: {
374     LogicalResult res = deserializeCacheControlDecoration<
375         CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
376         unknownLoc, opBuilder, decorations, words, symbol, decorationName,
377         "store");
378     if (failed(res))
379       return res;
380     break;
381   }
382   default:
383     return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
384   }
385   return success();
386 }
387 
388 LogicalResult
389 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
390   // The binary layout of OpMemberDecorate is different comparing to OpDecorate
391   if (words.size() < 3) {
392     return emitError(unknownLoc,
393                      "OpMemberDecorate must have at least 3 operands");
394   }
395 
396   auto decoration = static_cast<spirv::Decoration>(words[2]);
397   if (decoration == spirv::Decoration::Offset && words.size() != 4) {
398     return emitError(unknownLoc,
399                      " missing offset specification in OpMemberDecorate with "
400                      "Offset decoration");
401   }
402   ArrayRef<uint32_t> decorationOperands;
403   if (words.size() > 3) {
404     decorationOperands = words.slice(3);
405   }
406   memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
407   return success();
408 }
409 
410 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
411   if (words.size() < 3) {
412     return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
413   }
414   unsigned wordIndex = 2;
415   auto name = decodeStringLiteral(words, wordIndex);
416   if (wordIndex != words.size()) {
417     return emitError(unknownLoc,
418                      "unexpected trailing words in OpMemberName instruction");
419   }
420   memberNameMap[words[0]][words[1]] = name;
421   return success();
422 }
423 
424 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
425     uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
426   if (!decorations.contains(argID)) {
427     argAttrs[argIndex] = DictionaryAttr::get(context, {});
428     return success();
429   }
430 
431   spirv::DecorationAttr foundDecorationAttr;
432   for (NamedAttribute decAttr : decorations[argID]) {
433     for (auto decoration :
434          {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
435           spirv::Decoration::AliasedPointer,
436           spirv::Decoration::RestrictPointer}) {
437 
438       if (decAttr.getName() !=
439           getSymbolDecoration(stringifyDecoration(decoration)))
440         continue;
441 
442       if (foundDecorationAttr)
443         return emitError(unknownLoc,
444                          "more than one Aliased/Restrict decorations for "
445                          "function argument with result <id> ")
446                << argID;
447 
448       foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
449       break;
450     }
451   }
452 
453   if (!foundDecorationAttr)
454     return emitError(unknownLoc, "unimplemented decoration support for "
455                                  "function argument with result <id> ")
456            << argID;
457 
458   NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
459                       foundDecorationAttr);
460   argAttrs[argIndex] = DictionaryAttr::get(context, attr);
461   return success();
462 }
463 
464 LogicalResult
465 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
466   if (curFunction) {
467     return emitError(unknownLoc, "found function inside function");
468   }
469 
470   // Get the result type
471   if (operands.size() != 4) {
472     return emitError(unknownLoc, "OpFunction must have 4 parameters");
473   }
474   Type resultType = getType(operands[0]);
475   if (!resultType) {
476     return emitError(unknownLoc, "undefined result type from <id> ")
477            << operands[0];
478   }
479 
480   uint32_t fnID = operands[1];
481   if (funcMap.count(fnID)) {
482     return emitError(unknownLoc, "duplicate function definition/declaration");
483   }
484 
485   auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
486   if (!fnControl) {
487     return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
488   }
489 
490   Type fnType = getType(operands[3]);
491   if (!fnType || !isa<FunctionType>(fnType)) {
492     return emitError(unknownLoc, "unknown function type from <id> ")
493            << operands[3];
494   }
495   auto functionType = cast<FunctionType>(fnType);
496 
497   if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
498       (functionType.getNumResults() == 1 &&
499        functionType.getResult(0) != resultType)) {
500     return emitError(unknownLoc, "mismatch in function type ")
501            << functionType << " and return type " << resultType << " specified";
502   }
503 
504   std::string fnName = getFunctionSymbol(fnID);
505   auto funcOp = opBuilder.create<spirv::FuncOp>(
506       unknownLoc, fnName, functionType, fnControl.value());
507   // Processing other function attributes.
508   if (decorations.count(fnID)) {
509     for (auto attr : decorations[fnID].getAttrs()) {
510       funcOp->setAttr(attr.getName(), attr.getValue());
511     }
512   }
513   curFunction = funcMap[fnID] = funcOp;
514   auto *entryBlock = funcOp.addEntryBlock();
515   LLVM_DEBUG({
516     logger.startLine()
517         << "//===-------------------------------------------===//\n";
518     logger.startLine() << "[fn] name: " << fnName << "\n";
519     logger.startLine() << "[fn] type: " << fnType << "\n";
520     logger.startLine() << "[fn] ID: " << fnID << "\n";
521     logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
522     logger.indent();
523   });
524 
525   SmallVector<Attribute> argAttrs;
526   argAttrs.resize(functionType.getNumInputs());
527 
528   // Parse the op argument instructions
529   if (functionType.getNumInputs()) {
530     for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
531       auto argType = functionType.getInput(i);
532       spirv::Opcode opcode = spirv::Opcode::OpNop;
533       ArrayRef<uint32_t> operands;
534       if (failed(sliceInstruction(opcode, operands,
535                                   spirv::Opcode::OpFunctionParameter))) {
536         return failure();
537       }
538       if (opcode != spirv::Opcode::OpFunctionParameter) {
539         return emitError(
540                    unknownLoc,
541                    "missing OpFunctionParameter instruction for argument ")
542                << i;
543       }
544       if (operands.size() != 2) {
545         return emitError(
546             unknownLoc,
547             "expected result type and result <id> for OpFunctionParameter");
548       }
549       auto argDefinedType = getType(operands[0]);
550       if (!argDefinedType || argDefinedType != argType) {
551         return emitError(unknownLoc,
552                          "mismatch in argument type between function type "
553                          "definition ")
554                << functionType << " and argument type definition "
555                << argDefinedType << " at argument " << i;
556       }
557       if (getValue(operands[1])) {
558         return emitError(unknownLoc, "duplicate definition of result <id> ")
559                << operands[1];
560       }
561       if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
562         return failure();
563       }
564 
565       auto argValue = funcOp.getArgument(i);
566       valueMap[operands[1]] = argValue;
567     }
568   }
569 
570   if (llvm::any_of(argAttrs, [](Attribute attr) {
571         auto argAttr = cast<DictionaryAttr>(attr);
572         return !argAttr.empty();
573       }))
574     funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
575 
576   // entryBlock is needed to access the arguments, Once that is done, we can
577   // erase the block for functions with 'Import' LinkageAttributes, since these
578   // are essentially function declarations, so they have no body.
579   auto linkageAttr = funcOp.getLinkageAttributes();
580   auto hasImportLinkage =
581       linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
582                       spirv::LinkageType::Import);
583   if (hasImportLinkage)
584     funcOp.eraseBody();
585 
586   // RAII guard to reset the insertion point to the module's region after
587   // deserializing the body of this function.
588   OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
589 
590   spirv::Opcode opcode = spirv::Opcode::OpNop;
591   ArrayRef<uint32_t> instOperands;
592 
593   // Special handling for the entry block. We need to make sure it starts with
594   // an OpLabel instruction. The entry block takes the same parameters as the
595   // function. All other blocks do not take any parameter. We have already
596   // created the entry block, here we need to register it to the correct label
597   // <id>.
598   if (failed(sliceInstruction(opcode, instOperands,
599                               spirv::Opcode::OpFunctionEnd))) {
600     return failure();
601   }
602   if (opcode == spirv::Opcode::OpFunctionEnd) {
603     return processFunctionEnd(instOperands);
604   }
605   if (opcode != spirv::Opcode::OpLabel) {
606     return emitError(unknownLoc, "a basic block must start with OpLabel");
607   }
608   if (instOperands.size() != 1) {
609     return emitError(unknownLoc, "OpLabel should only have result <id>");
610   }
611   blockMap[instOperands[0]] = entryBlock;
612   if (failed(processLabel(instOperands))) {
613     return failure();
614   }
615 
616   // Then process all the other instructions in the function until we hit
617   // OpFunctionEnd.
618   while (succeeded(sliceInstruction(opcode, instOperands,
619                                     spirv::Opcode::OpFunctionEnd)) &&
620          opcode != spirv::Opcode::OpFunctionEnd) {
621     if (failed(processInstruction(opcode, instOperands))) {
622       return failure();
623     }
624   }
625   if (opcode != spirv::Opcode::OpFunctionEnd) {
626     return failure();
627   }
628 
629   return processFunctionEnd(instOperands);
630 }
631 
632 LogicalResult
633 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
634   // Process OpFunctionEnd.
635   if (!operands.empty()) {
636     return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
637   }
638 
639   // Wire up block arguments from OpPhi instructions.
640   // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
641   // ops.
642   if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
643     return failure();
644   }
645 
646   curBlock = nullptr;
647   curFunction = std::nullopt;
648 
649   LLVM_DEBUG({
650     logger.unindent();
651     logger.startLine()
652         << "//===-------------------------------------------===//\n";
653   });
654   return success();
655 }
656 
657 std::optional<std::pair<Attribute, Type>>
658 spirv::Deserializer::getConstant(uint32_t id) {
659   auto constIt = constantMap.find(id);
660   if (constIt == constantMap.end())
661     return std::nullopt;
662   return constIt->getSecond();
663 }
664 
665 std::optional<spirv::SpecConstOperationMaterializationInfo>
666 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
667   auto constIt = specConstOperationMap.find(id);
668   if (constIt == specConstOperationMap.end())
669     return std::nullopt;
670   return constIt->getSecond();
671 }
672 
673 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
674   auto funcName = nameMap.lookup(id).str();
675   if (funcName.empty()) {
676     funcName = "spirv_fn_" + std::to_string(id);
677   }
678   return funcName;
679 }
680 
681 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
682   auto constName = nameMap.lookup(id).str();
683   if (constName.empty()) {
684     constName = "spirv_spec_const_" + std::to_string(id);
685   }
686   return constName;
687 }
688 
689 spirv::SpecConstantOp
690 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
691                                         TypedAttr defaultValue) {
692   auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
693   auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
694                                                     defaultValue);
695   if (decorations.count(resultID)) {
696     for (auto attr : decorations[resultID].getAttrs())
697       op->setAttr(attr.getName(), attr.getValue());
698   }
699   specConstMap[resultID] = op;
700   return op;
701 }
702 
703 LogicalResult
704 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
705   unsigned wordIndex = 0;
706   if (operands.size() < 3) {
707     return emitError(
708         unknownLoc,
709         "OpVariable needs at least 3 operands, type, <id> and storage class");
710   }
711 
712   // Result Type.
713   auto type = getType(operands[wordIndex]);
714   if (!type) {
715     return emitError(unknownLoc, "unknown result type <id> : ")
716            << operands[wordIndex];
717   }
718   auto ptrType = dyn_cast<spirv::PointerType>(type);
719   if (!ptrType) {
720     return emitError(unknownLoc,
721                      "expected a result type <id> to be a spirv.ptr, found : ")
722            << type;
723   }
724   wordIndex++;
725 
726   // Result <id>.
727   auto variableID = operands[wordIndex];
728   auto variableName = nameMap.lookup(variableID).str();
729   if (variableName.empty()) {
730     variableName = "spirv_var_" + std::to_string(variableID);
731   }
732   wordIndex++;
733 
734   // Storage class.
735   auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
736   if (ptrType.getStorageClass() != storageClass) {
737     return emitError(unknownLoc, "mismatch in storage class of pointer type ")
738            << type << " and that specified in OpVariable instruction  : "
739            << stringifyStorageClass(storageClass);
740   }
741   wordIndex++;
742 
743   // Initializer.
744   FlatSymbolRefAttr initializer = nullptr;
745 
746   if (wordIndex < operands.size()) {
747     Operation *op = nullptr;
748 
749     if (auto initOp = getGlobalVariable(operands[wordIndex]))
750       op = initOp;
751     else if (auto initOp = getSpecConstant(operands[wordIndex]))
752       op = initOp;
753     else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
754       op = initOp;
755     else
756       return emitError(unknownLoc, "unknown <id> ")
757              << operands[wordIndex] << "used as initializer";
758 
759     initializer = SymbolRefAttr::get(op);
760     wordIndex++;
761   }
762   if (wordIndex != operands.size()) {
763     return emitError(unknownLoc,
764                      "found more operands than expected when deserializing "
765                      "OpVariable instruction, only ")
766            << wordIndex << " of " << operands.size() << " processed";
767   }
768   auto loc = createFileLineColLoc(opBuilder);
769   auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
770       loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
771       initializer);
772 
773   // Decorations.
774   if (decorations.count(variableID)) {
775     for (auto attr : decorations[variableID].getAttrs())
776       varOp->setAttr(attr.getName(), attr.getValue());
777   }
778   globalVariableMap[variableID] = varOp;
779   return success();
780 }
781 
782 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
783   auto constInfo = getConstant(id);
784   if (!constInfo) {
785     return nullptr;
786   }
787   return dyn_cast<IntegerAttr>(constInfo->first);
788 }
789 
790 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
791   if (operands.size() < 2) {
792     return emitError(unknownLoc, "OpName needs at least 2 operands");
793   }
794   if (!nameMap.lookup(operands[0]).empty()) {
795     return emitError(unknownLoc, "duplicate name found for result <id> ")
796            << operands[0];
797   }
798   unsigned wordIndex = 1;
799   StringRef name = decodeStringLiteral(operands, wordIndex);
800   if (wordIndex != operands.size()) {
801     return emitError(unknownLoc,
802                      "unexpected trailing words in OpName instruction");
803   }
804   nameMap[operands[0]] = name;
805   return success();
806 }
807 
808 //===----------------------------------------------------------------------===//
809 // Type
810 //===----------------------------------------------------------------------===//
811 
812 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
813                                                ArrayRef<uint32_t> operands) {
814   if (operands.empty()) {
815     return emitError(unknownLoc, "type instruction with opcode ")
816            << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
817   }
818 
819   /// TODO: Types might be forward declared in some instructions and need to be
820   /// handled appropriately.
821   if (typeMap.count(operands[0])) {
822     return emitError(unknownLoc, "duplicate definition for result <id> ")
823            << operands[0];
824   }
825 
826   switch (opcode) {
827   case spirv::Opcode::OpTypeVoid:
828     if (operands.size() != 1)
829       return emitError(unknownLoc, "OpTypeVoid must have no parameters");
830     typeMap[operands[0]] = opBuilder.getNoneType();
831     break;
832   case spirv::Opcode::OpTypeBool:
833     if (operands.size() != 1)
834       return emitError(unknownLoc, "OpTypeBool must have no parameters");
835     typeMap[operands[0]] = opBuilder.getI1Type();
836     break;
837   case spirv::Opcode::OpTypeInt: {
838     if (operands.size() != 3)
839       return emitError(
840           unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
841 
842     // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
843     // to preserve or validate.
844     // 0 indicates unsigned, or no signedness semantics
845     // 1 indicates signed semantics."
846     //
847     // So we cannot differentiate signless and unsigned integers; always use
848     // signless semantics for such cases.
849     auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
850                                  : IntegerType::SignednessSemantics::Signless;
851     typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
852   } break;
853   case spirv::Opcode::OpTypeFloat: {
854     if (operands.size() != 2)
855       return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
856 
857     Type floatTy;
858     switch (operands[1]) {
859     case 16:
860       floatTy = opBuilder.getF16Type();
861       break;
862     case 32:
863       floatTy = opBuilder.getF32Type();
864       break;
865     case 64:
866       floatTy = opBuilder.getF64Type();
867       break;
868     default:
869       return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
870              << operands[1];
871     }
872     typeMap[operands[0]] = floatTy;
873   } break;
874   case spirv::Opcode::OpTypeVector: {
875     if (operands.size() != 3) {
876       return emitError(
877           unknownLoc,
878           "OpTypeVector must have element type and count parameters");
879     }
880     Type elementTy = getType(operands[1]);
881     if (!elementTy) {
882       return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
883              << operands[1];
884     }
885     typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
886   } break;
887   case spirv::Opcode::OpTypePointer: {
888     return processOpTypePointer(operands);
889   } break;
890   case spirv::Opcode::OpTypeArray:
891     return processArrayType(operands);
892   case spirv::Opcode::OpTypeCooperativeMatrixKHR:
893     return processCooperativeMatrixTypeKHR(operands);
894   case spirv::Opcode::OpTypeFunction:
895     return processFunctionType(operands);
896   case spirv::Opcode::OpTypeImage:
897     return processImageType(operands);
898   case spirv::Opcode::OpTypeSampledImage:
899     return processSampledImageType(operands);
900   case spirv::Opcode::OpTypeRuntimeArray:
901     return processRuntimeArrayType(operands);
902   case spirv::Opcode::OpTypeStruct:
903     return processStructType(operands);
904   case spirv::Opcode::OpTypeMatrix:
905     return processMatrixType(operands);
906   default:
907     return emitError(unknownLoc, "unhandled type instruction");
908   }
909   return success();
910 }
911 
912 LogicalResult
913 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
914   if (operands.size() != 3)
915     return emitError(unknownLoc, "OpTypePointer must have two parameters");
916 
917   auto pointeeType = getType(operands[2]);
918   if (!pointeeType)
919     return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
920            << operands[2];
921 
922   uint32_t typePointerID = operands[0];
923   auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
924   typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
925 
926   for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
927        deferredStructIt != std::end(deferredStructTypesInfos);) {
928     for (auto *unresolvedMemberIt =
929              std::begin(deferredStructIt->unresolvedMemberTypes);
930          unresolvedMemberIt !=
931          std::end(deferredStructIt->unresolvedMemberTypes);) {
932       if (unresolvedMemberIt->first == typePointerID) {
933         // The newly constructed pointer type can resolve one of the
934         // deferred struct type members; update the memberTypes list and
935         // clean the unresolvedMemberTypes list accordingly.
936         deferredStructIt->memberTypes[unresolvedMemberIt->second] =
937             typeMap[typePointerID];
938         unresolvedMemberIt =
939             deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
940       } else {
941         ++unresolvedMemberIt;
942       }
943     }
944 
945     if (deferredStructIt->unresolvedMemberTypes.empty()) {
946       // All deferred struct type members are now resolved, set the struct body.
947       auto structType = deferredStructIt->deferredStructType;
948 
949       assert(structType && "expected a spirv::StructType");
950       assert(structType.isIdentified() && "expected an indentified struct");
951 
952       if (failed(structType.trySetBody(
953               deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
954               deferredStructIt->memberDecorationsInfo)))
955         return failure();
956 
957       deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
958     } else {
959       ++deferredStructIt;
960     }
961   }
962 
963   return success();
964 }
965 
966 LogicalResult
967 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
968   if (operands.size() != 3) {
969     return emitError(unknownLoc,
970                      "OpTypeArray must have element type and count parameters");
971   }
972 
973   Type elementTy = getType(operands[1]);
974   if (!elementTy) {
975     return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
976            << operands[1];
977   }
978 
979   unsigned count = 0;
980   // TODO: The count can also come frome a specialization constant.
981   auto countInfo = getConstant(operands[2]);
982   if (!countInfo) {
983     return emitError(unknownLoc, "OpTypeArray count <id> ")
984            << operands[2] << "can only come from normal constant right now";
985   }
986 
987   if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
988     count = intVal.getValue().getZExtValue();
989   } else {
990     return emitError(unknownLoc, "OpTypeArray count must come from a "
991                                  "scalar integer constant instruction");
992   }
993 
994   typeMap[operands[0]] = spirv::ArrayType::get(
995       elementTy, count, typeDecorations.lookup(operands[0]));
996   return success();
997 }
998 
999 LogicalResult
1000 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
1001   assert(!operands.empty() && "No operands for processing function type");
1002   if (operands.size() == 1) {
1003     return emitError(unknownLoc, "missing return type for OpTypeFunction");
1004   }
1005   auto returnType = getType(operands[1]);
1006   if (!returnType) {
1007     return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1008   }
1009   SmallVector<Type, 1> argTypes;
1010   for (size_t i = 2, e = operands.size(); i < e; ++i) {
1011     auto ty = getType(operands[i]);
1012     if (!ty) {
1013       return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1014     }
1015     argTypes.push_back(ty);
1016   }
1017   ArrayRef<Type> returnTypes;
1018   if (!isVoidType(returnType)) {
1019     returnTypes = llvm::ArrayRef(returnType);
1020   }
1021   typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1022   return success();
1023 }
1024 
1025 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1026     ArrayRef<uint32_t> operands) {
1027   if (operands.size() != 6) {
1028     return emitError(unknownLoc,
1029                      "OpTypeCooperativeMatrixKHR must have element type, "
1030                      "scope, row and column parameters, and use");
1031   }
1032 
1033   Type elementTy = getType(operands[1]);
1034   if (!elementTy) {
1035     return emitError(unknownLoc,
1036                      "OpTypeCooperativeMatrixKHR references undefined <id> ")
1037            << operands[1];
1038   }
1039 
1040   std::optional<spirv::Scope> scope =
1041       spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1042   if (!scope) {
1043     return emitError(
1044                unknownLoc,
1045                "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1046            << operands[2];
1047   }
1048 
1049   unsigned rows = getConstantInt(operands[3]).getInt();
1050   unsigned columns = getConstantInt(operands[4]).getInt();
1051 
1052   std::optional<spirv::CooperativeMatrixUseKHR> use =
1053       spirv::symbolizeCooperativeMatrixUseKHR(
1054           getConstantInt(operands[5]).getInt());
1055   if (!use) {
1056     return emitError(
1057                unknownLoc,
1058                "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1059            << operands[5];
1060   }
1061 
1062   typeMap[operands[0]] =
1063       spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1064   return success();
1065 }
1066 
1067 LogicalResult
1068 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1069   if (operands.size() != 2) {
1070     return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1071   }
1072   Type memberType = getType(operands[1]);
1073   if (!memberType) {
1074     return emitError(unknownLoc,
1075                      "OpTypeRuntimeArray references undefined <id> ")
1076            << operands[1];
1077   }
1078   typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1079       memberType, typeDecorations.lookup(operands[0]));
1080   return success();
1081 }
1082 
1083 LogicalResult
1084 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1085   // TODO: Find a way to handle identified structs when debug info is stripped.
1086 
1087   if (operands.empty()) {
1088     return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1089   }
1090 
1091   if (operands.size() == 1) {
1092     // Handle empty struct.
1093     typeMap[operands[0]] =
1094         spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1095     return success();
1096   }
1097 
1098   // First element is operand ID, second element is member index in the struct.
1099   SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1100   SmallVector<Type, 4> memberTypes;
1101 
1102   for (auto op : llvm::drop_begin(operands, 1)) {
1103     Type memberType = getType(op);
1104     bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1105 
1106     if (!memberType && !typeForwardPtr)
1107       return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1108              << op;
1109 
1110     if (!memberType)
1111       unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1112 
1113     memberTypes.push_back(memberType);
1114   }
1115 
1116   SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
1117   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
1118   if (memberDecorationMap.count(operands[0])) {
1119     auto &allMemberDecorations = memberDecorationMap[operands[0]];
1120     for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1121       if (allMemberDecorations.count(memberIndex)) {
1122         for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1123           // Check for offset.
1124           if (memberDecoration.first == spirv::Decoration::Offset) {
1125             // If offset info is empty, resize to the number of members;
1126             if (offsetInfo.empty()) {
1127               offsetInfo.resize(memberTypes.size());
1128             }
1129             offsetInfo[memberIndex] = memberDecoration.second[0];
1130           } else {
1131             if (!memberDecoration.second.empty()) {
1132               memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1133                                                  memberDecoration.first,
1134                                                  memberDecoration.second[0]);
1135             } else {
1136               memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1137                                                  memberDecoration.first, 0);
1138             }
1139           }
1140         }
1141       }
1142     }
1143   }
1144 
1145   uint32_t structID = operands[0];
1146   std::string structIdentifier = nameMap.lookup(structID).str();
1147 
1148   if (structIdentifier.empty()) {
1149     assert(unresolvedMemberTypes.empty() &&
1150            "didn't expect unresolved member types");
1151     typeMap[structID] =
1152         spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
1153   } else {
1154     auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1155     typeMap[structID] = structTy;
1156 
1157     if (!unresolvedMemberTypes.empty())
1158       deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1159                                           memberTypes, offsetInfo,
1160                                           memberDecorationsInfo});
1161     else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1162                                         memberDecorationsInfo)))
1163       return failure();
1164   }
1165 
1166   // TODO: Update StructType to have member name as attribute as
1167   // well.
1168   return success();
1169 }
1170 
1171 LogicalResult
1172 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1173   if (operands.size() != 3) {
1174     // Three operands are needed: result_id, column_type, and column_count
1175     return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1176                                  " (result_id, column_type, and column_count)");
1177   }
1178   // Matrix columns must be of vector type
1179   Type elementTy = getType(operands[1]);
1180   if (!elementTy) {
1181     return emitError(unknownLoc,
1182                      "OpTypeMatrix references undefined column type.")
1183            << operands[1];
1184   }
1185 
1186   uint32_t colsCount = operands[2];
1187   typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1188   return success();
1189 }
1190 
1191 LogicalResult
1192 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1193   if (operands.size() != 2)
1194     return emitError(unknownLoc,
1195                      "OpTypeForwardPointer instruction must have two operands");
1196 
1197   typeForwardPointerIDs.insert(operands[0]);
1198   // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1199   // instruction that defines the actual type.
1200 
1201   return success();
1202 }
1203 
1204 LogicalResult
1205 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1206   // TODO: Add support for Access Qualifier.
1207   if (operands.size() != 8)
1208     return emitError(
1209         unknownLoc,
1210         "OpTypeImage with non-eight operands are not supported yet");
1211 
1212   Type elementTy = getType(operands[1]);
1213   if (!elementTy)
1214     return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1215            << operands[1];
1216 
1217   auto dim = spirv::symbolizeDim(operands[2]);
1218   if (!dim)
1219     return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1220            << operands[2];
1221 
1222   auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1223   if (!depthInfo)
1224     return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1225            << operands[3];
1226 
1227   auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1228   if (!arrayedInfo)
1229     return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1230            << operands[4];
1231 
1232   auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1233   if (!samplingInfo)
1234     return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1235 
1236   auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1237   if (!samplerUseInfo)
1238     return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1239            << operands[6];
1240 
1241   auto format = spirv::symbolizeImageFormat(operands[7]);
1242   if (!format)
1243     return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1244            << operands[7];
1245 
1246   typeMap[operands[0]] = spirv::ImageType::get(
1247       elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1248       samplingInfo.value(), samplerUseInfo.value(), format.value());
1249   return success();
1250 }
1251 
1252 LogicalResult
1253 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1254   if (operands.size() != 2)
1255     return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1256 
1257   Type elementTy = getType(operands[1]);
1258   if (!elementTy)
1259     return emitError(unknownLoc,
1260                      "OpTypeSampledImage references undefined <id>: ")
1261            << operands[1];
1262 
1263   typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1264   return success();
1265 }
1266 
1267 //===----------------------------------------------------------------------===//
1268 // Constant
1269 //===----------------------------------------------------------------------===//
1270 
1271 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1272                                                    bool isSpec) {
1273   StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1274 
1275   if (operands.size() < 2) {
1276     return emitError(unknownLoc)
1277            << opname << " must have type <id> and result <id>";
1278   }
1279   if (operands.size() < 3) {
1280     return emitError(unknownLoc)
1281            << opname << " must have at least 1 more parameter";
1282   }
1283 
1284   Type resultType = getType(operands[0]);
1285   if (!resultType) {
1286     return emitError(unknownLoc, "undefined result type from <id> ")
1287            << operands[0];
1288   }
1289 
1290   auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1291     if (bitwidth == 64) {
1292       if (operands.size() == 4) {
1293         return success();
1294       }
1295       return emitError(unknownLoc)
1296              << opname << " should have 2 parameters for 64-bit values";
1297     }
1298     if (bitwidth <= 32) {
1299       if (operands.size() == 3) {
1300         return success();
1301       }
1302 
1303       return emitError(unknownLoc)
1304              << opname
1305              << " should have 1 parameter for values with no more than 32 bits";
1306     }
1307     return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1308            << bitwidth;
1309   };
1310 
1311   auto resultID = operands[1];
1312 
1313   if (auto intType = dyn_cast<IntegerType>(resultType)) {
1314     auto bitwidth = intType.getWidth();
1315     if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1316       return failure();
1317     }
1318 
1319     APInt value;
1320     if (bitwidth == 64) {
1321       // 64-bit integers are represented with two SPIR-V words. According to
1322       // SPIR-V spec: "When the type’s bit width is larger than one word, the
1323       // literal’s low-order words appear first."
1324       struct DoubleWord {
1325         uint32_t word1;
1326         uint32_t word2;
1327       } words = {operands[2], operands[3]};
1328       value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1329     } else if (bitwidth <= 32) {
1330       value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1331                     /*implicitTrunc=*/true);
1332     }
1333 
1334     auto attr = opBuilder.getIntegerAttr(intType, value);
1335 
1336     if (isSpec) {
1337       createSpecConstant(unknownLoc, resultID, attr);
1338     } else {
1339       // For normal constants, we just record the attribute (and its type) for
1340       // later materialization at use sites.
1341       constantMap.try_emplace(resultID, attr, intType);
1342     }
1343 
1344     return success();
1345   }
1346 
1347   if (auto floatType = dyn_cast<FloatType>(resultType)) {
1348     auto bitwidth = floatType.getWidth();
1349     if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1350       return failure();
1351     }
1352 
1353     APFloat value(0.f);
1354     if (floatType.isF64()) {
1355       // Double values are represented with two SPIR-V words. According to
1356       // SPIR-V spec: "When the type’s bit width is larger than one word, the
1357       // literal’s low-order words appear first."
1358       struct DoubleWord {
1359         uint32_t word1;
1360         uint32_t word2;
1361       } words = {operands[2], operands[3]};
1362       value = APFloat(llvm::bit_cast<double>(words));
1363     } else if (floatType.isF32()) {
1364       value = APFloat(llvm::bit_cast<float>(operands[2]));
1365     } else if (floatType.isF16()) {
1366       APInt data(16, operands[2]);
1367       value = APFloat(APFloat::IEEEhalf(), data);
1368     }
1369 
1370     auto attr = opBuilder.getFloatAttr(floatType, value);
1371     if (isSpec) {
1372       createSpecConstant(unknownLoc, resultID, attr);
1373     } else {
1374       // For normal constants, we just record the attribute (and its type) for
1375       // later materialization at use sites.
1376       constantMap.try_emplace(resultID, attr, floatType);
1377     }
1378 
1379     return success();
1380   }
1381 
1382   return emitError(unknownLoc, "OpConstant can only generate values of "
1383                                "scalar integer or floating-point type");
1384 }
1385 
1386 LogicalResult spirv::Deserializer::processConstantBool(
1387     bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1388   if (operands.size() != 2) {
1389     return emitError(unknownLoc, "Op")
1390            << (isSpec ? "Spec" : "") << "Constant"
1391            << (isTrue ? "True" : "False")
1392            << " must have type <id> and result <id>";
1393   }
1394 
1395   auto attr = opBuilder.getBoolAttr(isTrue);
1396   auto resultID = operands[1];
1397   if (isSpec) {
1398     createSpecConstant(unknownLoc, resultID, attr);
1399   } else {
1400     // For normal constants, we just record the attribute (and its type) for
1401     // later materialization at use sites.
1402     constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1403   }
1404 
1405   return success();
1406 }
1407 
1408 LogicalResult
1409 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1410   if (operands.size() < 2) {
1411     return emitError(unknownLoc,
1412                      "OpConstantComposite must have type <id> and result <id>");
1413   }
1414   if (operands.size() < 3) {
1415     return emitError(unknownLoc,
1416                      "OpConstantComposite must have at least 1 parameter");
1417   }
1418 
1419   Type resultType = getType(operands[0]);
1420   if (!resultType) {
1421     return emitError(unknownLoc, "undefined result type from <id> ")
1422            << operands[0];
1423   }
1424 
1425   SmallVector<Attribute, 4> elements;
1426   elements.reserve(operands.size() - 2);
1427   for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1428     auto elementInfo = getConstant(operands[i]);
1429     if (!elementInfo) {
1430       return emitError(unknownLoc, "OpConstantComposite component <id> ")
1431              << operands[i] << " must come from a normal constant";
1432     }
1433     elements.push_back(elementInfo->first);
1434   }
1435 
1436   auto resultID = operands[1];
1437   if (auto vectorType = dyn_cast<VectorType>(resultType)) {
1438     auto attr = DenseElementsAttr::get(vectorType, elements);
1439     // For normal constants, we just record the attribute (and its type) for
1440     // later materialization at use sites.
1441     constantMap.try_emplace(resultID, attr, resultType);
1442   } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1443     auto attr = opBuilder.getArrayAttr(elements);
1444     constantMap.try_emplace(resultID, attr, resultType);
1445   } else {
1446     return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1447            << resultType;
1448   }
1449 
1450   return success();
1451 }
1452 
1453 LogicalResult
1454 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1455   if (operands.size() < 2) {
1456     return emitError(unknownLoc,
1457                      "OpConstantComposite must have type <id> and result <id>");
1458   }
1459   if (operands.size() < 3) {
1460     return emitError(unknownLoc,
1461                      "OpConstantComposite must have at least 1 parameter");
1462   }
1463 
1464   Type resultType = getType(operands[0]);
1465   if (!resultType) {
1466     return emitError(unknownLoc, "undefined result type from <id> ")
1467            << operands[0];
1468   }
1469 
1470   auto resultID = operands[1];
1471   auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1472 
1473   SmallVector<Attribute, 4> elements;
1474   elements.reserve(operands.size() - 2);
1475   for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1476     auto elementInfo = getSpecConstant(operands[i]);
1477     elements.push_back(SymbolRefAttr::get(elementInfo));
1478   }
1479 
1480   auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1481       unknownLoc, TypeAttr::get(resultType), symName,
1482       opBuilder.getArrayAttr(elements));
1483   specConstCompositeMap[resultID] = op;
1484 
1485   return success();
1486 }
1487 
1488 LogicalResult
1489 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1490   if (operands.size() < 3)
1491     return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1492                                  "result <id>, and operand opcode");
1493 
1494   uint32_t resultTypeID = operands[0];
1495 
1496   if (!getType(resultTypeID))
1497     return emitError(unknownLoc, "undefined result type from <id> ")
1498            << resultTypeID;
1499 
1500   uint32_t resultID = operands[1];
1501   spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1502   auto emplaceResult = specConstOperationMap.try_emplace(
1503       resultID,
1504       SpecConstOperationMaterializationInfo{
1505           enclosedOpcode, resultTypeID,
1506           SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1507 
1508   if (!emplaceResult.second)
1509     return emitError(unknownLoc, "value with <id>: ")
1510            << resultID << " is probably defined before.";
1511 
1512   return success();
1513 }
1514 
1515 Value spirv::Deserializer::materializeSpecConstantOperation(
1516     uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1517     ArrayRef<uint32_t> enclosedOpOperands) {
1518 
1519   Type resultType = getType(resultTypeID);
1520 
1521   // Instructions wrapped by OpSpecConstantOp need an ID for their
1522   // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1523   // dialect wrapped op. For that purpose, a new value map is created and "fake"
1524   // ID in that map is assigned to the result of the enclosed instruction. Note
1525   // that there is no need to update this fake ID since we only need to
1526   // reference the created Value for the enclosed op from the spv::YieldOp
1527   // created later in this method (both of which are the only values in their
1528   // region: the SpecConstantOperation's region). If we encounter another
1529   // SpecConstantOperation in the module, we simply re-use the fake ID since the
1530   // previous Value assigned to it isn't visible in the current scope anyway.
1531   DenseMap<uint32_t, Value> newValueMap;
1532   llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1533   constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1534 
1535   SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1536   enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1537   enclosedOpResultTypeAndOperands.push_back(fakeID);
1538   enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1539                                          enclosedOpOperands.end());
1540 
1541   // Process enclosed instruction before creating the enclosing
1542   // specConstantOperation (and its region). This way, references to constants,
1543   // global variables, and spec constants will be materialized outside the new
1544   // op's region. For more info, see Deserializer::getValue's implementation.
1545   if (failed(
1546           processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1547     return Value();
1548 
1549   // Since the enclosed op is emitted in the current block, split it in a
1550   // separate new block.
1551   Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1552 
1553   auto loc = createFileLineColLoc(opBuilder);
1554   auto specConstOperationOp =
1555       opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1556 
1557   Region &body = specConstOperationOp.getBody();
1558   // Move the new block into SpecConstantOperation's body.
1559   body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1560                           Region::iterator(enclosedBlock));
1561   Block &block = body.back();
1562 
1563   // RAII guard to reset the insertion point to the module's region after
1564   // deserializing the body of the specConstantOperation.
1565   OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1566   opBuilder.setInsertionPointToEnd(&block);
1567 
1568   opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1569   return specConstOperationOp.getResult();
1570 }
1571 
1572 LogicalResult
1573 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1574   if (operands.size() != 2) {
1575     return emitError(unknownLoc,
1576                      "OpConstantNull must have type <id> and result <id>");
1577   }
1578 
1579   Type resultType = getType(operands[0]);
1580   if (!resultType) {
1581     return emitError(unknownLoc, "undefined result type from <id> ")
1582            << operands[0];
1583   }
1584 
1585   auto resultID = operands[1];
1586   if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1587     auto attr = opBuilder.getZeroAttr(resultType);
1588     // For normal constants, we just record the attribute (and its type) for
1589     // later materialization at use sites.
1590     constantMap.try_emplace(resultID, attr, resultType);
1591     return success();
1592   }
1593 
1594   return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1595          << resultType;
1596 }
1597 
1598 //===----------------------------------------------------------------------===//
1599 // Control flow
1600 //===----------------------------------------------------------------------===//
1601 
1602 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1603   if (auto *block = getBlock(id)) {
1604     LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1605                                   << " @ " << block << "\n");
1606     return block;
1607   }
1608 
1609   // We don't know where this block will be placed finally (in a
1610   // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1611   // function for now and sort out the proper place later.
1612   auto *block = curFunction->addBlock();
1613   LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1614                                 << " @ " << block << "\n");
1615   return blockMap[id] = block;
1616 }
1617 
1618 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1619   if (!curBlock) {
1620     return emitError(unknownLoc, "OpBranch must appear inside a block");
1621   }
1622 
1623   if (operands.size() != 1) {
1624     return emitError(unknownLoc, "OpBranch must take exactly one target label");
1625   }
1626 
1627   auto *target = getOrCreateBlock(operands[0]);
1628   auto loc = createFileLineColLoc(opBuilder);
1629   // The preceding instruction for the OpBranch instruction could be an
1630   // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1631   // the same OpLine information.
1632   opBuilder.create<spirv::BranchOp>(loc, target);
1633 
1634   clearDebugLine();
1635   return success();
1636 }
1637 
1638 LogicalResult
1639 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1640   if (!curBlock) {
1641     return emitError(unknownLoc,
1642                      "OpBranchConditional must appear inside a block");
1643   }
1644 
1645   if (operands.size() != 3 && operands.size() != 5) {
1646     return emitError(unknownLoc,
1647                      "OpBranchConditional must have condition, true label, "
1648                      "false label, and optionally two branch weights");
1649   }
1650 
1651   auto condition = getValue(operands[0]);
1652   auto *trueBlock = getOrCreateBlock(operands[1]);
1653   auto *falseBlock = getOrCreateBlock(operands[2]);
1654 
1655   std::optional<std::pair<uint32_t, uint32_t>> weights;
1656   if (operands.size() == 5) {
1657     weights = std::make_pair(operands[3], operands[4]);
1658   }
1659   // The preceding instruction for the OpBranchConditional instruction could be
1660   // an OpSelectionMerge instruction, in this case they will have the same
1661   // OpLine information.
1662   auto loc = createFileLineColLoc(opBuilder);
1663   opBuilder.create<spirv::BranchConditionalOp>(
1664       loc, condition, trueBlock,
1665       /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1666       /*falseArguments=*/ArrayRef<Value>(), weights);
1667 
1668   clearDebugLine();
1669   return success();
1670 }
1671 
1672 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1673   if (!curFunction) {
1674     return emitError(unknownLoc, "OpLabel must appear inside a function");
1675   }
1676 
1677   if (operands.size() != 1) {
1678     return emitError(unknownLoc, "OpLabel should only have result <id>");
1679   }
1680 
1681   auto labelID = operands[0];
1682   // We may have forward declared this block.
1683   auto *block = getOrCreateBlock(labelID);
1684   LLVM_DEBUG(logger.startLine()
1685              << "[block] populating block " << block << "\n");
1686   // If we have seen this block, make sure it was just a forward declaration.
1687   assert(block->empty() && "re-deserialize the same block!");
1688 
1689   opBuilder.setInsertionPointToStart(block);
1690   blockMap[labelID] = curBlock = block;
1691 
1692   return success();
1693 }
1694 
1695 LogicalResult
1696 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1697   if (!curBlock) {
1698     return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1699   }
1700 
1701   if (operands.size() < 2) {
1702     return emitError(
1703         unknownLoc,
1704         "OpSelectionMerge must specify merge target and selection control");
1705   }
1706 
1707   auto *mergeBlock = getOrCreateBlock(operands[0]);
1708   auto loc = createFileLineColLoc(opBuilder);
1709   auto selectionControl = operands[1];
1710 
1711   if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1712            .second) {
1713     return emitError(
1714         unknownLoc,
1715         "a block cannot have more than one OpSelectionMerge instruction");
1716   }
1717 
1718   return success();
1719 }
1720 
1721 LogicalResult
1722 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1723   if (!curBlock) {
1724     return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1725   }
1726 
1727   if (operands.size() < 3) {
1728     return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1729                                  "continue target and loop control");
1730   }
1731 
1732   auto *mergeBlock = getOrCreateBlock(operands[0]);
1733   auto *continueBlock = getOrCreateBlock(operands[1]);
1734   auto loc = createFileLineColLoc(opBuilder);
1735   uint32_t loopControl = operands[2];
1736 
1737   if (!blockMergeInfo
1738            .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1739            .second) {
1740     return emitError(
1741         unknownLoc,
1742         "a block cannot have more than one OpLoopMerge instruction");
1743   }
1744 
1745   return success();
1746 }
1747 
1748 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1749   if (!curBlock) {
1750     return emitError(unknownLoc, "OpPhi must appear in a block");
1751   }
1752 
1753   if (operands.size() < 4) {
1754     return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1755                                  "and variable-parent pairs");
1756   }
1757 
1758   // Create a block argument for this OpPhi instruction.
1759   Type blockArgType = getType(operands[0]);
1760   BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1761   valueMap[operands[1]] = blockArg;
1762   LLVM_DEBUG(logger.startLine()
1763              << "[phi] created block argument " << blockArg
1764              << " id = " << operands[1] << " of type " << blockArgType << "\n");
1765 
1766   // For each (value, predecessor) pair, insert the value to the predecessor's
1767   // blockPhiInfo entry so later we can fix the block argument there.
1768   for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1769     uint32_t value = operands[i];
1770     Block *predecessor = getOrCreateBlock(operands[i + 1]);
1771     std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1772     blockPhiInfo[predecessorTargetPair].push_back(value);
1773     LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1774                                   << " with arg id = " << value << "\n");
1775   }
1776 
1777   return success();
1778 }
1779 
1780 namespace {
1781 /// A class for putting all blocks in a structured selection/loop in a
1782 /// spirv.mlir.selection/spirv.mlir.loop op.
1783 class ControlFlowStructurizer {
1784 public:
1785 #ifndef NDEBUG
1786   ControlFlowStructurizer(Location loc, uint32_t control,
1787                           spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1788                           Block *merge, Block *cont,
1789                           llvm::ScopedPrinter &logger)
1790       : location(loc), control(control), blockMergeInfo(mergeInfo),
1791         headerBlock(header), mergeBlock(merge), continueBlock(cont),
1792         logger(logger) {}
1793 #else
1794   ControlFlowStructurizer(Location loc, uint32_t control,
1795                           spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1796                           Block *merge, Block *cont)
1797       : location(loc), control(control), blockMergeInfo(mergeInfo),
1798         headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1799 #endif
1800 
1801   /// Structurizes the loop at the given `headerBlock`.
1802   ///
1803   /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1804   /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1805   /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1806   /// method will also update `mergeInfo` by remapping all blocks inside to the
1807   /// newly cloned ones inside structured control flow op's regions.
1808   LogicalResult structurize();
1809 
1810 private:
1811   /// Creates a new spirv.mlir.selection op at the beginning of the
1812   /// `mergeBlock`.
1813   spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1814 
1815   /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1816   spirv::LoopOp createLoopOp(uint32_t loopControl);
1817 
1818   /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1819   void collectBlocksInConstruct();
1820 
1821   Location location;
1822   uint32_t control;
1823 
1824   spirv::BlockMergeInfoMap &blockMergeInfo;
1825 
1826   Block *headerBlock;
1827   Block *mergeBlock;
1828   Block *continueBlock; // nullptr for spirv.mlir.selection
1829 
1830   SetVector<Block *> constructBlocks;
1831 
1832 #ifndef NDEBUG
1833   /// A logger used to emit information during the deserialzation process.
1834   llvm::ScopedPrinter &logger;
1835 #endif
1836 };
1837 } // namespace
1838 
1839 spirv::SelectionOp
1840 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1841   // Create a builder and set the insertion point to the beginning of the
1842   // merge block so that the newly created SelectionOp will be inserted there.
1843   OpBuilder builder(&mergeBlock->front());
1844 
1845   auto control = static_cast<spirv::SelectionControl>(selectionControl);
1846   auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1847   selectionOp.addMergeBlock(builder);
1848 
1849   return selectionOp;
1850 }
1851 
1852 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1853   // Create a builder and set the insertion point to the beginning of the
1854   // merge block so that the newly created LoopOp will be inserted there.
1855   OpBuilder builder(&mergeBlock->front());
1856 
1857   auto control = static_cast<spirv::LoopControl>(loopControl);
1858   auto loopOp = builder.create<spirv::LoopOp>(location, control);
1859   loopOp.addEntryAndMergeBlock(builder);
1860 
1861   return loopOp;
1862 }
1863 
1864 void ControlFlowStructurizer::collectBlocksInConstruct() {
1865   assert(constructBlocks.empty() && "expected empty constructBlocks");
1866 
1867   // Put the header block in the work list first.
1868   constructBlocks.insert(headerBlock);
1869 
1870   // For each item in the work list, add its successors excluding the merge
1871   // block.
1872   for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1873     for (auto *successor : constructBlocks[i]->getSuccessors())
1874       if (successor != mergeBlock)
1875         constructBlocks.insert(successor);
1876   }
1877 }
1878 
1879 LogicalResult ControlFlowStructurizer::structurize() {
1880   Operation *op = nullptr;
1881   bool isLoop = continueBlock != nullptr;
1882   if (isLoop) {
1883     if (auto loopOp = createLoopOp(control))
1884       op = loopOp.getOperation();
1885   } else {
1886     if (auto selectionOp = createSelectionOp(control))
1887       op = selectionOp.getOperation();
1888   }
1889   if (!op)
1890     return failure();
1891   Region &body = op->getRegion(0);
1892 
1893   IRMapping mapper;
1894   // All references to the old merge block should be directed to the
1895   // selection/loop merge block in the SelectionOp/LoopOp's region.
1896   mapper.map(mergeBlock, &body.back());
1897 
1898   collectBlocksInConstruct();
1899 
1900   // We've identified all blocks belonging to the selection/loop's region. Now
1901   // need to "move" them into the selection/loop. Instead of really moving the
1902   // blocks, in the following we copy them and remap all values and branches.
1903   // This is because:
1904   // * Inserting a block into a region requires the block not in any region
1905   //   before. But selections/loops can nest so we can create selection/loop ops
1906   //   in a nested manner, which means some blocks may already be in a
1907   //   selection/loop region when to be moved again.
1908   // * It's much trickier to fix up the branches into and out of the loop's
1909   //   region: we need to treat not-moved blocks and moved blocks differently:
1910   //   Not-moved blocks jumping to the loop header block need to jump to the
1911   //   merge point containing the new loop op but not the loop continue block's
1912   //   back edge. Moved blocks jumping out of the loop need to jump to the
1913   //   merge block inside the loop region but not other not-moved blocks.
1914   //   We cannot use replaceAllUsesWith clearly and it's harder to follow the
1915   //   logic.
1916 
1917   // Create a corresponding block in the SelectionOp/LoopOp's region for each
1918   // block in this loop construct.
1919   OpBuilder builder(body);
1920   for (auto *block : constructBlocks) {
1921     // Create a block and insert it before the selection/loop merge block in the
1922     // SelectionOp/LoopOp's region.
1923     auto *newBlock = builder.createBlock(&body.back());
1924     mapper.map(block, newBlock);
1925     LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1926                                   << " from block " << block << "\n");
1927     if (!isFnEntryBlock(block)) {
1928       for (BlockArgument blockArg : block->getArguments()) {
1929         auto newArg =
1930             newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1931         mapper.map(blockArg, newArg);
1932         LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1933                                       << blockArg << " to " << newArg << "\n");
1934       }
1935     } else {
1936       LLVM_DEBUG(logger.startLine()
1937                  << "[cf] block " << block << " is a function entry block\n");
1938     }
1939 
1940     for (auto &op : *block)
1941       newBlock->push_back(op.clone(mapper));
1942   }
1943 
1944   // Go through all ops and remap the operands.
1945   auto remapOperands = [&](Operation *op) {
1946     for (auto &operand : op->getOpOperands())
1947       if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1948         operand.set(mappedOp);
1949     for (auto &succOp : op->getBlockOperands())
1950       if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1951         succOp.set(mappedOp);
1952   };
1953   for (auto &block : body)
1954     block.walk(remapOperands);
1955 
1956   // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1957   // the selection/loop construct into its region. Next we need to fix the
1958   // connections between this new SelectionOp/LoopOp with existing blocks.
1959 
1960   // All existing incoming branches should go to the merge block, where the
1961   // SelectionOp/LoopOp resides right now.
1962   headerBlock->replaceAllUsesWith(mergeBlock);
1963 
1964   LLVM_DEBUG({
1965     logger.startLine() << "[cf] after cloning and fixing references:\n";
1966     headerBlock->getParentOp()->print(logger.getOStream());
1967     logger.startLine() << "\n";
1968   });
1969 
1970   if (isLoop) {
1971     if (!mergeBlock->args_empty()) {
1972       return mergeBlock->getParentOp()->emitError(
1973           "OpPhi in loop merge block unsupported");
1974     }
1975 
1976     // The loop header block may have block arguments. Since now we place the
1977     // loop op inside the old merge block, we need to make sure the old merge
1978     // block has the same block argument list.
1979     for (BlockArgument blockArg : headerBlock->getArguments())
1980       mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1981 
1982     // If the loop header block has block arguments, make sure the spirv.Branch
1983     // op matches.
1984     SmallVector<Value, 4> blockArgs;
1985     if (!headerBlock->args_empty())
1986       blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1987 
1988     // The loop entry block should have a unconditional branch jumping to the
1989     // loop header block.
1990     builder.setInsertionPointToEnd(&body.front());
1991     builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1992                                     ArrayRef<Value>(blockArgs));
1993   }
1994 
1995   // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1996   // cleaned up.
1997   LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1998   // First we need to drop all operands' references inside all blocks. This is
1999   // needed because we can have blocks referencing SSA values from one another.
2000   for (auto *block : constructBlocks)
2001     block->dropAllReferences();
2002 
2003   // Check that whether some op in the to-be-erased blocks still has uses. Those
2004   // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2005   // region. We cannot handle such cases given that once a value is sinked into
2006   // the SelectionOp/LoopOp's region, there is no escape for it:
2007   // SelectionOp/LooOp does not support yield values right now.
2008   for (auto *block : constructBlocks) {
2009     for (Operation &op : *block)
2010       if (!op.use_empty())
2011         return op.emitOpError(
2012             "failed control flow structurization: it has uses outside of the "
2013             "enclosing selection/loop construct");
2014   }
2015 
2016   // Then erase all old blocks.
2017   for (auto *block : constructBlocks) {
2018     // We've cloned all blocks belonging to this construct into the structured
2019     // control flow op's region. Among these blocks, some may compose another
2020     // selection/loop. If so, they will be recorded within blockMergeInfo.
2021     // We need to update the pointers there to the newly remapped ones so we can
2022     // continue structurizing them later.
2023     // TODO: The asserts in the following assumes input SPIR-V blob forms
2024     // correctly nested selection/loop constructs. We should relax this and
2025     // support error cases better.
2026     auto it = blockMergeInfo.find(block);
2027     if (it != blockMergeInfo.end()) {
2028       // Use the original location for nested selection/loop ops.
2029       Location loc = it->second.loc;
2030 
2031       Block *newHeader = mapper.lookupOrNull(block);
2032       if (!newHeader)
2033         return emitError(loc, "failed control flow structurization: nested "
2034                               "loop header block should be remapped!");
2035 
2036       Block *newContinue = it->second.continueBlock;
2037       if (newContinue) {
2038         newContinue = mapper.lookupOrNull(newContinue);
2039         if (!newContinue)
2040           return emitError(loc, "failed control flow structurization: nested "
2041                                 "loop continue block should be remapped!");
2042       }
2043 
2044       Block *newMerge = it->second.mergeBlock;
2045       if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2046         newMerge = mappedTo;
2047 
2048       // The iterator should be erased before adding a new entry into
2049       // blockMergeInfo to avoid iterator invalidation.
2050       blockMergeInfo.erase(it);
2051       blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2052                                  newContinue);
2053     }
2054 
2055     // The structured selection/loop's entry block does not have arguments.
2056     // If the function's header block is also part of the structured control
2057     // flow, we cannot just simply erase it because it may contain arguments
2058     // matching the function signature and used by the cloned blocks.
2059     if (isFnEntryBlock(block)) {
2060       LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2061                                     << " to only contain a spirv.Branch op\n");
2062       // Still keep the function entry block for the potential block arguments,
2063       // but replace all ops inside with a branch to the merge block.
2064       block->clear();
2065       builder.setInsertionPointToEnd(block);
2066       builder.create<spirv::BranchOp>(location, mergeBlock);
2067     } else {
2068       LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2069       block->erase();
2070     }
2071   }
2072 
2073   LLVM_DEBUG(logger.startLine()
2074              << "[cf] after structurizing construct with header block "
2075              << headerBlock << ":\n"
2076              << *op << "\n");
2077 
2078   return success();
2079 }
2080 
2081 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2082   LLVM_DEBUG({
2083     logger.startLine()
2084         << "//----- [phi] start wiring up block arguments -----//\n";
2085     logger.indent();
2086   });
2087 
2088   OpBuilder::InsertionGuard guard(opBuilder);
2089 
2090   for (const auto &info : blockPhiInfo) {
2091     Block *block = info.first.first;
2092     Block *target = info.first.second;
2093     const BlockPhiInfo &phiInfo = info.second;
2094     LLVM_DEBUG({
2095       logger.startLine() << "[phi] block " << block << "\n";
2096       logger.startLine() << "[phi] before creating block argument:\n";
2097       block->getParentOp()->print(logger.getOStream());
2098       logger.startLine() << "\n";
2099     });
2100 
2101     // Set insertion point to before this block's terminator early because we
2102     // may materialize ops via getValue() call.
2103     auto *op = block->getTerminator();
2104     opBuilder.setInsertionPoint(op);
2105 
2106     SmallVector<Value, 4> blockArgs;
2107     blockArgs.reserve(phiInfo.size());
2108     for (uint32_t valueId : phiInfo) {
2109       if (Value value = getValue(valueId)) {
2110         blockArgs.push_back(value);
2111         LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2112                                       << " id = " << valueId << "\n");
2113       } else {
2114         return emitError(unknownLoc, "OpPhi references undefined value!");
2115       }
2116     }
2117 
2118     if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2119       // Replace the previous branch op with a new one with block arguments.
2120       opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2121                                         blockArgs);
2122       branchOp.erase();
2123     } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2124       assert((branchCondOp.getTrueBlock() == target ||
2125               branchCondOp.getFalseBlock() == target) &&
2126              "expected target to be either the true or false target");
2127       if (target == branchCondOp.getTrueTarget())
2128         opBuilder.create<spirv::BranchConditionalOp>(
2129             branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2130             branchCondOp.getFalseBlockArguments(),
2131             branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2132             branchCondOp.getFalseTarget());
2133       else
2134         opBuilder.create<spirv::BranchConditionalOp>(
2135             branchCondOp.getLoc(), branchCondOp.getCondition(),
2136             branchCondOp.getTrueBlockArguments(), blockArgs,
2137             branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2138             branchCondOp.getFalseBlock());
2139 
2140       branchCondOp.erase();
2141     } else {
2142       return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2143     }
2144 
2145     LLVM_DEBUG({
2146       logger.startLine() << "[phi] after creating block argument:\n";
2147       block->getParentOp()->print(logger.getOStream());
2148       logger.startLine() << "\n";
2149     });
2150   }
2151   blockPhiInfo.clear();
2152 
2153   LLVM_DEBUG({
2154     logger.unindent();
2155     logger.startLine()
2156         << "//--- [phi] completed wiring up block arguments ---//\n";
2157   });
2158   return success();
2159 }
2160 
2161 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2162   LLVM_DEBUG({
2163     logger.startLine()
2164         << "//----- [cf] start structurizing control flow -----//\n";
2165     logger.indent();
2166   });
2167 
2168   while (!blockMergeInfo.empty()) {
2169     Block *headerBlock = blockMergeInfo.begin()->first;
2170     BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2171 
2172     LLVM_DEBUG({
2173       logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2174       headerBlock->print(logger.getOStream());
2175       logger.startLine() << "\n";
2176     });
2177 
2178     auto *mergeBlock = mergeInfo.mergeBlock;
2179     assert(mergeBlock && "merge block cannot be nullptr");
2180     if (!mergeBlock->args_empty())
2181       return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2182     LLVM_DEBUG({
2183       logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2184       mergeBlock->print(logger.getOStream());
2185       logger.startLine() << "\n";
2186     });
2187 
2188     auto *continueBlock = mergeInfo.continueBlock;
2189     LLVM_DEBUG(if (continueBlock) {
2190       logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2191       continueBlock->print(logger.getOStream());
2192       logger.startLine() << "\n";
2193     });
2194     // Erase this case before calling into structurizer, who will update
2195     // blockMergeInfo.
2196     blockMergeInfo.erase(blockMergeInfo.begin());
2197     ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2198                                          blockMergeInfo, headerBlock,
2199                                          mergeBlock, continueBlock
2200 #ifndef NDEBUG
2201                                          ,
2202                                          logger
2203 #endif
2204     );
2205     if (failed(structurizer.structurize()))
2206       return failure();
2207   }
2208 
2209   LLVM_DEBUG({
2210     logger.unindent();
2211     logger.startLine()
2212         << "//--- [cf] completed structurizing control flow ---//\n";
2213   });
2214   return success();
2215 }
2216 
2217 //===----------------------------------------------------------------------===//
2218 // Debug
2219 //===----------------------------------------------------------------------===//
2220 
2221 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2222   if (!debugLine)
2223     return unknownLoc;
2224 
2225   auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2226   if (fileName.empty())
2227     fileName = "<unknown>";
2228   return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2229                              debugLine->column);
2230 }
2231 
2232 LogicalResult
2233 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2234   // According to SPIR-V spec:
2235   // "This location information applies to the instructions physically
2236   // following this instruction, up to the first occurrence of any of the
2237   // following: the next end of block, the next OpLine instruction, or the next
2238   // OpNoLine instruction."
2239   if (operands.size() != 3)
2240     return emitError(unknownLoc, "OpLine must have 3 operands");
2241   debugLine = DebugLine{operands[0], operands[1], operands[2]};
2242   return success();
2243 }
2244 
2245 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2246 
2247 LogicalResult
2248 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2249   if (operands.size() < 2)
2250     return emitError(unknownLoc, "OpString needs at least 2 operands");
2251 
2252   if (!debugInfoMap.lookup(operands[0]).empty())
2253     return emitError(unknownLoc,
2254                      "duplicate debug string found for result <id> ")
2255            << operands[0];
2256 
2257   unsigned wordIndex = 1;
2258   StringRef debugString = decodeStringLiteral(operands, wordIndex);
2259   if (wordIndex != operands.size())
2260     return emitError(unknownLoc,
2261                      "unexpected trailing words in OpString instruction");
2262 
2263   debugInfoMap[operands[0]] = debugString;
2264   return success();
2265 }
2266