xref: /llvm-project/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp (revision 4f4cd963a6e820b50514706a1a3faed3a05779a2)
190a8260cSergawy //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
290a8260cSergawy //
390a8260cSergawy // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
490a8260cSergawy // See https://llvm.org/LICENSE.txt for license information.
590a8260cSergawy // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
690a8260cSergawy //
790a8260cSergawy //===----------------------------------------------------------------------===//
890a8260cSergawy //
9f88fab50SKazuaki Ishizaki // This file implements the SPIR-V module combiner library.
1090a8260cSergawy //
1190a8260cSergawy //===----------------------------------------------------------------------===//
1290a8260cSergawy 
1301178654SLei Zhang #include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
1490a8260cSergawy 
1523326b9fSLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1701178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1823326b9fSLei Zhang #include "mlir/IR/Attributes.h"
1990a8260cSergawy #include "mlir/IR/Builders.h"
2090a8260cSergawy #include "mlir/IR/SymbolTable.h"
2190a8260cSergawy #include "llvm/ADT/ArrayRef.h"
22341f3c11Sergawy #include "llvm/ADT/Hashing.h"
2323326b9fSLei Zhang #include "llvm/ADT/STLExtras.h"
2490a8260cSergawy #include "llvm/ADT/StringExtras.h"
2523326b9fSLei Zhang #include "llvm/ADT/StringMap.h"
2690a8260cSergawy 
2790a8260cSergawy using namespace mlir;
2890a8260cSergawy 
2990a8260cSergawy static constexpr unsigned maxFreeID = 1 << 20;
3090a8260cSergawy 
31*4f4cd963SJakub Kuderski /// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
3223326b9fSLei Zhang /// suffix in `lastUsedID`.
renameSymbol(StringRef oldSymName,unsigned & lastUsedID,spirv::ModuleOp module)3341d4aa7dSChris Lattner static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
3423326b9fSLei Zhang                                spirv::ModuleOp module) {
3590a8260cSergawy   SmallString<64> newSymName(oldSymName);
3690a8260cSergawy   newSymName.push_back('_');
3790a8260cSergawy 
3841d4aa7dSChris Lattner   MLIRContext *ctx = module->getContext();
3941d4aa7dSChris Lattner 
4090a8260cSergawy   while (lastUsedID < maxFreeID) {
4141d4aa7dSChris Lattner     auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
4241d4aa7dSChris Lattner     if (!SymbolTable::lookupSymbolIn(module, possible))
4341d4aa7dSChris Lattner       return possible;
4490a8260cSergawy   }
4590a8260cSergawy 
4641d4aa7dSChris Lattner   return StringAttr::get(ctx, newSymName);
4790a8260cSergawy }
4890a8260cSergawy 
4923326b9fSLei Zhang /// Checks if a symbol with the same name as `op` already exists in `source`.
5023326b9fSLei Zhang /// If so, renames `op` and updates all its references in `target`.
updateSymbolAndAllUses(SymbolOpInterface op,spirv::ModuleOp target,spirv::ModuleOp source,unsigned & lastUsedID)5190a8260cSergawy static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
5290a8260cSergawy                                             spirv::ModuleOp target,
5390a8260cSergawy                                             spirv::ModuleOp source,
5490a8260cSergawy                                             unsigned &lastUsedID) {
5590a8260cSergawy   if (!SymbolTable::lookupSymbolIn(source, op.getName()))
5690a8260cSergawy     return success();
5790a8260cSergawy 
5890a8260cSergawy   StringRef oldSymName = op.getName();
5941d4aa7dSChris Lattner   StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
6090a8260cSergawy 
6190a8260cSergawy   if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
6290a8260cSergawy     return op.emitError("unable to update all symbol uses for ")
6390a8260cSergawy            << oldSymName << " to " << newSymName;
6490a8260cSergawy 
6590a8260cSergawy   SymbolTable::setSymbolName(op, newSymName);
6690a8260cSergawy   return success();
6790a8260cSergawy }
6890a8260cSergawy 
6923326b9fSLei Zhang /// Computes a hash code to represent `symbolOp` based on all its attributes
7023326b9fSLei Zhang /// except for the symbol name.
71341f3c11Sergawy ///
72341f3c11Sergawy /// Note: We use the operation's name (not the symbol name) as part of the hash
73341f3c11Sergawy /// computation. This prevents, for example, mistakenly considering a global
74341f3c11Sergawy /// variable and a spec constant as duplicates because their descriptor set +
75f88fab50SKazuaki Ishizaki /// binding and spec_id, respectively, happen to hash to the same value.
computeHash(SymbolOpInterface symbolOp)76341f3c11Sergawy static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
7723326b9fSLei Zhang   auto range =
7823326b9fSLei Zhang       llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
790c7890c8SRiver Riddle         return attr.getName() != SymbolTable::getSymbolAttrName();
8023326b9fSLei Zhang       });
81341f3c11Sergawy 
8223326b9fSLei Zhang   return llvm::hash_combine(
8323326b9fSLei Zhang       symbolOp->getName(),
8423326b9fSLei Zhang       llvm::hash_combine_range(range.begin(), range.end()));
85341f3c11Sergawy }
86341f3c11Sergawy 
8790a8260cSergawy namespace mlir {
8890a8260cSergawy namespace spirv {
8990a8260cSergawy 
combine(ArrayRef<spirv::ModuleOp> inputModules,OpBuilder & combinedModuleBuilder,SymbolRenameListener symRenameListener)9023326b9fSLei Zhang OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
9190a8260cSergawy                                      OpBuilder &combinedModuleBuilder,
9223326b9fSLei Zhang                                      SymbolRenameListener symRenameListener) {
9323326b9fSLei Zhang   if (inputModules.empty())
9490a8260cSergawy     return nullptr;
9590a8260cSergawy 
9623326b9fSLei Zhang   spirv::ModuleOp firstModule = inputModules.front();
9790a1632dSJakub Kuderski   auto addressingModel = firstModule.getAddressingModel();
9890a1632dSJakub Kuderski   auto memoryModel = firstModule.getMemoryModel();
9990a1632dSJakub Kuderski   auto vceTriple = firstModule.getVceTriple();
10023326b9fSLei Zhang 
10123326b9fSLei Zhang   // First check whether there are conflicts between addressing/memory model.
10223326b9fSLei Zhang   // Return early if so.
10323326b9fSLei Zhang   for (auto module : inputModules) {
10490a1632dSJakub Kuderski     if (module.getAddressingModel() != addressingModel ||
10590a1632dSJakub Kuderski         module.getMemoryModel() != memoryModel ||
10690a1632dSJakub Kuderski         module.getVceTriple() != vceTriple) {
10723326b9fSLei Zhang       module.emitError("input modules differ in addressing model, memory "
10823326b9fSLei Zhang                        "model, and/or VCE triple");
10923326b9fSLei Zhang       return nullptr;
11023326b9fSLei Zhang     }
11123326b9fSLei Zhang   }
11290a8260cSergawy 
11390a8260cSergawy   auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
11423326b9fSLei Zhang       firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
11556f60a1cSLei Zhang   combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
11690a8260cSergawy 
11790a8260cSergawy   // In some cases, a symbol in the (current state of the) combined module is
11823326b9fSLei Zhang   // renamed in order to enable the conflicting symbol in the input module
11990a8260cSergawy   // being merged. For example, if the conflict is between a global variable in
12090a8260cSergawy   // the current combined module and a function in the input module, the global
121f88fab50SKazuaki Ishizaki   // variable is renamed. In order to notify listeners of the symbol updates in
12290a8260cSergawy   // such cases, we need to keep track of the module from which the renamed
12390a8260cSergawy   // symbol in the combined module originated. This map keeps such information.
12423326b9fSLei Zhang   llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
12590a8260cSergawy 
12623326b9fSLei Zhang   unsigned lastUsedID = 0;
12790a8260cSergawy 
12823326b9fSLei Zhang   for (auto inputModule : inputModules) {
1292da3facdSMehdi Amini     OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone();
13090a8260cSergawy 
13190a8260cSergawy     // In the combined module, rename all symbols that conflict with symbols
132f88fab50SKazuaki Ishizaki     // from the current input module. This renaming applies to all ops except
1335ab6ef75SJakub Kuderski     // for spirv.funcs. This way, if the conflicting op in the input module is
1345ab6ef75SJakub Kuderski     // non-spirv.func, we rename that symbol instead and maintain the spirv.func
1355ab6ef75SJakub Kuderski     // in the combined module name as it is.
13656f60a1cSLei Zhang     for (auto &op : *combinedModule.getBody()) {
13723326b9fSLei Zhang       auto symbolOp = dyn_cast<SymbolOpInterface>(op);
13823326b9fSLei Zhang       if (!symbolOp)
13923326b9fSLei Zhang         continue;
14023326b9fSLei Zhang 
14190a8260cSergawy       StringRef oldSymName = symbolOp.getName();
14290a8260cSergawy 
14390a8260cSergawy       if (!isa<FuncOp>(op) &&
1442da3facdSMehdi Amini           failed(updateSymbolAndAllUses(symbolOp, combinedModule, *moduleClone,
14590a8260cSergawy                                         lastUsedID)))
14690a8260cSergawy         return nullptr;
14790a8260cSergawy 
14890a8260cSergawy       StringRef newSymName = symbolOp.getName();
14990a8260cSergawy 
15090a8260cSergawy       if (symRenameListener && oldSymName != newSymName) {
15123326b9fSLei Zhang         spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
15290a8260cSergawy 
15390a8260cSergawy         if (!originalModule) {
15423326b9fSLei Zhang           inputModule.emitError(
15523326b9fSLei Zhang               "unable to find original spirv::ModuleOp for symbol ")
15690a8260cSergawy               << oldSymName;
15790a8260cSergawy           return nullptr;
15890a8260cSergawy         }
15990a8260cSergawy 
16090a8260cSergawy         symRenameListener(originalModule, oldSymName, newSymName);
16190a8260cSergawy 
16290a8260cSergawy         // Since the symbol name is updated, there is no need to maintain the
163f88fab50SKazuaki Ishizaki         // entry that associates the old symbol name with the original module.
16490a8260cSergawy         symNameToModuleMap.erase(oldSymName);
16590a8260cSergawy         // Instead, add a new entry to map the new symbol name to the original
16690a8260cSergawy         // module in case it gets renamed again later.
16790a8260cSergawy         symNameToModuleMap[newSymName] = originalModule;
16890a8260cSergawy       }
16990a8260cSergawy     }
17090a8260cSergawy 
17190a8260cSergawy     // In the current input module, rename all symbols that conflict with
1725ab6ef75SJakub Kuderski     // symbols from the combined module. This includes renaming spirv.funcs.
1732da3facdSMehdi Amini     for (auto &op : *moduleClone->getBody()) {
17423326b9fSLei Zhang       auto symbolOp = dyn_cast<SymbolOpInterface>(op);
17523326b9fSLei Zhang       if (!symbolOp)
17623326b9fSLei Zhang         continue;
17723326b9fSLei Zhang 
17890a8260cSergawy       StringRef oldSymName = symbolOp.getName();
17990a8260cSergawy 
1802da3facdSMehdi Amini       if (failed(updateSymbolAndAllUses(symbolOp, *moduleClone, combinedModule,
18190a8260cSergawy                                         lastUsedID)))
18290a8260cSergawy         return nullptr;
18390a8260cSergawy 
18490a8260cSergawy       StringRef newSymName = symbolOp.getName();
18590a8260cSergawy 
18623326b9fSLei Zhang       if (symRenameListener) {
18723326b9fSLei Zhang         if (oldSymName != newSymName)
18823326b9fSLei Zhang           symRenameListener(inputModule, oldSymName, newSymName);
18990a8260cSergawy 
19090a8260cSergawy         // Insert the module associated with the symbol name.
19190a8260cSergawy         auto emplaceResult =
19223326b9fSLei Zhang             symNameToModuleMap.try_emplace(newSymName, inputModule);
19390a8260cSergawy 
19490a8260cSergawy         // If an entry with the same symbol name is already present, this must
19590a8260cSergawy         // be a problem with the implementation, specially clean-up of the map
19690a8260cSergawy         // while iterating over the combined module above.
19790a8260cSergawy         if (!emplaceResult.second) {
19823326b9fSLei Zhang           inputModule.emitError("did not expect to find an entry for symbol ")
19990a8260cSergawy               << symbolOp.getName();
20090a8260cSergawy           return nullptr;
20190a8260cSergawy         }
20290a8260cSergawy       }
20390a8260cSergawy     }
20490a8260cSergawy 
20590a8260cSergawy     // Clone all the module's ops to the combined module.
2062da3facdSMehdi Amini     for (auto &op : *moduleClone->getBody())
20790a8260cSergawy       combinedModuleBuilder.insert(op.clone());
20890a8260cSergawy   }
20990a8260cSergawy 
210341f3c11Sergawy   // Deduplicate identical global variables, spec constants, and functions.
211341f3c11Sergawy   DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
212341f3c11Sergawy   SmallVector<SymbolOpInterface, 0> eraseList;
213341f3c11Sergawy 
21456f60a1cSLei Zhang   for (auto &op : *combinedModule.getBody()) {
215341f3c11Sergawy     SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
216341f3c11Sergawy     if (!symbolOp)
217341f3c11Sergawy       continue;
218341f3c11Sergawy 
21923326b9fSLei Zhang     // Do not support ops with operands or results.
22023326b9fSLei Zhang     // Global variables, spec constants, and functions won't have
22123326b9fSLei Zhang     // operands/results, but just for safety here.
22223326b9fSLei Zhang     if (op.getNumOperands() != 0 || op.getNumResults() != 0)
223341f3c11Sergawy       continue;
224341f3c11Sergawy 
22523326b9fSLei Zhang     // Deduplicating functions are not supported yet.
22623326b9fSLei Zhang     if (isa<FuncOp>(op))
227341f3c11Sergawy       continue;
228341f3c11Sergawy 
22923326b9fSLei Zhang     auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
23023326b9fSLei Zhang     if (result.second)
23123326b9fSLei Zhang       continue;
23223326b9fSLei Zhang 
23323326b9fSLei Zhang     SymbolOpInterface replacementSymOp = result.first->second;
23423326b9fSLei Zhang 
235341f3c11Sergawy     if (failed(SymbolTable::replaceAllSymbolUses(
23641d4aa7dSChris Lattner             symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
237341f3c11Sergawy       symbolOp.emitError("unable to update all symbol uses for ")
238341f3c11Sergawy           << symbolOp.getName() << " to " << replacementSymOp.getName();
239341f3c11Sergawy       return nullptr;
240341f3c11Sergawy     }
241341f3c11Sergawy 
242341f3c11Sergawy     eraseList.push_back(symbolOp);
243341f3c11Sergawy   }
244341f3c11Sergawy 
245341f3c11Sergawy   for (auto symbolOp : eraseList)
246341f3c11Sergawy     symbolOp.erase();
247341f3c11Sergawy 
24890a8260cSergawy   return combinedModule;
24990a8260cSergawy }
25090a8260cSergawy 
25190a8260cSergawy } // namespace spirv
25290a8260cSergawy } // namespace mlir
253