xref: /llvm-project/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp (revision 23326b9f1723a398681def87c80e608fa94485f2)
1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIR-V module combiner library.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/SymbolTable.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringMap.h"
26 
27 using namespace mlir;
28 
29 static constexpr unsigned maxFreeID = 1 << 20;
30 
31 /// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric
32 /// suffix in `lastUsedID`.
33 static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
34                                     spirv::ModuleOp module) {
35   SmallString<64> newSymName(oldSymName);
36   newSymName.push_back('_');
37 
38   while (lastUsedID < maxFreeID) {
39     std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
40 
41     if (!SymbolTable::lookupSymbolIn(module, possible)) {
42       newSymName += llvm::utostr(lastUsedID);
43       break;
44     }
45   }
46 
47   return newSymName;
48 }
49 
50 /// Checks if a symbol with the same name as `op` already exists in `source`.
51 /// If so, renames `op` and updates all its references in `target`.
52 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
53                                             spirv::ModuleOp target,
54                                             spirv::ModuleOp source,
55                                             unsigned &lastUsedID) {
56   if (!SymbolTable::lookupSymbolIn(source, op.getName()))
57     return success();
58 
59   StringRef oldSymName = op.getName();
60   SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);
61 
62   if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
63     return op.emitError("unable to update all symbol uses for ")
64            << oldSymName << " to " << newSymName;
65 
66   SymbolTable::setSymbolName(op, newSymName);
67   return success();
68 }
69 
70 /// Computes a hash code to represent `symbolOp` based on all its attributes
71 /// except for the symbol name.
72 ///
73 /// Note: We use the operation's name (not the symbol name) as part of the hash
74 /// computation. This prevents, for example, mistakenly considering a global
75 /// variable and a spec constant as duplicates because their descriptor set +
76 /// binding and spec_id, respectively, happen to hash to the same value.
77 static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
78   auto range =
79       llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
80         return attr.first != SymbolTable::getSymbolAttrName();
81       });
82 
83   return llvm::hash_combine(
84       symbolOp->getName(),
85       llvm::hash_combine_range(range.begin(), range.end()));
86 }
87 
88 namespace mlir {
89 namespace spirv {
90 
91 OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
92                                      OpBuilder &combinedModuleBuilder,
93                                      SymbolRenameListener symRenameListener) {
94   if (inputModules.empty())
95     return nullptr;
96 
97   spirv::ModuleOp firstModule = inputModules.front();
98   auto addressingModel = firstModule.addressing_model();
99   auto memoryModel = firstModule.memory_model();
100   auto vceTriple = firstModule.vce_triple();
101 
102   // First check whether there are conflicts between addressing/memory model.
103   // Return early if so.
104   for (auto module : inputModules) {
105     if (module.addressing_model() != addressingModel ||
106         module.memory_model() != memoryModel ||
107         module.vce_triple() != vceTriple) {
108       module.emitError("input modules differ in addressing model, memory "
109                        "model, and/or VCE triple");
110       return nullptr;
111     }
112   }
113 
114   auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
115       firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
116   combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
117 
118   // In some cases, a symbol in the (current state of the) combined module is
119   // renamed in order to enable the conflicting symbol in the input module
120   // being merged. For example, if the conflict is between a global variable in
121   // the current combined module and a function in the input module, the global
122   // variable is renamed. In order to notify listeners of the symbol updates in
123   // such cases, we need to keep track of the module from which the renamed
124   // symbol in the combined module originated. This map keeps such information.
125   llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
126 
127   unsigned lastUsedID = 0;
128 
129   for (auto inputModule : inputModules) {
130     spirv::ModuleOp moduleClone = inputModule.clone();
131 
132     // In the combined module, rename all symbols that conflict with symbols
133     // from the current input module. This renaming applies to all ops except
134     // for spv.funcs. This way, if the conflicting op in the input module is
135     // non-spv.func, we rename that symbol instead and maintain the spv.func in
136     // the combined module name as it is.
137     for (auto &op : *combinedModule.getBody()) {
138       auto symbolOp = dyn_cast<SymbolOpInterface>(op);
139       if (!symbolOp)
140         continue;
141 
142       StringRef oldSymName = symbolOp.getName();
143 
144       if (!isa<FuncOp>(op) &&
145           failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
146                                         lastUsedID)))
147         return nullptr;
148 
149       StringRef newSymName = symbolOp.getName();
150 
151       if (symRenameListener && oldSymName != newSymName) {
152         spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
153 
154         if (!originalModule) {
155           inputModule.emitError(
156               "unable to find original spirv::ModuleOp for symbol ")
157               << oldSymName;
158           return nullptr;
159         }
160 
161         symRenameListener(originalModule, oldSymName, newSymName);
162 
163         // Since the symbol name is updated, there is no need to maintain the
164         // entry that associates the old symbol name with the original module.
165         symNameToModuleMap.erase(oldSymName);
166         // Instead, add a new entry to map the new symbol name to the original
167         // module in case it gets renamed again later.
168         symNameToModuleMap[newSymName] = originalModule;
169       }
170     }
171 
172     // In the current input module, rename all symbols that conflict with
173     // symbols from the combined module. This includes renaming spv.funcs.
174     for (auto &op : *moduleClone.getBody()) {
175       auto symbolOp = dyn_cast<SymbolOpInterface>(op);
176       if (!symbolOp)
177         continue;
178 
179       StringRef oldSymName = symbolOp.getName();
180 
181       if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
182                                         lastUsedID)))
183         return nullptr;
184 
185       StringRef newSymName = symbolOp.getName();
186 
187       if (symRenameListener) {
188         if (oldSymName != newSymName)
189           symRenameListener(inputModule, oldSymName, newSymName);
190 
191         // Insert the module associated with the symbol name.
192         auto emplaceResult =
193             symNameToModuleMap.try_emplace(newSymName, inputModule);
194 
195         // If an entry with the same symbol name is already present, this must
196         // be a problem with the implementation, specially clean-up of the map
197         // while iterating over the combined module above.
198         if (!emplaceResult.second) {
199           inputModule.emitError("did not expect to find an entry for symbol ")
200               << symbolOp.getName();
201           return nullptr;
202         }
203       }
204     }
205 
206     // Clone all the module's ops to the combined module.
207     for (auto &op : *moduleClone.getBody())
208       combinedModuleBuilder.insert(op.clone());
209   }
210 
211   // Deduplicate identical global variables, spec constants, and functions.
212   DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
213   SmallVector<SymbolOpInterface, 0> eraseList;
214 
215   for (auto &op : *combinedModule.getBody()) {
216     SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
217     if (!symbolOp)
218       continue;
219 
220     // Do not support ops with operands or results.
221     // Global variables, spec constants, and functions won't have
222     // operands/results, but just for safety here.
223     if (op.getNumOperands() != 0 || op.getNumResults() != 0)
224       continue;
225 
226     // Deduplicating functions are not supported yet.
227     if (isa<FuncOp>(op))
228       continue;
229 
230     auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
231     if (result.second)
232       continue;
233 
234     SymbolOpInterface replacementSymOp = result.first->second;
235 
236     if (failed(SymbolTable::replaceAllSymbolUses(
237             symbolOp, replacementSymOp.getName(), combinedModule))) {
238       symbolOp.emitError("unable to update all symbol uses for ")
239           << symbolOp.getName() << " to " << replacementSymOp.getName();
240       return nullptr;
241     }
242 
243     eraseList.push_back(symbolOp);
244   }
245 
246   for (auto symbolOp : eraseList)
247     symbolOp.erase();
248 
249   return combinedModule;
250 }
251 
252 } // namespace spirv
253 } // namespace mlir
254