1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for buffer optimization passes. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Interfaces/ControlFlowInterfaces.h" 19 #include "mlir/Interfaces/LoopLikeInterface.h" 20 #include "mlir/Pass/Pass.h" 21 #include "llvm/ADT/SetOperations.h" 22 #include "llvm/ADT/SmallString.h" 23 #include <optional> 24 25 using namespace mlir; 26 using namespace mlir::bufferization; 27 28 //===----------------------------------------------------------------------===// 29 // BufferPlacementAllocs 30 //===----------------------------------------------------------------------===// 31 32 /// Get the start operation to place the given alloc value withing the 33 // specified placement block. 34 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue, 35 Block *placementBlock, 36 const Liveness &liveness) { 37 // We have to ensure that we place the alloc before its first use in this 38 // block. 39 const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock); 40 Operation *startOperation = livenessInfo.getStartOperation(allocValue); 41 // Check whether the start operation lies in the desired placement block. 42 // If not, we will use the terminator as this is the last operation in 43 // this block. 44 if (startOperation->getBlock() != placementBlock) { 45 Operation *opInPlacementBlock = 46 placementBlock->findAncestorOpInBlock(*startOperation); 47 startOperation = opInPlacementBlock ? opInPlacementBlock 48 : placementBlock->getTerminator(); 49 } 50 51 return startOperation; 52 } 53 54 /// Initializes the internal list by discovering all supported allocation 55 /// nodes. 56 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } 57 58 /// Searches for and registers all supported allocation entries. 59 void BufferPlacementAllocs::build(Operation *op) { 60 op->walk([&](MemoryEffectOpInterface opInterface) { 61 // Try to find a single allocation result. 62 SmallVector<MemoryEffects::EffectInstance, 2> effects; 63 opInterface.getEffects(effects); 64 65 SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects; 66 llvm::copy_if( 67 effects, std::back_inserter(allocateResultEffects), 68 [=](MemoryEffects::EffectInstance &it) { 69 Value value = it.getValue(); 70 return isa<MemoryEffects::Allocate>(it.getEffect()) && value && 71 isa<OpResult>(value) && 72 it.getResource() != 73 SideEffects::AutomaticAllocationScopeResource::get(); 74 }); 75 // If there is one result only, we will be able to move the allocation and 76 // (possibly existing) deallocation ops. 77 if (allocateResultEffects.size() != 1) 78 return; 79 // Get allocation result. 80 Value allocValue = allocateResultEffects[0].getValue(); 81 // Find the associated dealloc value and register the allocation entry. 82 std::optional<Operation *> dealloc = memref::findDealloc(allocValue); 83 // If the allocation has > 1 dealloc associated with it, skip handling it. 84 if (!dealloc) 85 return; 86 allocs.push_back(std::make_tuple(allocValue, *dealloc)); 87 }); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // BufferPlacementTransformationBase 92 //===----------------------------------------------------------------------===// 93 94 /// Constructs a new transformation base using the given root operation. 95 BufferPlacementTransformationBase::BufferPlacementTransformationBase( 96 Operation *op) 97 : aliases(op), allocs(op), liveness(op) {} 98 99 //===----------------------------------------------------------------------===// 100 // BufferPlacementTransformationBase 101 //===----------------------------------------------------------------------===// 102 103 FailureOr<memref::GlobalOp> 104 bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, 105 Attribute memorySpace) { 106 auto type = cast<RankedTensorType>(constantOp.getType()); 107 auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 108 if (!moduleOp) 109 return failure(); 110 111 // If we already have a global for this constant value, no need to do 112 // anything else. 113 for (Operation &op : moduleOp.getRegion().getOps()) { 114 auto globalOp = dyn_cast<memref::GlobalOp>(&op); 115 if (!globalOp) 116 continue; 117 if (!globalOp.getInitialValue().has_value()) 118 continue; 119 uint64_t opAlignment = globalOp.getAlignment().value_or(0); 120 Attribute initialValue = globalOp.getInitialValue().value(); 121 if (opAlignment == alignment && initialValue == constantOp.getValue()) 122 return globalOp; 123 } 124 125 // Create a builder without an insertion point. We will insert using the 126 // symbol table to guarantee unique names. 127 OpBuilder globalBuilder(moduleOp.getContext()); 128 SymbolTable symbolTable(moduleOp); 129 130 // Create a pretty name. 131 SmallString<64> buf; 132 llvm::raw_svector_ostream os(buf); 133 interleave(type.getShape(), os, "x"); 134 os << "x" << type.getElementType(); 135 136 // Add an optional alignment to the global memref. 137 IntegerAttr memrefAlignment = 138 alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) 139 : IntegerAttr(); 140 141 BufferizeTypeConverter typeConverter; 142 auto memrefType = cast<MemRefType>(typeConverter.convertType(type)); 143 if (memorySpace) 144 memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); 145 auto global = globalBuilder.create<memref::GlobalOp>( 146 constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), 147 /*sym_visibility=*/globalBuilder.getStringAttr("private"), 148 /*type=*/memrefType, 149 /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()), 150 /*constant=*/true, 151 /*alignment=*/memrefAlignment); 152 symbolTable.insert(global); 153 // The symbol table inserts at the end of the module, but globals are a bit 154 // nicer if they are at the beginning. 155 global->moveBefore(&moduleOp.front()); 156 return global; 157 } 158