1d84d418eSTina Jung //===- TosaFolders.cpp ----------------------------------------------------===//
2d84d418eSTina Jung //
3d84d418eSTina Jung // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d84d418eSTina Jung // See https://llvm.org/LICENSE.txt for license information.
5d84d418eSTina Jung // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d84d418eSTina Jung //
7d84d418eSTina Jung //===----------------------------------------------------------------------===//
8d84d418eSTina Jung //
9d84d418eSTina Jung // Fold TOSA operations
10d84d418eSTina Jung //
11d84d418eSTina Jung //===----------------------------------------------------------------------===//
12d84d418eSTina Jung
13d84d418eSTina Jung #include <functional>
14f5f7e2a3SAmir Bishara #include <numeric>
15d84d418eSTina Jung
16d84d418eSTina Jung #include "mlir/Dialect/Tosa/IR/TosaOps.h"
17d84d418eSTina Jung #include "mlir/Dialect/Tosa/Transforms/Passes.h"
18d84d418eSTina Jung #include "mlir/Dialect/Utils/IndexingUtils.h"
19d84d418eSTina Jung #include "mlir/IR/BuiltinAttributes.h"
20d84d418eSTina Jung #include "mlir/IR/BuiltinTypes.h"
21d84d418eSTina Jung #include "mlir/IR/Matchers.h"
22d84d418eSTina Jung #include "mlir/Pass/Pass.h"
23d84d418eSTina Jung #include "llvm/ADT/APFloat.h"
24d84d418eSTina Jung #include "llvm/ADT/FloatingPointMode.h"
25d84d418eSTina Jung #include "llvm/ADT/SmallVector.h"
26d84d418eSTina Jung
27d84d418eSTina Jung using namespace mlir;
28d84d418eSTina Jung using namespace mlir::tosa;
29d84d418eSTina Jung
30d84d418eSTina Jung namespace {
31d84d418eSTina Jung
32d84d418eSTina Jung /// Apply the given transformation \p toApply to every element of the tensor to
33d84d418eSTina Jung /// be transformed \p toTransform.
34d84d418eSTina Jung ///
35d84d418eSTina Jung /// Elements of \p toTransform are extracted as \p SrcValueType.
36d84d418eSTina Jung ///
37d84d418eSTina Jung /// \returns A tensor with the same size as \p toTransform, containing
38d84d418eSTina Jung /// \p TargetValueType values of type \p TargetType.
39d84d418eSTina Jung template <class SrcValType, class TargetValType, class TargetType>
applyElementWise(const DenseElementsAttr & toTransform,const std::function<TargetValType (const SrcValType &)> & toApply,TargetType targetType)40d84d418eSTina Jung DenseElementsAttr applyElementWise(
41d84d418eSTina Jung const DenseElementsAttr &toTransform,
42d89a0a65SAviad Cohen const std::function<TargetValType(const SrcValType &)> &toApply,
43d84d418eSTina Jung TargetType targetType) {
44d84d418eSTina Jung SmallVector<TargetValType> transformedValues;
45d84d418eSTina Jung // We already know the amount of values we will insert, reserve space for
46d84d418eSTina Jung // all of them to avoid dynamic resizing
47d84d418eSTina Jung transformedValues.reserve(toTransform.getNumElements());
48d84d418eSTina Jung for (auto val : toTransform.getValues<SrcValType>()) {
49d89a0a65SAviad Cohen auto transformedVal = toApply(val);
50d84d418eSTina Jung transformedValues.push_back(transformedVal);
51d84d418eSTina Jung }
52d84d418eSTina Jung
53d84d418eSTina Jung // Make sure that the output tensor has the expected output type
54d84d418eSTina Jung auto inShape = toTransform.getType();
55d84d418eSTina Jung auto outTy = inShape.cloneWith({}, targetType);
56d84d418eSTina Jung
57d84d418eSTina Jung return DenseElementsAttr::get(outTy, transformedValues);
58d84d418eSTina Jung }
59d84d418eSTina Jung
60d84d418eSTina Jung template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
61d84d418eSTina Jung const DenseElementsAttr &toTransform,
62d89a0a65SAviad Cohen const std::function<APFloat(const APFloat &)> &toApply,
63d84d418eSTina Jung FloatType targetType);
64d84d418eSTina Jung
65d84d418eSTina Jung /// Function that checks if the type contained in \p toCheck is float.
notifyIfNotFloat(TypedValue<TensorType> toCheck,TosaOp location,PatternRewriter & rewriter)66d84d418eSTina Jung LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
67d84d418eSTina Jung PatternRewriter &rewriter) {
68d84d418eSTina Jung if (isa<FloatType>(toCheck.getType().getElementType())) {
69d84d418eSTina Jung return success();
70d84d418eSTina Jung }
71d84d418eSTina Jung return rewriter.notifyMatchFailure(location,
72d84d418eSTina Jung "Unexpected input tensor type: the "
73d84d418eSTina Jung "TOSA spec only allows floats");
74d84d418eSTina Jung }
75d84d418eSTina Jung
76d84d418eSTina Jung /// Function that checks if \p toCheck is a dense TOSA constant tensor.
notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,TosaOp location,PatternRewriter & rewriter)77d84d418eSTina Jung LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
78d84d418eSTina Jung TosaOp location,
79d84d418eSTina Jung PatternRewriter &rewriter) {
80d84d418eSTina Jung // Check whether the tensor is constant and dense
81d84d418eSTina Jung // TODO We currently ensure the tensor is dense by using the correct type for
82d84d418eSTina Jung // the bind_value, however we do not actually need this value. It would be
83d84d418eSTina Jung // nicer to only have a check here.
84d84d418eSTina Jung DenseElementsAttr tmp;
85d84d418eSTina Jung if (!matchPattern(toCheck, m_Constant(&tmp))) {
86d84d418eSTina Jung return rewriter.notifyMatchFailure(location,
87d84d418eSTina Jung "Non-const or non-dense input tensor");
88d84d418eSTina Jung }
89d84d418eSTina Jung
90d84d418eSTina Jung // Make sure it actually is a TOSA constant (the match allows for other
91d84d418eSTina Jung // constants as well)
92d84d418eSTina Jung if (isa<ConstOp>(toCheck.getDefiningOp())) {
93d84d418eSTina Jung return success();
94d84d418eSTina Jung }
95d84d418eSTina Jung
96d84d418eSTina Jung return rewriter.notifyMatchFailure(location,
97d84d418eSTina Jung "The reciprocal can only be folded if "
98d84d418eSTina Jung "it operates on a TOSA constant");
99d84d418eSTina Jung }
100d84d418eSTina Jung
101d84d418eSTina Jung /// Function that checks if \p toCheck is a dense TOSA constant float tensor.
notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,TosaOp location,PatternRewriter & rewriter)102d84d418eSTina Jung LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
103d84d418eSTina Jung TosaOp location,
104d84d418eSTina Jung PatternRewriter &rewriter) {
105d84d418eSTina Jung auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
106d84d418eSTina Jung if (failed(floatCheck)) {
107d84d418eSTina Jung return floatCheck;
108d84d418eSTina Jung }
109d84d418eSTina Jung return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
110d84d418eSTina Jung }
111d84d418eSTina Jung
112d84d418eSTina Jung /// Heuristic to decide when to replace a unary operation on a constant with the
113d84d418eSTina Jung /// folded value.
114d84d418eSTina Jung /// Folding operations on constants can lead to an increased memory usage
115d84d418eSTina Jung /// whenever the input cannot be replaced but a new constant is inserted. Hence,
116d84d418eSTina Jung /// this will currently only suggest folding when the memory impact is
117d84d418eSTina Jung /// negligible.
118d84d418eSTina Jung /// Takes the \p unaryOp and the constant input \p values.
119d84d418eSTina Jung /// \returns Whether folding should be applied.
constantUnaryOpShouldBeFolded(TosaOp unaryOp,DenseElementsAttr values)120d84d418eSTina Jung bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
121d84d418eSTina Jung assert(unaryOp->getNumOperands() == 1);
122d84d418eSTina Jung auto inputOp = unaryOp->getOperand(0);
123d84d418eSTina Jung
124d84d418eSTina Jung // If the input is a splat, we don't care for the number of users
125d84d418eSTina Jung if (isa<SplatElementsAttr>(values)) {
126d84d418eSTina Jung return true;
127d84d418eSTina Jung }
128d84d418eSTina Jung
129d84d418eSTina Jung // If this is the only use of the tensor it should be replaced as no
130d84d418eSTina Jung // additional memory is required
131d84d418eSTina Jung return inputOp.hasOneUse();
132d84d418eSTina Jung }
133d84d418eSTina Jung
134fa6e4338SSpenser Bauman template <typename RangeType>
transposeType(const RangeType & data,ShapedType inputType,ShapedType outputType,llvm::ArrayRef<int64_t> permValues)135fa6e4338SSpenser Bauman DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
136d84d418eSTina Jung ShapedType outputType,
137d84d418eSTina Jung llvm::ArrayRef<int64_t> permValues) {
138fa6e4338SSpenser Bauman using ElementType = std::decay_t<decltype(*std::begin(data))>;
139d84d418eSTina Jung
140fa6e4338SSpenser Bauman assert(inputType.getElementType() == outputType.getElementType());
141fa6e4338SSpenser Bauman
142fa6e4338SSpenser Bauman if (inputType.getNumElements() == 0)
143fa6e4338SSpenser Bauman return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});
144fa6e4338SSpenser Bauman
145d84d418eSTina Jung auto inputShape = inputType.getShape();
146d84d418eSTina Jung
147d84d418eSTina Jung // The inverted permutation map and strides of the output are used to compute
148d84d418eSTina Jung // the contribution of a given dimension to the destination linear index in
149d84d418eSTina Jung // an order-independent way.
150d84d418eSTina Jung auto outputStrides = computeStrides(outputType.getShape());
151d84d418eSTina Jung auto invertedPermValues = invertPermutationVector(permValues);
152d84d418eSTina Jung
153fa6e4338SSpenser Bauman auto initialValue = *std::begin(data);
154fa6e4338SSpenser Bauman SmallVector<ElementType> outputValues(inputType.getNumElements(),
155fa6e4338SSpenser Bauman initialValue);
156d84d418eSTina Jung
157fa6e4338SSpenser Bauman for (const auto &it : llvm::enumerate(data)) {
158d84d418eSTina Jung auto srcLinearIndex = it.index();
159d84d418eSTina Jung
160d84d418eSTina Jung uint64_t dstLinearIndex = 0;
161d84d418eSTina Jung for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
162d84d418eSTina Jung // Compute the index into the current dimension of the source vector.
163d84d418eSTina Jung auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
164d84d418eSTina Jung srcLinearIndex /= inputShape[dim];
165d84d418eSTina Jung
166d84d418eSTina Jung // Add the contribution of the current dimension to the output using the
167d84d418eSTina Jung // permutation map.
168d84d418eSTina Jung dstLinearIndex +=
169d84d418eSTina Jung outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
170d84d418eSTina Jung }
171d84d418eSTina Jung
172d84d418eSTina Jung outputValues[dstLinearIndex] = it.value();
173d84d418eSTina Jung }
174d84d418eSTina Jung
175d84d418eSTina Jung return DenseElementsAttr::get(outputType,
176fa6e4338SSpenser Bauman llvm::ArrayRef<ElementType>(outputValues));
177d84d418eSTina Jung }
178d84d418eSTina Jung
179d84d418eSTina Jung // A type specialized transposition of an ElementsAttr.
180d84d418eSTina Jung // This implementation tries to operate on the underlying data in its raw
181d84d418eSTina Jung // representation when possible to avoid allocating a large number of Attribute
182d84d418eSTina Jung // objects.
transpose(ElementsAttr attr,ShapedType inputType,ShapedType outputType,llvm::ArrayRef<int64_t> permValues)183d84d418eSTina Jung DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
184d84d418eSTina Jung ShapedType outputType,
185d84d418eSTina Jung llvm::ArrayRef<int64_t> permValues) {
186fa6e4338SSpenser Bauman if (auto data = attr.tryGetValues<bool>())
187fa6e4338SSpenser Bauman return transposeType(*data, inputType, outputType, permValues);
188d84d418eSTina Jung
189fa6e4338SSpenser Bauman if (auto data = attr.tryGetValues<int8_t>())
190fa6e4338SSpenser Bauman return transposeType(*data, inputType, outputType, permValues);
191d84d418eSTina Jung
192fa6e4338SSpenser Bauman if (auto data = attr.tryGetValues<int16_t>())
193fa6e4338SSpenser Bauman return transposeType(*data, inputType, outputType, permValues);
194d84d418eSTina Jung
195fa6e4338SSpenser Bauman if (auto data = attr.tryGetValues<int32_t>())
196fa6e4338SSpenser Bauman return transposeType(*data, inputType, outputType, permValues);
197fa6e4338SSpenser Bauman
198fa6e4338SSpenser Bauman if (auto data = attr.tryGetValues<int64_t>())
199fa6e4338SSpenser Bauman return transposeType(*data, inputType, outputType, permValues);
200fa6e4338SSpenser Bauman
201fa6e4338SSpenser Bauman if (auto data = attr.tryGetValues<float>())
202fa6e4338SSpenser Bauman return transposeType(*data, inputType, outputType, permValues);
203fa6e4338SSpenser Bauman
204fa6e4338SSpenser Bauman if (auto data = attr.tryGetValues<APFloat>())
205fa6e4338SSpenser Bauman return transposeType(*data, inputType, outputType, permValues);
206fa6e4338SSpenser Bauman
207fa6e4338SSpenser Bauman return nullptr;
208d84d418eSTina Jung }
209d84d418eSTina Jung
210d84d418eSTina Jung struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
211d84d418eSTina Jung using OpRewritePattern::OpRewritePattern;
212d84d418eSTina Jung
matchAndRewrite__anon3882782b0111::TosaFoldConstantTranspose213d84d418eSTina Jung LogicalResult matchAndRewrite(tosa::TransposeOp op,
214d84d418eSTina Jung PatternRewriter &rewriter) const override {
215d84d418eSTina Jung auto outputType = cast<ShapedType>(op.getType());
216d84d418eSTina Jung // TOSA supports quantized types.
217d84d418eSTina Jung if (!outputType.getElementType().isIntOrIndexOrFloat())
218d84d418eSTina Jung return failure();
219d84d418eSTina Jung
220d84d418eSTina Jung ElementsAttr inputValues;
221d84d418eSTina Jung if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
222d84d418eSTina Jung return failure();
223d84d418eSTina Jung // Make sure the input is a constant that has a single user.
224d84d418eSTina Jung if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
225d84d418eSTina Jung return failure();
226d84d418eSTina Jung
227d84d418eSTina Jung DenseIntElementsAttr permAttr;
228d84d418eSTina Jung if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
229d84d418eSTina Jung return failure();
230fa6e4338SSpenser Bauman auto permValues = llvm::map_to_vector(
231d84d418eSTina Jung // TOSA allows both 32- and 64-bit integer tensors here.
232d84d418eSTina Jung permAttr.getValues<APInt>(),
233fa6e4338SSpenser Bauman [](const APInt &val) { return val.getSExtValue(); });
234d84d418eSTina Jung
235d84d418eSTina Jung auto inputType = cast<ShapedType>(op.getInput1().getType());
236d84d418eSTina Jung
237d84d418eSTina Jung auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
238fa6e4338SSpenser Bauman if (!resultAttr) {
239fa6e4338SSpenser Bauman return rewriter.notifyMatchFailure(
240fa6e4338SSpenser Bauman op, "unsupported attribute or element type");
241fa6e4338SSpenser Bauman }
242fa6e4338SSpenser Bauman
243d84d418eSTina Jung rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
244d84d418eSTina Jung return success();
245d84d418eSTina Jung }
246d84d418eSTina Jung };
247d84d418eSTina Jung
248d84d418eSTina Jung struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
249d84d418eSTina Jung
250d84d418eSTina Jung using OpRewritePattern::OpRewritePattern;
251d84d418eSTina Jung
matchAndRewrite__anon3882782b0111::TosaFoldConstantReciprocal252d84d418eSTina Jung LogicalResult matchAndRewrite(ReciprocalOp recip,
253d84d418eSTina Jung PatternRewriter &rewriter) const override {
254d84d418eSTina Jung auto inputTensor = recip.getInput1();
255d84d418eSTina Jung
256d84d418eSTina Jung // Check that we can apply folding
257d84d418eSTina Jung auto preCondCheck =
258d84d418eSTina Jung notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
259d84d418eSTina Jung if (failed(preCondCheck)) {
260d84d418eSTina Jung return preCondCheck;
261d84d418eSTina Jung }
262d84d418eSTina Jung
263d84d418eSTina Jung // Extract the tensor values
264d84d418eSTina Jung DenseElementsAttr inputValues;
265d84d418eSTina Jung matchPattern(inputTensor, m_Constant(&inputValues));
266d84d418eSTina Jung
267d84d418eSTina Jung // Check whether this should be folded.
268d84d418eSTina Jung if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
269d84d418eSTina Jung return rewriter.notifyMatchFailure(
270d84d418eSTina Jung recip, "Currently, reciprocals will only be folded if the input "
271d84d418eSTina Jung "tensor has a single user");
272d84d418eSTina Jung }
273d84d418eSTina Jung
274d84d418eSTina Jung // Create a new tensor with the updated values
275d84d418eSTina Jung auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
276d89a0a65SAviad Cohen inputValues, &ReciprocalOp::calcOneElement,
277d84d418eSTina Jung cast<FloatType>(inputValues.getElementType()));
278d84d418eSTina Jung
279d84d418eSTina Jung // Replace the use of the reciprocal with the transformed tensor
280d84d418eSTina Jung rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
281d84d418eSTina Jung return success();
282d84d418eSTina Jung }
283d84d418eSTina Jung };
284d84d418eSTina Jung
285f5f7e2a3SAmir Bishara /// Getting the axes position of the element which is located
286f5f7e2a3SAmir Bishara /// in the tensor at the counter index
287f5f7e2a3SAmir Bishara
288f5f7e2a3SAmir Bishara llvm::SmallVector<int64_t>
getPositionFromIndex(int64_t index,llvm::ArrayRef<int64_t> tensorShape)289f5f7e2a3SAmir Bishara getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
290f5f7e2a3SAmir Bishara int64_t remaining = index;
291f5f7e2a3SAmir Bishara llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
292f5f7e2a3SAmir Bishara for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
293f5f7e2a3SAmir Bishara position[i] = remaining % tensorShape[i];
294f5f7e2a3SAmir Bishara remaining /= tensorShape[i];
295f5f7e2a3SAmir Bishara }
296f5f7e2a3SAmir Bishara return position;
297f5f7e2a3SAmir Bishara }
298f5f7e2a3SAmir Bishara
299f5f7e2a3SAmir Bishara /// Getting the index of the element which is located at the
300f5f7e2a3SAmir Bishara /// axes position in the tensor
301f5f7e2a3SAmir Bishara
getIndexFromPosition(llvm::ArrayRef<int64_t> position,llvm::ArrayRef<int64_t> tensorShape)302f5f7e2a3SAmir Bishara int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
303f5f7e2a3SAmir Bishara llvm::ArrayRef<int64_t> tensorShape) {
304f5f7e2a3SAmir Bishara int64_t index = 0;
305f5f7e2a3SAmir Bishara int64_t multiplierTmp = 1;
306f5f7e2a3SAmir Bishara for (int64_t i = position.size() - 1; i >= 0; --i) {
307f5f7e2a3SAmir Bishara index += position[i] * multiplierTmp;
308f5f7e2a3SAmir Bishara multiplierTmp *= tensorShape[i];
309f5f7e2a3SAmir Bishara }
310f5f7e2a3SAmir Bishara return index;
311f5f7e2a3SAmir Bishara }
312f5f7e2a3SAmir Bishara
313f5f7e2a3SAmir Bishara template <typename OperationType>
calculateReducedValue(const mlir::ElementsAttr & oldTensorAttr,llvm::ArrayRef<int64_t> oldShape,int64_t reductionAxis,int64_t reductionIndex)314f5f7e2a3SAmir Bishara llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
315f5f7e2a3SAmir Bishara llvm::ArrayRef<int64_t> oldShape,
316f5f7e2a3SAmir Bishara int64_t reductionAxis,
317f5f7e2a3SAmir Bishara int64_t reductionIndex) {
318f5f7e2a3SAmir Bishara
319f5f7e2a3SAmir Bishara llvm::SmallVector<int64_t> newShape(oldShape);
320f5f7e2a3SAmir Bishara newShape[reductionAxis] = 1;
321f5f7e2a3SAmir Bishara /// Let's calculate the position of the index
322f5f7e2a3SAmir Bishara llvm::SmallVector<int64_t> position =
323f5f7e2a3SAmir Bishara getPositionFromIndex(reductionIndex, newShape);
324f5f7e2a3SAmir Bishara auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
325f5f7e2a3SAmir Bishara /// Starting from the first positon along the reduction axis
326f5f7e2a3SAmir Bishara position[reductionAxis] = 0;
327f5f7e2a3SAmir Bishara int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
328f5f7e2a3SAmir Bishara llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
329f5f7e2a3SAmir Bishara
330f5f7e2a3SAmir Bishara for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
331f5f7e2a3SAmir Bishara ++reductionAxisVal) {
332f5f7e2a3SAmir Bishara
333f5f7e2a3SAmir Bishara int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
334f5f7e2a3SAmir Bishara oldShape.end(), 1, std::multiplies<int>());
335f5f7e2a3SAmir Bishara int64_t index = indexAtOldTensor + stride * reductionAxisVal;
336f5f7e2a3SAmir Bishara reducedValue =
337f5f7e2a3SAmir Bishara OperationType::calcOneElement(reducedValue, oldTensor[index]);
338f5f7e2a3SAmir Bishara }
339f5f7e2a3SAmir Bishara return reducedValue;
340f5f7e2a3SAmir Bishara }
341f5f7e2a3SAmir Bishara
342f5f7e2a3SAmir Bishara template <typename OperationType>
343f5f7e2a3SAmir Bishara struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
344f5f7e2a3SAmir Bishara
ReduceConstantOptimization__anon3882782b0111::ReduceConstantOptimization3459dd15f74SAmir Bishara ReduceConstantOptimization(MLIRContext *context,
3469dd15f74SAmir Bishara bool aggressiveReduceConstant)
3479dd15f74SAmir Bishara : OpRewritePattern<OperationType>(context),
3489dd15f74SAmir Bishara aggressiveReduceConstant(aggressiveReduceConstant) {}
3499dd15f74SAmir Bishara
350f5f7e2a3SAmir Bishara using OpRewritePattern<OperationType>::OpRewritePattern;
351f5f7e2a3SAmir Bishara
matchAndRewrite__anon3882782b0111::ReduceConstantOptimization352f5f7e2a3SAmir Bishara LogicalResult matchAndRewrite(OperationType op,
353f5f7e2a3SAmir Bishara PatternRewriter &rewriter) const override {
354f5f7e2a3SAmir Bishara Value inputOp = op.getInput();
355f5f7e2a3SAmir Bishara auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
356f5f7e2a3SAmir Bishara
357f5f7e2a3SAmir Bishara if (!constOp)
358f5f7e2a3SAmir Bishara return rewriter.notifyMatchFailure(
359f5f7e2a3SAmir Bishara op, "reduce input must be const operation");
360f5f7e2a3SAmir Bishara
3619dd15f74SAmir Bishara if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
362f5f7e2a3SAmir Bishara return rewriter.notifyMatchFailure(
363f5f7e2a3SAmir Bishara op, "input operation has more than one user");
364f5f7e2a3SAmir Bishara
365f5f7e2a3SAmir Bishara auto resultType = cast<ShapedType>(op.getOutput().getType());
366f5f7e2a3SAmir Bishara
367f5f7e2a3SAmir Bishara if (!resultType.hasStaticShape())
368f5f7e2a3SAmir Bishara return rewriter.notifyMatchFailure(op, "result type shape is not static");
369f5f7e2a3SAmir Bishara
370f5f7e2a3SAmir Bishara auto reductionAxis = op.getAxis();
371f5f7e2a3SAmir Bishara const auto denseElementsAttr = constOp.getValue();
372f5f7e2a3SAmir Bishara const auto shapedOldElementsValues =
373*a5757c5bSChristian Sigg cast<ShapedType>(denseElementsAttr.getType());
374f5f7e2a3SAmir Bishara
375f5f7e2a3SAmir Bishara if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
376f5f7e2a3SAmir Bishara return rewriter.notifyMatchFailure(
377f5f7e2a3SAmir Bishara op, "reduce input currently supported with integer type");
378f5f7e2a3SAmir Bishara
379f5f7e2a3SAmir Bishara auto oldShape = shapedOldElementsValues.getShape();
380f5f7e2a3SAmir Bishara auto newShape = resultType.getShape();
381f5f7e2a3SAmir Bishara
382f5f7e2a3SAmir Bishara auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
383f5f7e2a3SAmir Bishara std::multiplies<int>());
384f5f7e2a3SAmir Bishara llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
385f5f7e2a3SAmir Bishara
386f5f7e2a3SAmir Bishara for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
387f5f7e2a3SAmir Bishara ++reductionIndex) {
388f5f7e2a3SAmir Bishara
389f5f7e2a3SAmir Bishara /// Let's reduce all the elements along this reduction axis
390f5f7e2a3SAmir Bishara newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
391f5f7e2a3SAmir Bishara denseElementsAttr, oldShape, reductionAxis, reductionIndex);
392f5f7e2a3SAmir Bishara }
393f5f7e2a3SAmir Bishara
394f5f7e2a3SAmir Bishara auto rankedTensorType = cast<RankedTensorType>(resultType);
395f5f7e2a3SAmir Bishara auto denseAttr =
396f5f7e2a3SAmir Bishara mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
397f5f7e2a3SAmir Bishara rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
398f5f7e2a3SAmir Bishara return success();
399f5f7e2a3SAmir Bishara }
4009dd15f74SAmir Bishara const bool aggressiveReduceConstant;
401f5f7e2a3SAmir Bishara };
402f5f7e2a3SAmir Bishara
403d84d418eSTina Jung } // namespace
404d84d418eSTina Jung
populateTosaConstantReduction(MLIRContext * ctx,RewritePatternSet & patterns,bool aggressiveReduceConstant)405f5f7e2a3SAmir Bishara void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
4069dd15f74SAmir Bishara RewritePatternSet &patterns,
4079dd15f74SAmir Bishara bool aggressiveReduceConstant) {
4089dd15f74SAmir Bishara patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
4099dd15f74SAmir Bishara ctx, aggressiveReduceConstant);
4109dd15f74SAmir Bishara patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
4119dd15f74SAmir Bishara ctx, aggressiveReduceConstant);
4129dd15f74SAmir Bishara patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
4139dd15f74SAmir Bishara ctx, aggressiveReduceConstant);
4149dd15f74SAmir Bishara patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
4159dd15f74SAmir Bishara ctx, aggressiveReduceConstant);
4169dd15f74SAmir Bishara patterns.add<ReduceConstantOptimization<ReduceProdOp>>(
4179dd15f74SAmir Bishara ctx, aggressiveReduceConstant);
4189dd15f74SAmir Bishara patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
4199dd15f74SAmir Bishara ctx, aggressiveReduceConstant);
420f5f7e2a3SAmir Bishara }
421f5f7e2a3SAmir Bishara
populateTosaFoldConstantTransposePatterns(MLIRContext * ctx,RewritePatternSet & patterns)422d84d418eSTina Jung void mlir::tosa::populateTosaFoldConstantTransposePatterns(
423d84d418eSTina Jung MLIRContext *ctx, RewritePatternSet &patterns) {
424d84d418eSTina Jung patterns.add<TosaFoldConstantTranspose>(ctx);
425d84d418eSTina Jung }
426d84d418eSTina Jung
populateTosaFoldConstantReciprocalPatterns(MLIRContext * ctx,RewritePatternSet & patterns)427d84d418eSTina Jung void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
428d84d418eSTina Jung MLIRContext *ctx, RewritePatternSet &patterns) {
429d84d418eSTina Jung patterns.add<TosaFoldConstantReciprocal>(ctx);
430d84d418eSTina Jung }
431