12291705dSMahesh Ravishankar //===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===//
22291705dSMahesh Ravishankar //
32291705dSMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42291705dSMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
52291705dSMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62291705dSMahesh Ravishankar //
72291705dSMahesh Ravishankar //===----------------------------------------------------------------------===//
82291705dSMahesh Ravishankar //
92291705dSMahesh Ravishankar // This file implements constant folding on Linalg operations.
102291705dSMahesh Ravishankar //
112291705dSMahesh Ravishankar //===----------------------------------------------------------------------===//
122291705dSMahesh Ravishankar
132291705dSMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
142291705dSMahesh Ravishankar #include "mlir/Dialect/Linalg/IR/Linalg.h"
152291705dSMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
162291705dSMahesh Ravishankar #include "mlir/IR/Matchers.h"
172291705dSMahesh Ravishankar #include "mlir/IR/PatternMatch.h"
182291705dSMahesh Ravishankar #include "mlir/Support/LLVM.h"
192291705dSMahesh Ravishankar #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20a1fe1f5fSKazu Hirata #include <optional>
212291705dSMahesh Ravishankar
222291705dSMahesh Ravishankar using namespace mlir;
232291705dSMahesh Ravishankar using namespace mlir::linalg;
242291705dSMahesh Ravishankar
252291705dSMahesh Ravishankar namespace {
26*74ed79f7SRyan Holt /// Base class for constant folding linalg structured ops with N inputs, 1
27*74ed79f7SRyan Holt /// output, and permutation indexing maps.
282291705dSMahesh Ravishankar ///
292291705dSMahesh Ravishankar /// `ConcreteType` should provide methods with signatures
302291705dSMahesh Ravishankar ///
312291705dSMahesh Ravishankar /// ```c++
32*74ed79f7SRyan Holt /// bool matchIndexingMaps(LinalgOp linalgOp) const;
33*74ed79f7SRyan Holt /// RegionComputationFn getRegionComputeFn(LinalgOp) const;
342291705dSMahesh Ravishankar /// ```
352291705dSMahesh Ravishankar ///
362291705dSMahesh Ravishankar /// The latter inspects the region and returns the computation inside as a
372291705dSMahesh Ravishankar /// functor. The functor will be invoked with constant elements for all inputs
382291705dSMahesh Ravishankar /// and should return the corresponding computed constant element for output.
392291705dSMahesh Ravishankar template <typename ConcreteType>
40*74ed79f7SRyan Holt class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
412291705dSMahesh Ravishankar public:
422291705dSMahesh Ravishankar struct APIntOrFloat {
430a81ace0SKazu Hirata std::optional<APInt> apInt;
440a81ace0SKazu Hirata std::optional<APFloat> apFloat;
452291705dSMahesh Ravishankar };
462291705dSMahesh Ravishankar struct APIntOrFloatArray {
472291705dSMahesh Ravishankar SmallVector<APInt> apInts;
482291705dSMahesh Ravishankar SmallVector<APFloat> apFloats;
492291705dSMahesh Ravishankar };
502291705dSMahesh Ravishankar using RegionComputationFn =
512291705dSMahesh Ravishankar std::function<APIntOrFloat(const APIntOrFloatArray &)>;
522291705dSMahesh Ravishankar
FoldConstantBase(MLIRContext * context,const ControlFusionFn & controlFn,PatternBenefit benefit=1)532291705dSMahesh Ravishankar FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
542291705dSMahesh Ravishankar PatternBenefit benefit = 1)
55*74ed79f7SRyan Holt : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
56*74ed79f7SRyan Holt controlFn(controlFn) {}
572291705dSMahesh Ravishankar
matchAndRewrite(LinalgOp linalgOp,PatternRewriter & rewriter) const58*74ed79f7SRyan Holt LogicalResult matchAndRewrite(LinalgOp linalgOp,
592291705dSMahesh Ravishankar PatternRewriter &rewriter) const override {
60e3f75c1cSIvan Butygin // Mixed and buffer sematics aren't supported.
61*74ed79f7SRyan Holt if (!linalgOp.hasPureTensorSemantics())
622291705dSMahesh Ravishankar return failure();
632291705dSMahesh Ravishankar
642291705dSMahesh Ravishankar // Only support ops generating one output for now.
65*74ed79f7SRyan Holt if (linalgOp.getNumDpsInits() != 1)
662291705dSMahesh Ravishankar return failure();
672291705dSMahesh Ravishankar
68*74ed79f7SRyan Holt auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
692291705dSMahesh Ravishankar // Require the output types to be static given that we are generating
702291705dSMahesh Ravishankar // constants.
712291705dSMahesh Ravishankar if (!outputType || !outputType.hasStaticShape())
722291705dSMahesh Ravishankar return failure();
732291705dSMahesh Ravishankar
74*74ed79f7SRyan Holt if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
755550c821STres Popp return isa<ShapedType>(input.getType());
762291705dSMahesh Ravishankar }))
772291705dSMahesh Ravishankar return failure();
782291705dSMahesh Ravishankar
792291705dSMahesh Ravishankar // Make sure all element types are the same.
80a7cccb9cSAlexander Belyaev auto getOperandElementType = [](Value value) {
815550c821STres Popp return cast<ShapedType>(value.getType()).getElementType();
822291705dSMahesh Ravishankar };
83a7cccb9cSAlexander Belyaev if (!llvm::all_equal(
84*74ed79f7SRyan Holt llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
852291705dSMahesh Ravishankar return failure();
862291705dSMahesh Ravishankar
872291705dSMahesh Ravishankar // We can only handle the case where we have int/float elements.
882291705dSMahesh Ravishankar auto elementType = outputType.getElementType();
892291705dSMahesh Ravishankar if (!elementType.isIntOrFloat())
902291705dSMahesh Ravishankar return failure();
912291705dSMahesh Ravishankar
922291705dSMahesh Ravishankar // Require all indexing maps to be permutations for now. This is common and
932291705dSMahesh Ravishankar // it simplifies input/output access greatly: we can do the data shuffling
942291705dSMahesh Ravishankar // entirely in the compiler, without needing to turn all indices into
952291705dSMahesh Ravishankar // Values, and then do affine apply on them, and then match back the
962291705dSMahesh Ravishankar // constant again.
97*74ed79f7SRyan Holt if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
982291705dSMahesh Ravishankar [](AffineMap map) { return map.isPermutation(); }))
992291705dSMahesh Ravishankar return failure();
1002291705dSMahesh Ravishankar
101*74ed79f7SRyan Holt for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
102*74ed79f7SRyan Holt if (linalgOp.payloadUsesValueFromOperand(&operand))
1032291705dSMahesh Ravishankar return failure();
1042291705dSMahesh Ravishankar }
1052291705dSMahesh Ravishankar
1062291705dSMahesh Ravishankar // Further check the indexing maps are okay for the ConcreteType.
107*74ed79f7SRyan Holt if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
1082291705dSMahesh Ravishankar return failure();
1092291705dSMahesh Ravishankar
1102291705dSMahesh Ravishankar // Defer to the concrete type to check the region and discover the
1112291705dSMahesh Ravishankar // computation inside.
1122291705dSMahesh Ravishankar RegionComputationFn computeFn =
113*74ed79f7SRyan Holt static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
1142291705dSMahesh Ravishankar if (!computeFn)
1152291705dSMahesh Ravishankar return failure();
1162291705dSMahesh Ravishankar
1172291705dSMahesh Ravishankar // All inputs should be constants.
118*74ed79f7SRyan Holt int numInputs = linalgOp.getNumDpsInputs();
1192291705dSMahesh Ravishankar SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
120*74ed79f7SRyan Holt for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
121a7cccb9cSAlexander Belyaev if (!matchPattern(en.value()->get(),
122a7cccb9cSAlexander Belyaev m_Constant(&inputValues[en.index()])))
1232291705dSMahesh Ravishankar return failure();
1242291705dSMahesh Ravishankar }
1252291705dSMahesh Ravishankar
1262291705dSMahesh Ravishankar // Identified this as a potential candidate for folding. Now check the
1272291705dSMahesh Ravishankar // policy to see whether we are allowed to proceed.
128*74ed79f7SRyan Holt for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
129a7bfdc23SMahesh Ravishankar if (!controlFn(operand))
1302291705dSMahesh Ravishankar return failure();
1312291705dSMahesh Ravishankar }
1322291705dSMahesh Ravishankar
1332291705dSMahesh Ravishankar SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
1342291705dSMahesh Ravishankar int64_t numElements = outputType.getNumElements();
1352291705dSMahesh Ravishankar
1362291705dSMahesh Ravishankar // Use APInt/APFloat instead of Attribute here for constructing the output.
1372291705dSMahesh Ravishankar // This helps to avoid blowing up compiler memory usage: Attributes would
1382291705dSMahesh Ravishankar // unify the following cases but they have lifetime as the MLIRContext.
1392291705dSMahesh Ravishankar SmallVector<APInt> intOutputValues;
1402291705dSMahesh Ravishankar SmallVector<APFloat> fpOutputValues;
1415550c821STres Popp if (isa<FloatType>(elementType))
1422291705dSMahesh Ravishankar fpOutputValues.resize(numElements, APFloat(0.f));
1432291705dSMahesh Ravishankar else
1442291705dSMahesh Ravishankar intOutputValues.resize(numElements);
1452291705dSMahesh Ravishankar
1462291705dSMahesh Ravishankar // Return the constant dim positions from the given permutation map.
1472291705dSMahesh Ravishankar auto getDimPositions = [](AffineMap map) {
1482291705dSMahesh Ravishankar SmallVector<unsigned> dims;
1492291705dSMahesh Ravishankar dims.reserve(map.getNumResults());
1502291705dSMahesh Ravishankar for (AffineExpr result : map.getResults()) {
1511609f1c2Slong.chen dims.push_back(cast<AffineDimExpr>(result).getPosition());
1522291705dSMahesh Ravishankar }
1532291705dSMahesh Ravishankar return dims;
1542291705dSMahesh Ravishankar };
1552291705dSMahesh Ravishankar
1562291705dSMahesh Ravishankar SmallVector<SmallVector<unsigned>> inputDims;
1572291705dSMahesh Ravishankar for (int i = 0; i < numInputs; ++i)
158*74ed79f7SRyan Holt inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
159*74ed79f7SRyan Holt auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
1602291705dSMahesh Ravishankar auto outputShape = outputType.getShape();
1612291705dSMahesh Ravishankar
1622291705dSMahesh Ravishankar // Allocate small vectors for index delinearization. Initial values do not
1632291705dSMahesh Ravishankar // matter here as they will be overwritten later.
1642291705dSMahesh Ravishankar SmallVector<uint64_t> indices(loopBounds.size(), 0);
1652291705dSMahesh Ravishankar SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
1662291705dSMahesh Ravishankar SmallVector<SmallVector<uint64_t>> srcIndices(
1672291705dSMahesh Ravishankar numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
1682291705dSMahesh Ravishankar SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
1692291705dSMahesh Ravishankar uint64_t dstLinearIndex = 0;
1702291705dSMahesh Ravishankar
1712291705dSMahesh Ravishankar // Allocate spaces for compute function inputs. Initial values do not matter
1722291705dSMahesh Ravishankar // here as they will be overwritten later.
1732291705dSMahesh Ravishankar APIntOrFloatArray computeFnInputs;
1742291705dSMahesh Ravishankar
1752291705dSMahesh Ravishankar auto inputShapes = llvm::to_vector<4>(
176*74ed79f7SRyan Holt llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
1775550c821STres Popp return cast<ShapedType>(value.getType()).getShape();
1782291705dSMahesh Ravishankar }));
1792291705dSMahesh Ravishankar
1802291705dSMahesh Ravishankar // Given a `linearIndex`, remap it to a linear index to access linalg op
1812291705dSMahesh Ravishankar // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
1822291705dSMahesh Ravishankar // `srcLinearIndices`, `dstLinearIndex` in place.
1832291705dSMahesh Ravishankar auto computeRemappedLinearIndex = [&](int linearIndex) {
1842291705dSMahesh Ravishankar int totalCount = linearIndex;
1852291705dSMahesh Ravishankar for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1862291705dSMahesh Ravishankar indices[dim] = totalCount % loopBounds[dim];
1872291705dSMahesh Ravishankar totalCount /= loopBounds[dim];
1882291705dSMahesh Ravishankar }
1892291705dSMahesh Ravishankar
1902291705dSMahesh Ravishankar for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1912291705dSMahesh Ravishankar for (int i = 0; i < numInputs; ++i)
1922291705dSMahesh Ravishankar srcIndices[i][dim] = indices[inputDims[i][dim]];
1932291705dSMahesh Ravishankar dstIndices[dim] = indices[outputDims[dim]];
1942291705dSMahesh Ravishankar }
1952291705dSMahesh Ravishankar
1962291705dSMahesh Ravishankar dstLinearIndex = dstIndices.front();
1972291705dSMahesh Ravishankar for (int i = 0; i < numInputs; ++i)
1982291705dSMahesh Ravishankar srcLinearIndices[i] = srcIndices[i].front();
1992291705dSMahesh Ravishankar
2002291705dSMahesh Ravishankar for (int dim = 1; dim < outputType.getRank(); ++dim) {
2012291705dSMahesh Ravishankar dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
2022291705dSMahesh Ravishankar for (int i = 0; i < numInputs; ++i)
2032291705dSMahesh Ravishankar srcLinearIndices[i] =
2042291705dSMahesh Ravishankar srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
2052291705dSMahesh Ravishankar }
2062291705dSMahesh Ravishankar };
2072291705dSMahesh Ravishankar
2085550c821STres Popp bool isFloat = isa<FloatType>(elementType);
2092291705dSMahesh Ravishankar if (isFloat) {
2102291705dSMahesh Ravishankar SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
2112291705dSMahesh Ravishankar for (int i = 0; i < numInputs; ++i)
2122291705dSMahesh Ravishankar inFpRanges.push_back(inputValues[i].getValues<APFloat>());
2132291705dSMahesh Ravishankar
2142291705dSMahesh Ravishankar computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
2152291705dSMahesh Ravishankar
2162291705dSMahesh Ravishankar // Transpose the input constant. Because we don't know its rank in
2172291705dSMahesh Ravishankar // advance, we need to loop over the range [0, element count) and
2182291705dSMahesh Ravishankar // delinearize the index.
2192291705dSMahesh Ravishankar for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
2202291705dSMahesh Ravishankar computeRemappedLinearIndex(linearIndex);
2212291705dSMahesh Ravishankar
2222291705dSMahesh Ravishankar // Collect constant elements for all inputs at this loop iteration.
2232291705dSMahesh Ravishankar for (int i = 0; i < numInputs; ++i)
2242291705dSMahesh Ravishankar computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
2252291705dSMahesh Ravishankar
2262291705dSMahesh Ravishankar // Invoke the computation to get the corresponding constant output
2272291705dSMahesh Ravishankar // element.
2282291705dSMahesh Ravishankar fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
2292291705dSMahesh Ravishankar }
2302291705dSMahesh Ravishankar } else {
2312291705dSMahesh Ravishankar SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
2322291705dSMahesh Ravishankar for (int i = 0; i < numInputs; ++i)
2332291705dSMahesh Ravishankar inIntRanges.push_back(inputValues[i].getValues<APInt>());
2342291705dSMahesh Ravishankar
2352291705dSMahesh Ravishankar computeFnInputs.apInts.resize(numInputs);
2362291705dSMahesh Ravishankar
2372291705dSMahesh Ravishankar // Transpose the input constant. Because we don't know its rank in
2382291705dSMahesh Ravishankar // advance, we need to loop over the range [0, element count) and
2392291705dSMahesh Ravishankar // delinearize the index.
2402291705dSMahesh Ravishankar for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
2412291705dSMahesh Ravishankar computeRemappedLinearIndex(linearIndex);
2422291705dSMahesh Ravishankar
2432291705dSMahesh Ravishankar // Collect constant elements for all inputs at this loop iteration.
2442291705dSMahesh Ravishankar for (int i = 0; i < numInputs; ++i)
2452291705dSMahesh Ravishankar computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
2462291705dSMahesh Ravishankar
2472291705dSMahesh Ravishankar // Invoke the computation to get the corresponding constant output
2482291705dSMahesh Ravishankar // element.
2492291705dSMahesh Ravishankar intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
2502291705dSMahesh Ravishankar }
2512291705dSMahesh Ravishankar }
2522291705dSMahesh Ravishankar
2532291705dSMahesh Ravishankar DenseElementsAttr outputAttr =
2542291705dSMahesh Ravishankar isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
2552291705dSMahesh Ravishankar : DenseElementsAttr::get(outputType, intOutputValues);
2562291705dSMahesh Ravishankar
257*74ed79f7SRyan Holt rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
2582291705dSMahesh Ravishankar return success();
2592291705dSMahesh Ravishankar }
2602291705dSMahesh Ravishankar
2612291705dSMahesh Ravishankar private:
2622291705dSMahesh Ravishankar ControlFusionFn controlFn;
2632291705dSMahesh Ravishankar };
2642291705dSMahesh Ravishankar
265*74ed79f7SRyan Holt // Folds linalg.transpose (and linalg.generic ops that are actually transposes)
266*74ed79f7SRyan Holt // on constant values.
2672291705dSMahesh Ravishankar struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
268*74ed79f7SRyan Holt
2692291705dSMahesh Ravishankar using FoldConstantBase::FoldConstantBase;
2702291705dSMahesh Ravishankar
matchIndexingMaps__anond55d2d340111::FoldConstantTranspose271*74ed79f7SRyan Holt bool matchIndexingMaps(LinalgOp linalgOp) const {
2722291705dSMahesh Ravishankar // We should have one input and one output.
273*74ed79f7SRyan Holt return linalgOp.getIndexingMapsArray().size() == 2;
2742291705dSMahesh Ravishankar }
2752291705dSMahesh Ravishankar
getRegionComputeFn__anond55d2d340111::FoldConstantTranspose276*74ed79f7SRyan Holt RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
2772291705dSMahesh Ravishankar // Make sure the region only contains a yield op.
278*74ed79f7SRyan Holt Block &body = linalgOp->getRegion(0).front();
2792291705dSMahesh Ravishankar if (!llvm::hasSingleElement(body))
2802291705dSMahesh Ravishankar return nullptr;
2812291705dSMahesh Ravishankar auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
2822291705dSMahesh Ravishankar if (!yieldOp)
2832291705dSMahesh Ravishankar return nullptr;
2842291705dSMahesh Ravishankar
2852291705dSMahesh Ravishankar // The yield op should return the block argument corresponds to the input.
286d3b3f765SJacques Pienaar for (Value yieldVal : yieldOp.getValues()) {
2875550c821STres Popp auto yieldArg = dyn_cast<BlockArgument>(yieldVal);
2882291705dSMahesh Ravishankar if (!yieldArg || yieldArg.getOwner() != &body)
2892291705dSMahesh Ravishankar return nullptr;
2902291705dSMahesh Ravishankar if (yieldArg.getArgNumber() != 0)
2912291705dSMahesh Ravishankar return nullptr;
2922291705dSMahesh Ravishankar }
2932291705dSMahesh Ravishankar
2942291705dSMahesh Ravishankar // No computation; just return the orginal value.
2952291705dSMahesh Ravishankar return [](const APIntOrFloatArray &inputs) {
2962291705dSMahesh Ravishankar if (inputs.apFloats.empty())
2971a36588eSKazu Hirata return APIntOrFloat{inputs.apInts.front(), std::nullopt};
2981a36588eSKazu Hirata return APIntOrFloat{std::nullopt, inputs.apFloats.front()};
2992291705dSMahesh Ravishankar };
3002291705dSMahesh Ravishankar }
3012291705dSMahesh Ravishankar
3022291705dSMahesh Ravishankar ControlFusionFn controlFn;
3032291705dSMahesh Ravishankar };
3042291705dSMahesh Ravishankar } // namespace
3052291705dSMahesh Ravishankar
populateConstantFoldLinalgOperations(RewritePatternSet & patterns,const ControlFusionFn & controlFn)3062291705dSMahesh Ravishankar void mlir::linalg::populateConstantFoldLinalgOperations(
3072291705dSMahesh Ravishankar RewritePatternSet &patterns, const ControlFusionFn &controlFn) {
3082291705dSMahesh Ravishankar MLIRContext *context = patterns.getContext();
3092291705dSMahesh Ravishankar patterns.insert<FoldConstantTranspose>(context, controlFn);
3102291705dSMahesh Ravishankar }
311