xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (revision 98226e62ecf4d64323e4531daff39183691800cf)
1 //===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements the linalg dialect Promotion pass.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
23 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
24 #include "mlir/Dialect/Linalg/Passes.h"
25 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
26 #include "mlir/Dialect/Linalg/Utils/Utils.h"
27 #include "mlir/Dialect/LoopOps/LoopOps.h"
28 #include "mlir/EDSC/Helpers.h"
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/AffineExprVisitor.h"
31 #include "mlir/IR/AffineMap.h"
32 #include "mlir/IR/OpImplementation.h"
33 #include "mlir/Pass/Pass.h"
34 #include "mlir/Support/LLVM.h"
35 #include "mlir/Support/STLExtras.h"
36 #include "mlir/Transforms/FoldUtils.h"
37 
38 #include "llvm/ADT/SetVector.h"
39 #include "llvm/Support/CommandLine.h"
40 
41 using namespace mlir;
42 using namespace mlir::edsc;
43 using namespace mlir::edsc::intrinsics;
44 using namespace mlir::linalg;
45 using namespace mlir::linalg::intrinsics;
46 using namespace mlir::loop;
47 
48 using llvm::SetVector;
49 
50 #define DEBUG_TYPE "linalg-promotion"
51 
52 static AffineMap getAffineDifferenceMap(MLIRContext *context) {
53   AffineExpr d0(getAffineDimExpr(0, context)), d1(getAffineDimExpr(1, context));
54   return AffineMap::get(2, 0, {d0 - d1});
55 }
56 
57 // TODO(ntv): replace this with 1-D memref alloc once there is an std.view op.
58 static Value *allocBuffer(Type elementType, Value *size) {
59   if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
60     return buffer_alloc(
61         BufferType::get(size->getContext(), elementType, cst.getValue()));
62   return buffer_alloc(BufferType::get(size->getContext(), elementType), size);
63 }
64 
65 // Performs promotion of a `subView` into a local buffer of the size of the
66 // *ranges* of the `subView`. This produces a buffer whose size may be bigger
67 // than the actual size of the `subView` at the boundaries.
68 // This is related to the full/partial tile problem.
69 // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
70 // `partialLocalView` such that:
71 //   * `buffer` is always the size of the full tile.
72 //   * `fullLocalView` is a dense contiguous view into that buffer.
73 //   * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
74 //     that corresponds to the size of `subView` and accounting for boundary
75 //     effects.
76 // The point of the full tile buffer is that constant static tile sizes are
77 // folded and result in a buffer type with statically known size and alignment
78 // properties.
79 // To account for general boundary effects, padding must be performed on the
80 // boundary tiles. For now this is done with an unconditional `fill` op followed
81 // by a partial `copy` op.
82 static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
83                                            SubViewOp subView,
84                                            OperationFolder *folder) {
85   auto zero = constant_index(folder, 0);
86   auto one = constant_index(folder, 1);
87 
88   auto viewType = subView.getViewType();
89   auto rank = viewType.getRank();
90   Value *allocSize = one;
91   SmallVector<Value *, 8> fullRanges, partialRanges;
92   fullRanges.reserve(rank);
93   partialRanges.reserve(rank);
94   for (auto en : llvm::enumerate(subView.getRanges())) {
95     auto rank = en.index();
96     auto rangeValue = en.value();
97     Value *d =
98         isa<DimOp>(rangeValue.max->getDefiningOp())
99             ? rangeValue.max
100             : applyMapToValues(b, loc, getAffineDifferenceMap(b.getContext()),
101                                {rangeValue.max, rangeValue.min}, folder)
102                   .front();
103     allocSize = muli(folder, allocSize, d).getValue();
104     fullRanges.push_back(range(folder, zero, d, one));
105     partialRanges.push_back(range(folder, zero, dim(subView, rank), one));
106   }
107   auto *buffer = allocBuffer(viewType.getElementType(), allocSize);
108   auto fullLocalView = view(buffer, fullRanges);
109   auto partialLocalView = slice(fullLocalView, partialRanges);
110   return PromotionInfo{buffer, fullLocalView, partialLocalView};
111 }
112 
113 SmallVector<PromotionInfo, 8>
114 mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
115                               ArrayRef<Value *> subViews,
116                               OperationFolder *folder) {
117   if (subViews.empty())
118     return {};
119 
120   ScopedContext scope(b, loc);
121   SmallVector<PromotionInfo, 8> res;
122   res.reserve(subViews.size());
123   DenseMap<Value *, PromotionInfo> promotionInfoMap;
124   for (auto *v : subViews) {
125     SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
126     auto viewType = subView.getViewType();
127     // TODO(ntv): support more cases than just float.
128     if (!viewType.getElementType().isa<FloatType>())
129       continue;
130     auto promotionInfo = promoteFullTileBuffer(b, loc, subView, folder);
131     promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
132     res.push_back(promotionInfo);
133   }
134 
135   for (auto *v : subViews) {
136     SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
137     auto info = promotionInfoMap.find(v);
138     if (info == promotionInfoMap.end())
139       continue;
140     // TODO(ntv): value to fill with should be related to the operation.
141     // For now, just use APFloat(0.0f).
142     auto t = subView.getViewType().getElementType().cast<FloatType>();
143     Value *fillVal = constant_float(folder, APFloat(0.0f), t);
144     // TODO(ntv): fill is only necessary if `promotionInfo` has a full local
145     // view that is different from the partial local view and we are on the
146     // boundary.
147     fill(info->second.fullLocalView, fillVal);
148   }
149 
150   for (auto *v : subViews) {
151     auto info = promotionInfoMap.find(v);
152     if (info == promotionInfoMap.end())
153       continue;
154     copy(cast<SubViewOp>(v->getDefiningOp()), info->second.partialLocalView);
155   }
156   return res;
157 }
158 
159 static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
160                                    OperationFolder *folder) {
161   // 1. Promote the specified views and use them in the new op.
162   OpBuilder b(op);
163   ScopedContext scope(b, op.getLoc());
164   auto promotedBufferAndViews =
165       promoteSubViews(b, op.getLoc(), subViews.getArrayRef(), folder);
166   SmallVector<Value *, 8> opViews;
167   opViews.reserve(op.getNumInputsAndOutputs());
168   SmallVector<std::pair<Value *, Value *>, 8> writebackViews;
169   writebackViews.reserve(subViews.size());
170   unsigned promotedIdx = 0;
171   for (auto *view : op.getInputsAndOutputs()) {
172     if (subViews.count(view) != 0) {
173       opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
174       writebackViews.emplace_back(std::make_pair(
175           view, promotedBufferAndViews[promotedIdx].partialLocalView));
176       promotedIdx++;
177     } else {
178       opViews.push_back(view);
179     }
180   }
181 
182   // 2. Append all other operands as they appear, this enforces that such
183   // operands are not views. This is to support cases such as FillOp taking
184   // extra scalars etc.
185   auto operands = getAssumedNonViewOperands(op);
186   opViews.append(operands.begin(), operands.end());
187   op.clone(b, op.getLoc(), opViews);
188 
189   // 3. Emit write-back for the promoted output views: copy the partial view.
190   for (auto viewAndPartialLocalView : writebackViews) {
191     // Note: use the old op to determine whether the operand view is an output.
192     bool isOutput =
193         op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue();
194     if (isOutput)
195       copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first);
196   }
197 
198   // 4. Dealloc local buffers.
199   for (const auto &pi : promotedBufferAndViews)
200     buffer_dealloc(pi.buffer);
201 }
202 
203 static void promoteSubViews(FuncOp f) {
204   SmallVector<LinalgOp, 8> toErase;
205   OperationFolder folder(f.getContext());
206   f.walk([&folder, &toErase](LinalgOp op) {
207     // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
208     // nothing.
209     SetVector<Value *> subViews;
210     for (auto it : op.getInputsAndOutputs())
211       if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
212         subViews.insert(sv);
213     if (!subViews.empty()) {
214       promoteSubViewOperands(op, subViews, &folder);
215       toErase.push_back(op);
216     }
217   });
218   for (auto op : toErase)
219     op.erase();
220 }
221 
222 namespace {
223 struct LinalgPromotionPass : public FunctionPass<LinalgPromotionPass> {
224   void runOnFunction() override { promoteSubViews(getFunction()); }
225 };
226 } // namespace
227 
228 std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgPromotionPass() {
229   return std::make_unique<LinalgPromotionPass>();
230 }
231 
232 static PassRegistration<LinalgPromotionPass>
233     pass("linalg-promote-subviews", "promote subview ops to local buffers");
234