1 //===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11 #include "mlir/Dialect/Utils/IndexingUtils.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14
15 #include "llvm/ADT/TypeSwitch.h"
16
17 using namespace mlir;
18 using namespace mlir::tensor;
19
20 namespace {
21
22 /// Rewrite tensor.generate with arith.constant if the yielded value is a
23 /// constant and the tensor type is static.
24 struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
25 using OpRewritePattern<GenerateOp>::OpRewritePattern;
26
matchAndRewrite__anon2f530e890111::GenerateToConstant27 LogicalResult matchAndRewrite(GenerateOp generateOp,
28 PatternRewriter &rewriter) const override {
29 auto tensorType =
30 llvm::cast<RankedTensorType>(generateOp.getResult().getType());
31 if (!tensorType.hasStaticShape())
32 return failure();
33 auto terminatorOp =
34 cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
35 Attribute attr;
36 if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
37 return failure();
38 Operation *constantOp =
39 rewriter.getContext()
40 ->getLoadedDialect<TensorDialect>()
41 ->materializeConstant(rewriter,
42 DenseElementsAttr::get(tensorType, attr),
43 tensorType, generateOp->getLoc());
44 if (!constantOp)
45 return failure();
46 rewriter.replaceOp(generateOp, constantOp->getResults());
47 return success();
48 }
49 };
50
51 /// Transform a linear index from one indexing space to another given:
52 ///
53 /// - the shape of the source indexing space,
54 /// - the strides of the target indexing space,
55 /// - a linear index into the source indexing space.
56 ///
57 /// This function is logically a sequence of linearize/delinearize over
58 /// different bases but avoids allocating intermediate SmallVectors.
transformIndexSpace(ArrayRef<int64_t> inputShape,ArrayRef<int64_t> outputStrides,int64_t srcLinearIndex)59 int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
60 ArrayRef<int64_t> outputStrides,
61 int64_t srcLinearIndex) {
62 assert(inputShape.size() == outputStrides.size());
63
64 int64_t dstLinearIndex = 0;
65
66 for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
67 // Compute the index into the current dimension of the source tensor.
68 // `quotient` is the remaining linear index after accounting for the
69 // current dimension.
70 //
71 // `remainder` is the index into the source tensor for the current
72 // dimension.
73 auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
74
75 srcLinearIndex = quotient;
76
77 // Add the contribution of the current dimension to the output using the
78 // permutation map.
79 dstLinearIndex += outputStrides[dim] * remainder;
80 }
81
82 return dstLinearIndex;
83 }
84
85 template <typename ElemType, typename AttrType>
constantFoldPadOp(PatternRewriter & rewriter,Location loc,DenseElementsAttr input,AttrType padValue,ArrayRef<int64_t> padLow,ArrayRef<int64_t> padHigh)86 Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
87 DenseElementsAttr input, AttrType padValue,
88 ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
89 auto inputValues = input.tryGetValues<ElemType>();
90 if (failed(inputValues))
91 return nullptr;
92
93 auto oldShape = input.getType().getShape();
94
95 // Compute the output shape of the new value.
96 auto newShape =
97 llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
98 [](std::tuple<int64_t, int64_t, int64_t> pack) {
99 auto [old, low, high] = pack;
100 return old + low + high;
101 });
102
103 int64_t outputSize = computeProduct(newShape);
104
105 // Fully initialize the vector with the padding value.
106 // The non-padded area will then be copied.
107 SmallVector<ElemType> values(outputSize, padValue.getValue());
108
109 // Strides for input and output are used to transform between the indexing
110 // space of the input and output tensors.
111 SmallVector<int64_t> outputStrides = computeStrides(newShape);
112
113 // The contribution of the low padding to the offset in the output tensor.
114 // This is the starting position of the source tensor within the padding
115 // tensor.
116 int64_t startingOffset = linearize(padLow, outputStrides);
117
118 // Copy values from the input tensor to the corresponding sub-region
119 // of the output tensor.
120 for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
121 auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
122 values[outputIndex + startingOffset] = inputValue;
123 }
124
125 // Create an attribute for the folded value.
126 auto newType = input.getType().clone(newShape);
127 auto newAttr = DenseElementsAttr::get(newType, values);
128
129 Operation *constantOp =
130 rewriter.getContext()
131 ->getLoadedDialect<TensorDialect>()
132 ->materializeConstant(rewriter, newAttr, newType, loc);
133
134 return constantOp ? constantOp->getResult(0) : nullptr;
135 }
136
137 struct PadOpToConstant final : public OpRewritePattern<PadOp> {
138
PadOpToConstant__anon2f530e890111::PadOpToConstant139 PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn,
140 PatternBenefit benefit = 1)
141 : OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}
142
matchAndRewrite__anon2f530e890111::PadOpToConstant143 LogicalResult matchAndRewrite(PadOp padTensorOp,
144 PatternRewriter &rewriter) const override {
145 if (padTensorOp.getNofold())
146 return rewriter.notifyMatchFailure(
147 padTensorOp, "refusing to fold nofold pad operation");
148
149 TypedValue<RankedTensorType> input = padTensorOp.getSource();
150 RankedTensorType resultType = padTensorOp.getResult().getType();
151
152 DenseElementsAttr inputAttr = nullptr;
153 if (!matchPattern(input, m_Constant(&inputAttr)))
154 return failure();
155
156 Value paddingValue = padTensorOp.getConstantPaddingValue();
157
158 // Extract the constant value used for padding or bail out.
159 Attribute paddingAttr = nullptr;
160 if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
161 return rewriter.notifyMatchFailure(padTensorOp,
162 "unable to get constant value");
163
164 // Try to extract the constant values of the low and high padding.
165 auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
166 auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
167
168 // If the padding cannot be extracted, bail out.
169 if (!lowPad || !highPad)
170 return rewriter.notifyMatchFailure(padTensorOp,
171 "unable to extract constant padding");
172
173 // We have a potential candidate, consult the control function to
174 // determine if the op should fold.
175 if (!controlFn(&padTensorOp.getSourceMutable()))
176 return rewriter.notifyMatchFailure(padTensorOp,
177 "not folding due to cost function");
178
179 Location loc = padTensorOp.getLoc();
180
181 // Try constant folding the supported cases of integer and float values.
182 Value newOp =
183 llvm::TypeSwitch<Attribute, Value>(paddingAttr)
184 .Case([&](FloatAttr floatAttr) {
185 return constantFoldPadOp<llvm::APFloat>(
186 rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
187 })
188 .Case([&](IntegerAttr integerAttr) {
189 return constantFoldPadOp<llvm::APInt>(
190 rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
191 })
192 .Default(Value());
193
194 if (!newOp)
195 return rewriter.notifyMatchFailure(padTensorOp,
196 "tensor type not supported");
197
198 if (newOp.getType() != resultType)
199 newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp);
200
201 rewriter.replaceOp(padTensorOp, newOp);
202 return success();
203 }
204
205 private:
206 ControlFoldFn controlFn;
207 };
208
209 } // namespace
210
populateRewriteAsConstantPatterns(RewritePatternSet & patterns,const ControlFoldFn & controlFn)211 void mlir::tensor::populateRewriteAsConstantPatterns(
212 RewritePatternSet &patterns, const ControlFoldFn &controlFn) {
213 patterns.add<GenerateToConstant>(patterns.getContext());
214
215 patterns.add<PadOpToConstant>(patterns.getContext(), controlFn);
216 }
217