102c66207SMatthias Springer //===- Padding.cpp - Padding of Linalg ops --------------------------------===// 202c66207SMatthias Springer // 302c66207SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 402c66207SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 502c66207SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 602c66207SMatthias Springer // 702c66207SMatthias Springer //===----------------------------------------------------------------------===// 802c66207SMatthias Springer 902c66207SMatthias Springer #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 1002c66207SMatthias Springer 11431c49d6SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12a5cee3e3SRobert Suderman #include "mlir/Dialect/Complex/IR/Complex.h" 1302c66207SMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h" 1402c66207SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 1502c66207SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h" 1602c66207SMatthias Springer 1702c66207SMatthias Springer #define DEBUG_TYPE "linalg-padding" 1802c66207SMatthias Springer 1902c66207SMatthias Springer using namespace mlir; 2002c66207SMatthias Springer using namespace mlir::linalg; 2102c66207SMatthias Springer 2202c66207SMatthias Springer #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 2302c66207SMatthias Springer #define DBGSNL() (llvm::dbgs() << "\n") 2402c66207SMatthias Springer 25369c5d6aSMatthias Springer /// Compute the padded shape of the given operand. The operand is padded to a 267c4c2746SJaved Absar /// static bounding box according to the specified padding options. 27369c5d6aSMatthias Springer static LogicalResult computePaddedShape(linalg::LinalgOp opToPad, 28369c5d6aSMatthias Springer OpOperand *opOperand, 29369c5d6aSMatthias Springer const LinalgPaddingOptions &options, 30369c5d6aSMatthias Springer SmallVector<int64_t> &paddedShape, 31369c5d6aSMatthias Springer bool &alreadyHasRequestedShape) { 3202c66207SMatthias Springer AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand); 3302c66207SMatthias Springer ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 3402c66207SMatthias Springer 35369c5d6aSMatthias Springer // Collect the shape dimensions that are a function of "paddingDimensions", 3602c66207SMatthias Springer // along with the multiple that they should be padded to ("1" if none). 37369c5d6aSMatthias Springer alreadyHasRequestedShape = true; 3802c66207SMatthias Springer DenseMap<int64_t, int64_t> shapeDimToMultiple; 39369c5d6aSMatthias Springer for (const auto &dimEn : enumerate(options.paddingDimensions)) { 4002c66207SMatthias Springer for (const auto &en : enumerate(indexingMap.getResults())) { 4102c66207SMatthias Springer if (en.value().isFunctionOfDim(dimEn.value())) { 4202c66207SMatthias Springer int64_t dimSize = shape[en.index()]; 43369c5d6aSMatthias Springer if (options.padToMultipleOf.has_value()) { 44369c5d6aSMatthias Springer shapeDimToMultiple[en.index()] = 45369c5d6aSMatthias Springer (*options.padToMultipleOf)[dimEn.index()]; 46369c5d6aSMatthias Springer } else { 47369c5d6aSMatthias Springer shapeDimToMultiple[en.index()] = 1; 48369c5d6aSMatthias Springer } 4902c66207SMatthias Springer if (ShapedType::isDynamic(dimSize)) { 5002c66207SMatthias Springer alreadyHasRequestedShape = false; 5102c66207SMatthias Springer } else if (dimSize % shapeDimToMultiple[en.index()] != 0) { 5202c66207SMatthias Springer alreadyHasRequestedShape = false; 5302c66207SMatthias Springer } 5402c66207SMatthias Springer } 5502c66207SMatthias Springer } 5602c66207SMatthias Springer } 5702c66207SMatthias Springer 5802c66207SMatthias Springer // Helper function to round a number up to a given multiple. 5902c66207SMatthias Springer auto ceil = [](int64_t val, int64_t multiple) { 6002c66207SMatthias Springer return ((val + multiple - 1) / multiple) * multiple; 6102c66207SMatthias Springer }; 6202c66207SMatthias Springer 6302c66207SMatthias Springer // Upper bound the sizes to obtain a static bounding box. 64369c5d6aSMatthias Springer paddedShape.assign(shape.begin(), shape.end()); 6502c66207SMatthias Springer for (int64_t i = 0, e = shape.size(); i < e; ++i) { 6602c66207SMatthias Springer LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n"); 6702c66207SMatthias Springer // Skip dimensions that do not require padding. 6802c66207SMatthias Springer if (!shapeDimToMultiple.contains(i)) { 6902c66207SMatthias Springer LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n"); 7002c66207SMatthias Springer continue; 7102c66207SMatthias Springer } 7202c66207SMatthias Springer // Otherwise, try to compute a constant upper bound for the size value. 7302c66207SMatthias Springer FailureOr<int64_t> upperBound = 7402c66207SMatthias Springer ValueBoundsConstraintSet::computeConstantBound( 7540dd3aa9SMatthias Springer presburger::BoundType::UB, 7640dd3aa9SMatthias Springer {opOperand->get(), 7740dd3aa9SMatthias Springer /*dim=*/i}, 7840dd3aa9SMatthias Springer /*stopCondition=*/nullptr, /*closedUB=*/true); 7902c66207SMatthias Springer if (failed(upperBound)) { 807c4c2746SJaved Absar LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding"); 81369c5d6aSMatthias Springer return failure(); 8202c66207SMatthias Springer } 8302c66207SMatthias Springer paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]); 8402c66207SMatthias Springer LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n"); 8502c66207SMatthias Springer } 8602c66207SMatthias Springer 87369c5d6aSMatthias Springer return success(); 88369c5d6aSMatthias Springer } 89369c5d6aSMatthias Springer 90369c5d6aSMatthias Springer /// Pad the `opOperand` in the "paddingDimensions" using the padding value and 91*a758bcdbSAndrzej Warzyński /// the nofold flag found in "paddingValues" and "nofoldFlags", respectively. 92369c5d6aSMatthias Springer /// 93369c5d6aSMatthias Springer /// Exit early and return the `opOperand` value if it already has the requested 947c4c2746SJaved Absar /// shape. i.e.: 95369c5d6aSMatthias Springer /// - static shape 96369c5d6aSMatthias Springer /// - nofold is not set 97369c5d6aSMatthias Springer /// - dim sizes are multiples of "padToMultipleOf" 98369c5d6aSMatthias Springer /// 99369c5d6aSMatthias Springer /// Otherwise, try to pad the shape dimensions that match the iterator 100369c5d6aSMatthias Springer /// dimensions "paddingDimensions" and return the tensor::PadOp result if 101369c5d6aSMatthias Springer /// padding succeeds or failure otherwise. 102369c5d6aSMatthias Springer static FailureOr<Value> padOperandToSmallestStaticBoundingBox( 103369c5d6aSMatthias Springer RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, 104369c5d6aSMatthias Springer const LinalgPaddingOptions &options) { 1052049b2adSLorenzo Chelini assert( 1062049b2adSLorenzo Chelini (!options.padToMultipleOf.has_value() || 1072049b2adSLorenzo Chelini options.padToMultipleOf->size() == options.paddingDimensions.size()) && 108369c5d6aSMatthias Springer "invalid number of elements in padToMultipleOf"); 109369c5d6aSMatthias Springer 110369c5d6aSMatthias Springer // Compute padded shape. 111369c5d6aSMatthias Springer SmallVector<int64_t> paddedShape; 112369c5d6aSMatthias Springer bool alreadyHasRequestedShape = false; 113369c5d6aSMatthias Springer if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape, 114369c5d6aSMatthias Springer alreadyHasRequestedShape))) 115369c5d6aSMatthias Springer return rewriter.notifyMatchFailure(opToPad, 116369c5d6aSMatthias Springer "--failed to compute padded shape"); 117369c5d6aSMatthias Springer 118369c5d6aSMatthias Springer // Return the unpadded operand if padding to a static shape is not needed and 119369c5d6aSMatthias Springer // if the nofold flag is not set. 120*a758bcdbSAndrzej Warzyński bool nofold = opOperand->getOperandNumber() < options.nofoldFlags.size() 121*a758bcdbSAndrzej Warzyński ? bool(options.nofoldFlags[opOperand->getOperandNumber()]) 122369c5d6aSMatthias Springer : false; 123369c5d6aSMatthias Springer if (!nofold && alreadyHasRequestedShape) 124369c5d6aSMatthias Springer return opOperand->get(); 125369c5d6aSMatthias Springer 126369c5d6aSMatthias Springer // Fail if `paddingValues` specifies no padding value. 127369c5d6aSMatthias Springer if (opOperand->getOperandNumber() >= options.paddingValues.size()) { 128369c5d6aSMatthias Springer return rewriter.notifyMatchFailure(opToPad, "--no padding value specified"); 129369c5d6aSMatthias Springer } 130369c5d6aSMatthias Springer Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()]; 131a5cee3e3SRobert Suderman 132a5cee3e3SRobert Suderman Value paddingValue; 133a5cee3e3SRobert Suderman if (auto complexTy = dyn_cast<ComplexType>( 134a5cee3e3SRobert Suderman getElementTypeOrSelf(opOperand->get().getType()))) { 135a5cee3e3SRobert Suderman auto complexAttr = cast<ArrayAttr>(paddingAttr); 136a5cee3e3SRobert Suderman paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(), 137a5cee3e3SRobert Suderman complexTy, complexAttr); 138a5cee3e3SRobert Suderman } else { 139a5cee3e3SRobert Suderman paddingValue = rewriter.create<arith::ConstantOp>( 140369c5d6aSMatthias Springer opToPad.getLoc(), cast<TypedAttr>(paddingAttr)); 141a5cee3e3SRobert Suderman } 142369c5d6aSMatthias Springer 14302c66207SMatthias Springer // Pad the operand to the bounding box defined by `paddedShape`. 14402c66207SMatthias Springer auto paddedTensorType = RankedTensorType::get( 14502c66207SMatthias Springer paddedShape, getElementTypeOrSelf(opOperand->get())); 14602c66207SMatthias Springer LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: " 14702c66207SMatthias Springer << paddedTensorType); 14802c66207SMatthias Springer return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType, 14902c66207SMatthias Springer opOperand->get(), paddingValue, nofold); 15002c66207SMatthias Springer } 15102c66207SMatthias Springer 1520e06ec59SMatthias Springer LogicalResult 15302c66207SMatthias Springer linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, 1541e84e91eSNicolas Vasilache const LinalgPaddingOptions &constOptions, 1550e06ec59SMatthias Springer LinalgOp &paddedOp, SmallVector<Value> &replacements, 156977cb4fdSMatthias Springer SmallVector<tensor::PadOp> &padOps) { 15702c66207SMatthias Springer LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); 15802c66207SMatthias Springer Location loc = opToPad->getLoc(); 15902c66207SMatthias Springer 1601e84e91eSNicolas Vasilache LinalgPaddingOptions options(constOptions); 1611e84e91eSNicolas Vasilache // Allow inference of pad values if they are not explicitly specified. 1621e84e91eSNicolas Vasilache // TODO: be mindful about the value depending on the actual operation. 1631e84e91eSNicolas Vasilache if (options.paddingValues.empty()) { 1641e84e91eSNicolas Vasilache SmallVector<Type> types(opToPad->getOperandTypes()); 1651e84e91eSNicolas Vasilache llvm::append_range(types, opToPad->getResultTypes()); 1661e84e91eSNicolas Vasilache for (Type t : types) { 1671e84e91eSNicolas Vasilache options.paddingValues.push_back( 1681e84e91eSNicolas Vasilache rewriter.getZeroAttr(getElementTypeOrSelf(t))); 1691e84e91eSNicolas Vasilache } 1701e84e91eSNicolas Vasilache } 1711e84e91eSNicolas Vasilache 17202c66207SMatthias Springer // TODO: there are cases where we may still want to pad to larger sizes. 1730a8e3dd4SMatthias Springer if (!opToPad.hasPureTensorSemantics()) 17402c66207SMatthias Springer return rewriter.notifyMatchFailure(opToPad, 17502c66207SMatthias Springer "expected operation on tensors"); 17602c66207SMatthias Springer 17702c66207SMatthias Springer OpBuilder::InsertionGuard g(rewriter); 17802c66207SMatthias Springer // Set IP after op because we also take the dims of the original output. 17902c66207SMatthias Springer rewriter.setInsertionPointAfter(opToPad); 18002c66207SMatthias Springer 18102c66207SMatthias Springer // Make a copy of the shaped operands and update it. 18202c66207SMatthias Springer SmallVector<Value> newOperands; 18302c66207SMatthias Springer newOperands.reserve(opToPad->getNumOperands()); 18402c66207SMatthias Springer for (OpOperand &opOperand : opToPad->getOpOperands()) { 18502c66207SMatthias Springer FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox( 186369c5d6aSMatthias Springer rewriter, opToPad, &opOperand, options); 18702c66207SMatthias Springer // Exit if `paddingDimensions` cannot be bounded statically. 18802c66207SMatthias Springer if (failed(paddedOperand)) { 18902c66207SMatthias Springer LLVM_DEBUG(DBGS() << "--operand cannot be bound statically : " 19002c66207SMatthias Springer << opOperand.get() << " -> FAIL\n"); 19102c66207SMatthias Springer return rewriter.notifyMatchFailure(opToPad, 19202c66207SMatthias Springer "operand cannot be bound statically"); 19302c66207SMatthias Springer } 19402c66207SMatthias Springer newOperands.push_back(*paddedOperand); 1950e06ec59SMatthias Springer if (auto padOp = paddedOperand->getDefiningOp<tensor::PadOp>()) 1960e06ec59SMatthias Springer padOps.push_back(padOp); 19702c66207SMatthias Springer } 19802c66207SMatthias Springer 19902c66207SMatthias Springer ReifiedRankedShapedTypeDims reifiedResultShapes; 20002c66207SMatthias Springer if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) { 20102c66207SMatthias Springer LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n"); 20202c66207SMatthias Springer return rewriter.notifyMatchFailure(opToPad, 20302c66207SMatthias Springer "failed to reify result shapes"); 20402c66207SMatthias Springer } 20502c66207SMatthias Springer assert(reifiedResultShapes.size() == opToPad->getNumResults() && 20602c66207SMatthias Springer "expected same number of results"); 20702c66207SMatthias Springer 20802c66207SMatthias Springer // Clone `opToPad` to operate on the statically padded shapes. 20902c66207SMatthias Springer auto resultTensorTypes = 21002c66207SMatthias Springer ValueRange(newOperands).take_back(opToPad.getNumDpsInits()).getTypes(); 21102c66207SMatthias Springer // clone **should** properly notify the rewriter. 21202c66207SMatthias Springer paddedOp = clone(rewriter, opToPad, resultTensorTypes, newOperands); 21302c66207SMatthias Springer LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n"); 21402c66207SMatthias Springer 21502c66207SMatthias Springer // Recover the slice out of the new static results. This keeps the original 21602c66207SMatthias Springer // linalg op around because it uses the dims of the original results. 21702c66207SMatthias Springer SmallVector<Value> paddedSubtensorResults; 21802c66207SMatthias Springer paddedSubtensorResults.reserve(opToPad->getNumResults()); 21902c66207SMatthias Springer for (const auto &en : llvm::enumerate(paddedOp->getResults())) { 22002c66207SMatthias Springer Value paddedResult = en.value(); 22102c66207SMatthias Springer int64_t resultNumber = en.index(); 22202c66207SMatthias Springer int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank(); 22302c66207SMatthias Springer SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 22402c66207SMatthias Springer SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 22502c66207SMatthias Springer paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>( 22602c66207SMatthias Springer loc, paddedResult, offsets, reifiedResultShapes[resultNumber], 22702c66207SMatthias Springer strides)); 22802c66207SMatthias Springer } 229431c49d6SMatthias Springer 230977cb4fdSMatthias Springer if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None) { 2310e06ec59SMatthias Springer replacements = std::move(paddedSubtensorResults); 2320e06ec59SMatthias Springer return success(); 2330e06ec59SMatthias Springer } 234431c49d6SMatthias Springer 235431c49d6SMatthias Springer // Copy back unpadded results to the original destination (i.e., inits of the 236431c49d6SMatthias Springer // linalg op), so that the destination buffer of the computation does not 2377c4c2746SJaved Absar // change. If the padding folds away, this will materialize as a memcpy 238431c49d6SMatthias Springer // between two identical buffers, which will then also fold away. 2390e06ec59SMatthias Springer assert(static_cast<int64_t>(paddedSubtensorResults.size()) == 2400e06ec59SMatthias Springer opToPad.getNumDpsInits() && 2410e06ec59SMatthias Springer "expected matching number of results"); 242431c49d6SMatthias Springer for (auto it : 2430b2197b0SMatthias Springer llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) { 244977cb4fdSMatthias Springer if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) { 245977cb4fdSMatthias Springer replacements.push_back(rewriter 246977cb4fdSMatthias Springer .create<linalg::CopyOp>(loc, std::get<0>(it), 2470b2197b0SMatthias Springer std::get<1>(it).get()) 248977cb4fdSMatthias Springer .getResult(0)); 249977cb4fdSMatthias Springer } else if (options.copyBackOp == 25091464e1dSMatthias Springer LinalgPaddingOptions::CopyBackOp:: 25191464e1dSMatthias Springer BufferizationMaterializeInDestination) { 25291464e1dSMatthias Springer replacements.push_back( 2530fcaca2fSMatthias Springer rewriter 2540fcaca2fSMatthias Springer .create<bufferization::MaterializeInDestinationOp>( 2550fcaca2fSMatthias Springer loc, std::get<0>(it), std::get<1>(it).get()) 2560fcaca2fSMatthias Springer ->getResult(0)); 257977cb4fdSMatthias Springer } else { 258977cb4fdSMatthias Springer llvm_unreachable("unsupported copy back op"); 259977cb4fdSMatthias Springer } 260431c49d6SMatthias Springer } 2610e06ec59SMatthias Springer return success(); 26202c66207SMatthias Springer } 26302c66207SMatthias Springer 26402c66207SMatthias Springer FailureOr<LinalgOp> 26502c66207SMatthias Springer mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp, 266369c5d6aSMatthias Springer const LinalgPaddingOptions &options) { 267977cb4fdSMatthias Springer assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None && 268977cb4fdSMatthias Springer "invalid options"); 269977cb4fdSMatthias Springer 2700a8e3dd4SMatthias Springer if (!linalgOp.hasPureTensorSemantics()) 27102c66207SMatthias Springer return rewriter.notifyMatchFailure( 27202c66207SMatthias Springer linalgOp, "only applies to Linalg ops with tensor semantics"); 27302c66207SMatthias Springer 27402c66207SMatthias Springer // Pad the operation. 27502c66207SMatthias Springer LinalgOp paddedOp; 2760e06ec59SMatthias Springer SmallVector<Value> newResults; 2770e06ec59SMatthias Springer SmallVector<tensor::PadOp> padOps; 2780e06ec59SMatthias Springer if (failed(rewriteAsPaddedOp(rewriter, linalgOp, options, paddedOp, 279977cb4fdSMatthias Springer newResults, padOps))) 28002c66207SMatthias Springer return rewriter.notifyMatchFailure(linalgOp, 28102c66207SMatthias Springer "failed to rewrite as a padded op"); 28202c66207SMatthias Springer 28302c66207SMatthias Springer // Hoist the padding. 28402c66207SMatthias Springer for (const auto &en : enumerate(options.hoistPaddings)) { 28502c66207SMatthias Springer if (static_cast<int64_t>(en.index()) >= paddedOp->getNumOperands()) 28602c66207SMatthias Springer break; 28702c66207SMatthias Springer OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 28802c66207SMatthias Springer auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>(); 28902c66207SMatthias Springer if (!padOp || en.value() == 0) { 29002c66207SMatthias Springer (void)rewriter.notifyMatchFailure(linalgOp, "not a tensor.pad -- skip"); 29102c66207SMatthias Springer continue; 29202c66207SMatthias Springer } 29302c66207SMatthias Springer 29402c66207SMatthias Springer // Fail hoisting if the operand shape is not fully static. 29502c66207SMatthias Springer if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) { 29602c66207SMatthias Springer (void)rewriter.notifyMatchFailure(linalgOp, 29702c66207SMatthias Springer "non static padding shape -- skip"); 29802c66207SMatthias Springer continue; 29902c66207SMatthias Springer } 30002c66207SMatthias Springer 30102c66207SMatthias Springer tensor::PadOp hoistedOp; 30228039055SHugo Trachino SmallVector<TransposeOp> transposeOps; 30302c66207SMatthias Springer SmallVector<int64_t> transposeVector = 30402c66207SMatthias Springer en.index() < options.transposePaddings.size() 30502c66207SMatthias Springer ? options.transposePaddings[en.index()] 30602c66207SMatthias Springer : SmallVector<int64_t>{}; 30702c66207SMatthias Springer 30802c66207SMatthias Springer FailureOr<Value> newResult = hoistPaddingOnTensors( 30902c66207SMatthias Springer padOp, en.value(), transposeVector, hoistedOp, transposeOps); 31002c66207SMatthias Springer if (failed(newResult)) { 31102c66207SMatthias Springer (void)rewriter.notifyMatchFailure(linalgOp, 31202c66207SMatthias Springer "failed to apply hoistPadding"); 31302c66207SMatthias Springer continue; 31402c66207SMatthias Springer } 31502c66207SMatthias Springer rewriter.replaceOp(padOp, *newResult); 31602c66207SMatthias Springer } 31702c66207SMatthias Springer 31802c66207SMatthias Springer // Replace the original operation to pad. 3190e06ec59SMatthias Springer rewriter.replaceOp(linalgOp, newResults); 32002c66207SMatthias Springer 32102c66207SMatthias Springer return paddedOp; 32202c66207SMatthias Springer } 323