xref: /llvm-project/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp (revision 341f3c1120dfa8879e5f714a07fc8b16c8887a7f)
1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the the SPIR-V module combiner library.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/ModuleCombiner.h"
14 
15 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/SymbolTable.h"
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/Hashing.h"
21 #include "llvm/ADT/StringExtras.h"
22 
23 using namespace mlir;
24 
25 static constexpr unsigned maxFreeID = 1 << 20;
26 
27 static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
28                                     spirv::ModuleOp combinedModule) {
29   SmallString<64> newSymName(oldSymName);
30   newSymName.push_back('_');
31 
32   while (lastUsedID < maxFreeID) {
33     std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
34 
35     if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
36       newSymName += llvm::utostr(lastUsedID);
37       break;
38     }
39   }
40 
41   return newSymName;
42 }
43 
44 /// Check if a symbol with the same name as op already exists in source. If so,
45 /// rename op and update all its references in target.
46 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
47                                             spirv::ModuleOp target,
48                                             spirv::ModuleOp source,
49                                             unsigned &lastUsedID) {
50   if (!SymbolTable::lookupSymbolIn(source, op.getName()))
51     return success();
52 
53   StringRef oldSymName = op.getName();
54   SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);
55 
56   if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
57     return op.emitError("unable to update all symbol uses for ")
58            << oldSymName << " to " << newSymName;
59 
60   SymbolTable::setSymbolName(op, newSymName);
61   return success();
62 }
63 
64 template <typename KeyTy, typename SymbolOpTy>
65 static SymbolOpTy
66 emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp,
67                               DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) {
68   auto result = deduplicationMap.try_emplace(key, symbolOp);
69 
70   if (result.second)
71     return SymbolOpTy();
72 
73   return result.first->second;
74 }
75 
76 /// Computes a hash code to represent the argument SymbolOpInterface based on
77 /// all the Op's attributes except for the symbol name.
78 ///
79 /// \return the hash code computed from the Op's attributes as described above.
80 ///
81 /// Note: We use the operation's name (not the symbol name) as part of the hash
82 /// computation. This prevents, for example, mistakenly considering a global
83 /// variable and a spec constant as duplicates because their descriptor set +
84 /// binding and spec_id, repectively, happen to hash to the same value.
85 static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
86   llvm::hash_code hashCode(0);
87   hashCode = llvm::hash_combine(symbolOp.getOperation()->getName());
88 
89   for (auto attr : symbolOp.getOperation()->getAttrs()) {
90     if (attr.first == SymbolTable::getSymbolAttrName())
91       continue;
92     hashCode = llvm::hash_combine(hashCode, attr);
93   }
94 
95   return hashCode;
96 }
97 
98 /// Computes a hash code from the argument Block.
99 llvm::hash_code computeHash(Block *block) {
100   // TODO: Consider extracting BlockEquivalenceData into a common header and
101   // re-using it here.
102   llvm::hash_code hash(0);
103 
104   for (Operation &op : *block) {
105     // TODO: Properly handle operations with regions.
106     if (op.getNumRegions() > 0)
107       return 0;
108 
109     hash = llvm::hash_combine(
110         hash, OperationEquivalence::computeHash(
111                   &op, OperationEquivalence::Flags::IgnoreOperands));
112   }
113 
114   return hash;
115 }
116 
117 namespace mlir {
118 namespace spirv {
119 
120 // TODO Properly test symbol rename listener mechanism.
121 
122 OwningSPIRVModuleRef
123 combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
124         OpBuilder &combinedModuleBuilder,
125         llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
126             symRenameListener) {
127   unsigned lastUsedID = 0;
128 
129   if (modules.empty())
130     return nullptr;
131 
132   auto addressingModel = modules[0].addressing_model();
133   auto memoryModel = modules[0].memory_model();
134 
135   auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
136       modules[0].getLoc(), addressingModel, memoryModel);
137   combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
138 
139   // In some cases, a symbol in the (current state of the) combined module is
140   // renamed in order to maintain the conflicting symbol in the input module
141   // being merged. For example, if the conflict is between a global variable in
142   // the current combined module and a function in the input module, the global
143   // varaible is renamed. In order to notify listeners of the symbol updates in
144   // such cases, we need to keep track of the module from which the renamed
145   // symbol in the combined module originated. This map keeps such information.
146   DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
147 
148   for (auto module : modules) {
149     if (module.addressing_model() != addressingModel ||
150         module.memory_model() != memoryModel) {
151       module.emitError(
152           "input modules differ in addressing model and/or memory model");
153       return nullptr;
154     }
155 
156     spirv::ModuleOp moduleClone = module.clone();
157 
158     // In the combined module, rename all symbols that conflict with symbols
159     // from the current input module. This renmaing applies to all ops except
160     // for spv.funcs. This way, if the conflicting op in the input module is
161     // non-spv.func, we rename that symbol instead and maintain the spv.func in
162     // the combined module name as it is.
163     for (auto &op : combinedModule.getBlock().without_terminator()) {
164       if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
165         StringRef oldSymName = symbolOp.getName();
166 
167         if (!isa<FuncOp>(op) &&
168             failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
169                                           lastUsedID)))
170           return nullptr;
171 
172         StringRef newSymName = symbolOp.getName();
173 
174         if (symRenameListener && oldSymName != newSymName) {
175           spirv::ModuleOp originalModule =
176               symNameToModuleMap.lookup(oldSymName);
177 
178           if (!originalModule) {
179             module.emitError("unable to find original ModuleOp for symbol ")
180                 << oldSymName;
181             return nullptr;
182           }
183 
184           symRenameListener(originalModule, oldSymName, newSymName);
185 
186           // Since the symbol name is updated, there is no need to maintain the
187           // entry that assocaites the old symbol name with the original module.
188           symNameToModuleMap.erase(oldSymName);
189           // Instead, add a new entry to map the new symbol name to the original
190           // module in case it gets renamed again later.
191           symNameToModuleMap[newSymName] = originalModule;
192         }
193       }
194     }
195 
196     // In the current input module, rename all symbols that conflict with
197     // symbols from the combined module. This includes renaming spv.funcs.
198     for (auto &op : moduleClone.getBlock().without_terminator()) {
199       if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
200         StringRef oldSymName = symbolOp.getName();
201 
202         if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
203                                           lastUsedID)))
204           return nullptr;
205 
206         StringRef newSymName = symbolOp.getName();
207 
208         if (symRenameListener && oldSymName != newSymName) {
209           symRenameListener(module, oldSymName, newSymName);
210 
211           // Insert the module associated with the symbol name.
212           auto emplaceResult =
213               symNameToModuleMap.try_emplace(symbolOp.getName(), module);
214 
215           // If an entry with the same symbol name is already present, this must
216           // be a problem with the implementation, specially clean-up of the map
217           // while iterating over the combined module above.
218           if (!emplaceResult.second) {
219             module.emitError("did not expect to find an entry for symbol ")
220                 << symbolOp.getName();
221             return nullptr;
222           }
223         }
224       }
225     }
226 
227     // Clone all the module's ops to the combined module.
228     for (auto &op : moduleClone.getBlock().without_terminator())
229       combinedModuleBuilder.insert(op.clone());
230   }
231 
232   // Deduplicate identical global variables, spec constants, and functions.
233   DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
234   SmallVector<SymbolOpInterface, 0> eraseList;
235 
236   for (auto &op : combinedModule.getBlock().without_terminator()) {
237     llvm::hash_code hashCode(0);
238     SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
239 
240     if (!symbolOp)
241       continue;
242 
243     hashCode = computeHash(symbolOp);
244 
245     // A 0 hash code means the op is not suitable for deduplication and should
246     // be skipped. An example of this is when a function has ops with regions
247     // which are not properly supported yet.
248     if (!hashCode)
249       continue;
250 
251     if (auto funcOp = dyn_cast<FuncOp>(op))
252       for (auto &blk : funcOp)
253         hashCode = llvm::hash_combine(hashCode, computeHash(&blk));
254 
255     SymbolOpInterface replacementSymOp =
256         emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp);
257 
258     if (!replacementSymOp)
259       continue;
260 
261     if (failed(SymbolTable::replaceAllSymbolUses(
262             symbolOp, replacementSymOp.getName(), combinedModule))) {
263       symbolOp.emitError("unable to update all symbol uses for ")
264           << symbolOp.getName() << " to " << replacementSymOp.getName();
265       return nullptr;
266     }
267 
268     eraseList.push_back(symbolOp);
269   }
270 
271   for (auto symbolOp : eraseList)
272     symbolOp.erase();
273 
274   return combinedModule;
275 }
276 
277 } // namespace spirv
278 } // namespace mlir
279