1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for buffer optimization passes. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Interfaces/ControlFlowInterfaces.h" 19 #include "mlir/Interfaces/LoopLikeInterface.h" 20 #include "mlir/Pass/Pass.h" 21 #include "llvm/ADT/SetOperations.h" 22 #include "llvm/ADT/SmallString.h" 23 #include <optional> 24 25 using namespace mlir; 26 using namespace mlir::bufferization; 27 28 //===----------------------------------------------------------------------===// 29 // BufferPlacementAllocs 30 //===----------------------------------------------------------------------===// 31 32 /// Get the start operation to place the given alloc value withing the 33 // specified placement block. 34 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue, 35 Block *placementBlock, 36 const Liveness &liveness) { 37 // We have to ensure that we place the alloc before its first use in this 38 // block. 39 const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock); 40 Operation *startOperation = livenessInfo.getStartOperation(allocValue); 41 // Check whether the start operation lies in the desired placement block. 42 // If not, we will use the terminator as this is the last operation in 43 // this block. 44 if (startOperation->getBlock() != placementBlock) { 45 Operation *opInPlacementBlock = 46 placementBlock->findAncestorOpInBlock(*startOperation); 47 startOperation = opInPlacementBlock ? opInPlacementBlock 48 : placementBlock->getTerminator(); 49 } 50 51 return startOperation; 52 } 53 54 /// Initializes the internal list by discovering all supported allocation 55 /// nodes. 56 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } 57 58 /// Searches for and registers all supported allocation entries. 59 void BufferPlacementAllocs::build(Operation *op) { 60 op->walk([&](MemoryEffectOpInterface opInterface) { 61 // Try to find a single allocation result. 62 SmallVector<MemoryEffects::EffectInstance, 2> effects; 63 opInterface.getEffects(effects); 64 65 SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects; 66 llvm::copy_if( 67 effects, std::back_inserter(allocateResultEffects), 68 [=](MemoryEffects::EffectInstance &it) { 69 Value value = it.getValue(); 70 return isa<MemoryEffects::Allocate>(it.getEffect()) && value && 71 value.isa<OpResult>() && 72 it.getResource() != 73 SideEffects::AutomaticAllocationScopeResource::get(); 74 }); 75 // If there is one result only, we will be able to move the allocation and 76 // (possibly existing) deallocation ops. 77 if (allocateResultEffects.size() != 1) 78 return; 79 // Get allocation result. 80 Value allocValue = allocateResultEffects[0].getValue(); 81 // Find the associated dealloc value and register the allocation entry. 82 std::optional<Operation *> dealloc = memref::findDealloc(allocValue); 83 // If the allocation has > 1 dealloc associated with it, skip handling it. 84 if (!dealloc) 85 return; 86 allocs.push_back(std::make_tuple(allocValue, *dealloc)); 87 }); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // BufferPlacementTransformationBase 92 //===----------------------------------------------------------------------===// 93 94 /// Constructs a new transformation base using the given root operation. 95 BufferPlacementTransformationBase::BufferPlacementTransformationBase( 96 Operation *op) 97 : aliases(op), allocs(op), liveness(op) {} 98 99 /// Returns true if the given operation represents a loop by testing whether it 100 /// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In 101 /// the case of a `RegionBranchOpInterface`, it checks all region-based control- 102 /// flow edges for cycles. 103 bool BufferPlacementTransformationBase::isLoop(Operation *op) { 104 // If the operation implements the `LoopLikeOpInterface` it can be considered 105 // a loop. 106 if (isa<LoopLikeOpInterface>(op)) 107 return true; 108 109 // If the operation does not implement the `RegionBranchOpInterface`, it is 110 // (currently) not possible to detect a loop. 111 RegionBranchOpInterface regionInterface; 112 if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op))) 113 return false; 114 115 // Recurses into a region using the current region interface to find potential 116 // cycles. 117 SmallPtrSet<Region *, 4> visitedRegions; 118 std::function<bool(Region *)> recurse = [&](Region *current) { 119 if (!current) 120 return false; 121 // If we have found a back edge, the parent operation induces a loop. 122 if (!visitedRegions.insert(current).second) 123 return true; 124 // Recurses into all region successors. 125 SmallVector<RegionSuccessor, 2> successors; 126 regionInterface.getSuccessorRegions(current->getRegionNumber(), successors); 127 for (RegionSuccessor ®ionEntry : successors) 128 if (recurse(regionEntry.getSuccessor())) 129 return true; 130 return false; 131 }; 132 133 // Start with all entry regions and test whether they induce a loop. 134 SmallVector<RegionSuccessor, 2> successorRegions; 135 regionInterface.getSuccessorRegions(/*index=*/std::nullopt, successorRegions); 136 for (RegionSuccessor ®ionEntry : successorRegions) { 137 if (recurse(regionEntry.getSuccessor())) 138 return true; 139 visitedRegions.clear(); 140 } 141 142 return false; 143 } 144 145 //===----------------------------------------------------------------------===// 146 // BufferPlacementTransformationBase 147 //===----------------------------------------------------------------------===// 148 149 FailureOr<memref::GlobalOp> 150 bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, 151 Attribute memorySpace) { 152 auto type = constantOp.getType().cast<RankedTensorType>(); 153 auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 154 if (!moduleOp) 155 return failure(); 156 157 // If we already have a global for this constant value, no need to do 158 // anything else. 159 for (Operation &op : moduleOp.getRegion().getOps()) { 160 auto globalOp = dyn_cast<memref::GlobalOp>(&op); 161 if (!globalOp) 162 continue; 163 if (!globalOp.getInitialValue().has_value()) 164 continue; 165 uint64_t opAlignment = globalOp.getAlignment().value_or(0); 166 Attribute initialValue = globalOp.getInitialValue().value(); 167 if (opAlignment == alignment && initialValue == constantOp.getValue()) 168 return globalOp; 169 } 170 171 // Create a builder without an insertion point. We will insert using the 172 // symbol table to guarantee unique names. 173 OpBuilder globalBuilder(moduleOp.getContext()); 174 SymbolTable symbolTable(moduleOp); 175 176 // Create a pretty name. 177 SmallString<64> buf; 178 llvm::raw_svector_ostream os(buf); 179 interleave(type.getShape(), os, "x"); 180 os << "x" << type.getElementType(); 181 182 // Add an optional alignment to the global memref. 183 IntegerAttr memrefAlignment = 184 alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) 185 : IntegerAttr(); 186 187 BufferizeTypeConverter typeConverter; 188 auto memrefType = typeConverter.convertType(type).cast<MemRefType>(); 189 if (memorySpace) 190 memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); 191 auto global = globalBuilder.create<memref::GlobalOp>( 192 constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), 193 /*sym_visibility=*/globalBuilder.getStringAttr("private"), 194 /*type=*/memrefType, 195 /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(), 196 /*constant=*/true, 197 /*alignment=*/memrefAlignment); 198 symbolTable.insert(global); 199 // The symbol table inserts at the end of the module, but globals are a bit 200 // nicer if they are at the beginning. 201 global->moveBefore(&moduleOp.front()); 202 return global; 203 } 204