xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (revision b613a54075c6e704dcaa15a676bf732955eb4352)
15b03e692SNicolas Vasilache //===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
25b03e692SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65b03e692SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85b03e692SNicolas Vasilache //
95b03e692SNicolas Vasilache // This file implements the linalg dialect Promotion pass.
105b03e692SNicolas Vasilache //
115b03e692SNicolas Vasilache //===----------------------------------------------------------------------===//
125b03e692SNicolas Vasilache 
13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h"
1542e5f422STres Popp #include "mlir/Dialect/Complex/IR/Complex.h"
16115711c1SAmir Mohammad Tavakkoli #include "mlir/Dialect/Func/IR/FuncOps.h"
17115711c1SAmir Mohammad Tavakkoli #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
195b03e692SNicolas Vasilache #include "mlir/Dialect/Linalg/Passes.h"
20307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
218b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
225b03e692SNicolas Vasilache #include "mlir/IR/AffineExpr.h"
235b03e692SNicolas Vasilache #include "mlir/IR/AffineExprVisitor.h"
245b03e692SNicolas Vasilache #include "mlir/IR/AffineMap.h"
25ef33c6e3SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
26eabb6ccdSMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h"
275b03e692SNicolas Vasilache #include "mlir/Support/LLVM.h"
285b03e692SNicolas Vasilache #include "mlir/Transforms/FoldUtils.h"
290ed2d4c7SMaheshRavishankar #include "llvm/ADT/MapVector.h"
3099069ab2SChristopher Bate #include "llvm/ADT/SmallBitVector.h"
3170604222SAviad Cohen #include "llvm/ADT/SmallSet.h"
324519ca3dSNicolas Vasilache #include "llvm/ADT/TypeSwitch.h"
335b03e692SNicolas Vasilache #include "llvm/Support/CommandLine.h"
34f89bb3c0SAlexander Belyaev #include "llvm/Support/Debug.h"
355b03e692SNicolas Vasilache 
365b03e692SNicolas Vasilache using namespace mlir;
375b03e692SNicolas Vasilache using namespace mlir::linalg;
38c25b20c0SAlex Zinenko using namespace mlir::scf;
395b03e692SNicolas Vasilache 
400ed2d4c7SMaheshRavishankar using llvm::MapVector;
415b03e692SNicolas Vasilache 
425b03e692SNicolas Vasilache #define DEBUG_TYPE "linalg-promotion"
435b03e692SNicolas Vasilache 
44ef33c6e3SNicolas Vasilache /// Alloc a new buffer of `size` * `width` i8; where `width` is given by the
45ef33c6e3SNicolas Vasilache /// data `layout` for `elementType`.
46ef33c6e3SNicolas Vasilache /// Use AllocOp or AllocaOp depending on `options`.
47ef33c6e3SNicolas Vasilache /// Take an optional alignment.
48ef33c6e3SNicolas Vasilache static Value allocBuffer(ImplicitLocOpBuilder &b,
49ef33c6e3SNicolas Vasilache                          const LinalgPromotionOptions &options,
50ef33c6e3SNicolas Vasilache                          Type elementType, Value allocSize, DataLayout &layout,
5122426110SRamkumar Ramachandra                          std::optional<unsigned> alignment = std::nullopt) {
528134a8fcSOleksandr "Alex" Zinenko   llvm::TypeSize width = layout.getTypeSize(elementType);
538134a8fcSOleksandr "Alex" Zinenko   assert(!width.isScalable() && "cannot allocate buffer for a scalable vector");
54ef33c6e3SNicolas Vasilache 
55ef33c6e3SNicolas Vasilache   IntegerAttr alignmentAttr;
56491d2701SKazu Hirata   if (alignment.has_value())
57c27d8152SKazu Hirata     alignmentAttr = b.getI64IntegerAttr(alignment.value());
58ef33c6e3SNicolas Vasilache 
59d6a2014eSAviad Cohen   Attribute memorySpaceAttr;
60d6a2014eSAviad Cohen   if (options.memorySpace.has_value())
61d6a2014eSAviad Cohen     memorySpaceAttr = *options.memorySpace;
62d6a2014eSAviad Cohen 
63ef33c6e3SNicolas Vasilache   // Static buffer.
64cb7bda2aSMatthias Springer   if (std::optional<int64_t> cst = getConstantIntValue(allocSize)) {
658134a8fcSOleksandr "Alex" Zinenko     auto staticBufferType = MemRefType::get(width.getFixedValue() * cst.value(),
668134a8fcSOleksandr "Alex" Zinenko                                             b.getIntegerType(8));
67d6a2014eSAviad Cohen     staticBufferType =
68d6a2014eSAviad Cohen         MemRefType::Builder(staticBufferType).setMemorySpace(memorySpaceAttr);
69ef33c6e3SNicolas Vasilache     if (options.useAlloca) {
70b23c8225SMatthias Springer       return b.create<memref::AllocaOp>(staticBufferType, ValueRange{},
71ef33c6e3SNicolas Vasilache                                         alignmentAttr);
72ef33c6e3SNicolas Vasilache     }
73b23c8225SMatthias Springer     return b.create<memref::AllocOp>(staticBufferType, ValueRange{},
74ef33c6e3SNicolas Vasilache                                      alignmentAttr);
75ef33c6e3SNicolas Vasilache   }
76ef33c6e3SNicolas Vasilache 
77ef33c6e3SNicolas Vasilache   // Fallback dynamic buffer.
78fb4cedccSAliia Khasanova   auto dynamicBufferType =
79399638f9SAliia Khasanova       MemRefType::get(ShapedType::kDynamic, b.getIntegerType(8));
80d6a2014eSAviad Cohen   dynamicBufferType =
81d6a2014eSAviad Cohen       MemRefType::Builder(dynamicBufferType).setMemorySpace(memorySpaceAttr);
82a54f4eaeSMogball   Value mul = b.createOrFold<arith::MulIOp>(
83a54f4eaeSMogball       b.create<arith::ConstantIndexOp>(width), allocSize);
84ef33c6e3SNicolas Vasilache   if (options.useAlloca)
85ef33c6e3SNicolas Vasilache     return b.create<memref::AllocaOp>(dynamicBufferType, mul, alignmentAttr);
86ef33c6e3SNicolas Vasilache   return b.create<memref::AllocOp>(dynamicBufferType, mul, alignmentAttr);
875b03e692SNicolas Vasilache }
885b03e692SNicolas Vasilache 
890ed2d4c7SMaheshRavishankar /// Default allocation callback function. This allocates a promoted buffer when
900ed2d4c7SMaheshRavishankar /// no call back to do so is provided. The default is to allocate a
910ed2d4c7SMaheshRavishankar /// memref<..xi8> and return a view to get a memref type of shape
920ed2d4c7SMaheshRavishankar /// boundingSubViewSize.
9322426110SRamkumar Ramachandra static std::optional<Value> defaultAllocBufferCallBack(
9422426110SRamkumar Ramachandra     const LinalgPromotionOptions &options, OpBuilder &builder,
9522426110SRamkumar Ramachandra     memref::SubViewOp subView, ArrayRef<Value> boundingSubViewSize,
9622426110SRamkumar Ramachandra     std::optional<unsigned> alignment, DataLayout &layout) {
970ed2d4c7SMaheshRavishankar   ShapedType viewType = subView.getType();
98ef33c6e3SNicolas Vasilache   ImplicitLocOpBuilder b(subView.getLoc(), builder);
99b23c8225SMatthias Springer   auto zero = b.create<arith::ConstantIndexOp>(0);
100b23c8225SMatthias Springer   auto one = b.create<arith::ConstantIndexOp>(1);
1010ed2d4c7SMaheshRavishankar 
102d6a2014eSAviad Cohen   Attribute memorySpaceAttr;
103d6a2014eSAviad Cohen   if (options.memorySpace.has_value())
104d6a2014eSAviad Cohen     memorySpaceAttr = *options.memorySpace;
105d6a2014eSAviad Cohen 
1060ed2d4c7SMaheshRavishankar   Value allocSize = one;
107e4853be2SMehdi Amini   for (const auto &size : llvm::enumerate(boundingSubViewSize))
108a54f4eaeSMogball     allocSize = b.createOrFold<arith::MulIOp>(allocSize, size.value());
109ef33c6e3SNicolas Vasilache   Value buffer = allocBuffer(b, options, viewType.getElementType(), allocSize,
110ef33c6e3SNicolas Vasilache                              layout, alignment);
1110ed2d4c7SMaheshRavishankar   SmallVector<int64_t, 4> dynSizes(boundingSubViewSize.size(),
112399638f9SAliia Khasanova                                    ShapedType::kDynamic);
113d6a2014eSAviad Cohen 
114d6a2014eSAviad Cohen   auto viewMemRefType = MemRefType::get(dynSizes, viewType.getElementType());
115d6a2014eSAviad Cohen   viewMemRefType =
116d6a2014eSAviad Cohen       MemRefType::Builder(viewMemRefType).setMemorySpace(memorySpaceAttr);
117d6a2014eSAviad Cohen   Value view = b.createOrFold<memref::ViewOp>(viewMemRefType, buffer, zero,
118ef33c6e3SNicolas Vasilache                                               boundingSubViewSize);
1190ed2d4c7SMaheshRavishankar   return view;
1200ed2d4c7SMaheshRavishankar }
1210ed2d4c7SMaheshRavishankar 
1220ed2d4c7SMaheshRavishankar /// Default implementation of deallocation of the buffer use for promotion. It
1230ed2d4c7SMaheshRavishankar /// expects to get the same value that the default allocation method returned,
1240ed2d4c7SMaheshRavishankar /// i.e. result of a ViewOp.
1257d9518c8SNicolas Vasilache static LogicalResult
1267d9518c8SNicolas Vasilache defaultDeallocBufferCallBack(const LinalgPromotionOptions &options,
1277d9518c8SNicolas Vasilache                              OpBuilder &b, Value fullLocalView) {
1284519ca3dSNicolas Vasilache   if (!options.useAlloca) {
1294519ca3dSNicolas Vasilache     auto viewOp = cast<memref::ViewOp>(fullLocalView.getDefiningOp());
130136d746eSJacques Pienaar     b.create<memref::DeallocOp>(viewOp.getSource().getLoc(),
131136d746eSJacques Pienaar                                 viewOp.getSource());
1324519ca3dSNicolas Vasilache   }
1330ed2d4c7SMaheshRavishankar   return success();
1340ed2d4c7SMaheshRavishankar }
1350ed2d4c7SMaheshRavishankar 
1360ed2d4c7SMaheshRavishankar namespace {
1370ed2d4c7SMaheshRavishankar 
1380ed2d4c7SMaheshRavishankar /// Helper struct that captures the information required to apply the
1390ed2d4c7SMaheshRavishankar /// transformation on each op. This bridges the abstraction gap with the
1400ed2d4c7SMaheshRavishankar /// user-facing API which exposes positional arguments to control which operands
1410ed2d4c7SMaheshRavishankar /// are promoted.
1420ed2d4c7SMaheshRavishankar struct LinalgOpInstancePromotionOptions {
1430ed2d4c7SMaheshRavishankar   LinalgOpInstancePromotionOptions(LinalgOp op,
1440ed2d4c7SMaheshRavishankar                                    const LinalgPromotionOptions &options);
1450ed2d4c7SMaheshRavishankar   /// SubViews to promote.
1464519ca3dSNicolas Vasilache   MapVector<int64_t, Value> subViews;
14770604222SAviad Cohen   /// Subviews operand numbers to copy in using copyInFn.
14870604222SAviad Cohen   llvm::SmallSet<int64_t, 4> operandsNumbersToCopyIn;
1490ed2d4c7SMaheshRavishankar   /// True if the full view should be used for the promoted buffer.
1500ed2d4c7SMaheshRavishankar   DenseMap<Value, bool> useFullTileBuffers;
1510ed2d4c7SMaheshRavishankar 
1520ed2d4c7SMaheshRavishankar   /// Callback functions for allocation and deallocation of promoted buffers, as
1530ed2d4c7SMaheshRavishankar   /// well as to copy the data into and out of these buffers.
1540ed2d4c7SMaheshRavishankar   AllocBufferCallbackFn allocationFn;
1550ed2d4c7SMaheshRavishankar   DeallocBufferCallbackFn deallocationFn;
1560ed2d4c7SMaheshRavishankar   CopyCallbackFn copyInFn;
1570ed2d4c7SMaheshRavishankar   CopyCallbackFn copyOutFn;
1580ed2d4c7SMaheshRavishankar 
1590ed2d4c7SMaheshRavishankar   /// Alignment of promoted buffer.
16022426110SRamkumar Ramachandra   std::optional<unsigned> alignment;
1610ed2d4c7SMaheshRavishankar };
1620ed2d4c7SMaheshRavishankar } // namespace
1630ed2d4c7SMaheshRavishankar 
1640ed2d4c7SMaheshRavishankar LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
1650ed2d4c7SMaheshRavishankar     LinalgOp linalgOp, const LinalgPromotionOptions &options)
1665a001136SNicolas Vasilache     : subViews(), alignment(options.alignment) {
1670a8e3dd4SMatthias Springer   assert(linalgOp.hasPureBufferSemantics() &&
1680a8e3dd4SMatthias Springer          "revisit usage of shaped operand");
1690ed2d4c7SMaheshRavishankar   auto vUseFullTileBuffers =
17030c67587SKazu Hirata       options.useFullTileBuffers.value_or(llvm::SmallBitVector());
171a7cccb9cSAlexander Belyaev   vUseFullTileBuffers.resize(linalgOp->getNumOperands(),
172e70d2c8eSTobias Gysi                              options.useFullTileBuffersDefault);
1730ed2d4c7SMaheshRavishankar 
174a7cccb9cSAlexander Belyaev   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
175a7cccb9cSAlexander Belyaev     int64_t operandNumber = opOperand.getOperandNumber();
176e70d2c8eSTobias Gysi     if (options.operandsToPromote &&
177e70d2c8eSTobias Gysi         !options.operandsToPromote->count(operandNumber))
1780ed2d4c7SMaheshRavishankar       continue;
179a7cccb9cSAlexander Belyaev     Operation *op = opOperand.get().getDefiningOp();
180e2310704SJulian Gross     if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
181e70d2c8eSTobias Gysi       subViews[operandNumber] = sv;
18270604222SAviad Cohen       // In case of linalg generic, copy in only if subview is used in linalg
18370604222SAviad Cohen       // payload.
18470604222SAviad Cohen       if (!isa<linalg::GenericOp>(linalgOp) ||
18570604222SAviad Cohen           linalgOp.payloadUsesValueFromOperand(&opOperand))
18670604222SAviad Cohen         operandsNumbersToCopyIn.insert(operandNumber);
187e70d2c8eSTobias Gysi       useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
1880ed2d4c7SMaheshRavishankar     }
1890ed2d4c7SMaheshRavishankar   }
1900ed2d4c7SMaheshRavishankar 
1914519ca3dSNicolas Vasilache   if (options.allocationFn) {
1924519ca3dSNicolas Vasilache     allocationFn = *options.allocationFn;
1934519ca3dSNicolas Vasilache   } else {
1944519ca3dSNicolas Vasilache     allocationFn = [&](OpBuilder &b, memref::SubViewOp subViewOp,
195ef33c6e3SNicolas Vasilache                        ArrayRef<Value> boundingSubViewSize,
19622426110SRamkumar Ramachandra                        DataLayout &layout) -> std::optional<Value> {
1974519ca3dSNicolas Vasilache       return defaultAllocBufferCallBack(options, b, subViewOp,
1984519ca3dSNicolas Vasilache                                         boundingSubViewSize, alignment, layout);
1994519ca3dSNicolas Vasilache     };
2004519ca3dSNicolas Vasilache   }
2014519ca3dSNicolas Vasilache 
2024519ca3dSNicolas Vasilache   if (options.deallocationFn) {
2034519ca3dSNicolas Vasilache     deallocationFn = *options.deallocationFn;
2044519ca3dSNicolas Vasilache   } else {
2054519ca3dSNicolas Vasilache     deallocationFn = [&](OpBuilder &b, Value buffer) {
2067d9518c8SNicolas Vasilache       return defaultDeallocBufferCallBack(options, b, buffer);
2074519ca3dSNicolas Vasilache     };
2084519ca3dSNicolas Vasilache   }
2094519ca3dSNicolas Vasilache 
2104519ca3dSNicolas Vasilache   // Save the loc because `linalgOp` goes out of scope.
2114519ca3dSNicolas Vasilache   Location loc = linalgOp.getLoc();
2124519ca3dSNicolas Vasilache   auto defaultCopyCallBack = [loc](OpBuilder &b, Value src,
2130ed2d4c7SMaheshRavishankar                                    Value dst) -> LogicalResult {
214a7d6039fSAviad Cohen     b.create<linalg::CopyOp>(loc, src, dst);
2150ed2d4c7SMaheshRavishankar     return success();
2160ed2d4c7SMaheshRavishankar   };
2170ed2d4c7SMaheshRavishankar   copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack);
2180ed2d4c7SMaheshRavishankar   copyOutFn = (options.copyOutFn ? *(options.copyOutFn) : defaultCopyCallBack);
2190ed2d4c7SMaheshRavishankar }
2200ed2d4c7SMaheshRavishankar 
2215b03e692SNicolas Vasilache // Performs promotion of a `subView` into a local buffer of the size of the
2225b03e692SNicolas Vasilache // *ranges* of the `subView`. This produces a buffer whose size may be bigger
2235b03e692SNicolas Vasilache // than the actual size of the `subView` at the boundaries.
2245b03e692SNicolas Vasilache // This is related to the full/partial tile problem.
2255b03e692SNicolas Vasilache // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
2265b03e692SNicolas Vasilache // `partialLocalView` such that:
2275b03e692SNicolas Vasilache //   * `buffer` is always the size of the full tile.
2285b03e692SNicolas Vasilache //   * `fullLocalView` is a dense contiguous view into that buffer.
2295b03e692SNicolas Vasilache //   * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
2305b03e692SNicolas Vasilache //     that corresponds to the size of `subView` and accounting for boundary
2315b03e692SNicolas Vasilache //     effects.
2325b03e692SNicolas Vasilache // The point of the full tile buffer is that constant static tile sizes are
2335b03e692SNicolas Vasilache // folded and result in a buffer type with statically known size and alignment
2345b03e692SNicolas Vasilache // properties.
2355b03e692SNicolas Vasilache // To account for general boundary effects, padding must be performed on the
2365b03e692SNicolas Vasilache // boundary tiles. For now this is done with an unconditional `fill` op followed
2375b03e692SNicolas Vasilache // by a partial `copy` op.
238489fec27SNicolas Vasilache FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
239e2310704SJulian Gross     OpBuilder &b, Location loc, memref::SubViewOp subView,
2401fc096afSMehdi Amini     const AllocBufferCallbackFn &allocationFn, DataLayout &layout) {
2410bd6390bSNicolas Vasilache   auto viewType = subView.getType();
2425b03e692SNicolas Vasilache   auto rank = viewType.getRank();
24305d5125dSNicolas Vasilache   SmallVector<Value, 4> fullSizes;
24405d5125dSNicolas Vasilache   SmallVector<OpFoldResult> partialSizes;
2453cb1f35dSNicolas Vasilache   fullSizes.reserve(rank);
2463cb1f35dSNicolas Vasilache   partialSizes.reserve(rank);
24799069ab2SChristopher Bate   llvm::SmallBitVector droppedDims = subView.getDroppedDims();
24899069ab2SChristopher Bate   int64_t resultDimIdx = 0;
249e4853be2SMehdi Amini   for (const auto &en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
25099069ab2SChristopher Bate     if (droppedDims[en.index()])
25199069ab2SChristopher Bate       continue;
2525b03e692SNicolas Vasilache     auto rangeValue = en.value();
25370e99f38SAlex Zinenko     // Try to extract a tight constant. If the size is known statically, no need
25470e99f38SAlex Zinenko     // to look for the bound.
2558dbbb223SNicolas Vasilache     LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
25670e99f38SAlex Zinenko     Value size;
25768f58812STres Popp     if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
2584bf84e43SAlexander Belyaev       size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
25970e99f38SAlex Zinenko     } else {
26053da8600STobias Gysi       FailureOr<int64_t> upperBound =
261eabb6ccdSMatthias Springer           ValueBoundsConstraintSet::computeConstantBound(
26240dd3aa9SMatthias Springer               presburger::BoundType::UB, rangeValue.size,
263eabb6ccdSMatthias Springer               /*stopCondition=*/nullptr, /*closedUB=*/true);
26470e99f38SAlex Zinenko       size = failed(upperBound)
26540dd3aa9SMatthias Springer                  ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
266cbb09813SFangrui Song                  : b.create<arith::ConstantIndexOp>(loc, *upperBound);
26770e99f38SAlex Zinenko     }
2688dbbb223SNicolas Vasilache     LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
2693cb1f35dSNicolas Vasilache     fullSizes.push_back(size);
2704519ca3dSNicolas Vasilache     partialSizes.push_back(
27199069ab2SChristopher Bate         b.createOrFold<memref::DimOp>(loc, subView, resultDimIdx++));
2725b03e692SNicolas Vasilache   }
273399638f9SAliia Khasanova   SmallVector<int64_t, 4> dynSizes(fullSizes.size(), ShapedType::kDynamic);
2740ed2d4c7SMaheshRavishankar   // If a callback is not specified, then use the default implementation for
2750ed2d4c7SMaheshRavishankar   // allocating the promoted buffer.
27622426110SRamkumar Ramachandra   std::optional<Value> fullLocalView =
27722426110SRamkumar Ramachandra       allocationFn(b, subView, fullSizes, layout);
2780ed2d4c7SMaheshRavishankar   if (!fullLocalView)
279489fec27SNicolas Vasilache     return failure();
28005d5125dSNicolas Vasilache   SmallVector<OpFoldResult, 4> zeros(fullSizes.size(), b.getIndexAttr(0));
28105d5125dSNicolas Vasilache   SmallVector<OpFoldResult, 4> ones(fullSizes.size(), b.getIndexAttr(1));
282ef33c6e3SNicolas Vasilache   auto partialLocalView = b.createOrFold<memref::SubViewOp>(
283ef33c6e3SNicolas Vasilache       loc, *fullLocalView, zeros, partialSizes, ones);
2840ed2d4c7SMaheshRavishankar   return PromotionInfo{*fullLocalView, partialLocalView};
2855b03e692SNicolas Vasilache }
2865b03e692SNicolas Vasilache 
287489fec27SNicolas Vasilache static FailureOr<MapVector<int64_t, PromotionInfo>>
2884519ca3dSNicolas Vasilache promoteSubViews(ImplicitLocOpBuilder &b,
289ef33c6e3SNicolas Vasilache                 LinalgOpInstancePromotionOptions options, DataLayout &layout) {
2908dbbb223SNicolas Vasilache   if (options.subViews.empty())
291489fec27SNicolas Vasilache     return failure();
2925b03e692SNicolas Vasilache 
2934519ca3dSNicolas Vasilache   MapVector<int64_t, PromotionInfo> promotionInfoMap;
2945b03e692SNicolas Vasilache 
2958dbbb223SNicolas Vasilache   for (auto v : options.subViews) {
296e2310704SJulian Gross     memref::SubViewOp subView =
297e2310704SJulian Gross         cast<memref::SubViewOp>(v.second.getDefiningOp());
298489fec27SNicolas Vasilache     auto promotionInfo = promoteSubviewAsNewBuffer(
2994519ca3dSNicolas Vasilache         b, b.getLoc(), subView, options.allocationFn, layout);
300489fec27SNicolas Vasilache     if (failed(promotionInfo))
301489fec27SNicolas Vasilache       return failure();
3020ed2d4c7SMaheshRavishankar     promotionInfoMap[v.first] = *promotionInfo;
3030ed2d4c7SMaheshRavishankar 
304d1866f89SPierre Oechsel     // Only fill the buffer if the full local view is used
3050ed2d4c7SMaheshRavishankar     if (!options.useFullTileBuffers[v.second])
306d1866f89SPierre Oechsel       continue;
3074519ca3dSNicolas Vasilache     Type subviewEltType = subView.getType().getElementType();
3084519ca3dSNicolas Vasilache     Value fillVal =
3094519ca3dSNicolas Vasilache         llvm::TypeSwitch<Type, Value>(subviewEltType)
3104519ca3dSNicolas Vasilache             .Case([&](FloatType t) {
311a54f4eaeSMogball               return b.create<arith::ConstantOp>(FloatAttr::get(t, 0.0));
3124519ca3dSNicolas Vasilache             })
3134519ca3dSNicolas Vasilache             .Case([&](IntegerType t) {
314a54f4eaeSMogball               return b.create<arith::ConstantOp>(IntegerAttr::get(t, 0));
3154519ca3dSNicolas Vasilache             })
3164519ca3dSNicolas Vasilache             .Case([&](ComplexType t) {
3174519ca3dSNicolas Vasilache               Value tmp;
3185550c821STres Popp               if (auto et = dyn_cast<FloatType>(t.getElementType()))
319a54f4eaeSMogball                 tmp = b.create<arith::ConstantOp>(FloatAttr::get(et, 0.0));
3205550c821STres Popp               else if (auto et = cast<IntegerType>(t.getElementType()))
321a54f4eaeSMogball                 tmp = b.create<arith::ConstantOp>(IntegerAttr::get(et, 0));
3224519ca3dSNicolas Vasilache               return b.create<complex::CreateOp>(t, tmp, tmp);
3234519ca3dSNicolas Vasilache             })
3244519ca3dSNicolas Vasilache             .Default([](auto) { return Value(); });
3254519ca3dSNicolas Vasilache     if (!fillVal)
326489fec27SNicolas Vasilache       return failure();
3277cef24eeSTobias Gysi     b.create<linalg::FillOp>(fillVal, promotionInfo->fullLocalView);
3285b03e692SNicolas Vasilache   }
3295b03e692SNicolas Vasilache 
3300ed2d4c7SMaheshRavishankar   // Copy data into the promoted buffers. Use callback if provided.
3318dbbb223SNicolas Vasilache   for (auto v : options.subViews) {
332b1d4265aSMehdi Amini     auto *info = promotionInfoMap.find(v.first);
3335b03e692SNicolas Vasilache     if (info == promotionInfoMap.end())
3345b03e692SNicolas Vasilache       continue;
33570604222SAviad Cohen     if (options.operandsNumbersToCopyIn.count(v.first) == 0)
33670604222SAviad Cohen       continue;
337e2310704SJulian Gross     if (failed(options.copyInFn(
338e2310704SJulian Gross             b, cast<memref::SubViewOp>(v.second.getDefiningOp()),
3390ed2d4c7SMaheshRavishankar             info->second.partialLocalView)))
340489fec27SNicolas Vasilache       return failure();
3415b03e692SNicolas Vasilache   }
3420ed2d4c7SMaheshRavishankar   return promotionInfoMap;
3435b03e692SNicolas Vasilache }
3445b03e692SNicolas Vasilache 
345489fec27SNicolas Vasilache static FailureOr<LinalgOp>
3464519ca3dSNicolas Vasilache promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
347ef33c6e3SNicolas Vasilache                 LinalgOpInstancePromotionOptions options, DataLayout &layout) {
3480a8e3dd4SMatthias Springer   assert(op.hasPureBufferSemantics() &&
3490a8e3dd4SMatthias Springer          "expected linalg op with buffer semantics");
350f52d7173SNicolas Vasilache 
3515b03e692SNicolas Vasilache   // 1. Promote the specified views and use them in the new op.
3524519ca3dSNicolas Vasilache   auto promotedBuffersAndViews = promoteSubViews(b, options, layout);
353489fec27SNicolas Vasilache   if (failed(promotedBuffersAndViews) ||
3540ed2d4c7SMaheshRavishankar       promotedBuffersAndViews->size() != options.subViews.size())
355489fec27SNicolas Vasilache     return failure();
3565b03e692SNicolas Vasilache 
3575b03e692SNicolas Vasilache   // 2. Append all other operands as they appear, this enforces that such
3585b03e692SNicolas Vasilache   // operands are not views. This is to support cases such as FillOp taking
3590ed2d4c7SMaheshRavishankar   // extra scalars etc.  Keep a reference to output buffers;
3600ed2d4c7SMaheshRavishankar   SmallVector<Value, 8> opViews;
361a7cccb9cSAlexander Belyaev   opViews.reserve(op->getNumOperands());
3620ed2d4c7SMaheshRavishankar   SmallVector<std::pair<Value, Value>, 8> writebackViews;
3630ed2d4c7SMaheshRavishankar   writebackViews.reserve(promotedBuffersAndViews->size());
364a7cccb9cSAlexander Belyaev   for (OpOperand &opOperand : op->getOpOperands()) {
365a7cccb9cSAlexander Belyaev     int64_t operandNumber = opOperand.getOperandNumber();
366e70d2c8eSTobias Gysi     if (options.subViews.count(operandNumber) != 0) {
367a7cccb9cSAlexander Belyaev       if (options.useFullTileBuffers[opOperand.get()])
3680ed2d4c7SMaheshRavishankar         opViews.push_back(
369e70d2c8eSTobias Gysi             (*promotedBuffersAndViews)[operandNumber].fullLocalView);
3700ed2d4c7SMaheshRavishankar       else
3710ed2d4c7SMaheshRavishankar         opViews.push_back(
372e70d2c8eSTobias Gysi             (*promotedBuffersAndViews)[operandNumber].partialLocalView);
373b4db15a9SAlexander Belyaev       if (operandNumber >= op.getNumDpsInputs())
3740ed2d4c7SMaheshRavishankar         writebackViews.emplace_back(std::make_pair(
375a7cccb9cSAlexander Belyaev             opOperand.get(),
376e70d2c8eSTobias Gysi             (*promotedBuffersAndViews)[operandNumber].partialLocalView));
3770ed2d4c7SMaheshRavishankar     } else {
378a7cccb9cSAlexander Belyaev       opViews.push_back(opOperand.get());
3790ed2d4c7SMaheshRavishankar     }
3800ed2d4c7SMaheshRavishankar   }
381c4a04059SChristian Sigg   op->setOperands(0, opViews.size(), opViews);
3825b03e692SNicolas Vasilache 
3838dbbb223SNicolas Vasilache   OpBuilder::InsertionGuard guard(b);
3848dbbb223SNicolas Vasilache   b.setInsertionPointAfter(op);
3855b03e692SNicolas Vasilache   // 3. Emit write-back for the promoted output views: copy the partial view.
3860ed2d4c7SMaheshRavishankar   for (auto viewAndPartialLocalView : writebackViews) {
3870ed2d4c7SMaheshRavishankar     if (failed(options.copyOutFn(b, viewAndPartialLocalView.second,
3880ed2d4c7SMaheshRavishankar                                  viewAndPartialLocalView.first)))
389489fec27SNicolas Vasilache       return failure();
3900ed2d4c7SMaheshRavishankar   }
3915b03e692SNicolas Vasilache 
3928dbbb223SNicolas Vasilache   // 4. Dealloc all local buffers.
3937d9518c8SNicolas Vasilache   for (const auto &pi : *promotedBuffersAndViews)
394e21adfa3SRiver Riddle     (void)options.deallocationFn(b, pi.second.fullLocalView);
3950ed2d4c7SMaheshRavishankar   return op;
3965b03e692SNicolas Vasilache }
3975b03e692SNicolas Vasilache 
3988dbbb223SNicolas Vasilache LogicalResult
3998dbbb223SNicolas Vasilache mlir::linalg::promoteSubviewsPrecondition(Operation *op,
4008dbbb223SNicolas Vasilache                                           LinalgPromotionOptions options) {
401e70d2c8eSTobias Gysi   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
402307cfdf5SNicolas Vasilache   // Transformation applies to buffers only.
4030a8e3dd4SMatthias Springer   if (!linalgOp || !linalgOp.hasPureBufferSemantics())
404307cfdf5SNicolas Vasilache     return failure();
4058dbbb223SNicolas Vasilache   // Check that at least one of the requested operands is indeed a subview.
406a7cccb9cSAlexander Belyaev   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
407e70d2c8eSTobias Gysi     auto sv =
408a7cccb9cSAlexander Belyaev         isa_and_nonnull<memref::SubViewOp>(opOperand.get().getDefiningOp());
4098dbbb223SNicolas Vasilache     if (sv) {
410037f0995SKazu Hirata       if (!options.operandsToPromote ||
411a7cccb9cSAlexander Belyaev           options.operandsToPromote->count(opOperand.getOperandNumber()))
412307cfdf5SNicolas Vasilache         return success();
413307cfdf5SNicolas Vasilache     }
4148dbbb223SNicolas Vasilache   }
4158dbbb223SNicolas Vasilache   // TODO: Check all subviews requested are bound by a static constant.
4168dbbb223SNicolas Vasilache   // TODO: Check that the total footprint fits within a given size.
417307cfdf5SNicolas Vasilache   return failure();
418307cfdf5SNicolas Vasilache }
419307cfdf5SNicolas Vasilache 
420489fec27SNicolas Vasilache FailureOr<LinalgOp>
4214519ca3dSNicolas Vasilache mlir::linalg::promoteSubViews(OpBuilder &builder, LinalgOp linalgOp,
4221fc096afSMehdi Amini                               const LinalgPromotionOptions &options) {
4238dbbb223SNicolas Vasilache   LinalgOpInstancePromotionOptions linalgOptions(linalgOp, options);
42442e5f422STres Popp   auto layout = DataLayout::closest(linalgOp);
4254519ca3dSNicolas Vasilache   ImplicitLocOpBuilder b(linalgOp.getLoc(), builder);
426489fec27SNicolas Vasilache   auto res = ::promoteSubViews(b, linalgOp, linalgOptions, layout);
427489fec27SNicolas Vasilache   if (failed(res))
428489fec27SNicolas Vasilache     return failure();
429489fec27SNicolas Vasilache   return res;
4308dbbb223SNicolas Vasilache }
431115711c1SAmir Mohammad Tavakkoli 
432115711c1SAmir Mohammad Tavakkoli /// Allocate the given subview to a memory address space in GPU by creating a
433115711c1SAmir Mohammad Tavakkoli /// allocation operation and setting the memref type address space to desired
434115711c1SAmir Mohammad Tavakkoli /// address space.
4352010269cSKazu Hirata static std::optional<Value> allocateSubviewGPUMemoryInAddressSpace(
436115711c1SAmir Mohammad Tavakkoli     OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds,
437115711c1SAmir Mohammad Tavakkoli     gpu::AddressSpace addressSpace) {
438115711c1SAmir Mohammad Tavakkoli   OpBuilder::InsertionGuard guard(builder);
439115711c1SAmir Mohammad Tavakkoli 
440115711c1SAmir Mohammad Tavakkoli   func::FuncOp funcOp = subview->getParentOfType<func::FuncOp>();
441115711c1SAmir Mohammad Tavakkoli   if (!funcOp)
442115711c1SAmir Mohammad Tavakkoli     return std::nullopt;
443115711c1SAmir Mohammad Tavakkoli 
444115711c1SAmir Mohammad Tavakkoli   // The subview size bounds are expected to be constant; they specify the shape
445115711c1SAmir Mohammad Tavakkoli   // of the allocation.
446115711c1SAmir Mohammad Tavakkoli   SmallVector<int64_t> shape;
447115711c1SAmir Mohammad Tavakkoli   for (Value bound : sizeBounds) {
448115711c1SAmir Mohammad Tavakkoli     APInt value;
449115711c1SAmir Mohammad Tavakkoli     if (!matchPattern(bound, m_ConstantInt(&value)))
450115711c1SAmir Mohammad Tavakkoli       return std::nullopt;
451115711c1SAmir Mohammad Tavakkoli     shape.push_back(value.getSExtValue());
452115711c1SAmir Mohammad Tavakkoli   }
453115711c1SAmir Mohammad Tavakkoli 
454*b613a540SMatthias Springer   builder.setInsertionPointToStart(&funcOp.front());
455115711c1SAmir Mohammad Tavakkoli   auto type = MemRefType::get(
456115711c1SAmir Mohammad Tavakkoli       shape, subview.getType().getElementType(), MemRefLayoutAttrInterface{},
457115711c1SAmir Mohammad Tavakkoli       gpu::AddressSpaceAttr::get(builder.getContext(), addressSpace));
458115711c1SAmir Mohammad Tavakkoli   Value buffer;
459115711c1SAmir Mohammad Tavakkoli   if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) {
460115711c1SAmir Mohammad Tavakkoli     buffer = builder.create<memref::AllocOp>(funcOp.getLoc(), type);
461115711c1SAmir Mohammad Tavakkoli   } else if (addressSpace == gpu::GPUDialect::getPrivateAddressSpace()) {
462115711c1SAmir Mohammad Tavakkoli     buffer = builder.create<memref::AllocaOp>(funcOp.getLoc(), type);
463115711c1SAmir Mohammad Tavakkoli   } else {
464115711c1SAmir Mohammad Tavakkoli     return std::nullopt;
465115711c1SAmir Mohammad Tavakkoli   }
466115711c1SAmir Mohammad Tavakkoli   return buffer;
467115711c1SAmir Mohammad Tavakkoli }
468115711c1SAmir Mohammad Tavakkoli 
469115711c1SAmir Mohammad Tavakkoli /// Allocate the subview in the GPU workgroup memory.
4702010269cSKazu Hirata std::optional<Value> mlir::linalg::allocateWorkgroupMemory(
471115711c1SAmir Mohammad Tavakkoli     OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds,
472115711c1SAmir Mohammad Tavakkoli     DataLayout &) {
473115711c1SAmir Mohammad Tavakkoli   return allocateSubviewGPUMemoryInAddressSpace(
474115711c1SAmir Mohammad Tavakkoli       builder, subview, sizeBounds,
475115711c1SAmir Mohammad Tavakkoli       gpu::GPUDialect::getWorkgroupAddressSpace());
476115711c1SAmir Mohammad Tavakkoli }
477115711c1SAmir Mohammad Tavakkoli 
478115711c1SAmir Mohammad Tavakkoli /// In case of GPU group memory there is no need to deallocate.
479115711c1SAmir Mohammad Tavakkoli LogicalResult mlir::linalg::deallocateWorkgroupMemory(OpBuilder &,
480115711c1SAmir Mohammad Tavakkoli                                                       Value /*buffer*/) {
481115711c1SAmir Mohammad Tavakkoli   return success();
482115711c1SAmir Mohammad Tavakkoli }
483115711c1SAmir Mohammad Tavakkoli 
484115711c1SAmir Mohammad Tavakkoli /// Create Memref copy operations and add gpu barrier guards before and after
485115711c1SAmir Mohammad Tavakkoli /// the copy operation to ensure data integrity.
486115711c1SAmir Mohammad Tavakkoli LogicalResult mlir::linalg::copyToWorkgroupMemory(OpBuilder &b, Value src,
487115711c1SAmir Mohammad Tavakkoli                                                   Value dst) {
488115711c1SAmir Mohammad Tavakkoli   b.create<gpu::BarrierOp>(src.getLoc());
489115711c1SAmir Mohammad Tavakkoli   Operation *copyOp = b.create<memref::CopyOp>(src.getLoc(), src, dst);
490115711c1SAmir Mohammad Tavakkoli   b.create<gpu::BarrierOp>(copyOp->getLoc());
491115711c1SAmir Mohammad Tavakkoli   return success();
492115711c1SAmir Mohammad Tavakkoli }
493115711c1SAmir Mohammad Tavakkoli 
494115711c1SAmir Mohammad Tavakkoli /// Allocate the subview in the GPU private memory.
4952010269cSKazu Hirata std::optional<Value> mlir::linalg::allocateGPUPrivateMemory(
496115711c1SAmir Mohammad Tavakkoli     OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds,
497115711c1SAmir Mohammad Tavakkoli     DataLayout &) {
498115711c1SAmir Mohammad Tavakkoli   return allocateSubviewGPUMemoryInAddressSpace(
499115711c1SAmir Mohammad Tavakkoli       builder, subview, sizeBounds, gpu::GPUDialect::getPrivateAddressSpace());
500115711c1SAmir Mohammad Tavakkoli }
501115711c1SAmir Mohammad Tavakkoli 
502115711c1SAmir Mohammad Tavakkoli /// Normal copy to between src and dst.
503115711c1SAmir Mohammad Tavakkoli LogicalResult mlir::linalg::copyToGPUPrivateMemory(OpBuilder &b, Value src,
504115711c1SAmir Mohammad Tavakkoli                                                    Value dst) {
505779d54fdSHaojian Wu   b.create<memref::CopyOp>(src.getLoc(), src, dst);
506115711c1SAmir Mohammad Tavakkoli   return success();
507115711c1SAmir Mohammad Tavakkoli }
508115711c1SAmir Mohammad Tavakkoli 
509115711c1SAmir Mohammad Tavakkoli /// In case of GPU private memory there is no need to deallocate since the
510115711c1SAmir Mohammad Tavakkoli /// memory is freed when going outside of the scope.
511115711c1SAmir Mohammad Tavakkoli LogicalResult mlir::linalg::deallocateGPUPrivateMemory(OpBuilder &,
512115711c1SAmir Mohammad Tavakkoli                                                        Value /*buffer*/) {
513115711c1SAmir Mohammad Tavakkoli   return success();
514115711c1SAmir Mohammad Tavakkoli }
515