1 //===- Promotion.cpp - Implementation of linalg Promotion -----------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Promotion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 15 #include "mlir/Dialect/Linalg/Passes.h" 16 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/LoopOps/LoopOps.h" 19 #include "mlir/EDSC/Helpers.h" 20 #include "mlir/IR/AffineExpr.h" 21 #include "mlir/IR/AffineExprVisitor.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Support/STLExtras.h" 27 #include "mlir/Transforms/FoldUtils.h" 28 29 #include "llvm/ADT/SetVector.h" 30 #include "llvm/Support/CommandLine.h" 31 32 using namespace mlir; 33 using namespace mlir::edsc; 34 using namespace mlir::edsc::intrinsics; 35 using namespace mlir::linalg; 36 using namespace mlir::linalg::intrinsics; 37 using namespace mlir::loop; 38 39 using llvm::SetVector; 40 41 #define DEBUG_TYPE "linalg-promotion" 42 43 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 44 static llvm::cl::opt<bool> clPromoteDynamic( 45 "test-linalg-promote-dynamic", 46 llvm::cl::desc("Test generation of dynamic promoted buffers"), 47 llvm::cl::cat(clOptionsCategory), llvm::cl::init(false)); 48 49 static ValuePtr allocBuffer(Type elementType, ValuePtr size, 50 bool dynamicBuffers) { 51 auto *ctx = size->getContext(); 52 auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); 53 if (!dynamicBuffers) 54 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp())) 55 return alloc( 56 MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx))); 57 ValuePtr mul = muli(constant_index(width), size); 58 return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul); 59 } 60 61 // Performs promotion of a `subView` into a local buffer of the size of the 62 // *ranges* of the `subView`. This produces a buffer whose size may be bigger 63 // than the actual size of the `subView` at the boundaries. 64 // This is related to the full/partial tile problem. 65 // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and 66 // `partialLocalView` such that: 67 // * `buffer` is always the size of the full tile. 68 // * `fullLocalView` is a dense contiguous view into that buffer. 69 // * `partialLocalView` is a dense non-contiguous slice of `fullLocalView` 70 // that corresponds to the size of `subView` and accounting for boundary 71 // effects. 72 // The point of the full tile buffer is that constant static tile sizes are 73 // folded and result in a buffer type with statically known size and alignment 74 // properties. 75 // To account for general boundary effects, padding must be performed on the 76 // boundary tiles. For now this is done with an unconditional `fill` op followed 77 // by a partial `copy` op. 78 static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, 79 SubViewOp subView, 80 bool dynamicBuffers, 81 OperationFolder *folder) { 82 auto zero = constant_index(folder, 0); 83 auto one = constant_index(folder, 1); 84 85 auto viewType = subView.getType(); 86 auto rank = viewType.getRank(); 87 ValuePtr allocSize = one; 88 SmallVector<ValuePtr, 8> fullRanges, partialRanges; 89 fullRanges.reserve(rank); 90 partialRanges.reserve(rank); 91 for (auto en : llvm::enumerate(subView.getRanges())) { 92 auto rank = en.index(); 93 auto rangeValue = en.value(); 94 ValuePtr d = rangeValue.size; 95 allocSize = muli(folder, allocSize, d).getValue(); 96 fullRanges.push_back(d); 97 partialRanges.push_back(range(folder, zero, dim(subView, rank), one)); 98 } 99 SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1); 100 auto buffer = 101 allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers); 102 auto fullLocalView = view( 103 MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges); 104 auto partialLocalView = slice(fullLocalView, partialRanges); 105 return PromotionInfo{buffer, fullLocalView, partialLocalView}; 106 } 107 108 SmallVector<PromotionInfo, 8> 109 mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, 110 ArrayRef<ValuePtr> subViews, bool dynamicBuffers, 111 OperationFolder *folder) { 112 if (subViews.empty()) 113 return {}; 114 115 ScopedContext scope(b, loc); 116 SmallVector<PromotionInfo, 8> res; 117 res.reserve(subViews.size()); 118 DenseMap<ValuePtr, PromotionInfo> promotionInfoMap; 119 for (auto v : subViews) { 120 SubViewOp subView = cast<SubViewOp>(v->getDefiningOp()); 121 auto viewType = subView.getType(); 122 // TODO(ntv): support more cases than just float. 123 if (!viewType.getElementType().isa<FloatType>()) 124 continue; 125 auto promotionInfo = 126 promoteFullTileBuffer(b, loc, subView, dynamicBuffers, folder); 127 promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo)); 128 res.push_back(promotionInfo); 129 } 130 131 for (auto v : subViews) { 132 SubViewOp subView = cast<SubViewOp>(v->getDefiningOp()); 133 auto info = promotionInfoMap.find(v); 134 if (info == promotionInfoMap.end()) 135 continue; 136 // TODO(ntv): value to fill with should be related to the operation. 137 // For now, just use APFloat(0.0f). 138 auto t = subView.getType().getElementType().cast<FloatType>(); 139 ValuePtr fillVal = constant_float(folder, APFloat(0.0f), t); 140 // TODO(ntv): fill is only necessary if `promotionInfo` has a full local 141 // view that is different from the partial local view and we are on the 142 // boundary. 143 fill(info->second.fullLocalView, fillVal); 144 } 145 146 for (auto v : subViews) { 147 auto info = promotionInfoMap.find(v); 148 if (info == promotionInfoMap.end()) 149 continue; 150 copy(cast<SubViewOp>(v->getDefiningOp()), info->second.partialLocalView); 151 } 152 return res; 153 } 154 155 LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, 156 SetVector<ValuePtr> subViews, 157 bool dynamicBuffers, 158 OperationFolder *folder) { 159 // 1. Promote the specified views and use them in the new op. 160 ScopedContext scope(b, op.getLoc()); 161 auto promotedBufferAndViews = promoteSubViews( 162 b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder); 163 SmallVector<ValuePtr, 8> opViews; 164 opViews.reserve(op.getNumInputsAndOutputs()); 165 SmallVector<std::pair<ValuePtr, ValuePtr>, 8> writebackViews; 166 writebackViews.reserve(subViews.size()); 167 unsigned promotedIdx = 0; 168 for (auto view : op.getInputsAndOutputs()) { 169 if (subViews.count(view) != 0) { 170 opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); 171 writebackViews.emplace_back(std::make_pair( 172 view, promotedBufferAndViews[promotedIdx].partialLocalView)); 173 promotedIdx++; 174 } else { 175 opViews.push_back(view); 176 } 177 } 178 179 // 2. Append all other operands as they appear, this enforces that such 180 // operands are not views. This is to support cases such as FillOp taking 181 // extra scalars etc. 182 auto operands = getAssumedNonViewOperands(op); 183 opViews.append(operands.begin(), operands.end()); 184 LinalgOp res = op.clone(b, op.getLoc(), opViews); 185 186 // 3. Emit write-back for the promoted output views: copy the partial view. 187 for (auto viewAndPartialLocalView : writebackViews) { 188 // WARNING: MUST use the old op to determine whether the operand view is an 189 // output. 190 bool isOutput = 191 op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue(); 192 if (isOutput) 193 copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first); 194 } 195 196 // 4. Dealloc local buffers. 197 for (const auto &pi : promotedBufferAndViews) 198 dealloc(pi.buffer); 199 200 return res; 201 } 202 203 static void promoteSubViews(FuncOp f, bool dynamicBuffers) { 204 SmallVector<LinalgOp, 8> toErase; 205 OperationFolder folder(f.getContext()); 206 f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) { 207 // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or 208 // nothing. 209 SetVector<ValuePtr> subViews; 210 OpBuilder b(op); 211 for (auto it : op.getInputsAndOutputs()) 212 if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) 213 subViews.insert(sv); 214 if (!subViews.empty()) { 215 promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder); 216 toErase.push_back(op); 217 } 218 }); 219 for (auto op : toErase) 220 op.erase(); 221 } 222 223 namespace { 224 struct LinalgPromotionPass : public FunctionPass<LinalgPromotionPass> { 225 LinalgPromotionPass() = default; 226 LinalgPromotionPass(bool dynamicBuffers) : dynamicBuffers(dynamicBuffers) {} 227 228 void runOnFunction() override { 229 promoteSubViews(getFunction(), dynamicBuffers); 230 } 231 232 bool dynamicBuffers; 233 }; 234 } // namespace 235 236 std::unique_ptr<OpPassBase<FuncOp>> 237 mlir::linalg::createLinalgPromotionPass(bool dynamicBuffers) { 238 return std::make_unique<LinalgPromotionPass>(dynamicBuffers); 239 } 240 241 static PassRegistration<LinalgPromotionPass> 242 pass("linalg-promote-subviews", "promote subview ops to local buffers", [] { 243 return std::make_unique<LinalgPromotionPass>(clPromoteDynamic); 244 }); 245