xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (revision 367229e100eca714276253bf95a0dd3d084a9624)
1 //===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Promotion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
15 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/LoopOps/LoopOps.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineExprVisitor.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/FoldUtils.h"
27 
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/Support/CommandLine.h"
30 
31 using namespace mlir;
32 using namespace mlir::edsc;
33 using namespace mlir::edsc::intrinsics;
34 using namespace mlir::linalg;
35 using namespace mlir::loop;
36 
37 using llvm::SetVector;
38 
39 using folded_affine_min = FoldedValueBuilder<AffineMinOp>;
40 using folded_linalg_range = FoldedValueBuilder<linalg::RangeOp>;
41 using folded_std_dim = FoldedValueBuilder<DimOp>;
42 using folded_std_subview = FoldedValueBuilder<SubViewOp>;
43 using folded_std_view = FoldedValueBuilder<ViewOp>;
44 
45 #define DEBUG_TYPE "linalg-promotion"
46 
47 /// If `size` comes from an AffineMinOp and one of the dimensions of AffineMin
48 /// is a constant then return a new value set to the smallest such constant.
49 /// Otherwise return size.
50 static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
51                                                  Value size) {
52   auto affineMinOp = dyn_cast_or_null<AffineMinOp>(size.getDefiningOp());
53   if (!affineMinOp)
54     return size;
55   if (!llvm::any_of(affineMinOp.getAffineMap().getResults(), [](AffineExpr e) {
56         return e.dyn_cast<AffineConstantExpr>();
57       }))
58     return size;
59   int64_t minConst = std::numeric_limits<int64_t>::max();
60   for (auto e : affineMinOp.getAffineMap().getResults())
61     if (auto cst = e.dyn_cast<AffineConstantExpr>())
62       minConst = std::min(minConst, cst.getValue());
63   assert(minConst != std::numeric_limits<int64_t>::max());
64   return b.create<ConstantIndexOp>(loc, minConst);
65 }
66 
67 static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
68                          OperationFolder *folder, int64_t alignment = 0) {
69   auto *ctx = size.getContext();
70   auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
71   IntegerAttr alignment_attr;
72   if (alignment)
73     alignment_attr = IntegerAttr::get(IntegerType::get(64, ctx), alignment);
74   if (!dynamicBuffers)
75     if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
76       return std_alloc(
77           MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)),
78           ValueRange{}, alignment_attr);
79   Value mul =
80       folded_std_muli(folder, folded_std_constant_index(folder, width), size);
81   return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
82                    alignment_attr);
83 }
84 
85 // Performs promotion of a `subView` into a local buffer of the size of the
86 // *ranges* of the `subView`. This produces a buffer whose size may be bigger
87 // than the actual size of the `subView` at the boundaries.
88 // This is related to the full/partial tile problem.
89 // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
90 // `partialLocalView` such that:
91 //   * `buffer` is always the size of the full tile.
92 //   * `fullLocalView` is a dense contiguous view into that buffer.
93 //   * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
94 //     that corresponds to the size of `subView` and accounting for boundary
95 //     effects.
96 // The point of the full tile buffer is that constant static tile sizes are
97 // folded and result in a buffer type with statically known size and alignment
98 // properties.
99 // To account for general boundary effects, padding must be performed on the
100 // boundary tiles. For now this is done with an unconditional `fill` op followed
101 // by a partial `copy` op.
102 static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
103                                            SubViewOp subView,
104                                            bool dynamicBuffers,
105                                            int64_t alignment,
106                                            OperationFolder *folder) {
107   auto zero = folded_std_constant_index(folder, 0);
108   auto one = folded_std_constant_index(folder, 1);
109 
110   auto viewType = subView.getType();
111   auto rank = viewType.getRank();
112   Value allocSize = one;
113   SmallVector<Value, 8> fullSizes, partialSizes;
114   fullSizes.reserve(rank);
115   partialSizes.reserve(rank);
116   for (auto en : llvm::enumerate(subView.getRanges())) {
117     auto rank = en.index();
118     auto rangeValue = en.value();
119     // Try to extract a tight constant
120     Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size);
121     allocSize = folded_std_muli(folder, allocSize, size);
122     fullSizes.push_back(size);
123     partialSizes.push_back(folded_std_dim(folder, subView, rank));
124   }
125   SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
126   auto buffer = allocBuffer(viewType.getElementType(), allocSize,
127                             dynamicBuffers, folder, alignment);
128   auto fullLocalView = folded_std_view(
129       folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
130       fullSizes);
131   SmallVector<Value, 4> zeros(fullSizes.size(), zero);
132   SmallVector<Value, 4> ones(fullSizes.size(), one);
133   auto partialLocalView =
134       folded_std_subview(folder, fullLocalView, zeros, partialSizes, ones);
135   return PromotionInfo{buffer, fullLocalView, partialLocalView};
136 }
137 
138 SmallVector<PromotionInfo, 8>
139 mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
140                               ArrayRef<Value> subViews, bool dynamicBuffers,
141                               int64_t alignment, OperationFolder *folder) {
142   if (subViews.empty())
143     return {};
144 
145   ScopedContext scope(b, loc);
146   SmallVector<PromotionInfo, 8> res;
147   res.reserve(subViews.size());
148   DenseMap<Value, PromotionInfo> promotionInfoMap;
149   for (auto v : subViews) {
150     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
151     auto promotionInfo = promoteFullTileBuffer(b, loc, subView, dynamicBuffers,
152                                                alignment, folder);
153     promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
154     res.push_back(promotionInfo);
155   }
156 
157   for (auto v : subViews) {
158     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
159     auto info = promotionInfoMap.find(v);
160     if (info == promotionInfoMap.end())
161       continue;
162     Value fillVal;
163     if (auto t = subView.getType().getElementType().dyn_cast<FloatType>())
164       fillVal = folded_std_constant(folder, FloatAttr::get(t, 0.0));
165     else if (auto t =
166                  subView.getType().getElementType().dyn_cast<IntegerType>())
167       fillVal = folded_std_constant_int(folder, 0, t);
168     // TODO(ntv): fill is only necessary if `promotionInfo` has a full local
169     // view that is different from the partial local view and we are on the
170     // boundary.
171     linalg_fill(info->second.fullLocalView, fillVal);
172   }
173 
174   for (auto v : subViews) {
175     auto info = promotionInfoMap.find(v);
176     if (info == promotionInfoMap.end())
177       continue;
178     linalg_copy(cast<SubViewOp>(v.getDefiningOp()),
179                 info->second.partialLocalView);
180   }
181   return res;
182 }
183 
184 LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
185                                               SetVector<Value> subViews,
186                                               bool dynamicBuffers,
187                                               int64_t alignment,
188                                               OperationFolder *folder) {
189   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
190 
191   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
192     // TODO(ntv): add a level of indirection to linalg.generic.
193     if (convOp.padding())
194       llvm_unreachable("Unexpected conv with padding");
195   }
196 
197   // 1. Promote the specified views and use them in the new op.
198   ScopedContext scope(b, op.getLoc());
199   auto promotedBufferAndViews =
200       promoteSubViews(b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers,
201                       alignment, folder);
202   SmallVector<Value, 8> opViews;
203   opViews.reserve(op.getNumInputsAndOutputs());
204   SmallVector<std::pair<Value, Value>, 8> writebackViews;
205   writebackViews.reserve(subViews.size());
206   unsigned promotedIdx = 0;
207   for (auto view : op.getInputsAndOutputBuffers()) {
208     if (subViews.count(view) != 0) {
209       opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
210       writebackViews.emplace_back(std::make_pair(
211           view, promotedBufferAndViews[promotedIdx].partialLocalView));
212       promotedIdx++;
213     } else {
214       opViews.push_back(view);
215     }
216   }
217 
218   // 2. Append all other operands as they appear, this enforces that such
219   // operands are not views. This is to support cases such as FillOp taking
220   // extra scalars etc.
221   auto operands = getAssumedNonViewOperands(op);
222   opViews.append(operands.begin(), operands.end());
223   LinalgOp res = op.clone(b, op.getLoc(), opViews);
224 
225   // 3. Emit write-back for the promoted output views: copy the partial view.
226   for (auto viewAndPartialLocalView : writebackViews) {
227     // WARNING: MUST use the old op to determine whether the operand view is an
228     // output.
229     bool isOutput =
230         op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue();
231     if (isOutput)
232       linalg_copy(viewAndPartialLocalView.second,
233                   viewAndPartialLocalView.first);
234   }
235 
236   // 4. Dealloc local buffers.
237   for (const auto &pi : promotedBufferAndViews)
238     std_dealloc(pi.buffer);
239 
240   return res;
241 }
242 
243 static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
244   SmallVector<LinalgOp, 8> toErase;
245   OperationFolder folder(f.getContext());
246   f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
247     if (!op.hasBufferSemantics())
248       return;
249 
250     // TODO(ntv) some heuristic here to decide what to promote. Atm only float
251     // and integer buffers can be promoted.
252     SetVector<Value> subViews;
253     OpBuilder b(op);
254     for (auto it : op.getInputsAndOutputBuffers())
255       if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
256         if (sv.getType().getElementType().isSignlessIntOrFloat())
257           subViews.insert(sv);
258     if (!subViews.empty()) {
259       promoteSubViewOperands(b, op, subViews, dynamicBuffers, 0, &folder);
260       toErase.push_back(op);
261     }
262   });
263   for (auto op : toErase)
264     op.erase();
265 }
266 
267 namespace {
268 struct LinalgPromotionPass : public LinalgPromotionBase<LinalgPromotionPass> {
269   LinalgPromotionPass() = default;
270   LinalgPromotionPass(bool dynamicBuffers) {
271     this->dynamicBuffers = dynamicBuffers;
272   }
273 
274   void runOnFunction() override {
275     promoteSubViews(getFunction(), dynamicBuffers);
276   }
277 };
278 } // namespace
279 
280 std::unique_ptr<OperationPass<FuncOp>>
281 mlir::createLinalgPromotionPass(bool dynamicBuffers) {
282   return std::make_unique<LinalgPromotionPass>(dynamicBuffers);
283 }
284 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgPromotionPass() {
285   return std::make_unique<LinalgPromotionPass>();
286 }
287