1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for buffer optimization passes. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 14 15 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 16 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/Interfaces/ControlFlowInterfaces.h" 21 #include "mlir/Interfaces/LoopLikeInterface.h" 22 #include "mlir/Pass/Pass.h" 23 #include "llvm/ADT/SetOperations.h" 24 #include "llvm/ADT/SmallString.h" 25 #include <optional> 26 27 using namespace mlir; 28 using namespace mlir::bufferization; 29 30 //===----------------------------------------------------------------------===// 31 // BufferPlacementAllocs 32 //===----------------------------------------------------------------------===// 33 34 /// Get the start operation to place the given alloc value withing the 35 // specified placement block. 36 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue, 37 Block *placementBlock, 38 const Liveness &liveness) { 39 // We have to ensure that we place the alloc before its first use in this 40 // block. 41 const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock); 42 Operation *startOperation = livenessInfo.getStartOperation(allocValue); 43 // Check whether the start operation lies in the desired placement block. 44 // If not, we will use the terminator as this is the last operation in 45 // this block. 46 if (startOperation->getBlock() != placementBlock) { 47 Operation *opInPlacementBlock = 48 placementBlock->findAncestorOpInBlock(*startOperation); 49 startOperation = opInPlacementBlock ? opInPlacementBlock 50 : placementBlock->getTerminator(); 51 } 52 53 return startOperation; 54 } 55 56 /// Initializes the internal list by discovering all supported allocation 57 /// nodes. 58 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } 59 60 /// Searches for and registers all supported allocation entries. 61 void BufferPlacementAllocs::build(Operation *op) { 62 op->walk([&](MemoryEffectOpInterface opInterface) { 63 // Try to find a single allocation result. 64 SmallVector<MemoryEffects::EffectInstance, 2> effects; 65 opInterface.getEffects(effects); 66 67 SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects; 68 llvm::copy_if( 69 effects, std::back_inserter(allocateResultEffects), 70 [=](MemoryEffects::EffectInstance &it) { 71 Value value = it.getValue(); 72 return isa<MemoryEffects::Allocate>(it.getEffect()) && value && 73 isa<OpResult>(value) && 74 it.getResource() != 75 SideEffects::AutomaticAllocationScopeResource::get(); 76 }); 77 // If there is one result only, we will be able to move the allocation and 78 // (possibly existing) deallocation ops. 79 if (allocateResultEffects.size() != 1) 80 return; 81 // Get allocation result. 82 Value allocValue = allocateResultEffects[0].getValue(); 83 // Find the associated dealloc value and register the allocation entry. 84 std::optional<Operation *> dealloc = memref::findDealloc(allocValue); 85 // If the allocation has > 1 dealloc associated with it, skip handling it. 86 if (!dealloc) 87 return; 88 allocs.push_back(std::make_tuple(allocValue, *dealloc)); 89 }); 90 } 91 92 //===----------------------------------------------------------------------===// 93 // BufferPlacementTransformationBase 94 //===----------------------------------------------------------------------===// 95 96 /// Constructs a new transformation base using the given root operation. 97 BufferPlacementTransformationBase::BufferPlacementTransformationBase( 98 Operation *op) 99 : aliases(op), allocs(op), liveness(op) {} 100 101 //===----------------------------------------------------------------------===// 102 // BufferPlacementTransformationBase 103 //===----------------------------------------------------------------------===// 104 105 FailureOr<memref::GlobalOp> 106 bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, 107 Attribute memorySpace) { 108 auto type = cast<RankedTensorType>(constantOp.getType()); 109 auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 110 if (!moduleOp) 111 return failure(); 112 113 // If we already have a global for this constant value, no need to do 114 // anything else. 115 for (Operation &op : moduleOp.getRegion().getOps()) { 116 auto globalOp = dyn_cast<memref::GlobalOp>(&op); 117 if (!globalOp) 118 continue; 119 if (!globalOp.getInitialValue().has_value()) 120 continue; 121 uint64_t opAlignment = globalOp.getAlignment().value_or(0); 122 Attribute initialValue = globalOp.getInitialValue().value(); 123 if (opAlignment == alignment && initialValue == constantOp.getValue()) 124 return globalOp; 125 } 126 127 // Create a builder without an insertion point. We will insert using the 128 // symbol table to guarantee unique names. 129 OpBuilder globalBuilder(moduleOp.getContext()); 130 SymbolTable symbolTable(moduleOp); 131 132 // Create a pretty name. 133 SmallString<64> buf; 134 llvm::raw_svector_ostream os(buf); 135 interleave(type.getShape(), os, "x"); 136 os << "x" << type.getElementType(); 137 138 // Add an optional alignment to the global memref. 139 IntegerAttr memrefAlignment = 140 alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) 141 : IntegerAttr(); 142 143 // Memref globals always have an identity layout. 144 auto memrefType = 145 cast<MemRefType>(getMemRefTypeWithStaticIdentityLayout(type)); 146 if (memorySpace) 147 memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); 148 auto global = globalBuilder.create<memref::GlobalOp>( 149 constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), 150 /*sym_visibility=*/globalBuilder.getStringAttr("private"), 151 /*type=*/memrefType, 152 /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()), 153 /*constant=*/true, 154 /*alignment=*/memrefAlignment); 155 symbolTable.insert(global); 156 // The symbol table inserts at the end of the module, but globals are a bit 157 // nicer if they are at the beginning. 158 global->moveBefore(&moduleOp.front()); 159 return global; 160 } 161