15b03e692SNicolas Vasilache //===- Promotion.cpp - Implementation of linalg Promotion -----------------===// 25b03e692SNicolas Vasilache // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65b03e692SNicolas Vasilache // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 85b03e692SNicolas Vasilache // 95b03e692SNicolas Vasilache // This file implements the linalg dialect Promotion pass. 105b03e692SNicolas Vasilache // 115b03e692SNicolas Vasilache //===----------------------------------------------------------------------===// 125b03e692SNicolas Vasilache 13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h" 1542e5f422STres Popp #include "mlir/Dialect/Complex/IR/Complex.h" 16115711c1SAmir Mohammad Tavakkoli #include "mlir/Dialect/Func/IR/FuncOps.h" 17115711c1SAmir Mohammad Tavakkoli #include "mlir/Dialect/GPU/IR/GPUDialect.h" 18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h" 195b03e692SNicolas Vasilache #include "mlir/Dialect/Linalg/Passes.h" 20307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 218b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 225b03e692SNicolas Vasilache #include "mlir/IR/AffineExpr.h" 235b03e692SNicolas Vasilache #include "mlir/IR/AffineExprVisitor.h" 245b03e692SNicolas Vasilache #include "mlir/IR/AffineMap.h" 25ef33c6e3SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h" 26eabb6ccdSMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h" 275b03e692SNicolas Vasilache #include "mlir/Support/LLVM.h" 285b03e692SNicolas Vasilache #include "mlir/Transforms/FoldUtils.h" 290ed2d4c7SMaheshRavishankar #include "llvm/ADT/MapVector.h" 3099069ab2SChristopher Bate #include "llvm/ADT/SmallBitVector.h" 3170604222SAviad Cohen #include "llvm/ADT/SmallSet.h" 324519ca3dSNicolas Vasilache #include "llvm/ADT/TypeSwitch.h" 335b03e692SNicolas Vasilache #include "llvm/Support/CommandLine.h" 34f89bb3c0SAlexander Belyaev #include "llvm/Support/Debug.h" 355b03e692SNicolas Vasilache 365b03e692SNicolas Vasilache using namespace mlir; 375b03e692SNicolas Vasilache using namespace mlir::linalg; 38c25b20c0SAlex Zinenko using namespace mlir::scf; 395b03e692SNicolas Vasilache 400ed2d4c7SMaheshRavishankar using llvm::MapVector; 415b03e692SNicolas Vasilache 425b03e692SNicolas Vasilache #define DEBUG_TYPE "linalg-promotion" 435b03e692SNicolas Vasilache 44ef33c6e3SNicolas Vasilache /// Alloc a new buffer of `size` * `width` i8; where `width` is given by the 45ef33c6e3SNicolas Vasilache /// data `layout` for `elementType`. 46ef33c6e3SNicolas Vasilache /// Use AllocOp or AllocaOp depending on `options`. 47ef33c6e3SNicolas Vasilache /// Take an optional alignment. 48ef33c6e3SNicolas Vasilache static Value allocBuffer(ImplicitLocOpBuilder &b, 49ef33c6e3SNicolas Vasilache const LinalgPromotionOptions &options, 50ef33c6e3SNicolas Vasilache Type elementType, Value allocSize, DataLayout &layout, 5122426110SRamkumar Ramachandra std::optional<unsigned> alignment = std::nullopt) { 528134a8fcSOleksandr "Alex" Zinenko llvm::TypeSize width = layout.getTypeSize(elementType); 538134a8fcSOleksandr "Alex" Zinenko assert(!width.isScalable() && "cannot allocate buffer for a scalable vector"); 54ef33c6e3SNicolas Vasilache 55ef33c6e3SNicolas Vasilache IntegerAttr alignmentAttr; 56491d2701SKazu Hirata if (alignment.has_value()) 57c27d8152SKazu Hirata alignmentAttr = b.getI64IntegerAttr(alignment.value()); 58ef33c6e3SNicolas Vasilache 59d6a2014eSAviad Cohen Attribute memorySpaceAttr; 60d6a2014eSAviad Cohen if (options.memorySpace.has_value()) 61d6a2014eSAviad Cohen memorySpaceAttr = *options.memorySpace; 62d6a2014eSAviad Cohen 63ef33c6e3SNicolas Vasilache // Static buffer. 64cb7bda2aSMatthias Springer if (std::optional<int64_t> cst = getConstantIntValue(allocSize)) { 658134a8fcSOleksandr "Alex" Zinenko auto staticBufferType = MemRefType::get(width.getFixedValue() * cst.value(), 668134a8fcSOleksandr "Alex" Zinenko b.getIntegerType(8)); 67d6a2014eSAviad Cohen staticBufferType = 68d6a2014eSAviad Cohen MemRefType::Builder(staticBufferType).setMemorySpace(memorySpaceAttr); 69ef33c6e3SNicolas Vasilache if (options.useAlloca) { 70b23c8225SMatthias Springer return b.create<memref::AllocaOp>(staticBufferType, ValueRange{}, 71ef33c6e3SNicolas Vasilache alignmentAttr); 72ef33c6e3SNicolas Vasilache } 73b23c8225SMatthias Springer return b.create<memref::AllocOp>(staticBufferType, ValueRange{}, 74ef33c6e3SNicolas Vasilache alignmentAttr); 75ef33c6e3SNicolas Vasilache } 76ef33c6e3SNicolas Vasilache 77ef33c6e3SNicolas Vasilache // Fallback dynamic buffer. 78fb4cedccSAliia Khasanova auto dynamicBufferType = 79399638f9SAliia Khasanova MemRefType::get(ShapedType::kDynamic, b.getIntegerType(8)); 80d6a2014eSAviad Cohen dynamicBufferType = 81d6a2014eSAviad Cohen MemRefType::Builder(dynamicBufferType).setMemorySpace(memorySpaceAttr); 82a54f4eaeSMogball Value mul = b.createOrFold<arith::MulIOp>( 83a54f4eaeSMogball b.create<arith::ConstantIndexOp>(width), allocSize); 84ef33c6e3SNicolas Vasilache if (options.useAlloca) 85ef33c6e3SNicolas Vasilache return b.create<memref::AllocaOp>(dynamicBufferType, mul, alignmentAttr); 86ef33c6e3SNicolas Vasilache return b.create<memref::AllocOp>(dynamicBufferType, mul, alignmentAttr); 875b03e692SNicolas Vasilache } 885b03e692SNicolas Vasilache 890ed2d4c7SMaheshRavishankar /// Default allocation callback function. This allocates a promoted buffer when 900ed2d4c7SMaheshRavishankar /// no call back to do so is provided. The default is to allocate a 910ed2d4c7SMaheshRavishankar /// memref<..xi8> and return a view to get a memref type of shape 920ed2d4c7SMaheshRavishankar /// boundingSubViewSize. 9322426110SRamkumar Ramachandra static std::optional<Value> defaultAllocBufferCallBack( 9422426110SRamkumar Ramachandra const LinalgPromotionOptions &options, OpBuilder &builder, 9522426110SRamkumar Ramachandra memref::SubViewOp subView, ArrayRef<Value> boundingSubViewSize, 9622426110SRamkumar Ramachandra std::optional<unsigned> alignment, DataLayout &layout) { 970ed2d4c7SMaheshRavishankar ShapedType viewType = subView.getType(); 98ef33c6e3SNicolas Vasilache ImplicitLocOpBuilder b(subView.getLoc(), builder); 99b23c8225SMatthias Springer auto zero = b.create<arith::ConstantIndexOp>(0); 100b23c8225SMatthias Springer auto one = b.create<arith::ConstantIndexOp>(1); 1010ed2d4c7SMaheshRavishankar 102d6a2014eSAviad Cohen Attribute memorySpaceAttr; 103d6a2014eSAviad Cohen if (options.memorySpace.has_value()) 104d6a2014eSAviad Cohen memorySpaceAttr = *options.memorySpace; 105d6a2014eSAviad Cohen 1060ed2d4c7SMaheshRavishankar Value allocSize = one; 107e4853be2SMehdi Amini for (const auto &size : llvm::enumerate(boundingSubViewSize)) 108a54f4eaeSMogball allocSize = b.createOrFold<arith::MulIOp>(allocSize, size.value()); 109ef33c6e3SNicolas Vasilache Value buffer = allocBuffer(b, options, viewType.getElementType(), allocSize, 110ef33c6e3SNicolas Vasilache layout, alignment); 1110ed2d4c7SMaheshRavishankar SmallVector<int64_t, 4> dynSizes(boundingSubViewSize.size(), 112399638f9SAliia Khasanova ShapedType::kDynamic); 113d6a2014eSAviad Cohen 114d6a2014eSAviad Cohen auto viewMemRefType = MemRefType::get(dynSizes, viewType.getElementType()); 115d6a2014eSAviad Cohen viewMemRefType = 116d6a2014eSAviad Cohen MemRefType::Builder(viewMemRefType).setMemorySpace(memorySpaceAttr); 117d6a2014eSAviad Cohen Value view = b.createOrFold<memref::ViewOp>(viewMemRefType, buffer, zero, 118ef33c6e3SNicolas Vasilache boundingSubViewSize); 1190ed2d4c7SMaheshRavishankar return view; 1200ed2d4c7SMaheshRavishankar } 1210ed2d4c7SMaheshRavishankar 1220ed2d4c7SMaheshRavishankar /// Default implementation of deallocation of the buffer use for promotion. It 1230ed2d4c7SMaheshRavishankar /// expects to get the same value that the default allocation method returned, 1240ed2d4c7SMaheshRavishankar /// i.e. result of a ViewOp. 1257d9518c8SNicolas Vasilache static LogicalResult 1267d9518c8SNicolas Vasilache defaultDeallocBufferCallBack(const LinalgPromotionOptions &options, 1277d9518c8SNicolas Vasilache OpBuilder &b, Value fullLocalView) { 1284519ca3dSNicolas Vasilache if (!options.useAlloca) { 1294519ca3dSNicolas Vasilache auto viewOp = cast<memref::ViewOp>(fullLocalView.getDefiningOp()); 130136d746eSJacques Pienaar b.create<memref::DeallocOp>(viewOp.getSource().getLoc(), 131136d746eSJacques Pienaar viewOp.getSource()); 1324519ca3dSNicolas Vasilache } 1330ed2d4c7SMaheshRavishankar return success(); 1340ed2d4c7SMaheshRavishankar } 1350ed2d4c7SMaheshRavishankar 1360ed2d4c7SMaheshRavishankar namespace { 1370ed2d4c7SMaheshRavishankar 1380ed2d4c7SMaheshRavishankar /// Helper struct that captures the information required to apply the 1390ed2d4c7SMaheshRavishankar /// transformation on each op. This bridges the abstraction gap with the 1400ed2d4c7SMaheshRavishankar /// user-facing API which exposes positional arguments to control which operands 1410ed2d4c7SMaheshRavishankar /// are promoted. 1420ed2d4c7SMaheshRavishankar struct LinalgOpInstancePromotionOptions { 1430ed2d4c7SMaheshRavishankar LinalgOpInstancePromotionOptions(LinalgOp op, 1440ed2d4c7SMaheshRavishankar const LinalgPromotionOptions &options); 1450ed2d4c7SMaheshRavishankar /// SubViews to promote. 1464519ca3dSNicolas Vasilache MapVector<int64_t, Value> subViews; 14770604222SAviad Cohen /// Subviews operand numbers to copy in using copyInFn. 14870604222SAviad Cohen llvm::SmallSet<int64_t, 4> operandsNumbersToCopyIn; 1490ed2d4c7SMaheshRavishankar /// True if the full view should be used for the promoted buffer. 1500ed2d4c7SMaheshRavishankar DenseMap<Value, bool> useFullTileBuffers; 1510ed2d4c7SMaheshRavishankar 1520ed2d4c7SMaheshRavishankar /// Callback functions for allocation and deallocation of promoted buffers, as 1530ed2d4c7SMaheshRavishankar /// well as to copy the data into and out of these buffers. 1540ed2d4c7SMaheshRavishankar AllocBufferCallbackFn allocationFn; 1550ed2d4c7SMaheshRavishankar DeallocBufferCallbackFn deallocationFn; 1560ed2d4c7SMaheshRavishankar CopyCallbackFn copyInFn; 1570ed2d4c7SMaheshRavishankar CopyCallbackFn copyOutFn; 1580ed2d4c7SMaheshRavishankar 1590ed2d4c7SMaheshRavishankar /// Alignment of promoted buffer. 16022426110SRamkumar Ramachandra std::optional<unsigned> alignment; 1610ed2d4c7SMaheshRavishankar }; 1620ed2d4c7SMaheshRavishankar } // namespace 1630ed2d4c7SMaheshRavishankar 1640ed2d4c7SMaheshRavishankar LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( 1650ed2d4c7SMaheshRavishankar LinalgOp linalgOp, const LinalgPromotionOptions &options) 1665a001136SNicolas Vasilache : subViews(), alignment(options.alignment) { 1670a8e3dd4SMatthias Springer assert(linalgOp.hasPureBufferSemantics() && 1680a8e3dd4SMatthias Springer "revisit usage of shaped operand"); 1690ed2d4c7SMaheshRavishankar auto vUseFullTileBuffers = 17030c67587SKazu Hirata options.useFullTileBuffers.value_or(llvm::SmallBitVector()); 171a7cccb9cSAlexander Belyaev vUseFullTileBuffers.resize(linalgOp->getNumOperands(), 172e70d2c8eSTobias Gysi options.useFullTileBuffersDefault); 1730ed2d4c7SMaheshRavishankar 174a7cccb9cSAlexander Belyaev for (OpOperand &opOperand : linalgOp->getOpOperands()) { 175a7cccb9cSAlexander Belyaev int64_t operandNumber = opOperand.getOperandNumber(); 176e70d2c8eSTobias Gysi if (options.operandsToPromote && 177e70d2c8eSTobias Gysi !options.operandsToPromote->count(operandNumber)) 1780ed2d4c7SMaheshRavishankar continue; 179a7cccb9cSAlexander Belyaev Operation *op = opOperand.get().getDefiningOp(); 180e2310704SJulian Gross if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) { 181e70d2c8eSTobias Gysi subViews[operandNumber] = sv; 18270604222SAviad Cohen // In case of linalg generic, copy in only if subview is used in linalg 18370604222SAviad Cohen // payload. 18470604222SAviad Cohen if (!isa<linalg::GenericOp>(linalgOp) || 18570604222SAviad Cohen linalgOp.payloadUsesValueFromOperand(&opOperand)) 18670604222SAviad Cohen operandsNumbersToCopyIn.insert(operandNumber); 187e70d2c8eSTobias Gysi useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber]; 1880ed2d4c7SMaheshRavishankar } 1890ed2d4c7SMaheshRavishankar } 1900ed2d4c7SMaheshRavishankar 1914519ca3dSNicolas Vasilache if (options.allocationFn) { 1924519ca3dSNicolas Vasilache allocationFn = *options.allocationFn; 1934519ca3dSNicolas Vasilache } else { 1944519ca3dSNicolas Vasilache allocationFn = [&](OpBuilder &b, memref::SubViewOp subViewOp, 195ef33c6e3SNicolas Vasilache ArrayRef<Value> boundingSubViewSize, 19622426110SRamkumar Ramachandra DataLayout &layout) -> std::optional<Value> { 1974519ca3dSNicolas Vasilache return defaultAllocBufferCallBack(options, b, subViewOp, 1984519ca3dSNicolas Vasilache boundingSubViewSize, alignment, layout); 1994519ca3dSNicolas Vasilache }; 2004519ca3dSNicolas Vasilache } 2014519ca3dSNicolas Vasilache 2024519ca3dSNicolas Vasilache if (options.deallocationFn) { 2034519ca3dSNicolas Vasilache deallocationFn = *options.deallocationFn; 2044519ca3dSNicolas Vasilache } else { 2054519ca3dSNicolas Vasilache deallocationFn = [&](OpBuilder &b, Value buffer) { 2067d9518c8SNicolas Vasilache return defaultDeallocBufferCallBack(options, b, buffer); 2074519ca3dSNicolas Vasilache }; 2084519ca3dSNicolas Vasilache } 2094519ca3dSNicolas Vasilache 2104519ca3dSNicolas Vasilache // Save the loc because `linalgOp` goes out of scope. 2114519ca3dSNicolas Vasilache Location loc = linalgOp.getLoc(); 2124519ca3dSNicolas Vasilache auto defaultCopyCallBack = [loc](OpBuilder &b, Value src, 2130ed2d4c7SMaheshRavishankar Value dst) -> LogicalResult { 214a7d6039fSAviad Cohen b.create<linalg::CopyOp>(loc, src, dst); 2150ed2d4c7SMaheshRavishankar return success(); 2160ed2d4c7SMaheshRavishankar }; 2170ed2d4c7SMaheshRavishankar copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack); 2180ed2d4c7SMaheshRavishankar copyOutFn = (options.copyOutFn ? *(options.copyOutFn) : defaultCopyCallBack); 2190ed2d4c7SMaheshRavishankar } 2200ed2d4c7SMaheshRavishankar 2215b03e692SNicolas Vasilache // Performs promotion of a `subView` into a local buffer of the size of the 2225b03e692SNicolas Vasilache // *ranges* of the `subView`. This produces a buffer whose size may be bigger 2235b03e692SNicolas Vasilache // than the actual size of the `subView` at the boundaries. 2245b03e692SNicolas Vasilache // This is related to the full/partial tile problem. 2255b03e692SNicolas Vasilache // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and 2265b03e692SNicolas Vasilache // `partialLocalView` such that: 2275b03e692SNicolas Vasilache // * `buffer` is always the size of the full tile. 2285b03e692SNicolas Vasilache // * `fullLocalView` is a dense contiguous view into that buffer. 2295b03e692SNicolas Vasilache // * `partialLocalView` is a dense non-contiguous slice of `fullLocalView` 2305b03e692SNicolas Vasilache // that corresponds to the size of `subView` and accounting for boundary 2315b03e692SNicolas Vasilache // effects. 2325b03e692SNicolas Vasilache // The point of the full tile buffer is that constant static tile sizes are 2335b03e692SNicolas Vasilache // folded and result in a buffer type with statically known size and alignment 2345b03e692SNicolas Vasilache // properties. 2355b03e692SNicolas Vasilache // To account for general boundary effects, padding must be performed on the 2365b03e692SNicolas Vasilache // boundary tiles. For now this is done with an unconditional `fill` op followed 2375b03e692SNicolas Vasilache // by a partial `copy` op. 238489fec27SNicolas Vasilache FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer( 239e2310704SJulian Gross OpBuilder &b, Location loc, memref::SubViewOp subView, 2401fc096afSMehdi Amini const AllocBufferCallbackFn &allocationFn, DataLayout &layout) { 2410bd6390bSNicolas Vasilache auto viewType = subView.getType(); 2425b03e692SNicolas Vasilache auto rank = viewType.getRank(); 24305d5125dSNicolas Vasilache SmallVector<Value, 4> fullSizes; 24405d5125dSNicolas Vasilache SmallVector<OpFoldResult> partialSizes; 2453cb1f35dSNicolas Vasilache fullSizes.reserve(rank); 2463cb1f35dSNicolas Vasilache partialSizes.reserve(rank); 24799069ab2SChristopher Bate llvm::SmallBitVector droppedDims = subView.getDroppedDims(); 24899069ab2SChristopher Bate int64_t resultDimIdx = 0; 249e4853be2SMehdi Amini for (const auto &en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) { 25099069ab2SChristopher Bate if (droppedDims[en.index()]) 25199069ab2SChristopher Bate continue; 2525b03e692SNicolas Vasilache auto rangeValue = en.value(); 25370e99f38SAlex Zinenko // Try to extract a tight constant. If the size is known statically, no need 25470e99f38SAlex Zinenko // to look for the bound. 2558dbbb223SNicolas Vasilache LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n"); 25670e99f38SAlex Zinenko Value size; 25768f58812STres Popp if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) { 2584bf84e43SAlexander Belyaev size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); 25970e99f38SAlex Zinenko } else { 26053da8600STobias Gysi FailureOr<int64_t> upperBound = 261eabb6ccdSMatthias Springer ValueBoundsConstraintSet::computeConstantBound( 26240dd3aa9SMatthias Springer presburger::BoundType::UB, rangeValue.size, 263eabb6ccdSMatthias Springer /*stopCondition=*/nullptr, /*closedUB=*/true); 26470e99f38SAlex Zinenko size = failed(upperBound) 26540dd3aa9SMatthias Springer ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size) 266cbb09813SFangrui Song : b.create<arith::ConstantIndexOp>(loc, *upperBound); 26770e99f38SAlex Zinenko } 2688dbbb223SNicolas Vasilache LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); 2693cb1f35dSNicolas Vasilache fullSizes.push_back(size); 2704519ca3dSNicolas Vasilache partialSizes.push_back( 27199069ab2SChristopher Bate b.createOrFold<memref::DimOp>(loc, subView, resultDimIdx++)); 2725b03e692SNicolas Vasilache } 273399638f9SAliia Khasanova SmallVector<int64_t, 4> dynSizes(fullSizes.size(), ShapedType::kDynamic); 2740ed2d4c7SMaheshRavishankar // If a callback is not specified, then use the default implementation for 2750ed2d4c7SMaheshRavishankar // allocating the promoted buffer. 27622426110SRamkumar Ramachandra std::optional<Value> fullLocalView = 27722426110SRamkumar Ramachandra allocationFn(b, subView, fullSizes, layout); 2780ed2d4c7SMaheshRavishankar if (!fullLocalView) 279489fec27SNicolas Vasilache return failure(); 28005d5125dSNicolas Vasilache SmallVector<OpFoldResult, 4> zeros(fullSizes.size(), b.getIndexAttr(0)); 28105d5125dSNicolas Vasilache SmallVector<OpFoldResult, 4> ones(fullSizes.size(), b.getIndexAttr(1)); 282ef33c6e3SNicolas Vasilache auto partialLocalView = b.createOrFold<memref::SubViewOp>( 283ef33c6e3SNicolas Vasilache loc, *fullLocalView, zeros, partialSizes, ones); 2840ed2d4c7SMaheshRavishankar return PromotionInfo{*fullLocalView, partialLocalView}; 2855b03e692SNicolas Vasilache } 2865b03e692SNicolas Vasilache 287489fec27SNicolas Vasilache static FailureOr<MapVector<int64_t, PromotionInfo>> 2884519ca3dSNicolas Vasilache promoteSubViews(ImplicitLocOpBuilder &b, 289ef33c6e3SNicolas Vasilache LinalgOpInstancePromotionOptions options, DataLayout &layout) { 2908dbbb223SNicolas Vasilache if (options.subViews.empty()) 291489fec27SNicolas Vasilache return failure(); 2925b03e692SNicolas Vasilache 2934519ca3dSNicolas Vasilache MapVector<int64_t, PromotionInfo> promotionInfoMap; 2945b03e692SNicolas Vasilache 2958dbbb223SNicolas Vasilache for (auto v : options.subViews) { 296e2310704SJulian Gross memref::SubViewOp subView = 297e2310704SJulian Gross cast<memref::SubViewOp>(v.second.getDefiningOp()); 298489fec27SNicolas Vasilache auto promotionInfo = promoteSubviewAsNewBuffer( 2994519ca3dSNicolas Vasilache b, b.getLoc(), subView, options.allocationFn, layout); 300489fec27SNicolas Vasilache if (failed(promotionInfo)) 301489fec27SNicolas Vasilache return failure(); 3020ed2d4c7SMaheshRavishankar promotionInfoMap[v.first] = *promotionInfo; 3030ed2d4c7SMaheshRavishankar 304d1866f89SPierre Oechsel // Only fill the buffer if the full local view is used 3050ed2d4c7SMaheshRavishankar if (!options.useFullTileBuffers[v.second]) 306d1866f89SPierre Oechsel continue; 3074519ca3dSNicolas Vasilache Type subviewEltType = subView.getType().getElementType(); 3084519ca3dSNicolas Vasilache Value fillVal = 3094519ca3dSNicolas Vasilache llvm::TypeSwitch<Type, Value>(subviewEltType) 3104519ca3dSNicolas Vasilache .Case([&](FloatType t) { 311a54f4eaeSMogball return b.create<arith::ConstantOp>(FloatAttr::get(t, 0.0)); 3124519ca3dSNicolas Vasilache }) 3134519ca3dSNicolas Vasilache .Case([&](IntegerType t) { 314a54f4eaeSMogball return b.create<arith::ConstantOp>(IntegerAttr::get(t, 0)); 3154519ca3dSNicolas Vasilache }) 3164519ca3dSNicolas Vasilache .Case([&](ComplexType t) { 3174519ca3dSNicolas Vasilache Value tmp; 3185550c821STres Popp if (auto et = dyn_cast<FloatType>(t.getElementType())) 319a54f4eaeSMogball tmp = b.create<arith::ConstantOp>(FloatAttr::get(et, 0.0)); 3205550c821STres Popp else if (auto et = cast<IntegerType>(t.getElementType())) 321a54f4eaeSMogball tmp = b.create<arith::ConstantOp>(IntegerAttr::get(et, 0)); 3224519ca3dSNicolas Vasilache return b.create<complex::CreateOp>(t, tmp, tmp); 3234519ca3dSNicolas Vasilache }) 3244519ca3dSNicolas Vasilache .Default([](auto) { return Value(); }); 3254519ca3dSNicolas Vasilache if (!fillVal) 326489fec27SNicolas Vasilache return failure(); 3277cef24eeSTobias Gysi b.create<linalg::FillOp>(fillVal, promotionInfo->fullLocalView); 3285b03e692SNicolas Vasilache } 3295b03e692SNicolas Vasilache 3300ed2d4c7SMaheshRavishankar // Copy data into the promoted buffers. Use callback if provided. 3318dbbb223SNicolas Vasilache for (auto v : options.subViews) { 332b1d4265aSMehdi Amini auto *info = promotionInfoMap.find(v.first); 3335b03e692SNicolas Vasilache if (info == promotionInfoMap.end()) 3345b03e692SNicolas Vasilache continue; 33570604222SAviad Cohen if (options.operandsNumbersToCopyIn.count(v.first) == 0) 33670604222SAviad Cohen continue; 337e2310704SJulian Gross if (failed(options.copyInFn( 338e2310704SJulian Gross b, cast<memref::SubViewOp>(v.second.getDefiningOp()), 3390ed2d4c7SMaheshRavishankar info->second.partialLocalView))) 340489fec27SNicolas Vasilache return failure(); 3415b03e692SNicolas Vasilache } 3420ed2d4c7SMaheshRavishankar return promotionInfoMap; 3435b03e692SNicolas Vasilache } 3445b03e692SNicolas Vasilache 345489fec27SNicolas Vasilache static FailureOr<LinalgOp> 3464519ca3dSNicolas Vasilache promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op, 347ef33c6e3SNicolas Vasilache LinalgOpInstancePromotionOptions options, DataLayout &layout) { 3480a8e3dd4SMatthias Springer assert(op.hasPureBufferSemantics() && 3490a8e3dd4SMatthias Springer "expected linalg op with buffer semantics"); 350f52d7173SNicolas Vasilache 3515b03e692SNicolas Vasilache // 1. Promote the specified views and use them in the new op. 3524519ca3dSNicolas Vasilache auto promotedBuffersAndViews = promoteSubViews(b, options, layout); 353489fec27SNicolas Vasilache if (failed(promotedBuffersAndViews) || 3540ed2d4c7SMaheshRavishankar promotedBuffersAndViews->size() != options.subViews.size()) 355489fec27SNicolas Vasilache return failure(); 3565b03e692SNicolas Vasilache 3575b03e692SNicolas Vasilache // 2. Append all other operands as they appear, this enforces that such 3585b03e692SNicolas Vasilache // operands are not views. This is to support cases such as FillOp taking 3590ed2d4c7SMaheshRavishankar // extra scalars etc. Keep a reference to output buffers; 3600ed2d4c7SMaheshRavishankar SmallVector<Value, 8> opViews; 361a7cccb9cSAlexander Belyaev opViews.reserve(op->getNumOperands()); 3620ed2d4c7SMaheshRavishankar SmallVector<std::pair<Value, Value>, 8> writebackViews; 3630ed2d4c7SMaheshRavishankar writebackViews.reserve(promotedBuffersAndViews->size()); 364a7cccb9cSAlexander Belyaev for (OpOperand &opOperand : op->getOpOperands()) { 365a7cccb9cSAlexander Belyaev int64_t operandNumber = opOperand.getOperandNumber(); 366e70d2c8eSTobias Gysi if (options.subViews.count(operandNumber) != 0) { 367a7cccb9cSAlexander Belyaev if (options.useFullTileBuffers[opOperand.get()]) 3680ed2d4c7SMaheshRavishankar opViews.push_back( 369e70d2c8eSTobias Gysi (*promotedBuffersAndViews)[operandNumber].fullLocalView); 3700ed2d4c7SMaheshRavishankar else 3710ed2d4c7SMaheshRavishankar opViews.push_back( 372e70d2c8eSTobias Gysi (*promotedBuffersAndViews)[operandNumber].partialLocalView); 373b4db15a9SAlexander Belyaev if (operandNumber >= op.getNumDpsInputs()) 3740ed2d4c7SMaheshRavishankar writebackViews.emplace_back(std::make_pair( 375a7cccb9cSAlexander Belyaev opOperand.get(), 376e70d2c8eSTobias Gysi (*promotedBuffersAndViews)[operandNumber].partialLocalView)); 3770ed2d4c7SMaheshRavishankar } else { 378a7cccb9cSAlexander Belyaev opViews.push_back(opOperand.get()); 3790ed2d4c7SMaheshRavishankar } 3800ed2d4c7SMaheshRavishankar } 381c4a04059SChristian Sigg op->setOperands(0, opViews.size(), opViews); 3825b03e692SNicolas Vasilache 3838dbbb223SNicolas Vasilache OpBuilder::InsertionGuard guard(b); 3848dbbb223SNicolas Vasilache b.setInsertionPointAfter(op); 3855b03e692SNicolas Vasilache // 3. Emit write-back for the promoted output views: copy the partial view. 3860ed2d4c7SMaheshRavishankar for (auto viewAndPartialLocalView : writebackViews) { 3870ed2d4c7SMaheshRavishankar if (failed(options.copyOutFn(b, viewAndPartialLocalView.second, 3880ed2d4c7SMaheshRavishankar viewAndPartialLocalView.first))) 389489fec27SNicolas Vasilache return failure(); 3900ed2d4c7SMaheshRavishankar } 3915b03e692SNicolas Vasilache 3928dbbb223SNicolas Vasilache // 4. Dealloc all local buffers. 3937d9518c8SNicolas Vasilache for (const auto &pi : *promotedBuffersAndViews) 394e21adfa3SRiver Riddle (void)options.deallocationFn(b, pi.second.fullLocalView); 3950ed2d4c7SMaheshRavishankar return op; 3965b03e692SNicolas Vasilache } 3975b03e692SNicolas Vasilache 3988dbbb223SNicolas Vasilache LogicalResult 3998dbbb223SNicolas Vasilache mlir::linalg::promoteSubviewsPrecondition(Operation *op, 4008dbbb223SNicolas Vasilache LinalgPromotionOptions options) { 401e70d2c8eSTobias Gysi LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 402307cfdf5SNicolas Vasilache // Transformation applies to buffers only. 4030a8e3dd4SMatthias Springer if (!linalgOp || !linalgOp.hasPureBufferSemantics()) 404307cfdf5SNicolas Vasilache return failure(); 4058dbbb223SNicolas Vasilache // Check that at least one of the requested operands is indeed a subview. 406a7cccb9cSAlexander Belyaev for (OpOperand &opOperand : linalgOp->getOpOperands()) { 407e70d2c8eSTobias Gysi auto sv = 408a7cccb9cSAlexander Belyaev isa_and_nonnull<memref::SubViewOp>(opOperand.get().getDefiningOp()); 4098dbbb223SNicolas Vasilache if (sv) { 410037f0995SKazu Hirata if (!options.operandsToPromote || 411a7cccb9cSAlexander Belyaev options.operandsToPromote->count(opOperand.getOperandNumber())) 412307cfdf5SNicolas Vasilache return success(); 413307cfdf5SNicolas Vasilache } 4148dbbb223SNicolas Vasilache } 4158dbbb223SNicolas Vasilache // TODO: Check all subviews requested are bound by a static constant. 4168dbbb223SNicolas Vasilache // TODO: Check that the total footprint fits within a given size. 417307cfdf5SNicolas Vasilache return failure(); 418307cfdf5SNicolas Vasilache } 419307cfdf5SNicolas Vasilache 420489fec27SNicolas Vasilache FailureOr<LinalgOp> 4214519ca3dSNicolas Vasilache mlir::linalg::promoteSubViews(OpBuilder &builder, LinalgOp linalgOp, 4221fc096afSMehdi Amini const LinalgPromotionOptions &options) { 4238dbbb223SNicolas Vasilache LinalgOpInstancePromotionOptions linalgOptions(linalgOp, options); 42442e5f422STres Popp auto layout = DataLayout::closest(linalgOp); 4254519ca3dSNicolas Vasilache ImplicitLocOpBuilder b(linalgOp.getLoc(), builder); 426489fec27SNicolas Vasilache auto res = ::promoteSubViews(b, linalgOp, linalgOptions, layout); 427489fec27SNicolas Vasilache if (failed(res)) 428489fec27SNicolas Vasilache return failure(); 429489fec27SNicolas Vasilache return res; 4308dbbb223SNicolas Vasilache } 431115711c1SAmir Mohammad Tavakkoli 432115711c1SAmir Mohammad Tavakkoli /// Allocate the given subview to a memory address space in GPU by creating a 433115711c1SAmir Mohammad Tavakkoli /// allocation operation and setting the memref type address space to desired 434115711c1SAmir Mohammad Tavakkoli /// address space. 4352010269cSKazu Hirata static std::optional<Value> allocateSubviewGPUMemoryInAddressSpace( 436115711c1SAmir Mohammad Tavakkoli OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds, 437115711c1SAmir Mohammad Tavakkoli gpu::AddressSpace addressSpace) { 438115711c1SAmir Mohammad Tavakkoli OpBuilder::InsertionGuard guard(builder); 439115711c1SAmir Mohammad Tavakkoli 440115711c1SAmir Mohammad Tavakkoli func::FuncOp funcOp = subview->getParentOfType<func::FuncOp>(); 441115711c1SAmir Mohammad Tavakkoli if (!funcOp) 442115711c1SAmir Mohammad Tavakkoli return std::nullopt; 443115711c1SAmir Mohammad Tavakkoli 444115711c1SAmir Mohammad Tavakkoli // The subview size bounds are expected to be constant; they specify the shape 445115711c1SAmir Mohammad Tavakkoli // of the allocation. 446115711c1SAmir Mohammad Tavakkoli SmallVector<int64_t> shape; 447115711c1SAmir Mohammad Tavakkoli for (Value bound : sizeBounds) { 448115711c1SAmir Mohammad Tavakkoli APInt value; 449115711c1SAmir Mohammad Tavakkoli if (!matchPattern(bound, m_ConstantInt(&value))) 450115711c1SAmir Mohammad Tavakkoli return std::nullopt; 451115711c1SAmir Mohammad Tavakkoli shape.push_back(value.getSExtValue()); 452115711c1SAmir Mohammad Tavakkoli } 453115711c1SAmir Mohammad Tavakkoli 454*b613a540SMatthias Springer builder.setInsertionPointToStart(&funcOp.front()); 455115711c1SAmir Mohammad Tavakkoli auto type = MemRefType::get( 456115711c1SAmir Mohammad Tavakkoli shape, subview.getType().getElementType(), MemRefLayoutAttrInterface{}, 457115711c1SAmir Mohammad Tavakkoli gpu::AddressSpaceAttr::get(builder.getContext(), addressSpace)); 458115711c1SAmir Mohammad Tavakkoli Value buffer; 459115711c1SAmir Mohammad Tavakkoli if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) { 460115711c1SAmir Mohammad Tavakkoli buffer = builder.create<memref::AllocOp>(funcOp.getLoc(), type); 461115711c1SAmir Mohammad Tavakkoli } else if (addressSpace == gpu::GPUDialect::getPrivateAddressSpace()) { 462115711c1SAmir Mohammad Tavakkoli buffer = builder.create<memref::AllocaOp>(funcOp.getLoc(), type); 463115711c1SAmir Mohammad Tavakkoli } else { 464115711c1SAmir Mohammad Tavakkoli return std::nullopt; 465115711c1SAmir Mohammad Tavakkoli } 466115711c1SAmir Mohammad Tavakkoli return buffer; 467115711c1SAmir Mohammad Tavakkoli } 468115711c1SAmir Mohammad Tavakkoli 469115711c1SAmir Mohammad Tavakkoli /// Allocate the subview in the GPU workgroup memory. 4702010269cSKazu Hirata std::optional<Value> mlir::linalg::allocateWorkgroupMemory( 471115711c1SAmir Mohammad Tavakkoli OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds, 472115711c1SAmir Mohammad Tavakkoli DataLayout &) { 473115711c1SAmir Mohammad Tavakkoli return allocateSubviewGPUMemoryInAddressSpace( 474115711c1SAmir Mohammad Tavakkoli builder, subview, sizeBounds, 475115711c1SAmir Mohammad Tavakkoli gpu::GPUDialect::getWorkgroupAddressSpace()); 476115711c1SAmir Mohammad Tavakkoli } 477115711c1SAmir Mohammad Tavakkoli 478115711c1SAmir Mohammad Tavakkoli /// In case of GPU group memory there is no need to deallocate. 479115711c1SAmir Mohammad Tavakkoli LogicalResult mlir::linalg::deallocateWorkgroupMemory(OpBuilder &, 480115711c1SAmir Mohammad Tavakkoli Value /*buffer*/) { 481115711c1SAmir Mohammad Tavakkoli return success(); 482115711c1SAmir Mohammad Tavakkoli } 483115711c1SAmir Mohammad Tavakkoli 484115711c1SAmir Mohammad Tavakkoli /// Create Memref copy operations and add gpu barrier guards before and after 485115711c1SAmir Mohammad Tavakkoli /// the copy operation to ensure data integrity. 486115711c1SAmir Mohammad Tavakkoli LogicalResult mlir::linalg::copyToWorkgroupMemory(OpBuilder &b, Value src, 487115711c1SAmir Mohammad Tavakkoli Value dst) { 488115711c1SAmir Mohammad Tavakkoli b.create<gpu::BarrierOp>(src.getLoc()); 489115711c1SAmir Mohammad Tavakkoli Operation *copyOp = b.create<memref::CopyOp>(src.getLoc(), src, dst); 490115711c1SAmir Mohammad Tavakkoli b.create<gpu::BarrierOp>(copyOp->getLoc()); 491115711c1SAmir Mohammad Tavakkoli return success(); 492115711c1SAmir Mohammad Tavakkoli } 493115711c1SAmir Mohammad Tavakkoli 494115711c1SAmir Mohammad Tavakkoli /// Allocate the subview in the GPU private memory. 4952010269cSKazu Hirata std::optional<Value> mlir::linalg::allocateGPUPrivateMemory( 496115711c1SAmir Mohammad Tavakkoli OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds, 497115711c1SAmir Mohammad Tavakkoli DataLayout &) { 498115711c1SAmir Mohammad Tavakkoli return allocateSubviewGPUMemoryInAddressSpace( 499115711c1SAmir Mohammad Tavakkoli builder, subview, sizeBounds, gpu::GPUDialect::getPrivateAddressSpace()); 500115711c1SAmir Mohammad Tavakkoli } 501115711c1SAmir Mohammad Tavakkoli 502115711c1SAmir Mohammad Tavakkoli /// Normal copy to between src and dst. 503115711c1SAmir Mohammad Tavakkoli LogicalResult mlir::linalg::copyToGPUPrivateMemory(OpBuilder &b, Value src, 504115711c1SAmir Mohammad Tavakkoli Value dst) { 505779d54fdSHaojian Wu b.create<memref::CopyOp>(src.getLoc(), src, dst); 506115711c1SAmir Mohammad Tavakkoli return success(); 507115711c1SAmir Mohammad Tavakkoli } 508115711c1SAmir Mohammad Tavakkoli 509115711c1SAmir Mohammad Tavakkoli /// In case of GPU private memory there is no need to deallocate since the 510115711c1SAmir Mohammad Tavakkoli /// memory is freed when going outside of the scope. 511115711c1SAmir Mohammad Tavakkoli LogicalResult mlir::linalg::deallocateGPUPrivateMemory(OpBuilder &, 512115711c1SAmir Mohammad Tavakkoli Value /*buffer*/) { 513115711c1SAmir Mohammad Tavakkoli return success(); 514115711c1SAmir Mohammad Tavakkoli } 515