1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the SPIR-V module combiner library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/SymbolTable.h" 21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/Hashing.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/StringExtras.h" 25 #include "llvm/ADT/StringMap.h" 26 27 using namespace mlir; 28 29 static constexpr unsigned maxFreeID = 1 << 20; 30 31 /// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric 32 /// suffix in `lastUsedID`. 33 static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID, 34 spirv::ModuleOp module) { 35 SmallString<64> newSymName(oldSymName); 36 newSymName.push_back('_'); 37 38 MLIRContext *ctx = module->getContext(); 39 40 while (lastUsedID < maxFreeID) { 41 auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID)); 42 if (!SymbolTable::lookupSymbolIn(module, possible)) 43 return possible; 44 } 45 46 return StringAttr::get(ctx, newSymName); 47 } 48 49 /// Checks if a symbol with the same name as `op` already exists in `source`. 50 /// If so, renames `op` and updates all its references in `target`. 51 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, 52 spirv::ModuleOp target, 53 spirv::ModuleOp source, 54 unsigned &lastUsedID) { 55 if (!SymbolTable::lookupSymbolIn(source, op.getName())) 56 return success(); 57 58 StringRef oldSymName = op.getName(); 59 StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target); 60 61 if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target))) 62 return op.emitError("unable to update all symbol uses for ") 63 << oldSymName << " to " << newSymName; 64 65 SymbolTable::setSymbolName(op, newSymName); 66 return success(); 67 } 68 69 /// Computes a hash code to represent `symbolOp` based on all its attributes 70 /// except for the symbol name. 71 /// 72 /// Note: We use the operation's name (not the symbol name) as part of the hash 73 /// computation. This prevents, for example, mistakenly considering a global 74 /// variable and a spec constant as duplicates because their descriptor set + 75 /// binding and spec_id, respectively, happen to hash to the same value. 76 static llvm::hash_code computeHash(SymbolOpInterface symbolOp) { 77 auto range = 78 llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) { 79 return attr.first != SymbolTable::getSymbolAttrName(); 80 }); 81 82 return llvm::hash_combine( 83 symbolOp->getName(), 84 llvm::hash_combine_range(range.begin(), range.end())); 85 } 86 87 namespace mlir { 88 namespace spirv { 89 90 OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules, 91 OpBuilder &combinedModuleBuilder, 92 SymbolRenameListener symRenameListener) { 93 if (inputModules.empty()) 94 return nullptr; 95 96 spirv::ModuleOp firstModule = inputModules.front(); 97 auto addressingModel = firstModule.addressing_model(); 98 auto memoryModel = firstModule.memory_model(); 99 auto vceTriple = firstModule.vce_triple(); 100 101 // First check whether there are conflicts between addressing/memory model. 102 // Return early if so. 103 for (auto module : inputModules) { 104 if (module.addressing_model() != addressingModel || 105 module.memory_model() != memoryModel || 106 module.vce_triple() != vceTriple) { 107 module.emitError("input modules differ in addressing model, memory " 108 "model, and/or VCE triple"); 109 return nullptr; 110 } 111 } 112 113 auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>( 114 firstModule.getLoc(), addressingModel, memoryModel, vceTriple); 115 combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody()); 116 117 // In some cases, a symbol in the (current state of the) combined module is 118 // renamed in order to enable the conflicting symbol in the input module 119 // being merged. For example, if the conflict is between a global variable in 120 // the current combined module and a function in the input module, the global 121 // variable is renamed. In order to notify listeners of the symbol updates in 122 // such cases, we need to keep track of the module from which the renamed 123 // symbol in the combined module originated. This map keeps such information. 124 llvm::StringMap<spirv::ModuleOp> symNameToModuleMap; 125 126 unsigned lastUsedID = 0; 127 128 for (auto inputModule : inputModules) { 129 OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone(); 130 131 // In the combined module, rename all symbols that conflict with symbols 132 // from the current input module. This renaming applies to all ops except 133 // for spv.funcs. This way, if the conflicting op in the input module is 134 // non-spv.func, we rename that symbol instead and maintain the spv.func in 135 // the combined module name as it is. 136 for (auto &op : *combinedModule.getBody()) { 137 auto symbolOp = dyn_cast<SymbolOpInterface>(op); 138 if (!symbolOp) 139 continue; 140 141 StringRef oldSymName = symbolOp.getName(); 142 143 if (!isa<FuncOp>(op) && 144 failed(updateSymbolAndAllUses(symbolOp, combinedModule, *moduleClone, 145 lastUsedID))) 146 return nullptr; 147 148 StringRef newSymName = symbolOp.getName(); 149 150 if (symRenameListener && oldSymName != newSymName) { 151 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName); 152 153 if (!originalModule) { 154 inputModule.emitError( 155 "unable to find original spirv::ModuleOp for symbol ") 156 << oldSymName; 157 return nullptr; 158 } 159 160 symRenameListener(originalModule, oldSymName, newSymName); 161 162 // Since the symbol name is updated, there is no need to maintain the 163 // entry that associates the old symbol name with the original module. 164 symNameToModuleMap.erase(oldSymName); 165 // Instead, add a new entry to map the new symbol name to the original 166 // module in case it gets renamed again later. 167 symNameToModuleMap[newSymName] = originalModule; 168 } 169 } 170 171 // In the current input module, rename all symbols that conflict with 172 // symbols from the combined module. This includes renaming spv.funcs. 173 for (auto &op : *moduleClone->getBody()) { 174 auto symbolOp = dyn_cast<SymbolOpInterface>(op); 175 if (!symbolOp) 176 continue; 177 178 StringRef oldSymName = symbolOp.getName(); 179 180 if (failed(updateSymbolAndAllUses(symbolOp, *moduleClone, combinedModule, 181 lastUsedID))) 182 return nullptr; 183 184 StringRef newSymName = symbolOp.getName(); 185 186 if (symRenameListener) { 187 if (oldSymName != newSymName) 188 symRenameListener(inputModule, oldSymName, newSymName); 189 190 // Insert the module associated with the symbol name. 191 auto emplaceResult = 192 symNameToModuleMap.try_emplace(newSymName, inputModule); 193 194 // If an entry with the same symbol name is already present, this must 195 // be a problem with the implementation, specially clean-up of the map 196 // while iterating over the combined module above. 197 if (!emplaceResult.second) { 198 inputModule.emitError("did not expect to find an entry for symbol ") 199 << symbolOp.getName(); 200 return nullptr; 201 } 202 } 203 } 204 205 // Clone all the module's ops to the combined module. 206 for (auto &op : *moduleClone->getBody()) 207 combinedModuleBuilder.insert(op.clone()); 208 } 209 210 // Deduplicate identical global variables, spec constants, and functions. 211 DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp; 212 SmallVector<SymbolOpInterface, 0> eraseList; 213 214 for (auto &op : *combinedModule.getBody()) { 215 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op); 216 if (!symbolOp) 217 continue; 218 219 // Do not support ops with operands or results. 220 // Global variables, spec constants, and functions won't have 221 // operands/results, but just for safety here. 222 if (op.getNumOperands() != 0 || op.getNumResults() != 0) 223 continue; 224 225 // Deduplicating functions are not supported yet. 226 if (isa<FuncOp>(op)) 227 continue; 228 229 auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp); 230 if (result.second) 231 continue; 232 233 SymbolOpInterface replacementSymOp = result.first->second; 234 235 if (failed(SymbolTable::replaceAllSymbolUses( 236 symbolOp, replacementSymOp.getNameAttr(), combinedModule))) { 237 symbolOp.emitError("unable to update all symbol uses for ") 238 << symbolOp.getName() << " to " << replacementSymOp.getName(); 239 return nullptr; 240 } 241 242 eraseList.push_back(symbolOp); 243 } 244 245 for (auto symbolOp : eraseList) 246 symbolOp.erase(); 247 248 return combinedModule; 249 } 250 251 } // namespace spirv 252 } // namespace mlir 253