1 //===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MLProgram/Transforms/Passes.h" 10 11 #include "mlir/Dialect/MLProgram/IR/MLProgram.h" 12 #include "mlir/Dialect/MLProgram/Transforms/Passes.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16 17 namespace mlir { 18 namespace ml_program { 19 #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS 20 #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc" 21 22 namespace { 23 24 class MLProgramPipelineGlobals 25 : public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> { 26 public: 27 void runOnOperation() override; 28 29 private: 30 LogicalResult buildGlobalMap(ModuleOp op); 31 32 void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad, 33 llvm::DenseSet<SymbolRefAttr> &symbolStore); 34 35 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap; 36 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap; 37 }; 38 39 // Traverses upwards searchign for the operation mapped by the symbol. 40 static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) { 41 for (auto *op = baseOp; op; op = op->getParentOp()) { 42 auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol); 43 if (lookup) 44 return lookup; 45 } 46 return nullptr; 47 } 48 49 // Builds map from a symbol to MLProgram global symbols loaded or stored 50 // during processing. 51 LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) { 52 llvm::DenseMap<SymbolRefAttr, Operation *> callableMap; 53 auto res = module->walk([&](Operation *op) { 54 if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) { 55 auto callable = caller.getCallableForCallee(); 56 // For now we do not know how to handle Value based tracing, so fail. 57 if (mlir::isa<Value>(callable)) { 58 return WalkResult::interrupt(); 59 } 60 61 auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable); 62 auto *func = getFromSymbol(op, symbol); 63 callableMap[symbol] = func; 64 } 65 return WalkResult::advance(); 66 }); 67 68 if (res.wasInterrupted()) { 69 return failure(); 70 } 71 72 // First grab all symbols loaded or stored by each function. This 73 // will not handle calls initially. 74 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols; 75 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols; 76 for (auto callable : callableMap) { 77 llvm::DenseSet<SymbolRefAttr> loadSymbols; 78 llvm::DenseSet<SymbolRefAttr> storeSymbols; 79 80 callable.getSecond()->walk( 81 [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); }); 82 83 callable.getSecond()->walk( 84 [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); }); 85 86 opLoadSymbols[callable.getFirst()] = std::move(loadSymbols); 87 opStoreSymbols[callable.getFirst()] = std::move(storeSymbols); 88 } 89 90 // For each callable function we find each global loaded/stored within the 91 // function or a nested called function. This includes recursion checking to 92 // avoid infinitely recursing. 93 for (auto callable : callableMap) { 94 SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first); 95 llvm::SmallVector<SymbolRefAttr> work = {thisSymbol}; 96 llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol}; 97 llvm::DenseSet<SymbolRefAttr> loadSymbols; 98 llvm::DenseSet<SymbolRefAttr> storeSymbols; 99 100 for (size_t i = 0; i < work.size(); ++i) { 101 callableMap[work[i]]->walk([&](CallOpInterface call) { 102 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee()); 103 if (visited.insert(symbol).second) 104 work.push_back(symbol); 105 }); 106 107 for (auto load : opLoadSymbols[work[i]]) 108 loadSymbols.insert(load); 109 110 for (auto store : opStoreSymbols[work[i]]) 111 storeSymbols.insert(store); 112 } 113 114 loadSymbolsMap[thisSymbol] = std::move(loadSymbols); 115 storeSymbolsMap[thisSymbol] = std::move(storeSymbols); 116 } 117 118 return success(); 119 } 120 121 // Process each operation in the block deleting unneeded loads / stores, 122 // recursing on subblocks and checking function calls. 123 void MLProgramPipelineGlobals::processBlock( 124 Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad, 125 llvm::DenseSet<SymbolRefAttr> &symbolStore) { 126 127 llvm::DenseMap<SymbolRefAttr, Value> previousLoads; 128 llvm::DenseMap<SymbolRefAttr, Operation *> previousStores; 129 llvm::SmallVector<Operation *> toDelete; 130 for (auto &op : block) { 131 // If this is a global load, remap to a previous value if known 132 // and delete this load. Remember that this value is the currently 133 // known load. 134 if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) { 135 auto ref = load.getGlobal(); 136 symbolLoad.insert(ref); 137 if (previousLoads.contains(ref)) { 138 toDelete.push_back(&op); 139 load.getResult().replaceAllUsesWith(previousLoads[ref]); 140 } else { 141 previousLoads[ref] = load.getResult(); 142 } 143 continue; 144 } 145 146 // Delete a previous store if it exists and is not needed, update 147 // the most recent known value for this global ref. 148 if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) { 149 auto ref = store.getGlobal(); 150 symbolStore.insert(ref); 151 auto it = previousStores.find(ref); 152 if (it != previousStores.end()) { 153 toDelete.push_back(it->getSecond()); 154 } 155 156 previousLoads[ref] = store.getValue(); 157 previousStores[ref] = &op; 158 continue; 159 } 160 161 // If a function is called, clear known values for loads/stores used by 162 // the function or its sub-functions. 163 if (auto call = mlir::dyn_cast<CallOpInterface>(op)) { 164 auto loadSymbols = 165 loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())]; 166 auto storeSymbols = 167 storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())]; 168 169 for (auto sym : loadSymbols) { 170 previousStores.erase(sym); 171 } 172 173 for (auto sym : storeSymbols) { 174 previousLoads.erase(sym); 175 previousStores.erase(sym); 176 } 177 continue; 178 } 179 180 // If the op has sub-regions, recurse inside. We make no guarantees whether 181 // the recursion occurs. 182 llvm::DenseSet<SymbolRefAttr> opSymbolLoad; 183 llvm::DenseSet<SymbolRefAttr> opSymbolStore; 184 for (auto ®ion : op.getRegions()) { 185 for (auto &block : region) { 186 processBlock(block, opSymbolLoad, opSymbolStore); 187 } 188 } 189 190 // Update current state from the subblock. 191 for (auto change : opSymbolLoad) { 192 symbolLoad.insert(change); 193 previousStores.erase(change); 194 } 195 196 for (auto change : opSymbolStore) { 197 symbolStore.insert(change); 198 previousLoads.erase(change); 199 previousStores.erase(change); 200 } 201 } 202 203 for (auto *op : toDelete) { 204 op->erase(); 205 } 206 } 207 208 void MLProgramPipelineGlobals::runOnOperation() { 209 auto targetOp = getOperation(); 210 if (failed(buildGlobalMap(targetOp))) { 211 return; 212 } 213 214 for (auto &funcOp : *targetOp.getBody()) { 215 for (auto ®ion : funcOp.getRegions()) { 216 for (auto &block : region.getBlocks()) { 217 llvm::DenseSet<SymbolRefAttr> symbolsLoaded; 218 llvm::DenseSet<SymbolRefAttr> symbolsStored; 219 processBlock(block, symbolsLoaded, symbolsStored); 220 } 221 } 222 } 223 } 224 225 } // namespace 226 227 std::unique_ptr<OperationPass<mlir::ModuleOp>> 228 createMLProgramPipelineGlobalsPass() { 229 return std::make_unique<MLProgramPipelineGlobals>(); 230 } 231 232 } // namespace ml_program 233 } // namespace mlir 234