xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (revision 92f1562f3dd158d837c66a1dd20ae745477d9c36)
1 //===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Promotion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
15 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/LoopOps/LoopOps.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineExprVisitor.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/FoldUtils.h"
27 
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/Support/CommandLine.h"
30 
31 using namespace mlir;
32 using namespace mlir::edsc;
33 using namespace mlir::edsc::intrinsics;
34 using namespace mlir::linalg;
35 using namespace mlir::loop;
36 
37 using llvm::SetVector;
38 
39 using folded_affine_min = folded::ValueBuilder<AffineMinOp>;
40 using folded_linalg_range = folded::ValueBuilder<linalg::RangeOp>;
41 using folded_std_dim = folded::ValueBuilder<DimOp>;
42 using folded_std_subview = folded::ValueBuilder<SubViewOp>;
43 using folded_std_view = folded::ValueBuilder<ViewOp>;
44 
45 #define DEBUG_TYPE "linalg-promotion"
46 
47 /// If `size` comes from an AffineMinOp and one of the dimensions of AffineMin
48 /// is a constant then return a new value set to the smallest such constant.
49 /// Otherwise return size.
50 static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
51                                                  Value size) {
52   auto affineMinOp = dyn_cast_or_null<AffineMinOp>(size.getDefiningOp());
53   if (!affineMinOp)
54     return size;
55   if (!llvm::any_of(affineMinOp.getAffineMap().getResults(), [](AffineExpr e) {
56         return e.dyn_cast<AffineConstantExpr>();
57       }))
58     return size;
59   int64_t minConst = std::numeric_limits<int64_t>::max();
60   for (auto e : affineMinOp.getAffineMap().getResults())
61     if (auto cst = e.dyn_cast<AffineConstantExpr>())
62       minConst = std::min(minConst, cst.getValue());
63   assert(minConst != std::numeric_limits<int64_t>::max());
64   return b.create<ConstantIndexOp>(loc, minConst);
65 }
66 
67 static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
68                          OperationFolder *folder) {
69   auto *ctx = size.getContext();
70   auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
71   if (!dynamicBuffers)
72     if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
73       return std_alloc(
74           MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
75   Value mul =
76       folded_std_muli(folder, folded_std_constant_index(folder, width), size);
77   return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul);
78 }
79 
80 // Performs promotion of a `subView` into a local buffer of the size of the
81 // *ranges* of the `subView`. This produces a buffer whose size may be bigger
82 // than the actual size of the `subView` at the boundaries.
83 // This is related to the full/partial tile problem.
84 // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
85 // `partialLocalView` such that:
86 //   * `buffer` is always the size of the full tile.
87 //   * `fullLocalView` is a dense contiguous view into that buffer.
88 //   * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
89 //     that corresponds to the size of `subView` and accounting for boundary
90 //     effects.
91 // The point of the full tile buffer is that constant static tile sizes are
92 // folded and result in a buffer type with statically known size and alignment
93 // properties.
94 // To account for general boundary effects, padding must be performed on the
95 // boundary tiles. For now this is done with an unconditional `fill` op followed
96 // by a partial `copy` op.
97 static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
98                                            SubViewOp subView,
99                                            bool dynamicBuffers,
100                                            OperationFolder *folder) {
101   auto zero = folded_std_constant_index(folder, 0);
102   auto one = folded_std_constant_index(folder, 1);
103 
104   auto viewType = subView.getType();
105   auto rank = viewType.getRank();
106   Value allocSize = one;
107   SmallVector<Value, 8> fullSizes, partialSizes;
108   fullSizes.reserve(rank);
109   partialSizes.reserve(rank);
110   for (auto en : llvm::enumerate(subView.getRanges())) {
111     auto rank = en.index();
112     auto rangeValue = en.value();
113     // Try to extract a tight constant
114     Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size);
115     allocSize = folded_std_muli(folder, allocSize, size).getValue();
116     fullSizes.push_back(size);
117     partialSizes.push_back(folded_std_dim(folder, subView, rank));
118   }
119   SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
120   auto buffer =
121       allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers, folder);
122   auto fullLocalView = folded_std_view(
123       folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
124       fullSizes);
125   SmallVector<Value, 4> zeros(fullSizes.size(), zero);
126   SmallVector<Value, 4> ones(fullSizes.size(), one);
127   auto partialLocalView =
128       folded_std_subview(folder, fullLocalView, zeros, partialSizes, ones);
129   return PromotionInfo{buffer, fullLocalView, partialLocalView};
130 }
131 
132 SmallVector<PromotionInfo, 8>
133 mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
134                               ArrayRef<Value> subViews, bool dynamicBuffers,
135                               OperationFolder *folder) {
136   if (subViews.empty())
137     return {};
138 
139   ScopedContext scope(b, loc);
140   SmallVector<PromotionInfo, 8> res;
141   res.reserve(subViews.size());
142   DenseMap<Value, PromotionInfo> promotionInfoMap;
143   for (auto v : subViews) {
144     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
145     auto promotionInfo =
146         promoteFullTileBuffer(b, loc, subView, dynamicBuffers, folder);
147     promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
148     res.push_back(promotionInfo);
149   }
150 
151   for (auto v : subViews) {
152     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
153     auto info = promotionInfoMap.find(v);
154     if (info == promotionInfoMap.end())
155       continue;
156     Value fillVal;
157     if (auto t = subView.getType().getElementType().dyn_cast<FloatType>())
158       fillVal = folded_std_constant(folder, FloatAttr::get(t, 0.0));
159     else if (auto t =
160                  subView.getType().getElementType().dyn_cast<IntegerType>())
161       fillVal = folded_std_constant_int(folder, 0, t);
162     // TODO(ntv): fill is only necessary if `promotionInfo` has a full local
163     // view that is different from the partial local view and we are on the
164     // boundary.
165     linalg_fill(info->second.fullLocalView, fillVal);
166   }
167 
168   for (auto v : subViews) {
169     auto info = promotionInfoMap.find(v);
170     if (info == promotionInfoMap.end())
171       continue;
172     linalg_copy(cast<SubViewOp>(v.getDefiningOp()),
173                 info->second.partialLocalView);
174   }
175   return res;
176 }
177 
178 LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
179                                               SetVector<Value> subViews,
180                                               bool dynamicBuffers,
181                                               OperationFolder *folder) {
182   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
183 
184   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
185     // TODO(ntv): add a level of indirection to linalg.generic.
186     if (convOp.padding())
187       llvm_unreachable("Unexpected conv with padding");
188   }
189 
190   // 1. Promote the specified views and use them in the new op.
191   ScopedContext scope(b, op.getLoc());
192   auto promotedBufferAndViews = promoteSubViews(
193       b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
194   SmallVector<Value, 8> opViews;
195   opViews.reserve(op.getNumInputsAndOutputs());
196   SmallVector<std::pair<Value, Value>, 8> writebackViews;
197   writebackViews.reserve(subViews.size());
198   unsigned promotedIdx = 0;
199   for (auto view : op.getInputsAndOutputBuffers()) {
200     if (subViews.count(view) != 0) {
201       opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
202       writebackViews.emplace_back(std::make_pair(
203           view, promotedBufferAndViews[promotedIdx].partialLocalView));
204       promotedIdx++;
205     } else {
206       opViews.push_back(view);
207     }
208   }
209 
210   // 2. Append all other operands as they appear, this enforces that such
211   // operands are not views. This is to support cases such as FillOp taking
212   // extra scalars etc.
213   auto operands = getAssumedNonViewOperands(op);
214   opViews.append(operands.begin(), operands.end());
215   LinalgOp res = op.clone(b, op.getLoc(), opViews);
216 
217   // 3. Emit write-back for the promoted output views: copy the partial view.
218   for (auto viewAndPartialLocalView : writebackViews) {
219     // WARNING: MUST use the old op to determine whether the operand view is an
220     // output.
221     bool isOutput =
222         op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue();
223     if (isOutput)
224       linalg_copy(viewAndPartialLocalView.second,
225                   viewAndPartialLocalView.first);
226   }
227 
228   // 4. Dealloc local buffers.
229   for (const auto &pi : promotedBufferAndViews)
230     std_dealloc(pi.buffer);
231 
232   return res;
233 }
234 
235 static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
236   SmallVector<LinalgOp, 8> toErase;
237   OperationFolder folder(f.getContext());
238   f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
239     if (!op.hasBufferSemantics())
240       return;
241 
242     // TODO(ntv) some heuristic here to decide what to promote. Atm only float
243     // and integer buffers can be promoted.
244     SetVector<Value> subViews;
245     OpBuilder b(op);
246     for (auto it : op.getInputsAndOutputBuffers())
247       if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
248         if (sv.getType().getElementType().isSignlessIntOrFloat())
249           subViews.insert(sv);
250     if (!subViews.empty()) {
251       promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
252       toErase.push_back(op);
253     }
254   });
255   for (auto op : toErase)
256     op.erase();
257 }
258 
259 namespace {
260 struct LinalgPromotionPass : public LinalgPromotionBase<LinalgPromotionPass> {
261   LinalgPromotionPass() = default;
262   LinalgPromotionPass(bool dynamicBuffers) {
263     this->dynamicBuffers = dynamicBuffers;
264   }
265 
266   void runOnFunction() override {
267     promoteSubViews(getFunction(), dynamicBuffers);
268   }
269 };
270 } // namespace
271 
272 std::unique_ptr<OperationPass<FuncOp>>
273 mlir::createLinalgPromotionPass(bool dynamicBuffers) {
274   return std::make_unique<LinalgPromotionPass>(dynamicBuffers);
275 }
276 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgPromotionPass() {
277   return std::make_unique<LinalgPromotionPass>();
278 }
279