10e9a4a3bSRiver Riddle //===- BufferUtils.cpp - buffer transformation utilities ------------------===// 20e9a4a3bSRiver Riddle // 30e9a4a3bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40e9a4a3bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 50e9a4a3bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60e9a4a3bSRiver Riddle // 70e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===// 80e9a4a3bSRiver Riddle // 90e9a4a3bSRiver Riddle // This file implements utilities for buffer optimization passes. 100e9a4a3bSRiver Riddle // 110e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===// 120e9a4a3bSRiver Riddle 130e9a4a3bSRiver Riddle #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 14*2ff2e871SMatthias Springer 15*2ff2e871SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 160e9a4a3bSRiver Riddle #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 170e9a4a3bSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 180e9a4a3bSRiver Riddle #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 190e9a4a3bSRiver Riddle #include "mlir/IR/Operation.h" 200e9a4a3bSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.h" 210e9a4a3bSRiver Riddle #include "mlir/Interfaces/LoopLikeInterface.h" 220e9a4a3bSRiver Riddle #include "mlir/Pass/Pass.h" 230e9a4a3bSRiver Riddle #include "llvm/ADT/SetOperations.h" 2406057248SRiver Riddle #include "llvm/ADT/SmallString.h" 25a1fe1f5fSKazu Hirata #include <optional> 260e9a4a3bSRiver Riddle 270e9a4a3bSRiver Riddle using namespace mlir; 280e9a4a3bSRiver Riddle using namespace mlir::bufferization; 290e9a4a3bSRiver Riddle 300e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===// 310e9a4a3bSRiver Riddle // BufferPlacementAllocs 320e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===// 330e9a4a3bSRiver Riddle 340e9a4a3bSRiver Riddle /// Get the start operation to place the given alloc value withing the 350e9a4a3bSRiver Riddle // specified placement block. 360e9a4a3bSRiver Riddle Operation *BufferPlacementAllocs::getStartOperation(Value allocValue, 370e9a4a3bSRiver Riddle Block *placementBlock, 380e9a4a3bSRiver Riddle const Liveness &liveness) { 390e9a4a3bSRiver Riddle // We have to ensure that we place the alloc before its first use in this 400e9a4a3bSRiver Riddle // block. 410e9a4a3bSRiver Riddle const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock); 420e9a4a3bSRiver Riddle Operation *startOperation = livenessInfo.getStartOperation(allocValue); 430e9a4a3bSRiver Riddle // Check whether the start operation lies in the desired placement block. 440e9a4a3bSRiver Riddle // If not, we will use the terminator as this is the last operation in 450e9a4a3bSRiver Riddle // this block. 460e9a4a3bSRiver Riddle if (startOperation->getBlock() != placementBlock) { 470e9a4a3bSRiver Riddle Operation *opInPlacementBlock = 480e9a4a3bSRiver Riddle placementBlock->findAncestorOpInBlock(*startOperation); 490e9a4a3bSRiver Riddle startOperation = opInPlacementBlock ? opInPlacementBlock 500e9a4a3bSRiver Riddle : placementBlock->getTerminator(); 510e9a4a3bSRiver Riddle } 520e9a4a3bSRiver Riddle 530e9a4a3bSRiver Riddle return startOperation; 540e9a4a3bSRiver Riddle } 550e9a4a3bSRiver Riddle 560e9a4a3bSRiver Riddle /// Initializes the internal list by discovering all supported allocation 570e9a4a3bSRiver Riddle /// nodes. 580e9a4a3bSRiver Riddle BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } 590e9a4a3bSRiver Riddle 600e9a4a3bSRiver Riddle /// Searches for and registers all supported allocation entries. 610e9a4a3bSRiver Riddle void BufferPlacementAllocs::build(Operation *op) { 620e9a4a3bSRiver Riddle op->walk([&](MemoryEffectOpInterface opInterface) { 630e9a4a3bSRiver Riddle // Try to find a single allocation result. 640e9a4a3bSRiver Riddle SmallVector<MemoryEffects::EffectInstance, 2> effects; 650e9a4a3bSRiver Riddle opInterface.getEffects(effects); 660e9a4a3bSRiver Riddle 670e9a4a3bSRiver Riddle SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects; 680e9a4a3bSRiver Riddle llvm::copy_if( 690e9a4a3bSRiver Riddle effects, std::back_inserter(allocateResultEffects), 700e9a4a3bSRiver Riddle [=](MemoryEffects::EffectInstance &it) { 710e9a4a3bSRiver Riddle Value value = it.getValue(); 720e9a4a3bSRiver Riddle return isa<MemoryEffects::Allocate>(it.getEffect()) && value && 735550c821STres Popp isa<OpResult>(value) && 740e9a4a3bSRiver Riddle it.getResource() != 750e9a4a3bSRiver Riddle SideEffects::AutomaticAllocationScopeResource::get(); 760e9a4a3bSRiver Riddle }); 770e9a4a3bSRiver Riddle // If there is one result only, we will be able to move the allocation and 780e9a4a3bSRiver Riddle // (possibly existing) deallocation ops. 790e9a4a3bSRiver Riddle if (allocateResultEffects.size() != 1) 800e9a4a3bSRiver Riddle return; 810e9a4a3bSRiver Riddle // Get allocation result. 820e9a4a3bSRiver Riddle Value allocValue = allocateResultEffects[0].getValue(); 830e9a4a3bSRiver Riddle // Find the associated dealloc value and register the allocation entry. 840a81ace0SKazu Hirata std::optional<Operation *> dealloc = memref::findDealloc(allocValue); 850e9a4a3bSRiver Riddle // If the allocation has > 1 dealloc associated with it, skip handling it. 86037f0995SKazu Hirata if (!dealloc) 870e9a4a3bSRiver Riddle return; 880e9a4a3bSRiver Riddle allocs.push_back(std::make_tuple(allocValue, *dealloc)); 890e9a4a3bSRiver Riddle }); 900e9a4a3bSRiver Riddle } 910e9a4a3bSRiver Riddle 920e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===// 930e9a4a3bSRiver Riddle // BufferPlacementTransformationBase 940e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===// 950e9a4a3bSRiver Riddle 960e9a4a3bSRiver Riddle /// Constructs a new transformation base using the given root operation. 970e9a4a3bSRiver Riddle BufferPlacementTransformationBase::BufferPlacementTransformationBase( 980e9a4a3bSRiver Riddle Operation *op) 990e9a4a3bSRiver Riddle : aliases(op), allocs(op), liveness(op) {} 1000e9a4a3bSRiver Riddle 1010e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===// 1020e9a4a3bSRiver Riddle // BufferPlacementTransformationBase 1030e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===// 1040e9a4a3bSRiver Riddle 105ab47418dSMatthias Springer FailureOr<memref::GlobalOp> 1068bb5ca58SMaya Amrami bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, 1078bb5ca58SMaya Amrami Attribute memorySpace) { 1085550c821STres Popp auto type = cast<RankedTensorType>(constantOp.getType()); 109ab47418dSMatthias Springer auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 110ab47418dSMatthias Springer if (!moduleOp) 111ab47418dSMatthias Springer return failure(); 1120e9a4a3bSRiver Riddle 1130e9a4a3bSRiver Riddle // If we already have a global for this constant value, no need to do 1140e9a4a3bSRiver Riddle // anything else. 115ab47418dSMatthias Springer for (Operation &op : moduleOp.getRegion().getOps()) { 116ab47418dSMatthias Springer auto globalOp = dyn_cast<memref::GlobalOp>(&op); 117ab47418dSMatthias Springer if (!globalOp) 118ab47418dSMatthias Springer continue; 119491d2701SKazu Hirata if (!globalOp.getInitialValue().has_value()) 120ab47418dSMatthias Springer continue; 1213b0dce5bSKazu Hirata uint64_t opAlignment = globalOp.getAlignment().value_or(0); 122c27d8152SKazu Hirata Attribute initialValue = globalOp.getInitialValue().value(); 123ab47418dSMatthias Springer if (opAlignment == alignment && initialValue == constantOp.getValue()) 124ab47418dSMatthias Springer return globalOp; 125ab47418dSMatthias Springer } 1260e9a4a3bSRiver Riddle 1270e9a4a3bSRiver Riddle // Create a builder without an insertion point. We will insert using the 1280e9a4a3bSRiver Riddle // symbol table to guarantee unique names. 1290e9a4a3bSRiver Riddle OpBuilder globalBuilder(moduleOp.getContext()); 1300e9a4a3bSRiver Riddle SymbolTable symbolTable(moduleOp); 1310e9a4a3bSRiver Riddle 1320e9a4a3bSRiver Riddle // Create a pretty name. 1330e9a4a3bSRiver Riddle SmallString<64> buf; 1340e9a4a3bSRiver Riddle llvm::raw_svector_ostream os(buf); 1350e9a4a3bSRiver Riddle interleave(type.getShape(), os, "x"); 1360e9a4a3bSRiver Riddle os << "x" << type.getElementType(); 1370e9a4a3bSRiver Riddle 1380e9a4a3bSRiver Riddle // Add an optional alignment to the global memref. 1390e9a4a3bSRiver Riddle IntegerAttr memrefAlignment = 1400e9a4a3bSRiver Riddle alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) 1410e9a4a3bSRiver Riddle : IntegerAttr(); 1420e9a4a3bSRiver Riddle 143*2ff2e871SMatthias Springer // Memref globals always have an identity layout. 144*2ff2e871SMatthias Springer auto memrefType = 145*2ff2e871SMatthias Springer cast<MemRefType>(getMemRefTypeWithStaticIdentityLayout(type)); 1468bb5ca58SMaya Amrami if (memorySpace) 1478bb5ca58SMaya Amrami memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); 1480e9a4a3bSRiver Riddle auto global = globalBuilder.create<memref::GlobalOp>( 1490e9a4a3bSRiver Riddle constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), 1500e9a4a3bSRiver Riddle /*sym_visibility=*/globalBuilder.getStringAttr("private"), 1518bb5ca58SMaya Amrami /*type=*/memrefType, 1525550c821STres Popp /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()), 1530e9a4a3bSRiver Riddle /*constant=*/true, 1540e9a4a3bSRiver Riddle /*alignment=*/memrefAlignment); 1550e9a4a3bSRiver Riddle symbolTable.insert(global); 1560e9a4a3bSRiver Riddle // The symbol table inserts at the end of the module, but globals are a bit 1570e9a4a3bSRiver Riddle // nicer if they are at the beginning. 1580e9a4a3bSRiver Riddle global->moveBefore(&moduleOp.front()); 1590e9a4a3bSRiver Riddle return global; 1600e9a4a3bSRiver Riddle } 161