xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (revision 92f7e8133ae98e1f300bad164c4099b2e609bae7)
1 //===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Promotion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/AffineOps/EDSC/Intrinsics.h"
14 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/LoopOps/LoopOps.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineExprVisitor.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Support/STLExtras.h"
28 #include "mlir/Transforms/FoldUtils.h"
29 
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/Support/CommandLine.h"
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 using namespace mlir::loop;
38 
39 using llvm::SetVector;
40 
41 using folded_affine_min = folded::ValueBuilder<AffineMinOp>;
42 using folded_linalg_range = folded::ValueBuilder<linalg::RangeOp>;
43 
44 #define DEBUG_TYPE "linalg-promotion"
45 
46 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
47 static llvm::cl::opt<bool> clPromoteDynamic(
48     "test-linalg-promote-dynamic",
49     llvm::cl::desc("Test generation of dynamic promoted buffers"),
50     llvm::cl::cat(clOptionsCategory), llvm::cl::init(false));
51 
52 static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) {
53   auto *ctx = size.getContext();
54   auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
55   if (!dynamicBuffers)
56     if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
57       return std_alloc(
58           MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
59   Value mul = std_muli(std_constant_index(width), size);
60   return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul);
61 }
62 
63 // Performs promotion of a `subView` into a local buffer of the size of the
64 // *ranges* of the `subView`. This produces a buffer whose size may be bigger
65 // than the actual size of the `subView` at the boundaries.
66 // This is related to the full/partial tile problem.
67 // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
68 // `partialLocalView` such that:
69 //   * `buffer` is always the size of the full tile.
70 //   * `fullLocalView` is a dense contiguous view into that buffer.
71 //   * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
72 //     that corresponds to the size of `subView` and accounting for boundary
73 //     effects.
74 // The point of the full tile buffer is that constant static tile sizes are
75 // folded and result in a buffer type with statically known size and alignment
76 // properties.
77 // To account for general boundary effects, padding must be performed on the
78 // boundary tiles. For now this is done with an unconditional `fill` op followed
79 // by a partial `copy` op.
80 static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
81                                            SubViewOp subView,
82                                            bool dynamicBuffers,
83                                            OperationFolder *folder) {
84   auto zero = folded_std_constant_index(folder, 0);
85   auto one = folded_std_constant_index(folder, 1);
86 
87   auto viewType = subView.getType();
88   auto rank = viewType.getRank();
89   Value allocSize = one;
90   SmallVector<Value, 8> fullRanges, partialRanges;
91   fullRanges.reserve(rank);
92   partialRanges.reserve(rank);
93   for (auto en : llvm::enumerate(subView.getRanges())) {
94     auto rank = en.index();
95     auto rangeValue = en.value();
96     Value d = rangeValue.size;
97     allocSize = folded_std_muli(folder, allocSize, d).getValue();
98     fullRanges.push_back(d);
99     partialRanges.push_back(
100         folded_linalg_range(folder, zero, std_dim(subView, rank), one));
101   }
102   SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1);
103   auto buffer =
104       allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers);
105   auto fullLocalView = std_view(
106       MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges);
107   auto partialLocalView = linalg_slice(fullLocalView, partialRanges);
108   return PromotionInfo{buffer, fullLocalView, partialLocalView};
109 }
110 
111 SmallVector<PromotionInfo, 8>
112 mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
113                               ArrayRef<Value> subViews, bool dynamicBuffers,
114                               OperationFolder *folder) {
115   if (subViews.empty())
116     return {};
117 
118   ScopedContext scope(b, loc);
119   SmallVector<PromotionInfo, 8> res;
120   res.reserve(subViews.size());
121   DenseMap<Value, PromotionInfo> promotionInfoMap;
122   for (auto v : subViews) {
123     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
124     auto promotionInfo =
125         promoteFullTileBuffer(b, loc, subView, dynamicBuffers, folder);
126     promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
127     res.push_back(promotionInfo);
128   }
129 
130   for (auto v : subViews) {
131     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
132     auto info = promotionInfoMap.find(v);
133     if (info == promotionInfoMap.end())
134       continue;
135     Value fillVal;
136     if (auto t = subView.getType().getElementType().dyn_cast<FloatType>())
137       fillVal = folded_std_constant(folder, FloatAttr::get(t, 0.0));
138     else if (auto t =
139                  subView.getType().getElementType().dyn_cast<IntegerType>())
140       fillVal = folded_std_constant_int(folder, 0, t);
141     // TODO(ntv): fill is only necessary if `promotionInfo` has a full local
142     // view that is different from the partial local view and we are on the
143     // boundary.
144     linalg_fill(info->second.fullLocalView, fillVal);
145   }
146 
147   for (auto v : subViews) {
148     auto info = promotionInfoMap.find(v);
149     if (info == promotionInfoMap.end())
150       continue;
151     linalg_copy(cast<SubViewOp>(v.getDefiningOp()),
152                 info->second.partialLocalView);
153   }
154   return res;
155 }
156 
157 LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
158                                               SetVector<Value> subViews,
159                                               bool dynamicBuffers,
160                                               OperationFolder *folder) {
161   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
162 
163   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
164     // TODO(ntv): add a level of indirection to linalg.generic.
165     if (convOp.padding())
166       llvm_unreachable("Unexpected conv with padding");
167   }
168 
169   // 1. Promote the specified views and use them in the new op.
170   ScopedContext scope(b, op.getLoc());
171   auto promotedBufferAndViews = promoteSubViews(
172       b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
173   SmallVector<Value, 8> opViews;
174   opViews.reserve(op.getNumInputsAndOutputs());
175   SmallVector<std::pair<Value, Value>, 8> writebackViews;
176   writebackViews.reserve(subViews.size());
177   unsigned promotedIdx = 0;
178   for (auto view : op.getInputsAndOutputBuffers()) {
179     if (subViews.count(view) != 0) {
180       opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
181       writebackViews.emplace_back(std::make_pair(
182           view, promotedBufferAndViews[promotedIdx].partialLocalView));
183       promotedIdx++;
184     } else {
185       opViews.push_back(view);
186     }
187   }
188 
189   // 2. Append all other operands as they appear, this enforces that such
190   // operands are not views. This is to support cases such as FillOp taking
191   // extra scalars etc.
192   auto operands = getAssumedNonViewOperands(op);
193   opViews.append(operands.begin(), operands.end());
194   LinalgOp res = op.clone(b, op.getLoc(), opViews);
195 
196   // 3. Emit write-back for the promoted output views: copy the partial view.
197   for (auto viewAndPartialLocalView : writebackViews) {
198     // WARNING: MUST use the old op to determine whether the operand view is an
199     // output.
200     bool isOutput =
201         op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue();
202     if (isOutput)
203       linalg_copy(viewAndPartialLocalView.second,
204                   viewAndPartialLocalView.first);
205   }
206 
207   // 4. Dealloc local buffers.
208   for (const auto &pi : promotedBufferAndViews)
209     std_dealloc(pi.buffer);
210 
211   return res;
212 }
213 
214 static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
215   SmallVector<LinalgOp, 8> toErase;
216   OperationFolder folder(f.getContext());
217   f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
218     if (!op.hasBufferSemantics())
219       return;
220 
221     // TODO(ntv) some heuristic here to decide what to promote. Atm only float
222     // and integer buffers can be promoted.
223     SetVector<Value> subViews;
224     OpBuilder b(op);
225     for (auto it : op.getInputsAndOutputBuffers())
226       if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
227         if (sv.getType().getElementType().isSignlessIntOrFloat())
228           subViews.insert(sv);
229     if (!subViews.empty()) {
230       promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
231       toErase.push_back(op);
232     }
233   });
234   for (auto op : toErase)
235     op.erase();
236 }
237 
238 namespace {
239 struct LinalgPromotionPass : public FunctionPass<LinalgPromotionPass> {
240   LinalgPromotionPass() = default;
241   LinalgPromotionPass(bool dynamicBuffers) : dynamicBuffers(dynamicBuffers) {}
242 
243   void runOnFunction() override {
244     promoteSubViews(getFunction(), dynamicBuffers);
245   }
246 
247   bool dynamicBuffers;
248 };
249 } // namespace
250 
251 std::unique_ptr<OpPassBase<FuncOp>>
252 mlir::createLinalgPromotionPass(bool dynamicBuffers) {
253   return std::make_unique<LinalgPromotionPass>(dynamicBuffers);
254 }
255 
256 static PassRegistration<LinalgPromotionPass>
257     pass("linalg-promote-subviews", "promote subview ops to local buffers", [] {
258       return std::make_unique<LinalgPromotionPass>(clPromoteDynamic);
259     });
260