xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (revision 307cfdf5338641e3a895857ef02dc9da35cd0eb6)
1 //===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Promotion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
15 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/LoopOps/LoopOps.h"
22 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineExprVisitor.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/FoldUtils.h"
28 
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/Support/CommandLine.h"
31 
32 using namespace mlir;
33 using namespace mlir::edsc;
34 using namespace mlir::edsc::intrinsics;
35 using namespace mlir::linalg;
36 using namespace mlir::loop;
37 
38 using llvm::SetVector;
39 
40 using folded_affine_min = FoldedValueBuilder<AffineMinOp>;
41 using folded_linalg_range = FoldedValueBuilder<linalg::RangeOp>;
42 using folded_std_dim = FoldedValueBuilder<DimOp>;
43 using folded_std_subview = FoldedValueBuilder<SubViewOp>;
44 using folded_std_view = FoldedValueBuilder<ViewOp>;
45 
46 #define DEBUG_TYPE "linalg-promotion"
47 
48 /// If `size` comes from an AffineMinOp and one of the dimensions of AffineMin
49 /// is a constant then return a new value set to the smallest such constant.
50 /// Otherwise return size.
51 static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
52                                                  Value size) {
53   auto affineMinOp = dyn_cast_or_null<AffineMinOp>(size.getDefiningOp());
54   if (!affineMinOp)
55     return size;
56   if (!llvm::any_of(affineMinOp.getAffineMap().getResults(), [](AffineExpr e) {
57         return e.dyn_cast<AffineConstantExpr>();
58       }))
59     return size;
60   int64_t minConst = std::numeric_limits<int64_t>::max();
61   for (auto e : affineMinOp.getAffineMap().getResults())
62     if (auto cst = e.dyn_cast<AffineConstantExpr>())
63       minConst = std::min(minConst, cst.getValue());
64   assert(minConst != std::numeric_limits<int64_t>::max());
65   return b.create<ConstantIndexOp>(loc, minConst);
66 }
67 
68 static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
69                          OperationFolder *folder, int64_t alignment = 0) {
70   auto *ctx = size.getContext();
71   auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
72   IntegerAttr alignment_attr;
73   if (alignment)
74     alignment_attr = IntegerAttr::get(IntegerType::get(64, ctx), alignment);
75   if (!dynamicBuffers)
76     if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
77       return std_alloc(
78           MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)),
79           ValueRange{}, alignment_attr);
80   Value mul =
81       folded_std_muli(folder, folded_std_constant_index(folder, width), size);
82   return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
83                    alignment_attr);
84 }
85 
86 // Performs promotion of a `subView` into a local buffer of the size of the
87 // *ranges* of the `subView`. This produces a buffer whose size may be bigger
88 // than the actual size of the `subView` at the boundaries.
89 // This is related to the full/partial tile problem.
90 // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
91 // `partialLocalView` such that:
92 //   * `buffer` is always the size of the full tile.
93 //   * `fullLocalView` is a dense contiguous view into that buffer.
94 //   * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
95 //     that corresponds to the size of `subView` and accounting for boundary
96 //     effects.
97 // The point of the full tile buffer is that constant static tile sizes are
98 // folded and result in a buffer type with statically known size and alignment
99 // properties.
100 // To account for general boundary effects, padding must be performed on the
101 // boundary tiles. For now this is done with an unconditional `fill` op followed
102 // by a partial `copy` op.
103 static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
104                                            SubViewOp subView,
105                                            bool dynamicBuffers,
106                                            int64_t alignment,
107                                            OperationFolder *folder) {
108   auto zero = folded_std_constant_index(folder, 0);
109   auto one = folded_std_constant_index(folder, 1);
110 
111   auto viewType = subView.getType();
112   auto rank = viewType.getRank();
113   Value allocSize = one;
114   SmallVector<Value, 8> fullSizes, partialSizes;
115   fullSizes.reserve(rank);
116   partialSizes.reserve(rank);
117   for (auto en : llvm::enumerate(subView.getRanges())) {
118     auto rank = en.index();
119     auto rangeValue = en.value();
120     // Try to extract a tight constant
121     Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size);
122     allocSize = folded_std_muli(folder, allocSize, size);
123     fullSizes.push_back(size);
124     partialSizes.push_back(folded_std_dim(folder, subView, rank));
125   }
126   SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
127   auto buffer = allocBuffer(viewType.getElementType(), allocSize,
128                             dynamicBuffers, folder, alignment);
129   auto fullLocalView = folded_std_view(
130       folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
131       fullSizes);
132   SmallVector<Value, 4> zeros(fullSizes.size(), zero);
133   SmallVector<Value, 4> ones(fullSizes.size(), one);
134   auto partialLocalView =
135       folded_std_subview(folder, fullLocalView, zeros, partialSizes, ones);
136   return PromotionInfo{buffer, fullLocalView, partialLocalView};
137 }
138 
139 SmallVector<PromotionInfo, 8>
140 mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
141                               ArrayRef<Value> subViews, bool dynamicBuffers,
142                               int64_t alignment, OperationFolder *folder) {
143   if (subViews.empty())
144     return {};
145 
146   ScopedContext scope(b, loc);
147   SmallVector<PromotionInfo, 8> res;
148   res.reserve(subViews.size());
149   DenseMap<Value, PromotionInfo> promotionInfoMap;
150   for (auto v : subViews) {
151     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
152     auto promotionInfo = promoteFullTileBuffer(b, loc, subView, dynamicBuffers,
153                                                alignment, folder);
154     promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
155     res.push_back(promotionInfo);
156   }
157 
158   for (auto v : subViews) {
159     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
160     auto info = promotionInfoMap.find(v);
161     if (info == promotionInfoMap.end())
162       continue;
163     Value fillVal;
164     if (auto t = subView.getType().getElementType().dyn_cast<FloatType>())
165       fillVal = folded_std_constant(folder, FloatAttr::get(t, 0.0));
166     else if (auto t =
167                  subView.getType().getElementType().dyn_cast<IntegerType>())
168       fillVal = folded_std_constant_int(folder, 0, t);
169     // TODO(ntv): fill is only necessary if `promotionInfo` has a full local
170     // view that is different from the partial local view and we are on the
171     // boundary.
172     linalg_fill(info->second.fullLocalView, fillVal);
173   }
174 
175   for (auto v : subViews) {
176     auto info = promotionInfoMap.find(v);
177     if (info == promotionInfoMap.end())
178       continue;
179     linalg_copy(cast<SubViewOp>(v.getDefiningOp()),
180                 info->second.partialLocalView);
181   }
182   return res;
183 }
184 
185 LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
186                                               SetVector<Value> subViews,
187                                               bool dynamicBuffers,
188                                               int64_t alignment,
189                                               OperationFolder *folder) {
190   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
191 
192   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
193     // TODO(ntv): add a level of indirection to linalg.generic.
194     if (convOp.padding())
195       llvm_unreachable("Unexpected conv with padding");
196   }
197 
198   // 1. Promote the specified views and use them in the new op.
199   ScopedContext scope(b, op.getLoc());
200   auto promotedBufferAndViews =
201       promoteSubViews(b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers,
202                       alignment, folder);
203   SmallVector<Value, 8> opViews;
204   opViews.reserve(op.getNumInputsAndOutputs());
205   SmallVector<std::pair<Value, Value>, 8> writebackViews;
206   writebackViews.reserve(subViews.size());
207   unsigned promotedIdx = 0;
208   for (auto view : op.getInputsAndOutputBuffers()) {
209     if (subViews.count(view) != 0) {
210       opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
211       writebackViews.emplace_back(std::make_pair(
212           view, promotedBufferAndViews[promotedIdx].partialLocalView));
213       promotedIdx++;
214     } else {
215       opViews.push_back(view);
216     }
217   }
218 
219   // 2. Append all other operands as they appear, this enforces that such
220   // operands are not views. This is to support cases such as FillOp taking
221   // extra scalars etc.
222   auto operands = getAssumedNonViewOperands(op);
223   opViews.append(operands.begin(), operands.end());
224   LinalgOp res = op.clone(b, op.getLoc(), opViews);
225 
226   // 3. Emit write-back for the promoted output views: copy the partial view.
227   for (auto viewAndPartialLocalView : writebackViews) {
228     // WARNING: MUST use the old op to determine whether the operand view is an
229     // output.
230     bool isOutput =
231         op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue();
232     if (isOutput)
233       linalg_copy(viewAndPartialLocalView.second,
234                   viewAndPartialLocalView.first);
235   }
236 
237   // 4. Dealloc local buffers.
238   for (const auto &pi : promotedBufferAndViews)
239     std_dealloc(pi.buffer);
240 
241   return res;
242 }
243 
244 static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
245   SmallVector<LinalgOp, 8> toErase;
246   OperationFolder folder(f.getContext());
247   f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
248     if (!op.hasBufferSemantics())
249       return;
250 
251     // TODO(ntv) some heuristic here to decide what to promote. Atm only float
252     // and integer buffers can be promoted.
253     SetVector<Value> subViews;
254     OpBuilder b(op);
255     for (auto it : op.getInputsAndOutputBuffers())
256       if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
257         if (sv.getType().getElementType().isSignlessIntOrFloat())
258           subViews.insert(sv);
259     if (!subViews.empty()) {
260       promoteSubViewOperands(b, op, subViews, dynamicBuffers, 0, &folder);
261       toErase.push_back(op);
262     }
263   });
264   for (auto op : toErase)
265     op.erase();
266 }
267 
268 LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(
269     Operation *op, llvm::Optional<DenseSet<unsigned>> operandIndicesToPromote) {
270   LinalgOp linOp = dyn_cast<LinalgOp>(op);
271   // Transformation applies to buffers only.
272   if (!linOp || !linOp.hasBufferSemantics())
273     return failure();
274   for (auto en : llvm::enumerate(linOp.getInputsAndOutputBuffers())) {
275     auto sv = isa_and_nonnull<SubViewOp>(en.value().getDefiningOp());
276     if (sv && (!operandIndicesToPromote.hasValue() ||
277                operandIndicesToPromote->count(en.index())))
278       return success();
279   }
280   return failure();
281 }
282 
283 namespace {
284 struct LinalgPromotionPass : public LinalgPromotionBase<LinalgPromotionPass> {
285   LinalgPromotionPass() = default;
286   LinalgPromotionPass(bool dynamicBuffers) {
287     this->dynamicBuffers = dynamicBuffers;
288   }
289 
290   void runOnFunction() override {
291     promoteSubViews(getFunction(), dynamicBuffers);
292   }
293 };
294 } // namespace
295 
296 std::unique_ptr<OperationPass<FuncOp>>
297 mlir::createLinalgPromotionPass(bool dynamicBuffers) {
298   return std::make_unique<LinalgPromotionPass>(dynamicBuffers);
299 }
300 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgPromotionPass() {
301   return std::make_unique<LinalgPromotionPass>();
302 }
303