193409967SMatthias Springer //===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
293409967SMatthias Springer //
393409967SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
493409967SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
593409967SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
693409967SMatthias Springer //
793409967SMatthias Springer //===----------------------------------------------------------------------===//
893409967SMatthias Springer //
993409967SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1093409967SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11*a9205c5cSSpenser Bauman #include "mlir/Dialect/Utils/IndexingUtils.h"
1293409967SMatthias Springer #include "mlir/IR/Matchers.h"
1393409967SMatthias Springer #include "mlir/IR/PatternMatch.h"
1493409967SMatthias Springer
15*a9205c5cSSpenser Bauman #include "llvm/ADT/TypeSwitch.h"
16*a9205c5cSSpenser Bauman
1793409967SMatthias Springer using namespace mlir;
1893409967SMatthias Springer using namespace mlir::tensor;
1993409967SMatthias Springer
2093409967SMatthias Springer namespace {
2193409967SMatthias Springer
2293409967SMatthias Springer /// Rewrite tensor.generate with arith.constant if the yielded value is a
2393409967SMatthias Springer /// constant and the tensor type is static.
2493409967SMatthias Springer struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
2593409967SMatthias Springer using OpRewritePattern<GenerateOp>::OpRewritePattern;
2693409967SMatthias Springer
matchAndRewrite__anon2f530e890111::GenerateToConstant2793409967SMatthias Springer LogicalResult matchAndRewrite(GenerateOp generateOp,
2893409967SMatthias Springer PatternRewriter &rewriter) const override {
2993409967SMatthias Springer auto tensorType =
3093409967SMatthias Springer llvm::cast<RankedTensorType>(generateOp.getResult().getType());
3193409967SMatthias Springer if (!tensorType.hasStaticShape())
3293409967SMatthias Springer return failure();
3393409967SMatthias Springer auto terminatorOp =
3493409967SMatthias Springer cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
3593409967SMatthias Springer Attribute attr;
3693409967SMatthias Springer if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
3793409967SMatthias Springer return failure();
3893409967SMatthias Springer Operation *constantOp =
3993409967SMatthias Springer rewriter.getContext()
4093409967SMatthias Springer ->getLoadedDialect<TensorDialect>()
4193409967SMatthias Springer ->materializeConstant(rewriter,
4293409967SMatthias Springer DenseElementsAttr::get(tensorType, attr),
4393409967SMatthias Springer tensorType, generateOp->getLoc());
4493409967SMatthias Springer if (!constantOp)
4593409967SMatthias Springer return failure();
4693409967SMatthias Springer rewriter.replaceOp(generateOp, constantOp->getResults());
4793409967SMatthias Springer return success();
4893409967SMatthias Springer }
4993409967SMatthias Springer };
5093409967SMatthias Springer
51*a9205c5cSSpenser Bauman /// Transform a linear index from one indexing space to another given:
52*a9205c5cSSpenser Bauman ///
53*a9205c5cSSpenser Bauman /// - the shape of the source indexing space,
54*a9205c5cSSpenser Bauman /// - the strides of the target indexing space,
55*a9205c5cSSpenser Bauman /// - a linear index into the source indexing space.
56*a9205c5cSSpenser Bauman ///
57*a9205c5cSSpenser Bauman /// This function is logically a sequence of linearize/delinearize over
58*a9205c5cSSpenser Bauman /// different bases but avoids allocating intermediate SmallVectors.
transformIndexSpace(ArrayRef<int64_t> inputShape,ArrayRef<int64_t> outputStrides,int64_t srcLinearIndex)59*a9205c5cSSpenser Bauman int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
60*a9205c5cSSpenser Bauman ArrayRef<int64_t> outputStrides,
61*a9205c5cSSpenser Bauman int64_t srcLinearIndex) {
62*a9205c5cSSpenser Bauman assert(inputShape.size() == outputStrides.size());
63*a9205c5cSSpenser Bauman
64*a9205c5cSSpenser Bauman int64_t dstLinearIndex = 0;
65*a9205c5cSSpenser Bauman
66*a9205c5cSSpenser Bauman for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
67*a9205c5cSSpenser Bauman // Compute the index into the current dimension of the source tensor.
68*a9205c5cSSpenser Bauman // `quotient` is the remaining linear index after accounting for the
69*a9205c5cSSpenser Bauman // current dimension.
70*a9205c5cSSpenser Bauman //
71*a9205c5cSSpenser Bauman // `remainder` is the index into the source tensor for the current
72*a9205c5cSSpenser Bauman // dimension.
73*a9205c5cSSpenser Bauman auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
74*a9205c5cSSpenser Bauman
75*a9205c5cSSpenser Bauman srcLinearIndex = quotient;
76*a9205c5cSSpenser Bauman
77*a9205c5cSSpenser Bauman // Add the contribution of the current dimension to the output using the
78*a9205c5cSSpenser Bauman // permutation map.
79*a9205c5cSSpenser Bauman dstLinearIndex += outputStrides[dim] * remainder;
80*a9205c5cSSpenser Bauman }
81*a9205c5cSSpenser Bauman
82*a9205c5cSSpenser Bauman return dstLinearIndex;
83*a9205c5cSSpenser Bauman }
84*a9205c5cSSpenser Bauman
85*a9205c5cSSpenser Bauman template <typename ElemType, typename AttrType>
constantFoldPadOp(PatternRewriter & rewriter,Location loc,DenseElementsAttr input,AttrType padValue,ArrayRef<int64_t> padLow,ArrayRef<int64_t> padHigh)86*a9205c5cSSpenser Bauman Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
87*a9205c5cSSpenser Bauman DenseElementsAttr input, AttrType padValue,
88*a9205c5cSSpenser Bauman ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
89*a9205c5cSSpenser Bauman auto inputValues = input.tryGetValues<ElemType>();
90*a9205c5cSSpenser Bauman if (failed(inputValues))
91*a9205c5cSSpenser Bauman return nullptr;
92*a9205c5cSSpenser Bauman
93*a9205c5cSSpenser Bauman auto oldShape = input.getType().getShape();
94*a9205c5cSSpenser Bauman
95*a9205c5cSSpenser Bauman // Compute the output shape of the new value.
96*a9205c5cSSpenser Bauman auto newShape =
97*a9205c5cSSpenser Bauman llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
98*a9205c5cSSpenser Bauman [](std::tuple<int64_t, int64_t, int64_t> pack) {
99*a9205c5cSSpenser Bauman auto [old, low, high] = pack;
100*a9205c5cSSpenser Bauman return old + low + high;
101*a9205c5cSSpenser Bauman });
102*a9205c5cSSpenser Bauman
103*a9205c5cSSpenser Bauman int64_t outputSize = computeProduct(newShape);
104*a9205c5cSSpenser Bauman
105*a9205c5cSSpenser Bauman // Fully initialize the vector with the padding value.
106*a9205c5cSSpenser Bauman // The non-padded area will then be copied.
107*a9205c5cSSpenser Bauman SmallVector<ElemType> values(outputSize, padValue.getValue());
108*a9205c5cSSpenser Bauman
109*a9205c5cSSpenser Bauman // Strides for input and output are used to transform between the indexing
110*a9205c5cSSpenser Bauman // space of the input and output tensors.
111*a9205c5cSSpenser Bauman SmallVector<int64_t> outputStrides = computeStrides(newShape);
112*a9205c5cSSpenser Bauman
113*a9205c5cSSpenser Bauman // The contribution of the low padding to the offset in the output tensor.
114*a9205c5cSSpenser Bauman // This is the starting position of the source tensor within the padding
115*a9205c5cSSpenser Bauman // tensor.
116*a9205c5cSSpenser Bauman int64_t startingOffset = linearize(padLow, outputStrides);
117*a9205c5cSSpenser Bauman
118*a9205c5cSSpenser Bauman // Copy values from the input tensor to the corresponding sub-region
119*a9205c5cSSpenser Bauman // of the output tensor.
120*a9205c5cSSpenser Bauman for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
121*a9205c5cSSpenser Bauman auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
122*a9205c5cSSpenser Bauman values[outputIndex + startingOffset] = inputValue;
123*a9205c5cSSpenser Bauman }
124*a9205c5cSSpenser Bauman
125*a9205c5cSSpenser Bauman // Create an attribute for the folded value.
126*a9205c5cSSpenser Bauman auto newType = input.getType().clone(newShape);
127*a9205c5cSSpenser Bauman auto newAttr = DenseElementsAttr::get(newType, values);
128*a9205c5cSSpenser Bauman
129*a9205c5cSSpenser Bauman Operation *constantOp =
130*a9205c5cSSpenser Bauman rewriter.getContext()
131*a9205c5cSSpenser Bauman ->getLoadedDialect<TensorDialect>()
132*a9205c5cSSpenser Bauman ->materializeConstant(rewriter, newAttr, newType, loc);
133*a9205c5cSSpenser Bauman
134*a9205c5cSSpenser Bauman return constantOp ? constantOp->getResult(0) : nullptr;
135*a9205c5cSSpenser Bauman }
136*a9205c5cSSpenser Bauman
137*a9205c5cSSpenser Bauman struct PadOpToConstant final : public OpRewritePattern<PadOp> {
138*a9205c5cSSpenser Bauman
PadOpToConstant__anon2f530e890111::PadOpToConstant139*a9205c5cSSpenser Bauman PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn,
140*a9205c5cSSpenser Bauman PatternBenefit benefit = 1)
141*a9205c5cSSpenser Bauman : OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}
142*a9205c5cSSpenser Bauman
matchAndRewrite__anon2f530e890111::PadOpToConstant143*a9205c5cSSpenser Bauman LogicalResult matchAndRewrite(PadOp padTensorOp,
144*a9205c5cSSpenser Bauman PatternRewriter &rewriter) const override {
145*a9205c5cSSpenser Bauman if (padTensorOp.getNofold())
146*a9205c5cSSpenser Bauman return rewriter.notifyMatchFailure(
147*a9205c5cSSpenser Bauman padTensorOp, "refusing to fold nofold pad operation");
148*a9205c5cSSpenser Bauman
149*a9205c5cSSpenser Bauman TypedValue<RankedTensorType> input = padTensorOp.getSource();
150*a9205c5cSSpenser Bauman RankedTensorType resultType = padTensorOp.getResult().getType();
151*a9205c5cSSpenser Bauman
152*a9205c5cSSpenser Bauman DenseElementsAttr inputAttr = nullptr;
153*a9205c5cSSpenser Bauman if (!matchPattern(input, m_Constant(&inputAttr)))
154*a9205c5cSSpenser Bauman return failure();
155*a9205c5cSSpenser Bauman
156*a9205c5cSSpenser Bauman Value paddingValue = padTensorOp.getConstantPaddingValue();
157*a9205c5cSSpenser Bauman
158*a9205c5cSSpenser Bauman // Extract the constant value used for padding or bail out.
159*a9205c5cSSpenser Bauman Attribute paddingAttr = nullptr;
160*a9205c5cSSpenser Bauman if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
161*a9205c5cSSpenser Bauman return rewriter.notifyMatchFailure(padTensorOp,
162*a9205c5cSSpenser Bauman "unable to get constant value");
163*a9205c5cSSpenser Bauman
164*a9205c5cSSpenser Bauman // Try to extract the constant values of the low and high padding.
165*a9205c5cSSpenser Bauman auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
166*a9205c5cSSpenser Bauman auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
167*a9205c5cSSpenser Bauman
168*a9205c5cSSpenser Bauman // If the padding cannot be extracted, bail out.
169*a9205c5cSSpenser Bauman if (!lowPad || !highPad)
170*a9205c5cSSpenser Bauman return rewriter.notifyMatchFailure(padTensorOp,
171*a9205c5cSSpenser Bauman "unable to extract constant padding");
172*a9205c5cSSpenser Bauman
173*a9205c5cSSpenser Bauman // We have a potential candidate, consult the control function to
174*a9205c5cSSpenser Bauman // determine if the op should fold.
175*a9205c5cSSpenser Bauman if (!controlFn(&padTensorOp.getSourceMutable()))
176*a9205c5cSSpenser Bauman return rewriter.notifyMatchFailure(padTensorOp,
177*a9205c5cSSpenser Bauman "not folding due to cost function");
178*a9205c5cSSpenser Bauman
179*a9205c5cSSpenser Bauman Location loc = padTensorOp.getLoc();
180*a9205c5cSSpenser Bauman
181*a9205c5cSSpenser Bauman // Try constant folding the supported cases of integer and float values.
182*a9205c5cSSpenser Bauman Value newOp =
183*a9205c5cSSpenser Bauman llvm::TypeSwitch<Attribute, Value>(paddingAttr)
184*a9205c5cSSpenser Bauman .Case([&](FloatAttr floatAttr) {
185*a9205c5cSSpenser Bauman return constantFoldPadOp<llvm::APFloat>(
186*a9205c5cSSpenser Bauman rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
187*a9205c5cSSpenser Bauman })
188*a9205c5cSSpenser Bauman .Case([&](IntegerAttr integerAttr) {
189*a9205c5cSSpenser Bauman return constantFoldPadOp<llvm::APInt>(
190*a9205c5cSSpenser Bauman rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
191*a9205c5cSSpenser Bauman })
192*a9205c5cSSpenser Bauman .Default(Value());
193*a9205c5cSSpenser Bauman
194*a9205c5cSSpenser Bauman if (!newOp)
195*a9205c5cSSpenser Bauman return rewriter.notifyMatchFailure(padTensorOp,
196*a9205c5cSSpenser Bauman "tensor type not supported");
197*a9205c5cSSpenser Bauman
198*a9205c5cSSpenser Bauman if (newOp.getType() != resultType)
199*a9205c5cSSpenser Bauman newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp);
200*a9205c5cSSpenser Bauman
201*a9205c5cSSpenser Bauman rewriter.replaceOp(padTensorOp, newOp);
202*a9205c5cSSpenser Bauman return success();
203*a9205c5cSSpenser Bauman }
204*a9205c5cSSpenser Bauman
205*a9205c5cSSpenser Bauman private:
206*a9205c5cSSpenser Bauman ControlFoldFn controlFn;
207*a9205c5cSSpenser Bauman };
208*a9205c5cSSpenser Bauman
20993409967SMatthias Springer } // namespace
21093409967SMatthias Springer
populateRewriteAsConstantPatterns(RewritePatternSet & patterns,const ControlFoldFn & controlFn)21193409967SMatthias Springer void mlir::tensor::populateRewriteAsConstantPatterns(
212*a9205c5cSSpenser Bauman RewritePatternSet &patterns, const ControlFoldFn &controlFn) {
21393409967SMatthias Springer patterns.add<GenerateToConstant>(patterns.getContext());
214*a9205c5cSSpenser Bauman
215*a9205c5cSSpenser Bauman patterns.add<PadOpToConstant>(patterns.getContext(), controlFn);
21693409967SMatthias Springer }
217