xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp (revision 024f562da67180b7be1663048c960b26c2cc16f8)
1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for buffer optimization passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Interfaces/ControlFlowInterfaces.h"
19 #include "mlir/Interfaces/LoopLikeInterface.h"
20 #include "mlir/Pass/Pass.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallString.h"
23 #include <optional>
24 
25 using namespace mlir;
26 using namespace mlir::bufferization;
27 
28 //===----------------------------------------------------------------------===//
29 // BufferPlacementAllocs
30 //===----------------------------------------------------------------------===//
31 
32 /// Get the start operation to place the given alloc value withing the
33 // specified placement block.
34 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue,
35                                                     Block *placementBlock,
36                                                     const Liveness &liveness) {
37   // We have to ensure that we place the alloc before its first use in this
38   // block.
39   const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock);
40   Operation *startOperation = livenessInfo.getStartOperation(allocValue);
41   // Check whether the start operation lies in the desired placement block.
42   // If not, we will use the terminator as this is the last operation in
43   // this block.
44   if (startOperation->getBlock() != placementBlock) {
45     Operation *opInPlacementBlock =
46         placementBlock->findAncestorOpInBlock(*startOperation);
47     startOperation = opInPlacementBlock ? opInPlacementBlock
48                                         : placementBlock->getTerminator();
49   }
50 
51   return startOperation;
52 }
53 
54 /// Initializes the internal list by discovering all supported allocation
55 /// nodes.
56 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); }
57 
58 /// Searches for and registers all supported allocation entries.
59 void BufferPlacementAllocs::build(Operation *op) {
60   op->walk([&](MemoryEffectOpInterface opInterface) {
61     // Try to find a single allocation result.
62     SmallVector<MemoryEffects::EffectInstance, 2> effects;
63     opInterface.getEffects(effects);
64 
65     SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
66     llvm::copy_if(
67         effects, std::back_inserter(allocateResultEffects),
68         [=](MemoryEffects::EffectInstance &it) {
69           Value value = it.getValue();
70           return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
71                  isa<OpResult>(value) &&
72                  it.getResource() !=
73                      SideEffects::AutomaticAllocationScopeResource::get();
74         });
75     // If there is one result only, we will be able to move the allocation and
76     // (possibly existing) deallocation ops.
77     if (allocateResultEffects.size() != 1)
78       return;
79     // Get allocation result.
80     Value allocValue = allocateResultEffects[0].getValue();
81     // Find the associated dealloc value and register the allocation entry.
82     std::optional<Operation *> dealloc = memref::findDealloc(allocValue);
83     // If the allocation has > 1 dealloc associated with it, skip handling it.
84     if (!dealloc)
85       return;
86     allocs.push_back(std::make_tuple(allocValue, *dealloc));
87   });
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // BufferPlacementTransformationBase
92 //===----------------------------------------------------------------------===//
93 
94 /// Constructs a new transformation base using the given root operation.
95 BufferPlacementTransformationBase::BufferPlacementTransformationBase(
96     Operation *op)
97     : aliases(op), allocs(op), liveness(op) {}
98 
99 /// Returns true if the given operation represents a loop by testing whether it
100 /// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
101 /// the case of a `RegionBranchOpInterface`, it checks all region-based control-
102 /// flow edges for cycles.
103 bool BufferPlacementTransformationBase::isLoop(Operation *op) {
104   // If the operation implements the `LoopLikeOpInterface` it can be considered
105   // a loop.
106   if (isa<LoopLikeOpInterface>(op))
107     return true;
108 
109   // If the operation does not implement the `RegionBranchOpInterface`, it is
110   // (currently) not possible to detect a loop.
111   RegionBranchOpInterface regionInterface;
112   if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
113     return false;
114 
115   // Recurses into a region using the current region interface to find potential
116   // cycles.
117   SmallPtrSet<Region *, 4> visitedRegions;
118   std::function<bool(Region *)> recurse = [&](Region *current) {
119     if (!current)
120       return false;
121     // If we have found a back edge, the parent operation induces a loop.
122     if (!visitedRegions.insert(current).second)
123       return true;
124     // Recurses into all region successors.
125     SmallVector<RegionSuccessor, 2> successors;
126     regionInterface.getSuccessorRegions(current, successors);
127     for (RegionSuccessor &regionEntry : successors)
128       if (recurse(regionEntry.getSuccessor()))
129         return true;
130     return false;
131   };
132 
133   // Start with all entry regions and test whether they induce a loop.
134   SmallVector<RegionSuccessor, 2> successorRegions;
135   regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
136                                       successorRegions);
137   for (RegionSuccessor &regionEntry : successorRegions) {
138     if (recurse(regionEntry.getSuccessor()))
139       return true;
140     visitedRegions.clear();
141   }
142 
143   return false;
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // BufferPlacementTransformationBase
148 //===----------------------------------------------------------------------===//
149 
150 FailureOr<memref::GlobalOp>
151 bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
152                             Attribute memorySpace) {
153   auto type = cast<RankedTensorType>(constantOp.getType());
154   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
155   if (!moduleOp)
156     return failure();
157 
158   // If we already have a global for this constant value, no need to do
159   // anything else.
160   for (Operation &op : moduleOp.getRegion().getOps()) {
161     auto globalOp = dyn_cast<memref::GlobalOp>(&op);
162     if (!globalOp)
163       continue;
164     if (!globalOp.getInitialValue().has_value())
165       continue;
166     uint64_t opAlignment = globalOp.getAlignment().value_or(0);
167     Attribute initialValue = globalOp.getInitialValue().value();
168     if (opAlignment == alignment && initialValue == constantOp.getValue())
169       return globalOp;
170   }
171 
172   // Create a builder without an insertion point. We will insert using the
173   // symbol table to guarantee unique names.
174   OpBuilder globalBuilder(moduleOp.getContext());
175   SymbolTable symbolTable(moduleOp);
176 
177   // Create a pretty name.
178   SmallString<64> buf;
179   llvm::raw_svector_ostream os(buf);
180   interleave(type.getShape(), os, "x");
181   os << "x" << type.getElementType();
182 
183   // Add an optional alignment to the global memref.
184   IntegerAttr memrefAlignment =
185       alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
186                     : IntegerAttr();
187 
188   BufferizeTypeConverter typeConverter;
189   auto memrefType = cast<MemRefType>(typeConverter.convertType(type));
190   if (memorySpace)
191     memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
192   auto global = globalBuilder.create<memref::GlobalOp>(
193       constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
194       /*sym_visibility=*/globalBuilder.getStringAttr("private"),
195       /*type=*/memrefType,
196       /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()),
197       /*constant=*/true,
198       /*alignment=*/memrefAlignment);
199   symbolTable.insert(global);
200   // The symbol table inserts at the end of the module, but globals are a bit
201   // nicer if they are at the beginning.
202   global->moveBefore(&moduleOp.front());
203   return global;
204 }
205