xref: /llvm-project/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp (revision 4f4cd963a6e820b50514706a1a3faed3a05779a2)
1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIR-V module combiner library.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/SymbolTable.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringMap.h"
26 
27 using namespace mlir;
28 
29 static constexpr unsigned maxFreeID = 1 << 20;
30 
31 /// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
32 /// suffix in `lastUsedID`.
renameSymbol(StringRef oldSymName,unsigned & lastUsedID,spirv::ModuleOp module)33 static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
34                                spirv::ModuleOp module) {
35   SmallString<64> newSymName(oldSymName);
36   newSymName.push_back('_');
37 
38   MLIRContext *ctx = module->getContext();
39 
40   while (lastUsedID < maxFreeID) {
41     auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
42     if (!SymbolTable::lookupSymbolIn(module, possible))
43       return possible;
44   }
45 
46   return StringAttr::get(ctx, newSymName);
47 }
48 
49 /// Checks if a symbol with the same name as `op` already exists in `source`.
50 /// If so, renames `op` and updates all its references in `target`.
updateSymbolAndAllUses(SymbolOpInterface op,spirv::ModuleOp target,spirv::ModuleOp source,unsigned & lastUsedID)51 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
52                                             spirv::ModuleOp target,
53                                             spirv::ModuleOp source,
54                                             unsigned &lastUsedID) {
55   if (!SymbolTable::lookupSymbolIn(source, op.getName()))
56     return success();
57 
58   StringRef oldSymName = op.getName();
59   StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
60 
61   if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
62     return op.emitError("unable to update all symbol uses for ")
63            << oldSymName << " to " << newSymName;
64 
65   SymbolTable::setSymbolName(op, newSymName);
66   return success();
67 }
68 
69 /// Computes a hash code to represent `symbolOp` based on all its attributes
70 /// except for the symbol name.
71 ///
72 /// Note: We use the operation's name (not the symbol name) as part of the hash
73 /// computation. This prevents, for example, mistakenly considering a global
74 /// variable and a spec constant as duplicates because their descriptor set +
75 /// binding and spec_id, respectively, happen to hash to the same value.
computeHash(SymbolOpInterface symbolOp)76 static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
77   auto range =
78       llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
79         return attr.getName() != SymbolTable::getSymbolAttrName();
80       });
81 
82   return llvm::hash_combine(
83       symbolOp->getName(),
84       llvm::hash_combine_range(range.begin(), range.end()));
85 }
86 
87 namespace mlir {
88 namespace spirv {
89 
combine(ArrayRef<spirv::ModuleOp> inputModules,OpBuilder & combinedModuleBuilder,SymbolRenameListener symRenameListener)90 OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
91                                      OpBuilder &combinedModuleBuilder,
92                                      SymbolRenameListener symRenameListener) {
93   if (inputModules.empty())
94     return nullptr;
95 
96   spirv::ModuleOp firstModule = inputModules.front();
97   auto addressingModel = firstModule.getAddressingModel();
98   auto memoryModel = firstModule.getMemoryModel();
99   auto vceTriple = firstModule.getVceTriple();
100 
101   // First check whether there are conflicts between addressing/memory model.
102   // Return early if so.
103   for (auto module : inputModules) {
104     if (module.getAddressingModel() != addressingModel ||
105         module.getMemoryModel() != memoryModel ||
106         module.getVceTriple() != vceTriple) {
107       module.emitError("input modules differ in addressing model, memory "
108                        "model, and/or VCE triple");
109       return nullptr;
110     }
111   }
112 
113   auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
114       firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
115   combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
116 
117   // In some cases, a symbol in the (current state of the) combined module is
118   // renamed in order to enable the conflicting symbol in the input module
119   // being merged. For example, if the conflict is between a global variable in
120   // the current combined module and a function in the input module, the global
121   // variable is renamed. In order to notify listeners of the symbol updates in
122   // such cases, we need to keep track of the module from which the renamed
123   // symbol in the combined module originated. This map keeps such information.
124   llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
125 
126   unsigned lastUsedID = 0;
127 
128   for (auto inputModule : inputModules) {
129     OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone();
130 
131     // In the combined module, rename all symbols that conflict with symbols
132     // from the current input module. This renaming applies to all ops except
133     // for spirv.funcs. This way, if the conflicting op in the input module is
134     // non-spirv.func, we rename that symbol instead and maintain the spirv.func
135     // in the combined module name as it is.
136     for (auto &op : *combinedModule.getBody()) {
137       auto symbolOp = dyn_cast<SymbolOpInterface>(op);
138       if (!symbolOp)
139         continue;
140 
141       StringRef oldSymName = symbolOp.getName();
142 
143       if (!isa<FuncOp>(op) &&
144           failed(updateSymbolAndAllUses(symbolOp, combinedModule, *moduleClone,
145                                         lastUsedID)))
146         return nullptr;
147 
148       StringRef newSymName = symbolOp.getName();
149 
150       if (symRenameListener && oldSymName != newSymName) {
151         spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
152 
153         if (!originalModule) {
154           inputModule.emitError(
155               "unable to find original spirv::ModuleOp for symbol ")
156               << oldSymName;
157           return nullptr;
158         }
159 
160         symRenameListener(originalModule, oldSymName, newSymName);
161 
162         // Since the symbol name is updated, there is no need to maintain the
163         // entry that associates the old symbol name with the original module.
164         symNameToModuleMap.erase(oldSymName);
165         // Instead, add a new entry to map the new symbol name to the original
166         // module in case it gets renamed again later.
167         symNameToModuleMap[newSymName] = originalModule;
168       }
169     }
170 
171     // In the current input module, rename all symbols that conflict with
172     // symbols from the combined module. This includes renaming spirv.funcs.
173     for (auto &op : *moduleClone->getBody()) {
174       auto symbolOp = dyn_cast<SymbolOpInterface>(op);
175       if (!symbolOp)
176         continue;
177 
178       StringRef oldSymName = symbolOp.getName();
179 
180       if (failed(updateSymbolAndAllUses(symbolOp, *moduleClone, combinedModule,
181                                         lastUsedID)))
182         return nullptr;
183 
184       StringRef newSymName = symbolOp.getName();
185 
186       if (symRenameListener) {
187         if (oldSymName != newSymName)
188           symRenameListener(inputModule, oldSymName, newSymName);
189 
190         // Insert the module associated with the symbol name.
191         auto emplaceResult =
192             symNameToModuleMap.try_emplace(newSymName, inputModule);
193 
194         // If an entry with the same symbol name is already present, this must
195         // be a problem with the implementation, specially clean-up of the map
196         // while iterating over the combined module above.
197         if (!emplaceResult.second) {
198           inputModule.emitError("did not expect to find an entry for symbol ")
199               << symbolOp.getName();
200           return nullptr;
201         }
202       }
203     }
204 
205     // Clone all the module's ops to the combined module.
206     for (auto &op : *moduleClone->getBody())
207       combinedModuleBuilder.insert(op.clone());
208   }
209 
210   // Deduplicate identical global variables, spec constants, and functions.
211   DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
212   SmallVector<SymbolOpInterface, 0> eraseList;
213 
214   for (auto &op : *combinedModule.getBody()) {
215     SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
216     if (!symbolOp)
217       continue;
218 
219     // Do not support ops with operands or results.
220     // Global variables, spec constants, and functions won't have
221     // operands/results, but just for safety here.
222     if (op.getNumOperands() != 0 || op.getNumResults() != 0)
223       continue;
224 
225     // Deduplicating functions are not supported yet.
226     if (isa<FuncOp>(op))
227       continue;
228 
229     auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
230     if (result.second)
231       continue;
232 
233     SymbolOpInterface replacementSymOp = result.first->second;
234 
235     if (failed(SymbolTable::replaceAllSymbolUses(
236             symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
237       symbolOp.emitError("unable to update all symbol uses for ")
238           << symbolOp.getName() << " to " << replacementSymOp.getName();
239       return nullptr;
240     }
241 
242     eraseList.push_back(symbolOp);
243   }
244 
245   for (auto symbolOp : eraseList)
246     symbolOp.erase();
247 
248   return combinedModule;
249 }
250 
251 } // namespace spirv
252 } // namespace mlir
253