1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the SPIR-V module combiner library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/SymbolTable.h" 21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/Hashing.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/StringExtras.h" 25 #include "llvm/ADT/StringMap.h" 26 27 using namespace mlir; 28 29 static constexpr unsigned maxFreeID = 1 << 20; 30 31 /// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric 32 /// suffix in `lastUsedID`. 33 static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID, 34 spirv::ModuleOp module) { 35 SmallString<64> newSymName(oldSymName); 36 newSymName.push_back('_'); 37 38 while (lastUsedID < maxFreeID) { 39 std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str(); 40 41 if (!SymbolTable::lookupSymbolIn(module, possible)) { 42 newSymName += llvm::utostr(lastUsedID); 43 break; 44 } 45 } 46 47 return newSymName; 48 } 49 50 /// Checks if a symbol with the same name as `op` already exists in `source`. 51 /// If so, renames `op` and updates all its references in `target`. 52 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, 53 spirv::ModuleOp target, 54 spirv::ModuleOp source, 55 unsigned &lastUsedID) { 56 if (!SymbolTable::lookupSymbolIn(source, op.getName())) 57 return success(); 58 59 StringRef oldSymName = op.getName(); 60 SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target); 61 62 if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target))) 63 return op.emitError("unable to update all symbol uses for ") 64 << oldSymName << " to " << newSymName; 65 66 SymbolTable::setSymbolName(op, newSymName); 67 return success(); 68 } 69 70 /// Computes a hash code to represent `symbolOp` based on all its attributes 71 /// except for the symbol name. 72 /// 73 /// Note: We use the operation's name (not the symbol name) as part of the hash 74 /// computation. This prevents, for example, mistakenly considering a global 75 /// variable and a spec constant as duplicates because their descriptor set + 76 /// binding and spec_id, respectively, happen to hash to the same value. 77 static llvm::hash_code computeHash(SymbolOpInterface symbolOp) { 78 auto range = 79 llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) { 80 return attr.first != SymbolTable::getSymbolAttrName(); 81 }); 82 83 return llvm::hash_combine( 84 symbolOp->getName(), 85 llvm::hash_combine_range(range.begin(), range.end())); 86 } 87 88 namespace mlir { 89 namespace spirv { 90 91 OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules, 92 OpBuilder &combinedModuleBuilder, 93 SymbolRenameListener symRenameListener) { 94 if (inputModules.empty()) 95 return nullptr; 96 97 spirv::ModuleOp firstModule = inputModules.front(); 98 auto addressingModel = firstModule.addressing_model(); 99 auto memoryModel = firstModule.memory_model(); 100 auto vceTriple = firstModule.vce_triple(); 101 102 // First check whether there are conflicts between addressing/memory model. 103 // Return early if so. 104 for (auto module : inputModules) { 105 if (module.addressing_model() != addressingModel || 106 module.memory_model() != memoryModel || 107 module.vce_triple() != vceTriple) { 108 module.emitError("input modules differ in addressing model, memory " 109 "model, and/or VCE triple"); 110 return nullptr; 111 } 112 } 113 114 auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>( 115 firstModule.getLoc(), addressingModel, memoryModel, vceTriple); 116 combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody()); 117 118 // In some cases, a symbol in the (current state of the) combined module is 119 // renamed in order to enable the conflicting symbol in the input module 120 // being merged. For example, if the conflict is between a global variable in 121 // the current combined module and a function in the input module, the global 122 // variable is renamed. In order to notify listeners of the symbol updates in 123 // such cases, we need to keep track of the module from which the renamed 124 // symbol in the combined module originated. This map keeps such information. 125 llvm::StringMap<spirv::ModuleOp> symNameToModuleMap; 126 127 unsigned lastUsedID = 0; 128 129 for (auto inputModule : inputModules) { 130 spirv::ModuleOp moduleClone = inputModule.clone(); 131 132 // In the combined module, rename all symbols that conflict with symbols 133 // from the current input module. This renaming applies to all ops except 134 // for spv.funcs. This way, if the conflicting op in the input module is 135 // non-spv.func, we rename that symbol instead and maintain the spv.func in 136 // the combined module name as it is. 137 for (auto &op : *combinedModule.getBody()) { 138 auto symbolOp = dyn_cast<SymbolOpInterface>(op); 139 if (!symbolOp) 140 continue; 141 142 StringRef oldSymName = symbolOp.getName(); 143 144 if (!isa<FuncOp>(op) && 145 failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone, 146 lastUsedID))) 147 return nullptr; 148 149 StringRef newSymName = symbolOp.getName(); 150 151 if (symRenameListener && oldSymName != newSymName) { 152 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName); 153 154 if (!originalModule) { 155 inputModule.emitError( 156 "unable to find original spirv::ModuleOp for symbol ") 157 << oldSymName; 158 return nullptr; 159 } 160 161 symRenameListener(originalModule, oldSymName, newSymName); 162 163 // Since the symbol name is updated, there is no need to maintain the 164 // entry that associates the old symbol name with the original module. 165 symNameToModuleMap.erase(oldSymName); 166 // Instead, add a new entry to map the new symbol name to the original 167 // module in case it gets renamed again later. 168 symNameToModuleMap[newSymName] = originalModule; 169 } 170 } 171 172 // In the current input module, rename all symbols that conflict with 173 // symbols from the combined module. This includes renaming spv.funcs. 174 for (auto &op : *moduleClone.getBody()) { 175 auto symbolOp = dyn_cast<SymbolOpInterface>(op); 176 if (!symbolOp) 177 continue; 178 179 StringRef oldSymName = symbolOp.getName(); 180 181 if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule, 182 lastUsedID))) 183 return nullptr; 184 185 StringRef newSymName = symbolOp.getName(); 186 187 if (symRenameListener) { 188 if (oldSymName != newSymName) 189 symRenameListener(inputModule, oldSymName, newSymName); 190 191 // Insert the module associated with the symbol name. 192 auto emplaceResult = 193 symNameToModuleMap.try_emplace(newSymName, inputModule); 194 195 // If an entry with the same symbol name is already present, this must 196 // be a problem with the implementation, specially clean-up of the map 197 // while iterating over the combined module above. 198 if (!emplaceResult.second) { 199 inputModule.emitError("did not expect to find an entry for symbol ") 200 << symbolOp.getName(); 201 return nullptr; 202 } 203 } 204 } 205 206 // Clone all the module's ops to the combined module. 207 for (auto &op : *moduleClone.getBody()) 208 combinedModuleBuilder.insert(op.clone()); 209 } 210 211 // Deduplicate identical global variables, spec constants, and functions. 212 DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp; 213 SmallVector<SymbolOpInterface, 0> eraseList; 214 215 for (auto &op : *combinedModule.getBody()) { 216 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op); 217 if (!symbolOp) 218 continue; 219 220 // Do not support ops with operands or results. 221 // Global variables, spec constants, and functions won't have 222 // operands/results, but just for safety here. 223 if (op.getNumOperands() != 0 || op.getNumResults() != 0) 224 continue; 225 226 // Deduplicating functions are not supported yet. 227 if (isa<FuncOp>(op)) 228 continue; 229 230 auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp); 231 if (result.second) 232 continue; 233 234 SymbolOpInterface replacementSymOp = result.first->second; 235 236 if (failed(SymbolTable::replaceAllSymbolUses( 237 symbolOp, replacementSymOp.getName(), combinedModule))) { 238 symbolOp.emitError("unable to update all symbol uses for ") 239 << symbolOp.getName() << " to " << replacementSymOp.getName(); 240 return nullptr; 241 } 242 243 eraseList.push_back(symbolOp); 244 } 245 246 for (auto symbolOp : eraseList) 247 symbolOp.erase(); 248 249 return combinedModule; 250 } 251 252 } // namespace spirv 253 } // namespace mlir 254