xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp (revision aa8feeefd3ac6c78ee8f67bf033976fc7d68bc6d)
1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for buffer optimization passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Interfaces/ControlFlowInterfaces.h"
19 #include "mlir/Interfaces/LoopLikeInterface.h"
20 #include "mlir/Pass/Pass.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallString.h"
23 
24 using namespace mlir;
25 using namespace mlir::bufferization;
26 
27 //===----------------------------------------------------------------------===//
28 // BufferPlacementAllocs
29 //===----------------------------------------------------------------------===//
30 
31 /// Get the start operation to place the given alloc value withing the
32 // specified placement block.
33 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue,
34                                                     Block *placementBlock,
35                                                     const Liveness &liveness) {
36   // We have to ensure that we place the alloc before its first use in this
37   // block.
38   const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock);
39   Operation *startOperation = livenessInfo.getStartOperation(allocValue);
40   // Check whether the start operation lies in the desired placement block.
41   // If not, we will use the terminator as this is the last operation in
42   // this block.
43   if (startOperation->getBlock() != placementBlock) {
44     Operation *opInPlacementBlock =
45         placementBlock->findAncestorOpInBlock(*startOperation);
46     startOperation = opInPlacementBlock ? opInPlacementBlock
47                                         : placementBlock->getTerminator();
48   }
49 
50   return startOperation;
51 }
52 
53 /// Initializes the internal list by discovering all supported allocation
54 /// nodes.
55 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); }
56 
57 /// Searches for and registers all supported allocation entries.
58 void BufferPlacementAllocs::build(Operation *op) {
59   op->walk([&](MemoryEffectOpInterface opInterface) {
60     // Try to find a single allocation result.
61     SmallVector<MemoryEffects::EffectInstance, 2> effects;
62     opInterface.getEffects(effects);
63 
64     SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
65     llvm::copy_if(
66         effects, std::back_inserter(allocateResultEffects),
67         [=](MemoryEffects::EffectInstance &it) {
68           Value value = it.getValue();
69           return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
70                  value.isa<OpResult>() &&
71                  it.getResource() !=
72                      SideEffects::AutomaticAllocationScopeResource::get();
73         });
74     // If there is one result only, we will be able to move the allocation and
75     // (possibly existing) deallocation ops.
76     if (allocateResultEffects.size() != 1)
77       return;
78     // Get allocation result.
79     Value allocValue = allocateResultEffects[0].getValue();
80     // Find the associated dealloc value and register the allocation entry.
81     llvm::Optional<Operation *> dealloc = memref::findDealloc(allocValue);
82     // If the allocation has > 1 dealloc associated with it, skip handling it.
83     if (!dealloc)
84       return;
85     allocs.push_back(std::make_tuple(allocValue, *dealloc));
86   });
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // BufferPlacementTransformationBase
91 //===----------------------------------------------------------------------===//
92 
93 /// Constructs a new transformation base using the given root operation.
94 BufferPlacementTransformationBase::BufferPlacementTransformationBase(
95     Operation *op)
96     : aliases(op), allocs(op), liveness(op) {}
97 
98 /// Returns true if the given operation represents a loop by testing whether it
99 /// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
100 /// the case of a `RegionBranchOpInterface`, it checks all region-based control-
101 /// flow edges for cycles.
102 bool BufferPlacementTransformationBase::isLoop(Operation *op) {
103   // If the operation implements the `LoopLikeOpInterface` it can be considered
104   // a loop.
105   if (isa<LoopLikeOpInterface>(op))
106     return true;
107 
108   // If the operation does not implement the `RegionBranchOpInterface`, it is
109   // (currently) not possible to detect a loop.
110   RegionBranchOpInterface regionInterface;
111   if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
112     return false;
113 
114   // Recurses into a region using the current region interface to find potential
115   // cycles.
116   SmallPtrSet<Region *, 4> visitedRegions;
117   std::function<bool(Region *)> recurse = [&](Region *current) {
118     if (!current)
119       return false;
120     // If we have found a back edge, the parent operation induces a loop.
121     if (!visitedRegions.insert(current).second)
122       return true;
123     // Recurses into all region successors.
124     SmallVector<RegionSuccessor, 2> successors;
125     regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
126     for (RegionSuccessor &regionEntry : successors)
127       if (recurse(regionEntry.getSuccessor()))
128         return true;
129     return false;
130   };
131 
132   // Start with all entry regions and test whether they induce a loop.
133   SmallVector<RegionSuccessor, 2> successorRegions;
134   regionInterface.getSuccessorRegions(/*index=*/llvm::None, successorRegions);
135   for (RegionSuccessor &regionEntry : successorRegions) {
136     if (recurse(regionEntry.getSuccessor()))
137       return true;
138     visitedRegions.clear();
139   }
140 
141   return false;
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // BufferPlacementTransformationBase
146 //===----------------------------------------------------------------------===//
147 
148 FailureOr<memref::GlobalOp>
149 bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) {
150   auto type = constantOp.getType().cast<RankedTensorType>();
151   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
152   if (!moduleOp)
153     return failure();
154 
155   // If we already have a global for this constant value, no need to do
156   // anything else.
157   for (Operation &op : moduleOp.getRegion().getOps()) {
158     auto globalOp = dyn_cast<memref::GlobalOp>(&op);
159     if (!globalOp)
160       continue;
161     if (!globalOp.initial_value().has_value())
162       continue;
163     uint64_t opAlignment =
164         globalOp.alignment().has_value() ? globalOp.alignment().value() : 0;
165     Attribute initialValue = globalOp.initial_value().value();
166     if (opAlignment == alignment && initialValue == constantOp.getValue())
167       return globalOp;
168   }
169 
170   // Create a builder without an insertion point. We will insert using the
171   // symbol table to guarantee unique names.
172   OpBuilder globalBuilder(moduleOp.getContext());
173   SymbolTable symbolTable(moduleOp);
174 
175   // Create a pretty name.
176   SmallString<64> buf;
177   llvm::raw_svector_ostream os(buf);
178   interleave(type.getShape(), os, "x");
179   os << "x" << type.getElementType();
180 
181   // Add an optional alignment to the global memref.
182   IntegerAttr memrefAlignment =
183       alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
184                     : IntegerAttr();
185 
186   BufferizeTypeConverter typeConverter;
187   auto global = globalBuilder.create<memref::GlobalOp>(
188       constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
189       /*sym_visibility=*/globalBuilder.getStringAttr("private"),
190       /*type=*/typeConverter.convertType(type).cast<MemRefType>(),
191       /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
192       /*constant=*/true,
193       /*alignment=*/memrefAlignment);
194   symbolTable.insert(global);
195   // The symbol table inserts at the end of the module, but globals are a bit
196   // nicer if they are at the beginning.
197   global->moveBefore(&moduleOp.front());
198   return global;
199 }
200