xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp (revision 74ed79f7f123788d95f1552800e1af9ceaee4a08)
12291705dSMahesh Ravishankar //===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===//
22291705dSMahesh Ravishankar //
32291705dSMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42291705dSMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
52291705dSMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62291705dSMahesh Ravishankar //
72291705dSMahesh Ravishankar //===----------------------------------------------------------------------===//
82291705dSMahesh Ravishankar //
92291705dSMahesh Ravishankar // This file implements constant folding on Linalg operations.
102291705dSMahesh Ravishankar //
112291705dSMahesh Ravishankar //===----------------------------------------------------------------------===//
122291705dSMahesh Ravishankar 
132291705dSMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
142291705dSMahesh Ravishankar #include "mlir/Dialect/Linalg/IR/Linalg.h"
152291705dSMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
162291705dSMahesh Ravishankar #include "mlir/IR/Matchers.h"
172291705dSMahesh Ravishankar #include "mlir/IR/PatternMatch.h"
182291705dSMahesh Ravishankar #include "mlir/Support/LLVM.h"
192291705dSMahesh Ravishankar #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20a1fe1f5fSKazu Hirata #include <optional>
212291705dSMahesh Ravishankar 
222291705dSMahesh Ravishankar using namespace mlir;
232291705dSMahesh Ravishankar using namespace mlir::linalg;
242291705dSMahesh Ravishankar 
252291705dSMahesh Ravishankar namespace {
26*74ed79f7SRyan Holt /// Base class for constant folding linalg structured ops with N inputs, 1
27*74ed79f7SRyan Holt /// output, and permutation indexing maps.
282291705dSMahesh Ravishankar ///
292291705dSMahesh Ravishankar /// `ConcreteType` should provide methods with signatures
302291705dSMahesh Ravishankar ///
312291705dSMahesh Ravishankar /// ```c++
32*74ed79f7SRyan Holt ///   bool matchIndexingMaps(LinalgOp linalgOp) const;
33*74ed79f7SRyan Holt ///   RegionComputationFn getRegionComputeFn(LinalgOp) const;
342291705dSMahesh Ravishankar /// ```
352291705dSMahesh Ravishankar ///
362291705dSMahesh Ravishankar /// The latter inspects the region and returns the computation inside as a
372291705dSMahesh Ravishankar /// functor. The functor will be invoked with constant elements for all inputs
382291705dSMahesh Ravishankar /// and should return the corresponding computed constant element for output.
392291705dSMahesh Ravishankar template <typename ConcreteType>
40*74ed79f7SRyan Holt class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
412291705dSMahesh Ravishankar public:
422291705dSMahesh Ravishankar   struct APIntOrFloat {
430a81ace0SKazu Hirata     std::optional<APInt> apInt;
440a81ace0SKazu Hirata     std::optional<APFloat> apFloat;
452291705dSMahesh Ravishankar   };
462291705dSMahesh Ravishankar   struct APIntOrFloatArray {
472291705dSMahesh Ravishankar     SmallVector<APInt> apInts;
482291705dSMahesh Ravishankar     SmallVector<APFloat> apFloats;
492291705dSMahesh Ravishankar   };
502291705dSMahesh Ravishankar   using RegionComputationFn =
512291705dSMahesh Ravishankar       std::function<APIntOrFloat(const APIntOrFloatArray &)>;
522291705dSMahesh Ravishankar 
FoldConstantBase(MLIRContext * context,const ControlFusionFn & controlFn,PatternBenefit benefit=1)532291705dSMahesh Ravishankar   FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
542291705dSMahesh Ravishankar                    PatternBenefit benefit = 1)
55*74ed79f7SRyan Holt       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
56*74ed79f7SRyan Holt         controlFn(controlFn) {}
572291705dSMahesh Ravishankar 
matchAndRewrite(LinalgOp linalgOp,PatternRewriter & rewriter) const58*74ed79f7SRyan Holt   LogicalResult matchAndRewrite(LinalgOp linalgOp,
592291705dSMahesh Ravishankar                                 PatternRewriter &rewriter) const override {
60e3f75c1cSIvan Butygin     // Mixed and buffer sematics aren't supported.
61*74ed79f7SRyan Holt     if (!linalgOp.hasPureTensorSemantics())
622291705dSMahesh Ravishankar       return failure();
632291705dSMahesh Ravishankar 
642291705dSMahesh Ravishankar     // Only support ops generating one output for now.
65*74ed79f7SRyan Holt     if (linalgOp.getNumDpsInits() != 1)
662291705dSMahesh Ravishankar       return failure();
672291705dSMahesh Ravishankar 
68*74ed79f7SRyan Holt     auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
692291705dSMahesh Ravishankar     // Require the output types to be static given that we are generating
702291705dSMahesh Ravishankar     // constants.
712291705dSMahesh Ravishankar     if (!outputType || !outputType.hasStaticShape())
722291705dSMahesh Ravishankar       return failure();
732291705dSMahesh Ravishankar 
74*74ed79f7SRyan Holt     if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
755550c821STres Popp           return isa<ShapedType>(input.getType());
762291705dSMahesh Ravishankar         }))
772291705dSMahesh Ravishankar       return failure();
782291705dSMahesh Ravishankar 
792291705dSMahesh Ravishankar     // Make sure all element types are the same.
80a7cccb9cSAlexander Belyaev     auto getOperandElementType = [](Value value) {
815550c821STres Popp       return cast<ShapedType>(value.getType()).getElementType();
822291705dSMahesh Ravishankar     };
83a7cccb9cSAlexander Belyaev     if (!llvm::all_equal(
84*74ed79f7SRyan Holt             llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
852291705dSMahesh Ravishankar       return failure();
862291705dSMahesh Ravishankar 
872291705dSMahesh Ravishankar     // We can only handle the case where we have int/float elements.
882291705dSMahesh Ravishankar     auto elementType = outputType.getElementType();
892291705dSMahesh Ravishankar     if (!elementType.isIntOrFloat())
902291705dSMahesh Ravishankar       return failure();
912291705dSMahesh Ravishankar 
922291705dSMahesh Ravishankar     // Require all indexing maps to be permutations for now. This is common and
932291705dSMahesh Ravishankar     // it simplifies input/output access greatly: we can do the data shuffling
942291705dSMahesh Ravishankar     // entirely in the compiler, without needing to turn all indices into
952291705dSMahesh Ravishankar     // Values, and then do affine apply on them, and then match back the
962291705dSMahesh Ravishankar     // constant again.
97*74ed79f7SRyan Holt     if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
982291705dSMahesh Ravishankar                       [](AffineMap map) { return map.isPermutation(); }))
992291705dSMahesh Ravishankar       return failure();
1002291705dSMahesh Ravishankar 
101*74ed79f7SRyan Holt     for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
102*74ed79f7SRyan Holt       if (linalgOp.payloadUsesValueFromOperand(&operand))
1032291705dSMahesh Ravishankar         return failure();
1042291705dSMahesh Ravishankar     }
1052291705dSMahesh Ravishankar 
1062291705dSMahesh Ravishankar     // Further check the indexing maps are okay for the ConcreteType.
107*74ed79f7SRyan Holt     if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
1082291705dSMahesh Ravishankar       return failure();
1092291705dSMahesh Ravishankar 
1102291705dSMahesh Ravishankar     // Defer to the concrete type to check the region and discover the
1112291705dSMahesh Ravishankar     // computation inside.
1122291705dSMahesh Ravishankar     RegionComputationFn computeFn =
113*74ed79f7SRyan Holt         static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
1142291705dSMahesh Ravishankar     if (!computeFn)
1152291705dSMahesh Ravishankar       return failure();
1162291705dSMahesh Ravishankar 
1172291705dSMahesh Ravishankar     // All inputs should be constants.
118*74ed79f7SRyan Holt     int numInputs = linalgOp.getNumDpsInputs();
1192291705dSMahesh Ravishankar     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
120*74ed79f7SRyan Holt     for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
121a7cccb9cSAlexander Belyaev       if (!matchPattern(en.value()->get(),
122a7cccb9cSAlexander Belyaev                         m_Constant(&inputValues[en.index()])))
1232291705dSMahesh Ravishankar         return failure();
1242291705dSMahesh Ravishankar     }
1252291705dSMahesh Ravishankar 
1262291705dSMahesh Ravishankar     // Identified this as a potential candidate for folding. Now check the
1272291705dSMahesh Ravishankar     // policy to see whether we are allowed to proceed.
128*74ed79f7SRyan Holt     for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
129a7bfdc23SMahesh Ravishankar       if (!controlFn(operand))
1302291705dSMahesh Ravishankar         return failure();
1312291705dSMahesh Ravishankar     }
1322291705dSMahesh Ravishankar 
1332291705dSMahesh Ravishankar     SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
1342291705dSMahesh Ravishankar     int64_t numElements = outputType.getNumElements();
1352291705dSMahesh Ravishankar 
1362291705dSMahesh Ravishankar     // Use APInt/APFloat instead of Attribute here for constructing the output.
1372291705dSMahesh Ravishankar     // This helps to avoid blowing up compiler memory usage: Attributes would
1382291705dSMahesh Ravishankar     // unify the following cases but they have lifetime as the MLIRContext.
1392291705dSMahesh Ravishankar     SmallVector<APInt> intOutputValues;
1402291705dSMahesh Ravishankar     SmallVector<APFloat> fpOutputValues;
1415550c821STres Popp     if (isa<FloatType>(elementType))
1422291705dSMahesh Ravishankar       fpOutputValues.resize(numElements, APFloat(0.f));
1432291705dSMahesh Ravishankar     else
1442291705dSMahesh Ravishankar       intOutputValues.resize(numElements);
1452291705dSMahesh Ravishankar 
1462291705dSMahesh Ravishankar     // Return the constant dim positions from the given permutation map.
1472291705dSMahesh Ravishankar     auto getDimPositions = [](AffineMap map) {
1482291705dSMahesh Ravishankar       SmallVector<unsigned> dims;
1492291705dSMahesh Ravishankar       dims.reserve(map.getNumResults());
1502291705dSMahesh Ravishankar       for (AffineExpr result : map.getResults()) {
1511609f1c2Slong.chen         dims.push_back(cast<AffineDimExpr>(result).getPosition());
1522291705dSMahesh Ravishankar       }
1532291705dSMahesh Ravishankar       return dims;
1542291705dSMahesh Ravishankar     };
1552291705dSMahesh Ravishankar 
1562291705dSMahesh Ravishankar     SmallVector<SmallVector<unsigned>> inputDims;
1572291705dSMahesh Ravishankar     for (int i = 0; i < numInputs; ++i)
158*74ed79f7SRyan Holt       inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
159*74ed79f7SRyan Holt     auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
1602291705dSMahesh Ravishankar     auto outputShape = outputType.getShape();
1612291705dSMahesh Ravishankar 
1622291705dSMahesh Ravishankar     // Allocate small vectors for index delinearization. Initial values do not
1632291705dSMahesh Ravishankar     // matter here as they will be overwritten later.
1642291705dSMahesh Ravishankar     SmallVector<uint64_t> indices(loopBounds.size(), 0);
1652291705dSMahesh Ravishankar     SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
1662291705dSMahesh Ravishankar     SmallVector<SmallVector<uint64_t>> srcIndices(
1672291705dSMahesh Ravishankar         numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
1682291705dSMahesh Ravishankar     SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
1692291705dSMahesh Ravishankar     uint64_t dstLinearIndex = 0;
1702291705dSMahesh Ravishankar 
1712291705dSMahesh Ravishankar     // Allocate spaces for compute function inputs. Initial values do not matter
1722291705dSMahesh Ravishankar     // here as they will be overwritten later.
1732291705dSMahesh Ravishankar     APIntOrFloatArray computeFnInputs;
1742291705dSMahesh Ravishankar 
1752291705dSMahesh Ravishankar     auto inputShapes = llvm::to_vector<4>(
176*74ed79f7SRyan Holt         llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
1775550c821STres Popp           return cast<ShapedType>(value.getType()).getShape();
1782291705dSMahesh Ravishankar         }));
1792291705dSMahesh Ravishankar 
1802291705dSMahesh Ravishankar     // Given a `linearIndex`, remap it to a linear index to access linalg op
1812291705dSMahesh Ravishankar     // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
1822291705dSMahesh Ravishankar     // `srcLinearIndices`, `dstLinearIndex` in place.
1832291705dSMahesh Ravishankar     auto computeRemappedLinearIndex = [&](int linearIndex) {
1842291705dSMahesh Ravishankar       int totalCount = linearIndex;
1852291705dSMahesh Ravishankar       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1862291705dSMahesh Ravishankar         indices[dim] = totalCount % loopBounds[dim];
1872291705dSMahesh Ravishankar         totalCount /= loopBounds[dim];
1882291705dSMahesh Ravishankar       }
1892291705dSMahesh Ravishankar 
1902291705dSMahesh Ravishankar       for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
1912291705dSMahesh Ravishankar         for (int i = 0; i < numInputs; ++i)
1922291705dSMahesh Ravishankar           srcIndices[i][dim] = indices[inputDims[i][dim]];
1932291705dSMahesh Ravishankar         dstIndices[dim] = indices[outputDims[dim]];
1942291705dSMahesh Ravishankar       }
1952291705dSMahesh Ravishankar 
1962291705dSMahesh Ravishankar       dstLinearIndex = dstIndices.front();
1972291705dSMahesh Ravishankar       for (int i = 0; i < numInputs; ++i)
1982291705dSMahesh Ravishankar         srcLinearIndices[i] = srcIndices[i].front();
1992291705dSMahesh Ravishankar 
2002291705dSMahesh Ravishankar       for (int dim = 1; dim < outputType.getRank(); ++dim) {
2012291705dSMahesh Ravishankar         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
2022291705dSMahesh Ravishankar         for (int i = 0; i < numInputs; ++i)
2032291705dSMahesh Ravishankar           srcLinearIndices[i] =
2042291705dSMahesh Ravishankar               srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
2052291705dSMahesh Ravishankar       }
2062291705dSMahesh Ravishankar     };
2072291705dSMahesh Ravishankar 
2085550c821STres Popp     bool isFloat = isa<FloatType>(elementType);
2092291705dSMahesh Ravishankar     if (isFloat) {
2102291705dSMahesh Ravishankar       SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
2112291705dSMahesh Ravishankar       for (int i = 0; i < numInputs; ++i)
2122291705dSMahesh Ravishankar         inFpRanges.push_back(inputValues[i].getValues<APFloat>());
2132291705dSMahesh Ravishankar 
2142291705dSMahesh Ravishankar       computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
2152291705dSMahesh Ravishankar 
2162291705dSMahesh Ravishankar       // Transpose the input constant. Because we don't know its rank in
2172291705dSMahesh Ravishankar       // advance, we need to loop over the range [0, element count) and
2182291705dSMahesh Ravishankar       // delinearize the index.
2192291705dSMahesh Ravishankar       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
2202291705dSMahesh Ravishankar         computeRemappedLinearIndex(linearIndex);
2212291705dSMahesh Ravishankar 
2222291705dSMahesh Ravishankar         // Collect constant elements for all inputs at this loop iteration.
2232291705dSMahesh Ravishankar         for (int i = 0; i < numInputs; ++i)
2242291705dSMahesh Ravishankar           computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
2252291705dSMahesh Ravishankar 
2262291705dSMahesh Ravishankar         // Invoke the computation to get the corresponding constant output
2272291705dSMahesh Ravishankar         // element.
2282291705dSMahesh Ravishankar         fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
2292291705dSMahesh Ravishankar       }
2302291705dSMahesh Ravishankar     } else {
2312291705dSMahesh Ravishankar       SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
2322291705dSMahesh Ravishankar       for (int i = 0; i < numInputs; ++i)
2332291705dSMahesh Ravishankar         inIntRanges.push_back(inputValues[i].getValues<APInt>());
2342291705dSMahesh Ravishankar 
2352291705dSMahesh Ravishankar       computeFnInputs.apInts.resize(numInputs);
2362291705dSMahesh Ravishankar 
2372291705dSMahesh Ravishankar       // Transpose the input constant. Because we don't know its rank in
2382291705dSMahesh Ravishankar       // advance, we need to loop over the range [0, element count) and
2392291705dSMahesh Ravishankar       // delinearize the index.
2402291705dSMahesh Ravishankar       for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
2412291705dSMahesh Ravishankar         computeRemappedLinearIndex(linearIndex);
2422291705dSMahesh Ravishankar 
2432291705dSMahesh Ravishankar         // Collect constant elements for all inputs at this loop iteration.
2442291705dSMahesh Ravishankar         for (int i = 0; i < numInputs; ++i)
2452291705dSMahesh Ravishankar           computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
2462291705dSMahesh Ravishankar 
2472291705dSMahesh Ravishankar         // Invoke the computation to get the corresponding constant output
2482291705dSMahesh Ravishankar         // element.
2492291705dSMahesh Ravishankar         intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
2502291705dSMahesh Ravishankar       }
2512291705dSMahesh Ravishankar     }
2522291705dSMahesh Ravishankar 
2532291705dSMahesh Ravishankar     DenseElementsAttr outputAttr =
2542291705dSMahesh Ravishankar         isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
2552291705dSMahesh Ravishankar                 : DenseElementsAttr::get(outputType, intOutputValues);
2562291705dSMahesh Ravishankar 
257*74ed79f7SRyan Holt     rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
2582291705dSMahesh Ravishankar     return success();
2592291705dSMahesh Ravishankar   }
2602291705dSMahesh Ravishankar 
2612291705dSMahesh Ravishankar private:
2622291705dSMahesh Ravishankar   ControlFusionFn controlFn;
2632291705dSMahesh Ravishankar };
2642291705dSMahesh Ravishankar 
265*74ed79f7SRyan Holt // Folds linalg.transpose (and linalg.generic ops that are actually transposes)
266*74ed79f7SRyan Holt // on constant values.
2672291705dSMahesh Ravishankar struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
268*74ed79f7SRyan Holt 
2692291705dSMahesh Ravishankar   using FoldConstantBase::FoldConstantBase;
2702291705dSMahesh Ravishankar 
matchIndexingMaps__anond55d2d340111::FoldConstantTranspose271*74ed79f7SRyan Holt   bool matchIndexingMaps(LinalgOp linalgOp) const {
2722291705dSMahesh Ravishankar     // We should have one input and one output.
273*74ed79f7SRyan Holt     return linalgOp.getIndexingMapsArray().size() == 2;
2742291705dSMahesh Ravishankar   }
2752291705dSMahesh Ravishankar 
getRegionComputeFn__anond55d2d340111::FoldConstantTranspose276*74ed79f7SRyan Holt   RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
2772291705dSMahesh Ravishankar     // Make sure the region only contains a yield op.
278*74ed79f7SRyan Holt     Block &body = linalgOp->getRegion(0).front();
2792291705dSMahesh Ravishankar     if (!llvm::hasSingleElement(body))
2802291705dSMahesh Ravishankar       return nullptr;
2812291705dSMahesh Ravishankar     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
2822291705dSMahesh Ravishankar     if (!yieldOp)
2832291705dSMahesh Ravishankar       return nullptr;
2842291705dSMahesh Ravishankar 
2852291705dSMahesh Ravishankar     // The yield op should return the block argument corresponds to the input.
286d3b3f765SJacques Pienaar     for (Value yieldVal : yieldOp.getValues()) {
2875550c821STres Popp       auto yieldArg = dyn_cast<BlockArgument>(yieldVal);
2882291705dSMahesh Ravishankar       if (!yieldArg || yieldArg.getOwner() != &body)
2892291705dSMahesh Ravishankar         return nullptr;
2902291705dSMahesh Ravishankar       if (yieldArg.getArgNumber() != 0)
2912291705dSMahesh Ravishankar         return nullptr;
2922291705dSMahesh Ravishankar     }
2932291705dSMahesh Ravishankar 
2942291705dSMahesh Ravishankar     // No computation; just return the orginal value.
2952291705dSMahesh Ravishankar     return [](const APIntOrFloatArray &inputs) {
2962291705dSMahesh Ravishankar       if (inputs.apFloats.empty())
2971a36588eSKazu Hirata         return APIntOrFloat{inputs.apInts.front(), std::nullopt};
2981a36588eSKazu Hirata       return APIntOrFloat{std::nullopt, inputs.apFloats.front()};
2992291705dSMahesh Ravishankar     };
3002291705dSMahesh Ravishankar   }
3012291705dSMahesh Ravishankar 
3022291705dSMahesh Ravishankar   ControlFusionFn controlFn;
3032291705dSMahesh Ravishankar };
3042291705dSMahesh Ravishankar } // namespace
3052291705dSMahesh Ravishankar 
populateConstantFoldLinalgOperations(RewritePatternSet & patterns,const ControlFusionFn & controlFn)3062291705dSMahesh Ravishankar void mlir::linalg::populateConstantFoldLinalgOperations(
3072291705dSMahesh Ravishankar     RewritePatternSet &patterns, const ControlFusionFn &controlFn) {
3082291705dSMahesh Ravishankar   MLIRContext *context = patterns.getContext();
3092291705dSMahesh Ravishankar   patterns.insert<FoldConstantTranspose>(context, controlFn);
3102291705dSMahesh Ravishankar }
311