xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp (revision a9205c5c9d5aeadbb97ed7283a35515df4ba49da)
1 //===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11 #include "mlir/Dialect/Utils/IndexingUtils.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 namespace {
21 
22 /// Rewrite tensor.generate with arith.constant if the yielded value is a
23 /// constant and the tensor type is static.
24 struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
25   using OpRewritePattern<GenerateOp>::OpRewritePattern;
26 
matchAndRewrite__anon2f530e890111::GenerateToConstant27   LogicalResult matchAndRewrite(GenerateOp generateOp,
28                                 PatternRewriter &rewriter) const override {
29     auto tensorType =
30         llvm::cast<RankedTensorType>(generateOp.getResult().getType());
31     if (!tensorType.hasStaticShape())
32       return failure();
33     auto terminatorOp =
34         cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
35     Attribute attr;
36     if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
37       return failure();
38     Operation *constantOp =
39         rewriter.getContext()
40             ->getLoadedDialect<TensorDialect>()
41             ->materializeConstant(rewriter,
42                                   DenseElementsAttr::get(tensorType, attr),
43                                   tensorType, generateOp->getLoc());
44     if (!constantOp)
45       return failure();
46     rewriter.replaceOp(generateOp, constantOp->getResults());
47     return success();
48   }
49 };
50 
51 /// Transform a linear index from one indexing space to another given:
52 ///
53 /// - the shape of the source indexing space,
54 /// - the strides of the target indexing space,
55 /// - a linear index into the source indexing space.
56 ///
57 /// This function is logically a sequence of linearize/delinearize over
58 /// different bases but avoids allocating intermediate SmallVectors.
transformIndexSpace(ArrayRef<int64_t> inputShape,ArrayRef<int64_t> outputStrides,int64_t srcLinearIndex)59 int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
60                             ArrayRef<int64_t> outputStrides,
61                             int64_t srcLinearIndex) {
62   assert(inputShape.size() == outputStrides.size());
63 
64   int64_t dstLinearIndex = 0;
65 
66   for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
67     // Compute the index into the current dimension of the source tensor.
68     // `quotient` is the remaining linear index after accounting for the
69     // current dimension.
70     //
71     // `remainder` is the index into the source tensor for the current
72     // dimension.
73     auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
74 
75     srcLinearIndex = quotient;
76 
77     // Add the contribution of the current dimension to the output using the
78     // permutation map.
79     dstLinearIndex += outputStrides[dim] * remainder;
80   }
81 
82   return dstLinearIndex;
83 }
84 
85 template <typename ElemType, typename AttrType>
constantFoldPadOp(PatternRewriter & rewriter,Location loc,DenseElementsAttr input,AttrType padValue,ArrayRef<int64_t> padLow,ArrayRef<int64_t> padHigh)86 Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
87                         DenseElementsAttr input, AttrType padValue,
88                         ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
89   auto inputValues = input.tryGetValues<ElemType>();
90   if (failed(inputValues))
91     return nullptr;
92 
93   auto oldShape = input.getType().getShape();
94 
95   // Compute the output shape of the new value.
96   auto newShape =
97       llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
98                           [](std::tuple<int64_t, int64_t, int64_t> pack) {
99                             auto [old, low, high] = pack;
100                             return old + low + high;
101                           });
102 
103   int64_t outputSize = computeProduct(newShape);
104 
105   // Fully initialize the vector with the padding value.
106   // The non-padded area will then be copied.
107   SmallVector<ElemType> values(outputSize, padValue.getValue());
108 
109   // Strides for input and output are used to transform between the indexing
110   // space of the input and output tensors.
111   SmallVector<int64_t> outputStrides = computeStrides(newShape);
112 
113   // The contribution of the low padding to the offset in the output tensor.
114   // This is the starting position of the source tensor within the padding
115   // tensor.
116   int64_t startingOffset = linearize(padLow, outputStrides);
117 
118   // Copy values from the input tensor to the corresponding sub-region
119   // of the output tensor.
120   for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
121     auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
122     values[outputIndex + startingOffset] = inputValue;
123   }
124 
125   // Create an attribute for the folded value.
126   auto newType = input.getType().clone(newShape);
127   auto newAttr = DenseElementsAttr::get(newType, values);
128 
129   Operation *constantOp =
130       rewriter.getContext()
131           ->getLoadedDialect<TensorDialect>()
132           ->materializeConstant(rewriter, newAttr, newType, loc);
133 
134   return constantOp ? constantOp->getResult(0) : nullptr;
135 }
136 
137 struct PadOpToConstant final : public OpRewritePattern<PadOp> {
138 
PadOpToConstant__anon2f530e890111::PadOpToConstant139   PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn,
140                   PatternBenefit benefit = 1)
141       : OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}
142 
matchAndRewrite__anon2f530e890111::PadOpToConstant143   LogicalResult matchAndRewrite(PadOp padTensorOp,
144                                 PatternRewriter &rewriter) const override {
145     if (padTensorOp.getNofold())
146       return rewriter.notifyMatchFailure(
147           padTensorOp, "refusing to fold nofold pad operation");
148 
149     TypedValue<RankedTensorType> input = padTensorOp.getSource();
150     RankedTensorType resultType = padTensorOp.getResult().getType();
151 
152     DenseElementsAttr inputAttr = nullptr;
153     if (!matchPattern(input, m_Constant(&inputAttr)))
154       return failure();
155 
156     Value paddingValue = padTensorOp.getConstantPaddingValue();
157 
158     // Extract the constant value used for padding or bail out.
159     Attribute paddingAttr = nullptr;
160     if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
161       return rewriter.notifyMatchFailure(padTensorOp,
162                                          "unable to get constant value");
163 
164     // Try to extract the constant values of the low and high padding.
165     auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
166     auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
167 
168     // If the padding cannot be extracted, bail out.
169     if (!lowPad || !highPad)
170       return rewriter.notifyMatchFailure(padTensorOp,
171                                          "unable to extract constant padding");
172 
173     // We have a potential candidate, consult the control function to
174     // determine if the op should fold.
175     if (!controlFn(&padTensorOp.getSourceMutable()))
176       return rewriter.notifyMatchFailure(padTensorOp,
177                                          "not folding due to cost function");
178 
179     Location loc = padTensorOp.getLoc();
180 
181     // Try constant folding the supported cases of integer and float values.
182     Value newOp =
183         llvm::TypeSwitch<Attribute, Value>(paddingAttr)
184             .Case([&](FloatAttr floatAttr) {
185               return constantFoldPadOp<llvm::APFloat>(
186                   rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
187             })
188             .Case([&](IntegerAttr integerAttr) {
189               return constantFoldPadOp<llvm::APInt>(
190                   rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
191             })
192             .Default(Value());
193 
194     if (!newOp)
195       return rewriter.notifyMatchFailure(padTensorOp,
196                                          "tensor type not supported");
197 
198     if (newOp.getType() != resultType)
199       newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp);
200 
201     rewriter.replaceOp(padTensorOp, newOp);
202     return success();
203   }
204 
205 private:
206   ControlFoldFn controlFn;
207 };
208 
209 } // namespace
210 
populateRewriteAsConstantPatterns(RewritePatternSet & patterns,const ControlFoldFn & controlFn)211 void mlir::tensor::populateRewriteAsConstantPatterns(
212     RewritePatternSet &patterns, const ControlFoldFn &controlFn) {
213   patterns.add<GenerateToConstant>(patterns.getContext());
214 
215   patterns.add<PadOpToConstant>(patterns.getContext(), controlFn);
216 }
217