xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (revision 74df89f67f17f1e95c249831ce2d9c9d9830e496)
1 //===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Promotion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/LoopOps/LoopOps.h"
19 #include "mlir/EDSC/Helpers.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineExprVisitor.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Support/STLExtras.h"
27 #include "mlir/Transforms/FoldUtils.h"
28 
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/Support/CommandLine.h"
31 
32 using namespace mlir;
33 using namespace mlir::edsc;
34 using namespace mlir::edsc::intrinsics;
35 using namespace mlir::linalg;
36 using namespace mlir::loop;
37 
38 using llvm::SetVector;
39 
40 #define DEBUG_TYPE "linalg-promotion"
41 
42 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
43 static llvm::cl::opt<bool> clPromoteDynamic(
44     "test-linalg-promote-dynamic",
45     llvm::cl::desc("Test generation of dynamic promoted buffers"),
46     llvm::cl::cat(clOptionsCategory), llvm::cl::init(false));
47 
48 static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) {
49   auto *ctx = size.getContext();
50   auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
51   if (!dynamicBuffers)
52     if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
53       return alloc(
54           MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
55   Value mul = muli(constant_index(width), size);
56   return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul);
57 }
58 
59 // Performs promotion of a `subView` into a local buffer of the size of the
60 // *ranges* of the `subView`. This produces a buffer whose size may be bigger
61 // than the actual size of the `subView` at the boundaries.
62 // This is related to the full/partial tile problem.
63 // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
64 // `partialLocalView` such that:
65 //   * `buffer` is always the size of the full tile.
66 //   * `fullLocalView` is a dense contiguous view into that buffer.
67 //   * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
68 //     that corresponds to the size of `subView` and accounting for boundary
69 //     effects.
70 // The point of the full tile buffer is that constant static tile sizes are
71 // folded and result in a buffer type with statically known size and alignment
72 // properties.
73 // To account for general boundary effects, padding must be performed on the
74 // boundary tiles. For now this is done with an unconditional `fill` op followed
75 // by a partial `copy` op.
76 static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
77                                            SubViewOp subView,
78                                            bool dynamicBuffers,
79                                            OperationFolder *folder) {
80   auto zero = constant_index(folder, 0);
81   auto one = constant_index(folder, 1);
82 
83   auto viewType = subView.getType();
84   auto rank = viewType.getRank();
85   Value allocSize = one;
86   SmallVector<Value, 8> fullRanges, partialRanges;
87   fullRanges.reserve(rank);
88   partialRanges.reserve(rank);
89   for (auto en : llvm::enumerate(subView.getRanges())) {
90     auto rank = en.index();
91     auto rangeValue = en.value();
92     Value d = rangeValue.size;
93     allocSize = muli(folder, allocSize, d).getValue();
94     fullRanges.push_back(d);
95     partialRanges.push_back(
96         linalg_range(folder, zero, dim(subView, rank), one));
97   }
98   SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1);
99   auto buffer =
100       allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers);
101   auto fullLocalView = view(
102       MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges);
103   auto partialLocalView = linalg_slice(fullLocalView, partialRanges);
104   return PromotionInfo{buffer, fullLocalView, partialLocalView};
105 }
106 
107 SmallVector<PromotionInfo, 8>
108 mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
109                               ArrayRef<Value> subViews, bool dynamicBuffers,
110                               OperationFolder *folder) {
111   if (subViews.empty())
112     return {};
113 
114   ScopedContext scope(b, loc);
115   SmallVector<PromotionInfo, 8> res;
116   res.reserve(subViews.size());
117   DenseMap<Value, PromotionInfo> promotionInfoMap;
118   for (auto v : subViews) {
119     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
120     auto viewType = subView.getType();
121     // TODO(ntv): support more cases than just float.
122     if (!viewType.getElementType().isa<FloatType>())
123       continue;
124     auto promotionInfo =
125         promoteFullTileBuffer(b, loc, subView, dynamicBuffers, folder);
126     promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
127     res.push_back(promotionInfo);
128   }
129 
130   for (auto v : subViews) {
131     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
132     auto info = promotionInfoMap.find(v);
133     if (info == promotionInfoMap.end())
134       continue;
135     // TODO(ntv): value to fill with should be related to the operation.
136     // For now, just use APFloat(0.0f).
137     auto t = subView.getType().getElementType().cast<FloatType>();
138     Value fillVal = constant_float(folder, APFloat(0.0f), t);
139     // TODO(ntv): fill is only necessary if `promotionInfo` has a full local
140     // view that is different from the partial local view and we are on the
141     // boundary.
142     linalg_fill(info->second.fullLocalView, fillVal);
143   }
144 
145   for (auto v : subViews) {
146     auto info = promotionInfoMap.find(v);
147     if (info == promotionInfoMap.end())
148       continue;
149     linalg_copy(cast<SubViewOp>(v.getDefiningOp()),
150                 info->second.partialLocalView);
151   }
152   return res;
153 }
154 
155 LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
156                                               SetVector<Value> subViews,
157                                               bool dynamicBuffers,
158                                               OperationFolder *folder) {
159   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
160 
161   // 1. Promote the specified views and use them in the new op.
162   ScopedContext scope(b, op.getLoc());
163   auto promotedBufferAndViews = promoteSubViews(
164       b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
165   SmallVector<Value, 8> opViews;
166   opViews.reserve(op.getNumInputsAndOutputs());
167   SmallVector<std::pair<Value, Value>, 8> writebackViews;
168   writebackViews.reserve(subViews.size());
169   unsigned promotedIdx = 0;
170   for (auto view : op.getInputsAndOutputBuffers()) {
171     if (subViews.count(view) != 0) {
172       opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
173       writebackViews.emplace_back(std::make_pair(
174           view, promotedBufferAndViews[promotedIdx].partialLocalView));
175       promotedIdx++;
176     } else {
177       opViews.push_back(view);
178     }
179   }
180 
181   // 2. Append all other operands as they appear, this enforces that such
182   // operands are not views. This is to support cases such as FillOp taking
183   // extra scalars etc.
184   auto operands = getAssumedNonViewOperands(op);
185   opViews.append(operands.begin(), operands.end());
186   LinalgOp res = op.clone(b, op.getLoc(), opViews);
187 
188   // 3. Emit write-back for the promoted output views: copy the partial view.
189   for (auto viewAndPartialLocalView : writebackViews) {
190     // WARNING: MUST use the old op to determine whether the operand view is an
191     // output.
192     bool isOutput =
193         op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue();
194     if (isOutput)
195       linalg_copy(viewAndPartialLocalView.second,
196                   viewAndPartialLocalView.first);
197   }
198 
199   // 4. Dealloc local buffers.
200   for (const auto &pi : promotedBufferAndViews)
201     dealloc(pi.buffer);
202 
203   return res;
204 }
205 
206 static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
207   SmallVector<LinalgOp, 8> toErase;
208   OperationFolder folder(f.getContext());
209   f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
210     if (!op.hasBufferSemantics())
211       return;
212 
213     // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
214     // nothing.
215     SetVector<Value> subViews;
216     OpBuilder b(op);
217     for (auto it : op.getInputsAndOutputBuffers())
218       if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
219         subViews.insert(sv);
220     if (!subViews.empty()) {
221       promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
222       toErase.push_back(op);
223     }
224   });
225   for (auto op : toErase)
226     op.erase();
227 }
228 
229 namespace {
230 struct LinalgPromotionPass : public FunctionPass<LinalgPromotionPass> {
231   LinalgPromotionPass() = default;
232   LinalgPromotionPass(bool dynamicBuffers) : dynamicBuffers(dynamicBuffers) {}
233 
234   void runOnFunction() override {
235     promoteSubViews(getFunction(), dynamicBuffers);
236   }
237 
238   bool dynamicBuffers;
239 };
240 } // namespace
241 
242 std::unique_ptr<OpPassBase<FuncOp>>
243 mlir::createLinalgPromotionPass(bool dynamicBuffers) {
244   return std::make_unique<LinalgPromotionPass>(dynamicBuffers);
245 }
246 
247 static PassRegistration<LinalgPromotionPass>
248     pass("linalg-promote-subviews", "promote subview ops to local buffers", [] {
249       return std::make_unique<LinalgPromotionPass>(clPromoteDynamic);
250     });
251