1cbd47504SRob Suderman //===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===// 2cbd47504SRob Suderman // 3cbd47504SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4cbd47504SRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5cbd47504SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6cbd47504SRob Suderman // 7cbd47504SRob Suderman //===----------------------------------------------------------------------===// 8cbd47504SRob Suderman 9cbd47504SRob Suderman #include "mlir/Dialect/MLProgram/Transforms/Passes.h" 10cbd47504SRob Suderman 11cbd47504SRob Suderman #include "mlir/Dialect/MLProgram/IR/MLProgram.h" 12cbd47504SRob Suderman #include "mlir/Dialect/MLProgram/Transforms/Passes.h" 13cbd47504SRob Suderman #include "mlir/IR/BuiltinOps.h" 14cbd47504SRob Suderman #include "mlir/Pass/Pass.h" 15cbd47504SRob Suderman #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16cbd47504SRob Suderman 17cbd47504SRob Suderman namespace mlir { 18cbd47504SRob Suderman namespace ml_program { 19cbd47504SRob Suderman #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS 20cbd47504SRob Suderman #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc" 21cbd47504SRob Suderman 22cbd47504SRob Suderman namespace { 23cbd47504SRob Suderman 24cbd47504SRob Suderman class MLProgramPipelineGlobals 25cbd47504SRob Suderman : public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> { 26cbd47504SRob Suderman public: 27cbd47504SRob Suderman void runOnOperation() override; 28cbd47504SRob Suderman 29cbd47504SRob Suderman private: 30cbd47504SRob Suderman LogicalResult buildGlobalMap(ModuleOp op); 31cbd47504SRob Suderman 324e713357SMehdi Amini void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad, 33cbd47504SRob Suderman llvm::DenseSet<SymbolRefAttr> &symbolStore); 34cbd47504SRob Suderman 35cbd47504SRob Suderman llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap; 36cbd47504SRob Suderman llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap; 37cbd47504SRob Suderman }; 38cbd47504SRob Suderman 39cbd47504SRob Suderman // Traverses upwards searchign for the operation mapped by the symbol. 40cbd47504SRob Suderman static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) { 419cebb285SMehdi Amini for (auto *op = baseOp; op; op = op->getParentOp()) { 429cebb285SMehdi Amini auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol); 43cbd47504SRob Suderman if (lookup) 44cbd47504SRob Suderman return lookup; 45cbd47504SRob Suderman } 46cbd47504SRob Suderman return nullptr; 47cbd47504SRob Suderman } 48cbd47504SRob Suderman 49cbd47504SRob Suderman // Builds map from a symbol to MLProgram global symbols loaded or stored 50cbd47504SRob Suderman // during processing. 51cbd47504SRob Suderman LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) { 52cbd47504SRob Suderman llvm::DenseMap<SymbolRefAttr, Operation *> callableMap; 53cbd47504SRob Suderman auto res = module->walk([&](Operation *op) { 54cbd47504SRob Suderman if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) { 55cbd47504SRob Suderman auto callable = caller.getCallableForCallee(); 56cbd47504SRob Suderman // For now we do not know how to handle Value based tracing, so fail. 57cbd47504SRob Suderman if (mlir::isa<Value>(callable)) { 58cbd47504SRob Suderman return WalkResult::interrupt(); 59cbd47504SRob Suderman } 60cbd47504SRob Suderman 61cbd47504SRob Suderman auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable); 629cebb285SMehdi Amini auto *func = getFromSymbol(op, symbol); 63cbd47504SRob Suderman callableMap[symbol] = func; 64cbd47504SRob Suderman } 65cbd47504SRob Suderman return WalkResult::advance(); 66cbd47504SRob Suderman }); 67cbd47504SRob Suderman 68cbd47504SRob Suderman if (res.wasInterrupted()) { 69cbd47504SRob Suderman return failure(); 70cbd47504SRob Suderman } 71cbd47504SRob Suderman 72cbd47504SRob Suderman // First grab all symbols loaded or stored by each function. This 73cbd47504SRob Suderman // will not handle calls initially. 74cbd47504SRob Suderman llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols; 75cbd47504SRob Suderman llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols; 76cbd47504SRob Suderman for (auto callable : callableMap) { 77cbd47504SRob Suderman llvm::DenseSet<SymbolRefAttr> loadSymbols; 78cbd47504SRob Suderman llvm::DenseSet<SymbolRefAttr> storeSymbols; 79cbd47504SRob Suderman 80cbd47504SRob Suderman callable.getSecond()->walk( 81cbd47504SRob Suderman [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); }); 82cbd47504SRob Suderman 83cbd47504SRob Suderman callable.getSecond()->walk( 84cbd47504SRob Suderman [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); }); 85cbd47504SRob Suderman 86cbd47504SRob Suderman opLoadSymbols[callable.getFirst()] = std::move(loadSymbols); 87cbd47504SRob Suderman opStoreSymbols[callable.getFirst()] = std::move(storeSymbols); 88cbd47504SRob Suderman } 89cbd47504SRob Suderman 90cbd47504SRob Suderman // For each callable function we find each global loaded/stored within the 91cbd47504SRob Suderman // function or a nested called function. This includes recursion checking to 92cbd47504SRob Suderman // avoid infinitely recursing. 93cbd47504SRob Suderman for (auto callable : callableMap) { 94cbd47504SRob Suderman SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first); 95cbd47504SRob Suderman llvm::SmallVector<SymbolRefAttr> work = {thisSymbol}; 96cbd47504SRob Suderman llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol}; 97cbd47504SRob Suderman llvm::DenseSet<SymbolRefAttr> loadSymbols; 98cbd47504SRob Suderman llvm::DenseSet<SymbolRefAttr> storeSymbols; 99cbd47504SRob Suderman 100cbd47504SRob Suderman for (size_t i = 0; i < work.size(); ++i) { 101cbd47504SRob Suderman callableMap[work[i]]->walk([&](CallOpInterface call) { 102cbd47504SRob Suderman auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee()); 10371a39ecaSKazu Hirata if (visited.insert(symbol).second) 104cbd47504SRob Suderman work.push_back(symbol); 105cbd47504SRob Suderman }); 106cbd47504SRob Suderman 107cbd47504SRob Suderman for (auto load : opLoadSymbols[work[i]]) 108cbd47504SRob Suderman loadSymbols.insert(load); 109cbd47504SRob Suderman 110cbd47504SRob Suderman for (auto store : opStoreSymbols[work[i]]) 111cbd47504SRob Suderman storeSymbols.insert(store); 112cbd47504SRob Suderman } 113cbd47504SRob Suderman 114cbd47504SRob Suderman loadSymbolsMap[thisSymbol] = std::move(loadSymbols); 115cbd47504SRob Suderman storeSymbolsMap[thisSymbol] = std::move(storeSymbols); 116cbd47504SRob Suderman } 117cbd47504SRob Suderman 118cbd47504SRob Suderman return success(); 119cbd47504SRob Suderman } 120cbd47504SRob Suderman 121cbd47504SRob Suderman // Process each operation in the block deleting unneeded loads / stores, 122cbd47504SRob Suderman // recursing on subblocks and checking function calls. 1234e713357SMehdi Amini void MLProgramPipelineGlobals::processBlock( 124cbd47504SRob Suderman Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad, 125cbd47504SRob Suderman llvm::DenseSet<SymbolRefAttr> &symbolStore) { 126cbd47504SRob Suderman 127cbd47504SRob Suderman llvm::DenseMap<SymbolRefAttr, Value> previousLoads; 128cbd47504SRob Suderman llvm::DenseMap<SymbolRefAttr, Operation *> previousStores; 129cbd47504SRob Suderman llvm::SmallVector<Operation *> toDelete; 130cbd47504SRob Suderman for (auto &op : block) { 131cbd47504SRob Suderman // If this is a global load, remap to a previous value if known 132cbd47504SRob Suderman // and delete this load. Remember that this value is the currently 133cbd47504SRob Suderman // known load. 134cbd47504SRob Suderman if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) { 135cbd47504SRob Suderman auto ref = load.getGlobal(); 136cbd47504SRob Suderman symbolLoad.insert(ref); 137cbd47504SRob Suderman if (previousLoads.contains(ref)) { 138cbd47504SRob Suderman toDelete.push_back(&op); 139cbd47504SRob Suderman load.getResult().replaceAllUsesWith(previousLoads[ref]); 140cbd47504SRob Suderman } else { 141cbd47504SRob Suderman previousLoads[ref] = load.getResult(); 142cbd47504SRob Suderman } 143cbd47504SRob Suderman continue; 144cbd47504SRob Suderman } 145cbd47504SRob Suderman 146cbd47504SRob Suderman // Delete a previous store if it exists and is not needed, update 147cbd47504SRob Suderman // the most recent known value for this global ref. 148cbd47504SRob Suderman if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) { 149cbd47504SRob Suderman auto ref = store.getGlobal(); 150cbd47504SRob Suderman symbolStore.insert(ref); 151*0a20ab90SKazu Hirata auto it = previousStores.find(ref); 152*0a20ab90SKazu Hirata if (it != previousStores.end()) { 153*0a20ab90SKazu Hirata toDelete.push_back(it->getSecond()); 154cbd47504SRob Suderman } 155cbd47504SRob Suderman 156cbd47504SRob Suderman previousLoads[ref] = store.getValue(); 157cbd47504SRob Suderman previousStores[ref] = &op; 158cbd47504SRob Suderman continue; 159cbd47504SRob Suderman } 160cbd47504SRob Suderman 161cbd47504SRob Suderman // If a function is called, clear known values for loads/stores used by 162cbd47504SRob Suderman // the function or its sub-functions. 163cbd47504SRob Suderman if (auto call = mlir::dyn_cast<CallOpInterface>(op)) { 164cbd47504SRob Suderman auto loadSymbols = 165cbd47504SRob Suderman loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())]; 166cbd47504SRob Suderman auto storeSymbols = 167cbd47504SRob Suderman storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())]; 168cbd47504SRob Suderman 169cbd47504SRob Suderman for (auto sym : loadSymbols) { 170cbd47504SRob Suderman previousStores.erase(sym); 171cbd47504SRob Suderman } 172cbd47504SRob Suderman 173cbd47504SRob Suderman for (auto sym : storeSymbols) { 174cbd47504SRob Suderman previousLoads.erase(sym); 175cbd47504SRob Suderman previousStores.erase(sym); 176cbd47504SRob Suderman } 177cbd47504SRob Suderman continue; 178cbd47504SRob Suderman } 179cbd47504SRob Suderman 180cbd47504SRob Suderman // If the op has sub-regions, recurse inside. We make no guarantees whether 181cbd47504SRob Suderman // the recursion occurs. 182cbd47504SRob Suderman llvm::DenseSet<SymbolRefAttr> opSymbolLoad; 183cbd47504SRob Suderman llvm::DenseSet<SymbolRefAttr> opSymbolStore; 184cbd47504SRob Suderman for (auto ®ion : op.getRegions()) { 185cbd47504SRob Suderman for (auto &block : region) { 1864e713357SMehdi Amini processBlock(block, opSymbolLoad, opSymbolStore); 187cbd47504SRob Suderman } 188cbd47504SRob Suderman } 189cbd47504SRob Suderman 190cbd47504SRob Suderman // Update current state from the subblock. 191cbd47504SRob Suderman for (auto change : opSymbolLoad) { 192cbd47504SRob Suderman symbolLoad.insert(change); 193cbd47504SRob Suderman previousStores.erase(change); 194cbd47504SRob Suderman } 195cbd47504SRob Suderman 196cbd47504SRob Suderman for (auto change : opSymbolStore) { 197cbd47504SRob Suderman symbolStore.insert(change); 198cbd47504SRob Suderman previousLoads.erase(change); 199cbd47504SRob Suderman previousStores.erase(change); 200cbd47504SRob Suderman } 201cbd47504SRob Suderman } 202cbd47504SRob Suderman 2039cebb285SMehdi Amini for (auto *op : toDelete) { 204cbd47504SRob Suderman op->erase(); 205cbd47504SRob Suderman } 206cbd47504SRob Suderman } 207cbd47504SRob Suderman 208cbd47504SRob Suderman void MLProgramPipelineGlobals::runOnOperation() { 209cbd47504SRob Suderman auto targetOp = getOperation(); 210cbd47504SRob Suderman if (failed(buildGlobalMap(targetOp))) { 211cbd47504SRob Suderman return; 212cbd47504SRob Suderman } 213cbd47504SRob Suderman 214cbd47504SRob Suderman for (auto &funcOp : *targetOp.getBody()) { 215cbd47504SRob Suderman for (auto ®ion : funcOp.getRegions()) { 216cbd47504SRob Suderman for (auto &block : region.getBlocks()) { 217cbd47504SRob Suderman llvm::DenseSet<SymbolRefAttr> symbolsLoaded; 218cbd47504SRob Suderman llvm::DenseSet<SymbolRefAttr> symbolsStored; 2194e713357SMehdi Amini processBlock(block, symbolsLoaded, symbolsStored); 220cbd47504SRob Suderman } 221cbd47504SRob Suderman } 222cbd47504SRob Suderman } 223cbd47504SRob Suderman } 224cbd47504SRob Suderman 225cbd47504SRob Suderman } // namespace 226cbd47504SRob Suderman 227cbd47504SRob Suderman std::unique_ptr<OperationPass<mlir::ModuleOp>> 228cbd47504SRob Suderman createMLProgramPipelineGlobalsPass() { 229cbd47504SRob Suderman return std::make_unique<MLProgramPipelineGlobals>(); 230cbd47504SRob Suderman } 231cbd47504SRob Suderman 232cbd47504SRob Suderman } // namespace ml_program 233cbd47504SRob Suderman } // namespace mlir 234