1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the the SPIR-V module combiner library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/ModuleCombiner.h" 14 15 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/SymbolTable.h" 19 #include "llvm/ADT/ArrayRef.h" 20 #include "llvm/ADT/Hashing.h" 21 #include "llvm/ADT/StringExtras.h" 22 23 using namespace mlir; 24 25 static constexpr unsigned maxFreeID = 1 << 20; 26 27 static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID, 28 spirv::ModuleOp combinedModule) { 29 SmallString<64> newSymName(oldSymName); 30 newSymName.push_back('_'); 31 32 while (lastUsedID < maxFreeID) { 33 std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str(); 34 35 if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) { 36 newSymName += llvm::utostr(lastUsedID); 37 break; 38 } 39 } 40 41 return newSymName; 42 } 43 44 /// Check if a symbol with the same name as op already exists in source. If so, 45 /// rename op and update all its references in target. 46 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, 47 spirv::ModuleOp target, 48 spirv::ModuleOp source, 49 unsigned &lastUsedID) { 50 if (!SymbolTable::lookupSymbolIn(source, op.getName())) 51 return success(); 52 53 StringRef oldSymName = op.getName(); 54 SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target); 55 56 if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target))) 57 return op.emitError("unable to update all symbol uses for ") 58 << oldSymName << " to " << newSymName; 59 60 SymbolTable::setSymbolName(op, newSymName); 61 return success(); 62 } 63 64 template <typename KeyTy, typename SymbolOpTy> 65 static SymbolOpTy 66 emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp, 67 DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) { 68 auto result = deduplicationMap.try_emplace(key, symbolOp); 69 70 if (result.second) 71 return SymbolOpTy(); 72 73 return result.first->second; 74 } 75 76 /// Computes a hash code to represent the argument SymbolOpInterface based on 77 /// all the Op's attributes except for the symbol name. 78 /// 79 /// \return the hash code computed from the Op's attributes as described above. 80 /// 81 /// Note: We use the operation's name (not the symbol name) as part of the hash 82 /// computation. This prevents, for example, mistakenly considering a global 83 /// variable and a spec constant as duplicates because their descriptor set + 84 /// binding and spec_id, repectively, happen to hash to the same value. 85 static llvm::hash_code computeHash(SymbolOpInterface symbolOp) { 86 llvm::hash_code hashCode(0); 87 hashCode = llvm::hash_combine(symbolOp->getName()); 88 89 for (auto attr : symbolOp->getAttrs()) { 90 if (attr.first == SymbolTable::getSymbolAttrName()) 91 continue; 92 hashCode = llvm::hash_combine(hashCode, attr); 93 } 94 95 return hashCode; 96 } 97 98 /// Computes a hash code from the argument Block. 99 llvm::hash_code computeHash(Block *block) { 100 // TODO: Consider extracting BlockEquivalenceData into a common header and 101 // re-using it here. 102 llvm::hash_code hash(0); 103 104 for (Operation &op : *block) { 105 // TODO: Properly handle operations with regions. 106 if (op.getNumRegions() > 0) 107 return 0; 108 109 hash = llvm::hash_combine( 110 hash, OperationEquivalence::computeHash( 111 &op, OperationEquivalence::Flags::IgnoreOperands)); 112 } 113 114 return hash; 115 } 116 117 namespace mlir { 118 namespace spirv { 119 120 // TODO Properly test symbol rename listener mechanism. 121 122 OwningSPIRVModuleRef 123 combine(llvm::MutableArrayRef<spirv::ModuleOp> modules, 124 OpBuilder &combinedModuleBuilder, 125 llvm::function_ref<void(ModuleOp, StringRef, StringRef)> 126 symRenameListener) { 127 unsigned lastUsedID = 0; 128 129 if (modules.empty()) 130 return nullptr; 131 132 auto addressingModel = modules[0].addressing_model(); 133 auto memoryModel = modules[0].memory_model(); 134 135 auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>( 136 modules[0].getLoc(), addressingModel, memoryModel); 137 combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody()); 138 139 // In some cases, a symbol in the (current state of the) combined module is 140 // renamed in order to maintain the conflicting symbol in the input module 141 // being merged. For example, if the conflict is between a global variable in 142 // the current combined module and a function in the input module, the global 143 // varaible is renamed. In order to notify listeners of the symbol updates in 144 // such cases, we need to keep track of the module from which the renamed 145 // symbol in the combined module originated. This map keeps such information. 146 DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap; 147 148 for (auto module : modules) { 149 if (module.addressing_model() != addressingModel || 150 module.memory_model() != memoryModel) { 151 module.emitError( 152 "input modules differ in addressing model and/or memory model"); 153 return nullptr; 154 } 155 156 spirv::ModuleOp moduleClone = module.clone(); 157 158 // In the combined module, rename all symbols that conflict with symbols 159 // from the current input module. This renmaing applies to all ops except 160 // for spv.funcs. This way, if the conflicting op in the input module is 161 // non-spv.func, we rename that symbol instead and maintain the spv.func in 162 // the combined module name as it is. 163 for (auto &op : combinedModule.getBlock().without_terminator()) { 164 if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) { 165 StringRef oldSymName = symbolOp.getName(); 166 167 if (!isa<FuncOp>(op) && 168 failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone, 169 lastUsedID))) 170 return nullptr; 171 172 StringRef newSymName = symbolOp.getName(); 173 174 if (symRenameListener && oldSymName != newSymName) { 175 spirv::ModuleOp originalModule = 176 symNameToModuleMap.lookup(oldSymName); 177 178 if (!originalModule) { 179 module.emitError("unable to find original ModuleOp for symbol ") 180 << oldSymName; 181 return nullptr; 182 } 183 184 symRenameListener(originalModule, oldSymName, newSymName); 185 186 // Since the symbol name is updated, there is no need to maintain the 187 // entry that assocaites the old symbol name with the original module. 188 symNameToModuleMap.erase(oldSymName); 189 // Instead, add a new entry to map the new symbol name to the original 190 // module in case it gets renamed again later. 191 symNameToModuleMap[newSymName] = originalModule; 192 } 193 } 194 } 195 196 // In the current input module, rename all symbols that conflict with 197 // symbols from the combined module. This includes renaming spv.funcs. 198 for (auto &op : moduleClone.getBlock().without_terminator()) { 199 if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) { 200 StringRef oldSymName = symbolOp.getName(); 201 202 if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule, 203 lastUsedID))) 204 return nullptr; 205 206 StringRef newSymName = symbolOp.getName(); 207 208 if (symRenameListener && oldSymName != newSymName) { 209 symRenameListener(module, oldSymName, newSymName); 210 211 // Insert the module associated with the symbol name. 212 auto emplaceResult = 213 symNameToModuleMap.try_emplace(symbolOp.getName(), module); 214 215 // If an entry with the same symbol name is already present, this must 216 // be a problem with the implementation, specially clean-up of the map 217 // while iterating over the combined module above. 218 if (!emplaceResult.second) { 219 module.emitError("did not expect to find an entry for symbol ") 220 << symbolOp.getName(); 221 return nullptr; 222 } 223 } 224 } 225 } 226 227 // Clone all the module's ops to the combined module. 228 for (auto &op : moduleClone.getBlock().without_terminator()) 229 combinedModuleBuilder.insert(op.clone()); 230 } 231 232 // Deduplicate identical global variables, spec constants, and functions. 233 DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp; 234 SmallVector<SymbolOpInterface, 0> eraseList; 235 236 for (auto &op : combinedModule.getBlock().without_terminator()) { 237 llvm::hash_code hashCode(0); 238 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op); 239 240 if (!symbolOp) 241 continue; 242 243 hashCode = computeHash(symbolOp); 244 245 // A 0 hash code means the op is not suitable for deduplication and should 246 // be skipped. An example of this is when a function has ops with regions 247 // which are not properly supported yet. 248 if (!hashCode) 249 continue; 250 251 if (auto funcOp = dyn_cast<FuncOp>(op)) 252 for (auto &blk : funcOp) 253 hashCode = llvm::hash_combine(hashCode, computeHash(&blk)); 254 255 SymbolOpInterface replacementSymOp = 256 emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp); 257 258 if (!replacementSymOp) 259 continue; 260 261 if (failed(SymbolTable::replaceAllSymbolUses( 262 symbolOp, replacementSymOp.getName(), combinedModule))) { 263 symbolOp.emitError("unable to update all symbol uses for ") 264 << symbolOp.getName() << " to " << replacementSymOp.getName(); 265 return nullptr; 266 } 267 268 eraseList.push_back(symbolOp); 269 } 270 271 for (auto symbolOp : eraseList) 272 symbolOp.erase(); 273 274 return combinedModule; 275 } 276 277 } // namespace spirv 278 } // namespace mlir 279