1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for buffer optimization passes. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Interfaces/ControlFlowInterfaces.h" 19 #include "mlir/Interfaces/LoopLikeInterface.h" 20 #include "mlir/Pass/Pass.h" 21 #include "llvm/ADT/SetOperations.h" 22 #include "llvm/ADT/SmallString.h" 23 24 using namespace mlir; 25 using namespace mlir::bufferization; 26 27 //===----------------------------------------------------------------------===// 28 // BufferPlacementAllocs 29 //===----------------------------------------------------------------------===// 30 31 /// Get the start operation to place the given alloc value withing the 32 // specified placement block. 33 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue, 34 Block *placementBlock, 35 const Liveness &liveness) { 36 // We have to ensure that we place the alloc before its first use in this 37 // block. 38 const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock); 39 Operation *startOperation = livenessInfo.getStartOperation(allocValue); 40 // Check whether the start operation lies in the desired placement block. 41 // If not, we will use the terminator as this is the last operation in 42 // this block. 43 if (startOperation->getBlock() != placementBlock) { 44 Operation *opInPlacementBlock = 45 placementBlock->findAncestorOpInBlock(*startOperation); 46 startOperation = opInPlacementBlock ? opInPlacementBlock 47 : placementBlock->getTerminator(); 48 } 49 50 return startOperation; 51 } 52 53 /// Initializes the internal list by discovering all supported allocation 54 /// nodes. 55 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } 56 57 /// Searches for and registers all supported allocation entries. 58 void BufferPlacementAllocs::build(Operation *op) { 59 op->walk([&](MemoryEffectOpInterface opInterface) { 60 // Try to find a single allocation result. 61 SmallVector<MemoryEffects::EffectInstance, 2> effects; 62 opInterface.getEffects(effects); 63 64 SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects; 65 llvm::copy_if( 66 effects, std::back_inserter(allocateResultEffects), 67 [=](MemoryEffects::EffectInstance &it) { 68 Value value = it.getValue(); 69 return isa<MemoryEffects::Allocate>(it.getEffect()) && value && 70 value.isa<OpResult>() && 71 it.getResource() != 72 SideEffects::AutomaticAllocationScopeResource::get(); 73 }); 74 // If there is one result only, we will be able to move the allocation and 75 // (possibly existing) deallocation ops. 76 if (allocateResultEffects.size() != 1) 77 return; 78 // Get allocation result. 79 Value allocValue = allocateResultEffects[0].getValue(); 80 // Find the associated dealloc value and register the allocation entry. 81 llvm::Optional<Operation *> dealloc = memref::findDealloc(allocValue); 82 // If the allocation has > 1 dealloc associated with it, skip handling it. 83 if (!dealloc) 84 return; 85 allocs.push_back(std::make_tuple(allocValue, *dealloc)); 86 }); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // BufferPlacementTransformationBase 91 //===----------------------------------------------------------------------===// 92 93 /// Constructs a new transformation base using the given root operation. 94 BufferPlacementTransformationBase::BufferPlacementTransformationBase( 95 Operation *op) 96 : aliases(op), allocs(op), liveness(op) {} 97 98 /// Returns true if the given operation represents a loop by testing whether it 99 /// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In 100 /// the case of a `RegionBranchOpInterface`, it checks all region-based control- 101 /// flow edges for cycles. 102 bool BufferPlacementTransformationBase::isLoop(Operation *op) { 103 // If the operation implements the `LoopLikeOpInterface` it can be considered 104 // a loop. 105 if (isa<LoopLikeOpInterface>(op)) 106 return true; 107 108 // If the operation does not implement the `RegionBranchOpInterface`, it is 109 // (currently) not possible to detect a loop. 110 RegionBranchOpInterface regionInterface; 111 if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op))) 112 return false; 113 114 // Recurses into a region using the current region interface to find potential 115 // cycles. 116 SmallPtrSet<Region *, 4> visitedRegions; 117 std::function<bool(Region *)> recurse = [&](Region *current) { 118 if (!current) 119 return false; 120 // If we have found a back edge, the parent operation induces a loop. 121 if (!visitedRegions.insert(current).second) 122 return true; 123 // Recurses into all region successors. 124 SmallVector<RegionSuccessor, 2> successors; 125 regionInterface.getSuccessorRegions(current->getRegionNumber(), successors); 126 for (RegionSuccessor ®ionEntry : successors) 127 if (recurse(regionEntry.getSuccessor())) 128 return true; 129 return false; 130 }; 131 132 // Start with all entry regions and test whether they induce a loop. 133 SmallVector<RegionSuccessor, 2> successorRegions; 134 regionInterface.getSuccessorRegions(/*index=*/llvm::None, successorRegions); 135 for (RegionSuccessor ®ionEntry : successorRegions) { 136 if (recurse(regionEntry.getSuccessor())) 137 return true; 138 visitedRegions.clear(); 139 } 140 141 return false; 142 } 143 144 //===----------------------------------------------------------------------===// 145 // BufferPlacementTransformationBase 146 //===----------------------------------------------------------------------===// 147 148 FailureOr<memref::GlobalOp> 149 bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) { 150 auto type = constantOp.getType().cast<RankedTensorType>(); 151 auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 152 if (!moduleOp) 153 return failure(); 154 155 // If we already have a global for this constant value, no need to do 156 // anything else. 157 for (Operation &op : moduleOp.getRegion().getOps()) { 158 auto globalOp = dyn_cast<memref::GlobalOp>(&op); 159 if (!globalOp) 160 continue; 161 if (!globalOp.initial_value().has_value()) 162 continue; 163 uint64_t opAlignment = 164 globalOp.alignment().has_value() ? globalOp.alignment().value() : 0; 165 Attribute initialValue = globalOp.initial_value().value(); 166 if (opAlignment == alignment && initialValue == constantOp.getValue()) 167 return globalOp; 168 } 169 170 // Create a builder without an insertion point. We will insert using the 171 // symbol table to guarantee unique names. 172 OpBuilder globalBuilder(moduleOp.getContext()); 173 SymbolTable symbolTable(moduleOp); 174 175 // Create a pretty name. 176 SmallString<64> buf; 177 llvm::raw_svector_ostream os(buf); 178 interleave(type.getShape(), os, "x"); 179 os << "x" << type.getElementType(); 180 181 // Add an optional alignment to the global memref. 182 IntegerAttr memrefAlignment = 183 alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) 184 : IntegerAttr(); 185 186 BufferizeTypeConverter typeConverter; 187 auto global = globalBuilder.create<memref::GlobalOp>( 188 constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), 189 /*sym_visibility=*/globalBuilder.getStringAttr("private"), 190 /*type=*/typeConverter.convertType(type).cast<MemRefType>(), 191 /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(), 192 /*constant=*/true, 193 /*alignment=*/memrefAlignment); 194 symbolTable.insert(global); 195 // The symbol table inserts at the end of the module, but globals are a bit 196 // nicer if they are at the beginning. 197 global->moveBefore(&moduleOp.front()); 198 return global; 199 } 200