xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp (revision 2ff2e871f5e632ea493efaf4f2192f8b18a54ab1)
1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for buffer optimization passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
14 
15 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
16 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Interfaces/ControlFlowInterfaces.h"
21 #include "mlir/Interfaces/LoopLikeInterface.h"
22 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/SetOperations.h"
24 #include "llvm/ADT/SmallString.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::bufferization;
29 
30 //===----------------------------------------------------------------------===//
31 // BufferPlacementAllocs
32 //===----------------------------------------------------------------------===//
33 
34 /// Get the start operation to place the given alloc value withing the
35 // specified placement block.
36 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue,
37                                                     Block *placementBlock,
38                                                     const Liveness &liveness) {
39   // We have to ensure that we place the alloc before its first use in this
40   // block.
41   const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock);
42   Operation *startOperation = livenessInfo.getStartOperation(allocValue);
43   // Check whether the start operation lies in the desired placement block.
44   // If not, we will use the terminator as this is the last operation in
45   // this block.
46   if (startOperation->getBlock() != placementBlock) {
47     Operation *opInPlacementBlock =
48         placementBlock->findAncestorOpInBlock(*startOperation);
49     startOperation = opInPlacementBlock ? opInPlacementBlock
50                                         : placementBlock->getTerminator();
51   }
52 
53   return startOperation;
54 }
55 
56 /// Initializes the internal list by discovering all supported allocation
57 /// nodes.
58 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); }
59 
60 /// Searches for and registers all supported allocation entries.
61 void BufferPlacementAllocs::build(Operation *op) {
62   op->walk([&](MemoryEffectOpInterface opInterface) {
63     // Try to find a single allocation result.
64     SmallVector<MemoryEffects::EffectInstance, 2> effects;
65     opInterface.getEffects(effects);
66 
67     SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
68     llvm::copy_if(
69         effects, std::back_inserter(allocateResultEffects),
70         [=](MemoryEffects::EffectInstance &it) {
71           Value value = it.getValue();
72           return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
73                  isa<OpResult>(value) &&
74                  it.getResource() !=
75                      SideEffects::AutomaticAllocationScopeResource::get();
76         });
77     // If there is one result only, we will be able to move the allocation and
78     // (possibly existing) deallocation ops.
79     if (allocateResultEffects.size() != 1)
80       return;
81     // Get allocation result.
82     Value allocValue = allocateResultEffects[0].getValue();
83     // Find the associated dealloc value and register the allocation entry.
84     std::optional<Operation *> dealloc = memref::findDealloc(allocValue);
85     // If the allocation has > 1 dealloc associated with it, skip handling it.
86     if (!dealloc)
87       return;
88     allocs.push_back(std::make_tuple(allocValue, *dealloc));
89   });
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // BufferPlacementTransformationBase
94 //===----------------------------------------------------------------------===//
95 
96 /// Constructs a new transformation base using the given root operation.
97 BufferPlacementTransformationBase::BufferPlacementTransformationBase(
98     Operation *op)
99     : aliases(op), allocs(op), liveness(op) {}
100 
101 //===----------------------------------------------------------------------===//
102 // BufferPlacementTransformationBase
103 //===----------------------------------------------------------------------===//
104 
105 FailureOr<memref::GlobalOp>
106 bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
107                             Attribute memorySpace) {
108   auto type = cast<RankedTensorType>(constantOp.getType());
109   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
110   if (!moduleOp)
111     return failure();
112 
113   // If we already have a global for this constant value, no need to do
114   // anything else.
115   for (Operation &op : moduleOp.getRegion().getOps()) {
116     auto globalOp = dyn_cast<memref::GlobalOp>(&op);
117     if (!globalOp)
118       continue;
119     if (!globalOp.getInitialValue().has_value())
120       continue;
121     uint64_t opAlignment = globalOp.getAlignment().value_or(0);
122     Attribute initialValue = globalOp.getInitialValue().value();
123     if (opAlignment == alignment && initialValue == constantOp.getValue())
124       return globalOp;
125   }
126 
127   // Create a builder without an insertion point. We will insert using the
128   // symbol table to guarantee unique names.
129   OpBuilder globalBuilder(moduleOp.getContext());
130   SymbolTable symbolTable(moduleOp);
131 
132   // Create a pretty name.
133   SmallString<64> buf;
134   llvm::raw_svector_ostream os(buf);
135   interleave(type.getShape(), os, "x");
136   os << "x" << type.getElementType();
137 
138   // Add an optional alignment to the global memref.
139   IntegerAttr memrefAlignment =
140       alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
141                     : IntegerAttr();
142 
143   // Memref globals always have an identity layout.
144   auto memrefType =
145       cast<MemRefType>(getMemRefTypeWithStaticIdentityLayout(type));
146   if (memorySpace)
147     memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
148   auto global = globalBuilder.create<memref::GlobalOp>(
149       constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
150       /*sym_visibility=*/globalBuilder.getStringAttr("private"),
151       /*type=*/memrefType,
152       /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()),
153       /*constant=*/true,
154       /*alignment=*/memrefAlignment);
155   symbolTable.insert(global);
156   // The symbol table inserts at the end of the module, but globals are a bit
157   // nicer if they are at the beginning.
158   global->moveBefore(&moduleOp.front());
159   return global;
160 }
161