xref: /llvm-project/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (revision ead0a9777f8ccb5c26d50d96bade6cd5b47f496b)
1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the serialization methods for MLIR SPIR-V module ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/IR/RegionGraphTraits.h"
18 #include "mlir/Support/LogicalResult.h"
19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20 #include "llvm/ADT/DepthFirstIterator.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "spirv-serialization"
25 
26 using namespace mlir;
27 
28 /// A pre-order depth-first visitor function for processing basic blocks.
29 ///
30 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
31 /// depth-first manner and calls `blockHandler` on each block. Skips handling
32 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
33 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
34 /// successors.
35 ///
36 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
37 /// of blocks in a function must satisfy the rule that blocks appear before
38 /// all blocks they dominate." This can be achieved by a pre-order CFG
39 /// traversal algorithm. To make the serialization output more logical and
40 /// readable to human, we perform depth-first CFG traversal and delay the
41 /// serialization of the merge block and the continue block, if exists, until
42 /// after all other blocks have been processed.
43 static LogicalResult
44 visitInPrettyBlockOrder(Block *headerBlock,
45                         function_ref<LogicalResult(Block *)> blockHandler,
46                         bool skipHeader = false, BlockRange skipBlocks = {}) {
47   llvm::df_iterator_default_set<Block *, 4> doneBlocks;
48   doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
49 
50   for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
51     if (skipHeader && block == headerBlock)
52       continue;
53     if (failed(blockHandler(block)))
54       return failure();
55   }
56   return success();
57 }
58 
59 namespace mlir {
60 namespace spirv {
61 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
62   if (auto resultID =
63           prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
64     valueIDMap[op.getResult()] = resultID;
65     return success();
66   }
67   return failure();
68 }
69 
70 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
71   if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
72                                             /*isSpec=*/true)) {
73     // Emit the OpDecorate instruction for SpecId.
74     if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
75       auto val = static_cast<uint32_t>(specID.getInt());
76       if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
77         return failure();
78     }
79 
80     specConstIDMap[op.getSymName()] = resultID;
81     return processName(resultID, op.getSymName());
82   }
83   return failure();
84 }
85 
86 LogicalResult
87 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
88   uint32_t typeID = 0;
89   if (failed(processType(op.getLoc(), op.getType(), typeID))) {
90     return failure();
91   }
92 
93   auto resultID = getNextID();
94 
95   SmallVector<uint32_t, 8> operands;
96   operands.push_back(typeID);
97   operands.push_back(resultID);
98 
99   auto constituents = op.getConstituents();
100 
101   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
102     auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
103 
104     auto constituentName = constituent.getValue();
105     auto constituentID = getSpecConstID(constituentName);
106 
107     if (!constituentID) {
108       return op.emitError("unknown result <id> for specialization constant ")
109              << constituentName;
110     }
111 
112     operands.push_back(constituentID);
113   }
114 
115   encodeInstructionInto(typesGlobalValues,
116                         spirv::Opcode::OpSpecConstantComposite, operands);
117   specConstIDMap[op.getSymName()] = resultID;
118 
119   return processName(resultID, op.getSymName());
120 }
121 
122 LogicalResult
123 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
124   uint32_t typeID = 0;
125   if (failed(processType(op.getLoc(), op.getType(), typeID))) {
126     return failure();
127   }
128 
129   auto resultID = getNextID();
130 
131   SmallVector<uint32_t, 8> operands;
132   operands.push_back(typeID);
133   operands.push_back(resultID);
134 
135   Block &block = op.getRegion().getBlocks().front();
136   Operation &enclosedOp = block.getOperations().front();
137 
138   std::string enclosedOpName;
139   llvm::raw_string_ostream rss(enclosedOpName);
140   rss << "Op" << enclosedOp.getName().stripDialect();
141   auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
142 
143   if (!enclosedOpcode) {
144     op.emitError("Couldn't find op code for op ")
145         << enclosedOp.getName().getStringRef();
146     return failure();
147   }
148 
149   operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
150 
151   // Append operands to the enclosed op to the list of operands.
152   for (Value operand : enclosedOp.getOperands()) {
153     uint32_t id = getValueID(operand);
154     assert(id && "use before def!");
155     operands.push_back(id);
156   }
157 
158   encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
159                         operands);
160   valueIDMap[op.getResult()] = resultID;
161 
162   return success();
163 }
164 
165 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
166   auto undefType = op.getType();
167   auto &id = undefValIDMap[undefType];
168   if (!id) {
169     id = getNextID();
170     uint32_t typeID = 0;
171     if (failed(processType(op.getLoc(), undefType, typeID)))
172       return failure();
173     encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
174                           {typeID, id});
175   }
176   valueIDMap[op.getResult()] = id;
177   return success();
178 }
179 
180 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
181   for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
182     uint32_t argTypeID = 0;
183     if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
184       return failure();
185     }
186     auto argValueID = getNextID();
187 
188     // Process decoration attributes of arguments.
189     auto funcOp = cast<FunctionOpInterface>(*op);
190     for (auto argAttr : funcOp.getArgAttrs(idx)) {
191       if (argAttr.getName() != DecorationAttr::name)
192         continue;
193 
194       if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
195         if (failed(processDecorationAttr(op->getLoc(), argValueID,
196                                          decAttr.getValue(), decAttr)))
197           return failure();
198       }
199     }
200 
201     valueIDMap[arg] = argValueID;
202     encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
203                           {argTypeID, argValueID});
204   }
205   return success();
206 }
207 
208 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
209   LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
210   assert(functionHeader.empty() && functionBody.empty());
211 
212   uint32_t fnTypeID = 0;
213   // Generate type of the function.
214   if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
215     return failure();
216 
217   // Add the function definition.
218   SmallVector<uint32_t, 4> operands;
219   uint32_t resTypeID = 0;
220   auto resultTypes = op.getFunctionType().getResults();
221   if (resultTypes.size() > 1) {
222     return op.emitError("cannot serialize function with multiple return types");
223   }
224   if (failed(processType(op.getLoc(),
225                          (resultTypes.empty() ? getVoidType() : resultTypes[0]),
226                          resTypeID))) {
227     return failure();
228   }
229   operands.push_back(resTypeID);
230   auto funcID = getOrCreateFunctionID(op.getName());
231   operands.push_back(funcID);
232   operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
233   operands.push_back(fnTypeID);
234   encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
235 
236   // Add function name.
237   if (failed(processName(funcID, op.getName()))) {
238     return failure();
239   }
240   // Handle external functions with linkage_attributes(LinkageAttributes)
241   // differently.
242   auto linkageAttr = op.getLinkageAttributes();
243   auto hasImportLinkage =
244       linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
245                       spirv::LinkageType::Import);
246   if (op.isExternal() && !hasImportLinkage) {
247     return op.emitError(
248         "'spirv.module' cannot contain external functions "
249         "without 'Import' linkage_attributes (LinkageAttributes)");
250   }
251   if (op.isExternal() && hasImportLinkage) {
252     // Add an entry block to set up the block arguments
253     // to match the signature of the function.
254     // This is to generate OpFunctionParameter for functions with
255     // LinkageAttributes.
256     // WARNING: This operation has side-effect, it essentially adds a body
257     // to the func. Hence, making it not external anymore (isExternal()
258     // is going to return false for this function from now on)
259     // Hence, we'll remove the body once we are done with the serialization.
260     op.addEntryBlock();
261     if (failed(processFuncParameter(op)))
262       return failure();
263     // Don't need to process the added block, there is nothing to process,
264     // the fake body was added just to get the arguments, remove the body,
265     // since it's use is done.
266     op.eraseBody();
267   } else {
268     if (failed(processFuncParameter(op)))
269       return failure();
270 
271     // Some instructions (e.g., OpVariable) in a function must be in the first
272     // block in the function. These instructions will be put in
273     // functionHeader. Thus, we put the label in functionHeader first, and
274     // omit it from the first block. OpLabel only needs to be added for
275     // functions with body (including empty body). Since, we added a fake body
276     // for functions with 'Import' Linkage attributes, these functions are
277     // essentially function delcaration, so they should not have OpLabel and a
278     // terminating instruction. That's why we skipped it for those functions.
279     encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
280                           {getOrCreateBlockID(&op.front())});
281     if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
282       return failure();
283     if (failed(visitInPrettyBlockOrder(
284             &op.front(), [&](Block *block) { return processBlock(block); },
285             /*skipHeader=*/true))) {
286       return failure();
287     }
288 
289     // There might be OpPhi instructions who have value references needing to
290     // fix.
291     for (const auto &deferredValue : deferredPhiValues) {
292       Value value = deferredValue.first;
293       uint32_t id = getValueID(value);
294       LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
295                               << " to id = " << id << '\n');
296       assert(id && "OpPhi references undefined value!");
297       for (size_t offset : deferredValue.second)
298         functionBody[offset] = id;
299     }
300     deferredPhiValues.clear();
301   }
302   LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
303                           << "' --\n");
304   // Insert Decorations based on Function Attributes.
305   // Only attributes we should be considering for decoration are the
306   // ::mlir::spirv::Decoration attributes.
307 
308   for (auto attr : op->getAttrs()) {
309     // Only generate OpDecorate op for spirv::Decoration attributes.
310     auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
311         llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
312                                           /*capitalizeFirst=*/true));
313     if (isValidDecoration != std::nullopt) {
314       if (failed(processDecoration(op.getLoc(), funcID, attr))) {
315         return failure();
316       }
317     }
318   }
319   // Insert OpFunctionEnd.
320   encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
321 
322   functions.append(functionHeader.begin(), functionHeader.end());
323   functions.append(functionBody.begin(), functionBody.end());
324   functionHeader.clear();
325   functionBody.clear();
326 
327   return success();
328 }
329 
330 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
331   SmallVector<uint32_t, 4> operands;
332   SmallVector<StringRef, 2> elidedAttrs;
333   uint32_t resultID = 0;
334   uint32_t resultTypeID = 0;
335   if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
336     return failure();
337   }
338   operands.push_back(resultTypeID);
339   resultID = getNextID();
340   valueIDMap[op.getResult()] = resultID;
341   operands.push_back(resultID);
342   auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
343   if (attr) {
344     operands.push_back(
345         static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
346   }
347   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
348   for (auto arg : op.getODSOperands(0)) {
349     auto argID = getValueID(arg);
350     if (!argID) {
351       return emitError(op.getLoc(), "operand 0 has a use before def");
352     }
353     operands.push_back(argID);
354   }
355   if (failed(emitDebugLine(functionHeader, op.getLoc())))
356     return failure();
357   encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
358   for (auto attr : op->getAttrs()) {
359     if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
360           return attr.getName() == elided;
361         })) {
362       continue;
363     }
364     if (failed(processDecoration(op.getLoc(), resultID, attr))) {
365       return failure();
366     }
367   }
368   return success();
369 }
370 
371 LogicalResult
372 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
373   // Get TypeID.
374   uint32_t resultTypeID = 0;
375   SmallVector<StringRef, 4> elidedAttrs;
376   if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
377     return failure();
378   }
379 
380   elidedAttrs.push_back("type");
381   SmallVector<uint32_t, 4> operands;
382   operands.push_back(resultTypeID);
383   auto resultID = getNextID();
384 
385   // Encode the name.
386   auto varName = varOp.getSymName();
387   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
388   if (failed(processName(resultID, varName))) {
389     return failure();
390   }
391   globalVarIDMap[varName] = resultID;
392   operands.push_back(resultID);
393 
394   // Encode StorageClass.
395   operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
396 
397   // Encode initialization.
398   StringRef initAttrName = varOp.getInitializerAttrName().getValue();
399   if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
400     uint32_t initializerID = 0;
401     auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
402     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
403         varOp->getParentOp(), initRef.getAttr());
404 
405     // Check if initializer is GlobalVariable or SpecConstant* cases.
406     if (isa<spirv::GlobalVariableOp>(initOp))
407       initializerID = getVariableID(*initSymbolName);
408     else
409       initializerID = getSpecConstID(*initSymbolName);
410 
411     if (!initializerID)
412       return emitError(varOp.getLoc(),
413                        "invalid usage of undefined variable as initializer");
414 
415     operands.push_back(initializerID);
416     elidedAttrs.push_back(initAttrName);
417   }
418 
419   if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
420     return failure();
421   encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
422   elidedAttrs.push_back(initAttrName);
423 
424   // Encode decorations.
425   for (auto attr : varOp->getAttrs()) {
426     if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
427           return attr.getName() == elided;
428         })) {
429       continue;
430     }
431     if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
432       return failure();
433     }
434   }
435   return success();
436 }
437 
438 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
439   // Assign <id>s to all blocks so that branches inside the SelectionOp can
440   // resolve properly.
441   auto &body = selectionOp.getBody();
442   for (Block &block : body)
443     getOrCreateBlockID(&block);
444 
445   auto *headerBlock = selectionOp.getHeaderBlock();
446   auto *mergeBlock = selectionOp.getMergeBlock();
447   auto headerID = getBlockID(headerBlock);
448   auto mergeID = getBlockID(mergeBlock);
449   auto loc = selectionOp.getLoc();
450 
451   // This SelectionOp is in some MLIR block with preceding and following ops. In
452   // the binary format, it should reside in separate SPIR-V blocks from its
453   // preceding and following ops. So we need to emit unconditional branches to
454   // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
455   // flow afterwards.
456   encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
457 
458   // Emit the selection header block, which dominates all other blocks, first.
459   // We need to emit an OpSelectionMerge instruction before the selection header
460   // block's terminator.
461   auto emitSelectionMerge = [&]() {
462     if (failed(emitDebugLine(functionBody, loc)))
463       return failure();
464     lastProcessedWasMergeInst = true;
465     encodeInstructionInto(
466         functionBody, spirv::Opcode::OpSelectionMerge,
467         {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
468     return success();
469   };
470   if (failed(
471           processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge)))
472     return failure();
473 
474   // Process all blocks with a depth-first visitor starting from the header
475   // block. The selection header block and merge block are skipped by this
476   // visitor.
477   if (failed(visitInPrettyBlockOrder(
478           headerBlock, [&](Block *block) { return processBlock(block); },
479           /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
480     return failure();
481 
482   // There is nothing to do for the merge block in the selection, which just
483   // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
484   // instruction to start a new SPIR-V block for ops following this SelectionOp.
485   // The block should use the <id> for the merge block.
486   encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
487   LLVM_DEBUG(llvm::dbgs() << "done merge ");
488   LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
489   LLVM_DEBUG(llvm::dbgs() << "\n");
490   return success();
491 }
492 
493 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
494   // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
495   // properly. We don't need to assign for the entry block, which is just for
496   // satisfying MLIR region's structural requirement.
497   auto &body = loopOp.getBody();
498   for (Block &block : llvm::drop_begin(body))
499     getOrCreateBlockID(&block);
500 
501   auto *headerBlock = loopOp.getHeaderBlock();
502   auto *continueBlock = loopOp.getContinueBlock();
503   auto *mergeBlock = loopOp.getMergeBlock();
504   auto headerID = getBlockID(headerBlock);
505   auto continueID = getBlockID(continueBlock);
506   auto mergeID = getBlockID(mergeBlock);
507   auto loc = loopOp.getLoc();
508 
509   // This LoopOp is in some MLIR block with preceding and following ops. In the
510   // binary format, it should reside in separate SPIR-V blocks from its
511   // preceding and following ops. So we need to emit unconditional branches to
512   // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
513   // afterwards.
514   encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
515 
516   // LoopOp's entry block is just there for satisfying MLIR's structural
517   // requirements so we omit it and start serialization from the loop header
518   // block.
519 
520   // Emit the loop header block, which dominates all other blocks, first. We
521   // need to emit an OpLoopMerge instruction before the loop header block's
522   // terminator.
523   auto emitLoopMerge = [&]() {
524     if (failed(emitDebugLine(functionBody, loc)))
525       return failure();
526     lastProcessedWasMergeInst = true;
527     encodeInstructionInto(
528         functionBody, spirv::Opcode::OpLoopMerge,
529         {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
530     return success();
531   };
532   if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
533     return failure();
534 
535   // Process all blocks with a depth-first visitor starting from the header
536   // block. The loop header block, loop continue block, and loop merge block are
537   // skipped by this visitor and handled later in this function.
538   if (failed(visitInPrettyBlockOrder(
539           headerBlock, [&](Block *block) { return processBlock(block); },
540           /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
541     return failure();
542 
543   // We have handled all other blocks. Now get to the loop continue block.
544   if (failed(processBlock(continueBlock)))
545     return failure();
546 
547   // There is nothing to do for the merge block in the loop, which just contains
548   // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
549   // to start a new SPIR-V block for ops following this LoopOp. The block should
550   // use the <id> for the merge block.
551   encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
552   LLVM_DEBUG(llvm::dbgs() << "done merge ");
553   LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
554   LLVM_DEBUG(llvm::dbgs() << "\n");
555   return success();
556 }
557 
558 LogicalResult Serializer::processBranchConditionalOp(
559     spirv::BranchConditionalOp condBranchOp) {
560   auto conditionID = getValueID(condBranchOp.getCondition());
561   auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
562   auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
563   SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
564 
565   if (auto weights = condBranchOp.getBranchWeights()) {
566     for (auto val : weights->getValue())
567       arguments.push_back(cast<IntegerAttr>(val).getInt());
568   }
569 
570   if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
571     return failure();
572   encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
573                         arguments);
574   return success();
575 }
576 
577 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
578   if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
579     return failure();
580   encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
581                         {getOrCreateBlockID(branchOp.getTarget())});
582   return success();
583 }
584 
585 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
586   auto varName = addressOfOp.getVariable();
587   auto variableID = getVariableID(varName);
588   if (!variableID) {
589     return addressOfOp.emitError("unknown result <id> for variable ")
590            << varName;
591   }
592   valueIDMap[addressOfOp.getPointer()] = variableID;
593   return success();
594 }
595 
596 LogicalResult
597 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
598   auto constName = referenceOfOp.getSpecConst();
599   auto constID = getSpecConstID(constName);
600   if (!constID) {
601     return referenceOfOp.emitError(
602                "unknown result <id> for specialization constant ")
603            << constName;
604   }
605   valueIDMap[referenceOfOp.getReference()] = constID;
606   return success();
607 }
608 
609 template <>
610 LogicalResult
611 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
612   SmallVector<uint32_t, 4> operands;
613   // Add the ExecutionModel.
614   operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
615   // Add the function <id>.
616   auto funcID = getFunctionID(op.getFn());
617   if (!funcID) {
618     return op.emitError("missing <id> for function ")
619            << op.getFn()
620            << "; function needs to be defined before spirv.EntryPoint is "
621               "serialized";
622   }
623   operands.push_back(funcID);
624   // Add the name of the function.
625   spirv::encodeStringLiteralInto(operands, op.getFn());
626 
627   // Add the interface values.
628   if (auto interface = op.getInterface()) {
629     for (auto var : interface.getValue()) {
630       auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
631       if (!id) {
632         return op.emitError(
633             "referencing undefined global variable."
634             "spirv.EntryPoint is at the end of spirv.module. All "
635             "referenced variables should already be defined");
636       }
637       operands.push_back(id);
638     }
639   }
640   encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
641   return success();
642 }
643 
644 template <>
645 LogicalResult
646 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
647   SmallVector<uint32_t, 4> operands;
648   // Add the function <id>.
649   auto funcID = getFunctionID(op.getFn());
650   if (!funcID) {
651     return op.emitError("missing <id> for function ")
652            << op.getFn()
653            << "; function needs to be serialized before ExecutionModeOp is "
654               "serialized";
655   }
656   operands.push_back(funcID);
657   // Add the ExecutionMode.
658   operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
659 
660   // Serialize values if any.
661   auto values = op.getValues();
662   if (values) {
663     for (auto &intVal : values.getValue()) {
664       operands.push_back(static_cast<uint32_t>(
665           llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
666     }
667   }
668   encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
669                         operands);
670   return success();
671 }
672 
673 template <>
674 LogicalResult
675 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
676   auto funcName = op.getCallee();
677   uint32_t resTypeID = 0;
678 
679   Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
680   if (failed(processType(op.getLoc(), resultTy, resTypeID)))
681     return failure();
682 
683   auto funcID = getOrCreateFunctionID(funcName);
684   auto funcCallID = getNextID();
685   SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
686 
687   for (auto value : op.getArguments()) {
688     auto valueID = getValueID(value);
689     assert(valueID && "cannot find a value for spirv.FunctionCall");
690     operands.push_back(valueID);
691   }
692 
693   if (!isa<NoneType>(resultTy))
694     valueIDMap[op.getResult(0)] = funcCallID;
695 
696   encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
697   return success();
698 }
699 
700 template <>
701 LogicalResult
702 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
703   SmallVector<uint32_t, 4> operands;
704   SmallVector<StringRef, 2> elidedAttrs;
705 
706   for (Value operand : op->getOperands()) {
707     auto id = getValueID(operand);
708     assert(id && "use before def!");
709     operands.push_back(id);
710   }
711 
712   StringAttr memoryAccess = op.getMemoryAccessAttrName();
713   if (auto attr = op->getAttr(memoryAccess)) {
714     operands.push_back(
715         static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
716   }
717 
718   elidedAttrs.push_back(memoryAccess.strref());
719 
720   StringAttr alignment = op.getAlignmentAttrName();
721   if (auto attr = op->getAttr(alignment)) {
722     operands.push_back(static_cast<uint32_t>(
723         cast<IntegerAttr>(attr).getValue().getZExtValue()));
724   }
725 
726   elidedAttrs.push_back(alignment.strref());
727 
728   StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
729   if (auto attr = op->getAttr(sourceMemoryAccess)) {
730     operands.push_back(
731         static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
732   }
733 
734   elidedAttrs.push_back(sourceMemoryAccess.strref());
735 
736   StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
737   if (auto attr = op->getAttr(sourceAlignment)) {
738     operands.push_back(static_cast<uint32_t>(
739         cast<IntegerAttr>(attr).getValue().getZExtValue()));
740   }
741 
742   elidedAttrs.push_back(sourceAlignment.strref());
743   if (failed(emitDebugLine(functionBody, op.getLoc())))
744     return failure();
745   encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
746 
747   return success();
748 }
749 template <>
750 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
751     spirv::GenericCastToPtrExplicitOp op) {
752   SmallVector<uint32_t, 4> operands;
753   Type resultTy;
754   Location loc = op->getLoc();
755   uint32_t resultTypeID = 0;
756   uint32_t resultID = 0;
757   resultTy = op->getResult(0).getType();
758   if (failed(processType(loc, resultTy, resultTypeID)))
759     return failure();
760   operands.push_back(resultTypeID);
761 
762   resultID = getNextID();
763   operands.push_back(resultID);
764   valueIDMap[op->getResult(0)] = resultID;
765 
766   for (Value operand : op->getOperands())
767     operands.push_back(getValueID(operand));
768   spirv::StorageClass resultStorage =
769       cast<spirv::PointerType>(resultTy).getStorageClass();
770   operands.push_back(static_cast<uint32_t>(resultStorage));
771   encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
772                         operands);
773   return success();
774 }
775 
776 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
777 // various Serializer::processOp<...>() specializations.
778 #define GET_SERIALIZATION_FNS
779 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
780 
781 } // namespace spirv
782 } // namespace mlir
783