xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp (revision 30916b6942371fc314f3ce1bfa4042cae3e6ff28)
1 //===- IndependenceTransforms.cpp - Make ops independent of values --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/Utils/StaticValueUtils.h"
15 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
16 
17 using namespace mlir;
18 using namespace mlir::memref;
19 
20 /// Make the given OpFoldResult independent of all independencies.
21 static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
22                                                OpFoldResult ofr,
23                                                ValueRange independencies) {
24   if (isa<Attribute>(ofr))
25     return ofr;
26   AffineMap boundMap;
27   ValueDimList mapOperands;
28   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
29           boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
30           /*closedUB=*/true)))
31     return failure();
32   return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
33 }
34 
35 FailureOr<Value> memref::buildIndependentOp(OpBuilder &b,
36                                             memref::AllocaOp allocaOp,
37                                             ValueRange independencies) {
38   OpBuilder::InsertionGuard g(b);
39   b.setInsertionPoint(allocaOp);
40   Location loc = allocaOp.getLoc();
41 
42   SmallVector<OpFoldResult> newSizes;
43   for (OpFoldResult ofr : allocaOp.getMixedSizes()) {
44     auto ub = makeIndependent(b, loc, ofr, independencies);
45     if (failed(ub))
46       return failure();
47     newSizes.push_back(*ub);
48   }
49 
50   // Return existing memref::AllocaOp if nothing has changed.
51   if (llvm::equal(allocaOp.getMixedSizes(), newSizes))
52     return allocaOp.getResult();
53 
54   // Create a new memref::AllocaOp.
55   Value newAllocaOp =
56       b.create<AllocaOp>(loc, newSizes, allocaOp.getType().getElementType());
57 
58   // Create a memref::SubViewOp.
59   SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0));
60   SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1));
61   return b
62       .create<SubViewOp>(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(),
63                          strides)
64       .getResult();
65 }
66 
67 /// Push down an UnrealizedConversionCastOp past a SubViewOp.
68 static UnrealizedConversionCastOp
69 propagateSubViewOp(RewriterBase &rewriter,
70                    UnrealizedConversionCastOp conversionOp, SubViewOp op) {
71   OpBuilder::InsertionGuard g(rewriter);
72   rewriter.setInsertionPoint(op);
73   auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
74       op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
75       op.getMixedSizes(), op.getMixedStrides()));
76   Value newSubview = rewriter.create<SubViewOp>(
77       op.getLoc(), newResultType, conversionOp.getOperand(0),
78       op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
79   auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>(
80       op.getLoc(), op.getType(), newSubview);
81   rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0));
82   return newConversionOp;
83 }
84 
85 /// Given an original op and a new, modified op with the same number of results,
86 /// whose memref return types may differ, replace all uses of the original op
87 /// with the new op and propagate the new memref types through the IR.
88 ///
89 /// Example:
90 /// %from = memref.alloca(%sz) : memref<?xf32>
91 /// %to = memref.subview ... : ... to memref<?xf32, strided<[1], offset: ?>>
92 /// memref.store %cst, %from[%c0] : memref<?xf32>
93 ///
94 /// In the above example, all uses of %from are replaced with %to. This can be
95 /// done directly for ops such as memref.store. For ops that have memref results
96 /// (e.g., memref.subview), the result type may depend on the operand type, so
97 /// we cannot just replace all uses. There is special handling for common memref
98 /// ops. For all other ops, unrealized_conversion_cast is inserted.
99 static void replaceAndPropagateMemRefType(RewriterBase &rewriter,
100                                           Operation *from, Operation *to) {
101   assert(from->getNumResults() == to->getNumResults() &&
102          "expected same number of results");
103   OpBuilder::InsertionGuard g(rewriter);
104   rewriter.setInsertionPointAfter(to);
105 
106   // Wrap new results in unrealized_conversion_cast and replace all uses of the
107   // original op.
108   SmallVector<UnrealizedConversionCastOp> unrealizedConversions;
109   for (const auto &it :
110        llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) {
111     unrealizedConversions.push_back(rewriter.create<UnrealizedConversionCastOp>(
112         to->getLoc(), std::get<0>(it.value()).getType(),
113         std::get<1>(it.value())));
114     rewriter.replaceAllUsesWith(from->getResult(it.index()),
115                                 unrealizedConversions.back()->getResult(0));
116   }
117 
118   // Push unrealized_conversion_cast ops further down in the IR. I.e., try to
119   // wrap results instead of operands in a cast.
120   for (int i = 0; i < static_cast<int>(unrealizedConversions.size()); ++i) {
121     UnrealizedConversionCastOp conversion = unrealizedConversions[i];
122     assert(conversion->getNumOperands() == 1 &&
123            conversion->getNumResults() == 1 &&
124            "expected single operand and single result");
125     SmallVector<Operation *> users = llvm::to_vector(conversion->getUsers());
126     for (Operation *user : users) {
127       // Handle common memref dialect ops that produce new memrefs and must
128       // be recreated with the new result type.
129       if (auto subviewOp = dyn_cast<SubViewOp>(user)) {
130         unrealizedConversions.push_back(
131             propagateSubViewOp(rewriter, conversion, subviewOp));
132         continue;
133       }
134 
135       // TODO: Other memref ops such as memref.collapse_shape/expand_shape
136       // should also be handled here.
137 
138       // Skip any ops that produce MemRef result or have MemRef region block
139       // arguments. These may need special handling (e.g., scf.for).
140       if (llvm::any_of(user->getResultTypes(),
141                        [](Type t) { return isa<MemRefType>(t); }))
142         continue;
143       if (llvm::any_of(user->getRegions(), [](Region &r) {
144             return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) {
145               return isa<MemRefType>(bbArg.getType());
146             });
147           }))
148         continue;
149 
150       // For all other ops, we assume that we can directly replace the operand.
151       // This may have to be revised in the future; e.g., there may be ops that
152       // do not support non-identity layout maps.
153       for (OpOperand &operand : user->getOpOperands()) {
154         if ([[maybe_unused]] auto castOp =
155                 operand.get().getDefiningOp<UnrealizedConversionCastOp>()) {
156           rewriter.modifyOpInPlace(
157               user, [&]() { operand.set(conversion->getOperand(0)); });
158         }
159       }
160     }
161   }
162 
163   // Erase all unrealized_conversion_cast ops without uses.
164   for (auto op : unrealizedConversions)
165     if (op->getUses().empty())
166       rewriter.eraseOp(op);
167 }
168 
169 FailureOr<Value> memref::replaceWithIndependentOp(RewriterBase &rewriter,
170                                                   memref::AllocaOp allocaOp,
171                                                   ValueRange independencies) {
172   auto replacement =
173       memref::buildIndependentOp(rewriter, allocaOp, independencies);
174   if (failed(replacement))
175     return failure();
176   replaceAndPropagateMemRefType(rewriter, allocaOp,
177                                 replacement->getDefiningOp());
178   return replacement;
179 }
180 
181 memref::AllocaOp memref::allocToAlloca(
182     RewriterBase &rewriter, memref::AllocOp alloc,
183     function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) {
184   memref::DeallocOp dealloc = nullptr;
185   for (Operation &candidate :
186        llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) {
187     dealloc = dyn_cast<memref::DeallocOp>(candidate);
188     if (dealloc && dealloc.getMemref() == alloc.getMemref() &&
189         (!filter || filter(alloc, dealloc))) {
190       break;
191     }
192   }
193 
194   if (!dealloc)
195     return nullptr;
196 
197   OpBuilder::InsertionGuard guard(rewriter);
198   rewriter.setInsertionPoint(alloc);
199   auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
200       alloc, alloc.getMemref().getType(), alloc.getOperands());
201   rewriter.eraseOp(dealloc);
202   return alloca;
203 }
204