1 //===- Promotion.cpp - Implementation of linalg Promotion -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Promotion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 15 #include "mlir/Dialect/Complex/IR/Complex.h" 16 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" 23 #include "mlir/Dialect/SCF/SCF.h" 24 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineExprVisitor.h" 27 #include "mlir/IR/AffineMap.h" 28 #include "mlir/IR/ImplicitLocOpBuilder.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/FoldUtils.h" 31 #include "llvm/ADT/MapVector.h" 32 #include "llvm/Support/CommandLine.h" 33 34 using namespace mlir; 35 using namespace mlir::edsc; 36 using namespace mlir::edsc::intrinsics; 37 using namespace mlir::linalg; 38 using namespace mlir::scf; 39 40 using llvm::MapVector; 41 42 #define DEBUG_TYPE "linalg-promotion" 43 44 /// Alloc a new buffer of `size` * `width` i8; where `width` is given by the 45 /// data `layout` for `elementType`. 46 /// Use AllocOp or AllocaOp depending on `options`. 47 /// Take an optional alignment. 48 static Value allocBuffer(ImplicitLocOpBuilder &b, 49 const LinalgPromotionOptions &options, 50 Type elementType, Value allocSize, DataLayout &layout, 51 Optional<unsigned> alignment = None) { 52 auto width = layout.getTypeSize(elementType); 53 54 IntegerAttr alignmentAttr; 55 if (alignment.hasValue()) 56 alignmentAttr = b.getI64IntegerAttr(alignment.getValue()); 57 58 // Static buffer. 59 if (auto cst = allocSize.getDefiningOp<ConstantIndexOp>()) { 60 auto staticBufferType = 61 MemRefType::get(width * cst.getValue(), b.getIntegerType(8)); 62 if (options.useAlloca) { 63 return b.createOrFold<memref::AllocaOp>(staticBufferType, ValueRange{}, 64 alignmentAttr); 65 } 66 return b.createOrFold<memref::AllocOp>(staticBufferType, ValueRange{}, 67 alignmentAttr); 68 } 69 70 // Fallback dynamic buffer. 71 auto dynamicBufferType = MemRefType::get(-1, b.getIntegerType(8)); 72 Value mul = 73 b.createOrFold<MulIOp>(b.create<ConstantIndexOp>(width), allocSize); 74 if (options.useAlloca) 75 return b.create<memref::AllocaOp>(dynamicBufferType, mul, alignmentAttr); 76 return b.create<memref::AllocOp>(dynamicBufferType, mul, alignmentAttr); 77 } 78 79 /// Default allocation callback function. This allocates a promoted buffer when 80 /// no call back to do so is provided. The default is to allocate a 81 /// memref<..xi8> and return a view to get a memref type of shape 82 /// boundingSubViewSize. 83 static Optional<Value> defaultAllocBufferCallBack( 84 const LinalgPromotionOptions &options, OpBuilder &builder, 85 memref::SubViewOp subView, ArrayRef<Value> boundingSubViewSize, 86 bool dynamicBuffers, Optional<unsigned> alignment, DataLayout &layout) { 87 ShapedType viewType = subView.getType(); 88 ImplicitLocOpBuilder b(subView.getLoc(), builder); 89 auto zero = b.createOrFold<ConstantIndexOp>(0); 90 auto one = b.createOrFold<ConstantIndexOp>(1); 91 92 Value allocSize = one; 93 for (auto size : llvm::enumerate(boundingSubViewSize)) 94 allocSize = b.createOrFold<MulIOp>(allocSize, size.value()); 95 Value buffer = allocBuffer(b, options, viewType.getElementType(), allocSize, 96 layout, alignment); 97 SmallVector<int64_t, 4> dynSizes(boundingSubViewSize.size(), 98 ShapedType::kDynamicSize); 99 Value view = b.createOrFold<memref::ViewOp>( 100 MemRefType::get(dynSizes, viewType.getElementType()), buffer, zero, 101 boundingSubViewSize); 102 return view; 103 } 104 105 /// Default implementation of deallocation of the buffer use for promotion. It 106 /// expects to get the same value that the default allocation method returned, 107 /// i.e. result of a ViewOp. 108 static LogicalResult 109 defaultDeallocBufferCallBack(const LinalgPromotionOptions &options, 110 OpBuilder &b, Value fullLocalView) { 111 auto viewOp = fullLocalView.getDefiningOp<memref::ViewOp>(); 112 assert(viewOp && "expected full local view to be a ViewOp"); 113 if (!options.useAlloca) 114 memref_dealloc(viewOp.source()); 115 return success(); 116 } 117 118 namespace { 119 120 /// Helper struct that captures the information required to apply the 121 /// transformation on each op. This bridges the abstraction gap with the 122 /// user-facing API which exposes positional arguments to control which operands 123 /// are promoted. 124 struct LinalgOpInstancePromotionOptions { 125 LinalgOpInstancePromotionOptions(LinalgOp op, 126 const LinalgPromotionOptions &options); 127 /// SubViews to promote. 128 MapVector<unsigned, Value> subViews; 129 /// True if the full view should be used for the promoted buffer. 130 DenseMap<Value, bool> useFullTileBuffers; 131 132 /// Callback functions for allocation and deallocation of promoted buffers, as 133 /// well as to copy the data into and out of these buffers. 134 AllocBufferCallbackFn allocationFn; 135 DeallocBufferCallbackFn deallocationFn; 136 CopyCallbackFn copyInFn; 137 CopyCallbackFn copyOutFn; 138 139 /// Allow the use of dynamically-sized buffers. 140 bool dynamicBuffers; 141 /// Alignment of promoted buffer. 142 Optional<unsigned> alignment; 143 }; 144 } // namespace 145 146 LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( 147 LinalgOp linalgOp, const LinalgPromotionOptions &options) 148 : subViews(), dynamicBuffers(options.dynamicBuffers), 149 alignment(options.alignment) { 150 assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand"); 151 unsigned nBuffers = linalgOp.getNumShapedOperands(); 152 auto vUseFullTileBuffers = 153 options.useFullTileBuffers.getValueOr(llvm::SmallBitVector()); 154 vUseFullTileBuffers.resize(nBuffers, options.useFullTileBuffersDefault); 155 156 for (unsigned idx = 0; idx != nBuffers; ++idx) { 157 if (options.operandsToPromote && !options.operandsToPromote->count(idx)) 158 continue; 159 auto *op = linalgOp.getShapedOperand(idx).getDefiningOp(); 160 if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) { 161 subViews[idx] = sv; 162 useFullTileBuffers[sv] = vUseFullTileBuffers[idx]; 163 } 164 } 165 166 allocationFn = (options.allocationFn 167 ? *(options.allocationFn) 168 : [&](OpBuilder &builder, memref::SubViewOp subViewOp, 169 ArrayRef<Value> boundingSubViewSize, 170 DataLayout &layout) -> Optional<Value> { 171 return defaultAllocBufferCallBack(options, builder, subViewOp, 172 boundingSubViewSize, dynamicBuffers, 173 alignment, layout); 174 }); 175 deallocationFn = 176 (options.deallocationFn 177 ? *(options.deallocationFn) 178 : [&](OpBuilder &b, Value buffer) { 179 return defaultDeallocBufferCallBack(options, b, buffer); 180 }); 181 auto defaultCopyCallBack = [&](OpBuilder &builder, Value src, 182 Value dst) -> LogicalResult { 183 linalg_copy(src, dst); 184 return success(); 185 }; 186 copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack); 187 copyOutFn = (options.copyOutFn ? *(options.copyOutFn) : defaultCopyCallBack); 188 } 189 190 // Performs promotion of a `subView` into a local buffer of the size of the 191 // *ranges* of the `subView`. This produces a buffer whose size may be bigger 192 // than the actual size of the `subView` at the boundaries. 193 // This is related to the full/partial tile problem. 194 // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and 195 // `partialLocalView` such that: 196 // * `buffer` is always the size of the full tile. 197 // * `fullLocalView` is a dense contiguous view into that buffer. 198 // * `partialLocalView` is a dense non-contiguous slice of `fullLocalView` 199 // that corresponds to the size of `subView` and accounting for boundary 200 // effects. 201 // The point of the full tile buffer is that constant static tile sizes are 202 // folded and result in a buffer type with statically known size and alignment 203 // properties. 204 // To account for general boundary effects, padding must be performed on the 205 // boundary tiles. For now this is done with an unconditional `fill` op followed 206 // by a partial `copy` op. 207 Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer( 208 OpBuilder &b, Location loc, memref::SubViewOp subView, 209 AllocBufferCallbackFn allocationFn, DataLayout &layout) { 210 ScopedContext scopedContext(b, loc); 211 auto viewType = subView.getType(); 212 auto rank = viewType.getRank(); 213 SmallVector<Value, 4> fullSizes; 214 SmallVector<OpFoldResult> partialSizes; 215 fullSizes.reserve(rank); 216 partialSizes.reserve(rank); 217 for (auto en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) { 218 auto rangeValue = en.value(); 219 // Try to extract a tight constant. 220 LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n"); 221 IntegerAttr sizeAttr = getSmallestBoundingIndex(rangeValue.size); 222 Value size = 223 (!sizeAttr) ? rangeValue.size : b.create<ConstantOp>(loc, sizeAttr); 224 LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); 225 fullSizes.push_back(size); 226 partialSizes.push_back(memref_dim(subView, en.index()).value); 227 } 228 SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1); 229 // If a callback is not specified, then use the default implementation for 230 // allocating the promoted buffer. 231 Optional<Value> fullLocalView = allocationFn(b, subView, fullSizes, layout); 232 if (!fullLocalView) 233 return {}; 234 SmallVector<OpFoldResult, 4> zeros(fullSizes.size(), b.getIndexAttr(0)); 235 SmallVector<OpFoldResult, 4> ones(fullSizes.size(), b.getIndexAttr(1)); 236 auto partialLocalView = b.createOrFold<memref::SubViewOp>( 237 loc, *fullLocalView, zeros, partialSizes, ones); 238 return PromotionInfo{*fullLocalView, partialLocalView}; 239 } 240 241 static Optional<MapVector<unsigned, PromotionInfo>> 242 promoteSubViews(OpBuilder &b, Location loc, 243 LinalgOpInstancePromotionOptions options, DataLayout &layout) { 244 if (options.subViews.empty()) 245 return {}; 246 247 ScopedContext scope(b, loc); 248 MapVector<unsigned, PromotionInfo> promotionInfoMap; 249 250 for (auto v : options.subViews) { 251 memref::SubViewOp subView = 252 cast<memref::SubViewOp>(v.second.getDefiningOp()); 253 Optional<PromotionInfo> promotionInfo = promoteSubviewAsNewBuffer( 254 b, loc, subView, options.allocationFn, layout); 255 if (!promotionInfo) 256 return {}; 257 promotionInfoMap[v.first] = *promotionInfo; 258 259 // Only fill the buffer if the full local view is used 260 if (!options.useFullTileBuffers[v.second]) 261 continue; 262 Value fillVal; 263 if (auto t = subView.getType().getElementType().dyn_cast<FloatType>()) { 264 fillVal = std_constant(FloatAttr::get(t, 0.0)); 265 } else if (auto t = 266 subView.getType().getElementType().dyn_cast<IntegerType>()) { 267 fillVal = std_constant_int(0, t); 268 } else if (auto t = 269 subView.getType().getElementType().dyn_cast<ComplexType>()) { 270 if (auto et = t.getElementType().dyn_cast<FloatType>()) 271 fillVal = std_constant(FloatAttr::get(et, 0.0)); 272 else if (auto et = t.getElementType().cast<IntegerType>()) 273 fillVal = std_constant_int(0, et); 274 fillVal = b.create<complex::CreateOp>(loc, t, fillVal, fillVal); 275 } else { 276 return {}; 277 } 278 linalg_fill(promotionInfo->fullLocalView, fillVal); 279 } 280 281 // Copy data into the promoted buffers. Use callback if provided. 282 for (auto v : options.subViews) { 283 auto info = promotionInfoMap.find(v.first); 284 if (info == promotionInfoMap.end()) 285 continue; 286 if (failed(options.copyInFn( 287 b, cast<memref::SubViewOp>(v.second.getDefiningOp()), 288 info->second.partialLocalView))) 289 return {}; 290 } 291 return promotionInfoMap; 292 } 293 294 static Optional<LinalgOp> 295 promoteSubViews(OpBuilder &b, LinalgOp op, 296 LinalgOpInstancePromotionOptions options, DataLayout &layout) { 297 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 298 299 if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) { 300 // TODO: add a level of indirection to linalg.generic. 301 if (convOp.padding()) 302 return {}; 303 } 304 305 // 1. Promote the specified views and use them in the new op. 306 auto loc = op.getLoc(); 307 auto promotedBuffersAndViews = promoteSubViews(b, loc, options, layout); 308 if (!promotedBuffersAndViews || 309 promotedBuffersAndViews->size() != options.subViews.size()) 310 return {}; 311 312 // 2. Append all other operands as they appear, this enforces that such 313 // operands are not views. This is to support cases such as FillOp taking 314 // extra scalars etc. Keep a reference to output buffers; 315 SmallVector<Value, 8> opViews; 316 opViews.reserve(op.getNumShapedOperands()); 317 SmallVector<std::pair<Value, Value>, 8> writebackViews; 318 writebackViews.reserve(promotedBuffersAndViews->size()); 319 for (auto view : llvm::enumerate(op.getShapedOperands())) { 320 if (options.subViews.count(view.index()) != 0) { 321 if (options.useFullTileBuffers[view.value()]) 322 opViews.push_back( 323 (*promotedBuffersAndViews)[view.index()].fullLocalView); 324 else 325 opViews.push_back( 326 (*promotedBuffersAndViews)[view.index()].partialLocalView); 327 if (view.index() >= op.getNumInputs()) 328 writebackViews.emplace_back(std::make_pair( 329 view.value(), 330 (*promotedBuffersAndViews)[view.index()].partialLocalView)); 331 } else { 332 opViews.push_back(view.value()); 333 } 334 } 335 op->setOperands(0, opViews.size(), opViews); 336 337 OpBuilder::InsertionGuard guard(b); 338 b.setInsertionPointAfter(op); 339 ScopedContext scope(b, loc); 340 // 3. Emit write-back for the promoted output views: copy the partial view. 341 for (auto viewAndPartialLocalView : writebackViews) { 342 if (failed(options.copyOutFn(b, viewAndPartialLocalView.second, 343 viewAndPartialLocalView.first))) 344 return {}; 345 } 346 347 // 4. Dealloc all local buffers. 348 for (const auto &pi : *promotedBuffersAndViews) 349 (void)options.deallocationFn(b, pi.second.fullLocalView); 350 return op; 351 } 352 353 LogicalResult 354 mlir::linalg::promoteSubviewsPrecondition(Operation *op, 355 LinalgPromotionOptions options) { 356 LinalgOp linOp = dyn_cast<LinalgOp>(op); 357 // Transformation applies to buffers only. 358 if (!linOp || !linOp.hasBufferSemantics()) 359 return failure(); 360 // Check that at least one of the requested operands is indeed a subview. 361 for (auto en : llvm::enumerate(linOp.getShapedOperands())) { 362 auto sv = isa_and_nonnull<memref::SubViewOp>(en.value().getDefiningOp()); 363 if (sv) { 364 if (!options.operandsToPromote.hasValue() || 365 options.operandsToPromote->count(en.index())) 366 return success(); 367 } 368 } 369 // TODO: Check all subviews requested are bound by a static constant. 370 // TODO: Check that the total footprint fits within a given size. 371 return failure(); 372 } 373 374 Optional<LinalgOp> 375 mlir::linalg::promoteSubViews(OpBuilder &b, LinalgOp linalgOp, 376 LinalgPromotionOptions options) { 377 LinalgOpInstancePromotionOptions linalgOptions(linalgOp, options); 378 auto layout = DataLayout::closest(linalgOp); 379 return ::promoteSubViews(b, linalgOp, linalgOptions, layout); 380 } 381 382 namespace { 383 struct LinalgPromotionPass : public LinalgPromotionBase<LinalgPromotionPass> { 384 LinalgPromotionPass() = default; 385 LinalgPromotionPass(bool dynamicBuffers, bool useAlloca) { 386 this->dynamicBuffers = dynamicBuffers; 387 this->useAlloca = useAlloca; 388 } 389 390 void runOnFunction() override { 391 getFunction().walk([this](LinalgOp op) { 392 auto options = LinalgPromotionOptions() 393 .setDynamicBuffers(dynamicBuffers) 394 .setUseAlloca(useAlloca); 395 if (failed(promoteSubviewsPrecondition(op, options))) 396 return; 397 LLVM_DEBUG(llvm::dbgs() << "Promote: " << *(op.getOperation()) << "\n"); 398 OpBuilder b(op); 399 promoteSubViews(b, op, options); 400 }); 401 } 402 }; 403 } // namespace 404 405 // TODO: support more transformation options in the pass. 406 std::unique_ptr<OperationPass<FuncOp>> 407 mlir::createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca) { 408 return std::make_unique<LinalgPromotionPass>(dynamicBuffers, useAlloca); 409 } 410 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgPromotionPass() { 411 return std::make_unique<LinalgPromotionPass>(); 412 } 413