xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (revision 56222a0694e4caf35e892d70591417c39fef1185)
1 //===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Promotion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
15 #include "mlir/Dialect/Linalg/Passes.h"
16 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/LoopOps/LoopOps.h"
19 #include "mlir/EDSC/Helpers.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineExprVisitor.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Support/STLExtras.h"
27 #include "mlir/Transforms/FoldUtils.h"
28 
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/Support/CommandLine.h"
31 
32 using namespace mlir;
33 using namespace mlir::edsc;
34 using namespace mlir::edsc::intrinsics;
35 using namespace mlir::linalg;
36 using namespace mlir::linalg::intrinsics;
37 using namespace mlir::loop;
38 
39 using llvm::SetVector;
40 
41 #define DEBUG_TYPE "linalg-promotion"
42 
43 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
44 static llvm::cl::opt<bool> clPromoteDynamic(
45     "test-linalg-promote-dynamic",
46     llvm::cl::desc("Test generation of dynamic promoted buffers"),
47     llvm::cl::cat(clOptionsCategory), llvm::cl::init(false));
48 
49 static ValuePtr allocBuffer(Type elementType, ValuePtr size,
50                             bool dynamicBuffers) {
51   auto *ctx = size->getContext();
52   auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
53   if (!dynamicBuffers)
54     if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
55       return alloc(
56           MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
57   ValuePtr mul = muli(constant_index(width), size);
58   return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul);
59 }
60 
61 // Performs promotion of a `subView` into a local buffer of the size of the
62 // *ranges* of the `subView`. This produces a buffer whose size may be bigger
63 // than the actual size of the `subView` at the boundaries.
64 // This is related to the full/partial tile problem.
65 // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
66 // `partialLocalView` such that:
67 //   * `buffer` is always the size of the full tile.
68 //   * `fullLocalView` is a dense contiguous view into that buffer.
69 //   * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
70 //     that corresponds to the size of `subView` and accounting for boundary
71 //     effects.
72 // The point of the full tile buffer is that constant static tile sizes are
73 // folded and result in a buffer type with statically known size and alignment
74 // properties.
75 // To account for general boundary effects, padding must be performed on the
76 // boundary tiles. For now this is done with an unconditional `fill` op followed
77 // by a partial `copy` op.
78 static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
79                                            SubViewOp subView,
80                                            bool dynamicBuffers,
81                                            OperationFolder *folder) {
82   auto zero = constant_index(folder, 0);
83   auto one = constant_index(folder, 1);
84 
85   auto viewType = subView.getType();
86   auto rank = viewType.getRank();
87   ValuePtr allocSize = one;
88   SmallVector<ValuePtr, 8> fullRanges, partialRanges;
89   fullRanges.reserve(rank);
90   partialRanges.reserve(rank);
91   for (auto en : llvm::enumerate(subView.getRanges())) {
92     auto rank = en.index();
93     auto rangeValue = en.value();
94     ValuePtr d = rangeValue.size;
95     allocSize = muli(folder, allocSize, d).getValue();
96     fullRanges.push_back(d);
97     partialRanges.push_back(range(folder, zero, dim(subView, rank), one));
98   }
99   SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1);
100   auto buffer =
101       allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers);
102   auto fullLocalView = view(
103       MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges);
104   auto partialLocalView = slice(fullLocalView, partialRanges);
105   return PromotionInfo{buffer, fullLocalView, partialLocalView};
106 }
107 
108 SmallVector<PromotionInfo, 8>
109 mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
110                               ArrayRef<ValuePtr> subViews, bool dynamicBuffers,
111                               OperationFolder *folder) {
112   if (subViews.empty())
113     return {};
114 
115   ScopedContext scope(b, loc);
116   SmallVector<PromotionInfo, 8> res;
117   res.reserve(subViews.size());
118   DenseMap<ValuePtr, PromotionInfo> promotionInfoMap;
119   for (auto v : subViews) {
120     SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
121     auto viewType = subView.getType();
122     // TODO(ntv): support more cases than just float.
123     if (!viewType.getElementType().isa<FloatType>())
124       continue;
125     auto promotionInfo =
126         promoteFullTileBuffer(b, loc, subView, dynamicBuffers, folder);
127     promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
128     res.push_back(promotionInfo);
129   }
130 
131   for (auto v : subViews) {
132     SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
133     auto info = promotionInfoMap.find(v);
134     if (info == promotionInfoMap.end())
135       continue;
136     // TODO(ntv): value to fill with should be related to the operation.
137     // For now, just use APFloat(0.0f).
138     auto t = subView.getType().getElementType().cast<FloatType>();
139     ValuePtr fillVal = constant_float(folder, APFloat(0.0f), t);
140     // TODO(ntv): fill is only necessary if `promotionInfo` has a full local
141     // view that is different from the partial local view and we are on the
142     // boundary.
143     fill(info->second.fullLocalView, fillVal);
144   }
145 
146   for (auto v : subViews) {
147     auto info = promotionInfoMap.find(v);
148     if (info == promotionInfoMap.end())
149       continue;
150     copy(cast<SubViewOp>(v->getDefiningOp()), info->second.partialLocalView);
151   }
152   return res;
153 }
154 
155 LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
156                                               SetVector<ValuePtr> subViews,
157                                               bool dynamicBuffers,
158                                               OperationFolder *folder) {
159   // 1. Promote the specified views and use them in the new op.
160   ScopedContext scope(b, op.getLoc());
161   auto promotedBufferAndViews = promoteSubViews(
162       b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
163   SmallVector<ValuePtr, 8> opViews;
164   opViews.reserve(op.getNumInputsAndOutputs());
165   SmallVector<std::pair<ValuePtr, ValuePtr>, 8> writebackViews;
166   writebackViews.reserve(subViews.size());
167   unsigned promotedIdx = 0;
168   for (auto view : op.getInputsAndOutputs()) {
169     if (subViews.count(view) != 0) {
170       opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
171       writebackViews.emplace_back(std::make_pair(
172           view, promotedBufferAndViews[promotedIdx].partialLocalView));
173       promotedIdx++;
174     } else {
175       opViews.push_back(view);
176     }
177   }
178 
179   // 2. Append all other operands as they appear, this enforces that such
180   // operands are not views. This is to support cases such as FillOp taking
181   // extra scalars etc.
182   auto operands = getAssumedNonViewOperands(op);
183   opViews.append(operands.begin(), operands.end());
184   LinalgOp res = op.clone(b, op.getLoc(), opViews);
185 
186   // 3. Emit write-back for the promoted output views: copy the partial view.
187   for (auto viewAndPartialLocalView : writebackViews) {
188     // WARNING: MUST use the old op to determine whether the operand view is an
189     // output.
190     bool isOutput =
191         op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue();
192     if (isOutput)
193       copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first);
194   }
195 
196   // 4. Dealloc local buffers.
197   for (const auto &pi : promotedBufferAndViews)
198     dealloc(pi.buffer);
199 
200   return res;
201 }
202 
203 static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
204   SmallVector<LinalgOp, 8> toErase;
205   OperationFolder folder(f.getContext());
206   f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
207     // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
208     // nothing.
209     SetVector<ValuePtr> subViews;
210     OpBuilder b(op);
211     for (auto it : op.getInputsAndOutputs())
212       if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
213         subViews.insert(sv);
214     if (!subViews.empty()) {
215       promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
216       toErase.push_back(op);
217     }
218   });
219   for (auto op : toErase)
220     op.erase();
221 }
222 
223 namespace {
224 struct LinalgPromotionPass : public FunctionPass<LinalgPromotionPass> {
225   LinalgPromotionPass() = default;
226   LinalgPromotionPass(bool dynamicBuffers) : dynamicBuffers(dynamicBuffers) {}
227 
228   void runOnFunction() override {
229     promoteSubViews(getFunction(), dynamicBuffers);
230   }
231 
232   bool dynamicBuffers;
233 };
234 } // namespace
235 
236 std::unique_ptr<OpPassBase<FuncOp>>
237 mlir::linalg::createLinalgPromotionPass(bool dynamicBuffers) {
238   return std::make_unique<LinalgPromotionPass>(dynamicBuffers);
239 }
240 
241 static PassRegistration<LinalgPromotionPass>
242     pass("linalg-promote-subviews", "promote subview ops to local buffers", [] {
243       return std::make_unique<LinalgPromotionPass>(clPromoteDynamic);
244     });
245