xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- TosaFolders.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold TOSA operations
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <functional>
14 #include <numeric>
15 
16 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
17 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/FloatingPointMode.h"
25 #include "llvm/ADT/SmallVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::tosa;
29 
30 namespace {
31 
32 /// Apply the given transformation \p toApply to every element of the tensor to
33 /// be transformed \p toTransform.
34 ///
35 /// Elements of \p toTransform are extracted as \p SrcValueType.
36 ///
37 /// \returns A tensor with the same size as \p toTransform, containing
38 /// \p TargetValueType values of type \p TargetType.
39 template <class SrcValType, class TargetValType, class TargetType>
applyElementWise(const DenseElementsAttr & toTransform,const std::function<TargetValType (const SrcValType &)> & toApply,TargetType targetType)40 DenseElementsAttr applyElementWise(
41     const DenseElementsAttr &toTransform,
42     const std::function<TargetValType(const SrcValType &)> &toApply,
43     TargetType targetType) {
44   SmallVector<TargetValType> transformedValues;
45   // We already know the amount of values we will insert, reserve space for
46   // all of them to avoid dynamic resizing
47   transformedValues.reserve(toTransform.getNumElements());
48   for (auto val : toTransform.getValues<SrcValType>()) {
49     auto transformedVal = toApply(val);
50     transformedValues.push_back(transformedVal);
51   }
52 
53   // Make sure that the output tensor has the expected output type
54   auto inShape = toTransform.getType();
55   auto outTy = inShape.cloneWith({}, targetType);
56 
57   return DenseElementsAttr::get(outTy, transformedValues);
58 }
59 
60 template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
61     const DenseElementsAttr &toTransform,
62     const std::function<APFloat(const APFloat &)> &toApply,
63     FloatType targetType);
64 
65 /// Function that checks if the type contained in \p toCheck is float.
notifyIfNotFloat(TypedValue<TensorType> toCheck,TosaOp location,PatternRewriter & rewriter)66 LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
67                                PatternRewriter &rewriter) {
68   if (isa<FloatType>(toCheck.getType().getElementType())) {
69     return success();
70   }
71   return rewriter.notifyMatchFailure(location,
72                                      "Unexpected input tensor type: the "
73                                      "TOSA spec only allows floats");
74 }
75 
76 /// Function that checks if \p toCheck is a dense TOSA constant tensor.
notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,TosaOp location,PatternRewriter & rewriter)77 LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
78                                                 TosaOp location,
79                                                 PatternRewriter &rewriter) {
80   // Check whether the tensor is constant and dense
81   // TODO We currently ensure the tensor is dense by using the correct type for
82   // the bind_value, however we do not actually need this value. It would be
83   // nicer to only have a check here.
84   DenseElementsAttr tmp;
85   if (!matchPattern(toCheck, m_Constant(&tmp))) {
86     return rewriter.notifyMatchFailure(location,
87                                        "Non-const or non-dense input tensor");
88   }
89 
90   // Make sure it actually is a TOSA constant (the match allows for other
91   // constants as well)
92   if (isa<ConstOp>(toCheck.getDefiningOp())) {
93     return success();
94   }
95 
96   return rewriter.notifyMatchFailure(location,
97                                      "The reciprocal can only be folded if "
98                                      "it operates on a TOSA constant");
99 }
100 
101 /// Function that checks if \p toCheck is a dense TOSA constant float tensor.
notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,TosaOp location,PatternRewriter & rewriter)102 LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
103                                                  TosaOp location,
104                                                  PatternRewriter &rewriter) {
105   auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
106   if (failed(floatCheck)) {
107     return floatCheck;
108   }
109   return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
110 }
111 
112 /// Heuristic to decide when to replace a unary operation on a constant with the
113 /// folded value.
114 /// Folding operations on constants can lead to an increased memory usage
115 /// whenever the input cannot be replaced but a new constant is inserted. Hence,
116 /// this will currently only suggest folding when the memory impact is
117 /// negligible.
118 /// Takes the \p unaryOp and the constant input \p values.
119 /// \returns Whether folding should be applied.
constantUnaryOpShouldBeFolded(TosaOp unaryOp,DenseElementsAttr values)120 bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
121   assert(unaryOp->getNumOperands() == 1);
122   auto inputOp = unaryOp->getOperand(0);
123 
124   // If the input is a splat, we don't care for the number of users
125   if (isa<SplatElementsAttr>(values)) {
126     return true;
127   }
128 
129   // If this is the only use of the tensor it should be replaced as no
130   // additional memory is required
131   return inputOp.hasOneUse();
132 }
133 
134 template <typename RangeType>
transposeType(const RangeType & data,ShapedType inputType,ShapedType outputType,llvm::ArrayRef<int64_t> permValues)135 DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
136                                 ShapedType outputType,
137                                 llvm::ArrayRef<int64_t> permValues) {
138   using ElementType = std::decay_t<decltype(*std::begin(data))>;
139 
140   assert(inputType.getElementType() == outputType.getElementType());
141 
142   if (inputType.getNumElements() == 0)
143     return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});
144 
145   auto inputShape = inputType.getShape();
146 
147   // The inverted permutation map and strides of the output are used to compute
148   // the contribution of a given dimension to the destination linear index in
149   // an order-independent way.
150   auto outputStrides = computeStrides(outputType.getShape());
151   auto invertedPermValues = invertPermutationVector(permValues);
152 
153   auto initialValue = *std::begin(data);
154   SmallVector<ElementType> outputValues(inputType.getNumElements(),
155                                         initialValue);
156 
157   for (const auto &it : llvm::enumerate(data)) {
158     auto srcLinearIndex = it.index();
159 
160     uint64_t dstLinearIndex = 0;
161     for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
162       // Compute the index into the current dimension of the source vector.
163       auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
164       srcLinearIndex /= inputShape[dim];
165 
166       // Add the contribution of the current dimension to the output using the
167       // permutation map.
168       dstLinearIndex +=
169           outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
170     }
171 
172     outputValues[dstLinearIndex] = it.value();
173   }
174 
175   return DenseElementsAttr::get(outputType,
176                                 llvm::ArrayRef<ElementType>(outputValues));
177 }
178 
179 // A type specialized transposition of an ElementsAttr.
180 // This implementation tries to operate on the underlying data in its raw
181 // representation when possible to avoid allocating a large number of Attribute
182 // objects.
transpose(ElementsAttr attr,ShapedType inputType,ShapedType outputType,llvm::ArrayRef<int64_t> permValues)183 DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
184                             ShapedType outputType,
185                             llvm::ArrayRef<int64_t> permValues) {
186   if (auto data = attr.tryGetValues<bool>())
187     return transposeType(*data, inputType, outputType, permValues);
188 
189   if (auto data = attr.tryGetValues<int8_t>())
190     return transposeType(*data, inputType, outputType, permValues);
191 
192   if (auto data = attr.tryGetValues<int16_t>())
193     return transposeType(*data, inputType, outputType, permValues);
194 
195   if (auto data = attr.tryGetValues<int32_t>())
196     return transposeType(*data, inputType, outputType, permValues);
197 
198   if (auto data = attr.tryGetValues<int64_t>())
199     return transposeType(*data, inputType, outputType, permValues);
200 
201   if (auto data = attr.tryGetValues<float>())
202     return transposeType(*data, inputType, outputType, permValues);
203 
204   if (auto data = attr.tryGetValues<APFloat>())
205     return transposeType(*data, inputType, outputType, permValues);
206 
207   return nullptr;
208 }
209 
210 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
211   using OpRewritePattern::OpRewritePattern;
212 
matchAndRewrite__anon3882782b0111::TosaFoldConstantTranspose213   LogicalResult matchAndRewrite(tosa::TransposeOp op,
214                                 PatternRewriter &rewriter) const override {
215     auto outputType = cast<ShapedType>(op.getType());
216     // TOSA supports quantized types.
217     if (!outputType.getElementType().isIntOrIndexOrFloat())
218       return failure();
219 
220     ElementsAttr inputValues;
221     if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
222       return failure();
223     // Make sure the input is a constant that has a single user.
224     if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
225       return failure();
226 
227     DenseIntElementsAttr permAttr;
228     if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
229       return failure();
230     auto permValues = llvm::map_to_vector(
231         // TOSA allows both 32- and 64-bit integer tensors here.
232         permAttr.getValues<APInt>(),
233         [](const APInt &val) { return val.getSExtValue(); });
234 
235     auto inputType = cast<ShapedType>(op.getInput1().getType());
236 
237     auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
238     if (!resultAttr) {
239       return rewriter.notifyMatchFailure(
240           op, "unsupported attribute or element type");
241     }
242 
243     rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
244     return success();
245   }
246 };
247 
248 struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
249 
250   using OpRewritePattern::OpRewritePattern;
251 
matchAndRewrite__anon3882782b0111::TosaFoldConstantReciprocal252   LogicalResult matchAndRewrite(ReciprocalOp recip,
253                                 PatternRewriter &rewriter) const override {
254     auto inputTensor = recip.getInput1();
255 
256     // Check that we can apply folding
257     auto preCondCheck =
258         notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
259     if (failed(preCondCheck)) {
260       return preCondCheck;
261     }
262 
263     // Extract the tensor values
264     DenseElementsAttr inputValues;
265     matchPattern(inputTensor, m_Constant(&inputValues));
266 
267     // Check whether this should be folded.
268     if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
269       return rewriter.notifyMatchFailure(
270           recip, "Currently, reciprocals will only be folded if the input "
271                  "tensor has a single user");
272     }
273 
274     // Create a new tensor with the updated values
275     auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
276         inputValues, &ReciprocalOp::calcOneElement,
277         cast<FloatType>(inputValues.getElementType()));
278 
279     // Replace the use of the reciprocal with the transformed tensor
280     rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
281     return success();
282   }
283 };
284 
285 /// Getting the axes position of the element which is located
286 /// in the tensor at the counter index
287 
288 llvm::SmallVector<int64_t>
getPositionFromIndex(int64_t index,llvm::ArrayRef<int64_t> tensorShape)289 getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
290   int64_t remaining = index;
291   llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
292   for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
293     position[i] = remaining % tensorShape[i];
294     remaining /= tensorShape[i];
295   }
296   return position;
297 }
298 
299 /// Getting the index of the element which is located at the
300 /// axes position in the tensor
301 
getIndexFromPosition(llvm::ArrayRef<int64_t> position,llvm::ArrayRef<int64_t> tensorShape)302 int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
303                              llvm::ArrayRef<int64_t> tensorShape) {
304   int64_t index = 0;
305   int64_t multiplierTmp = 1;
306   for (int64_t i = position.size() - 1; i >= 0; --i) {
307     index += position[i] * multiplierTmp;
308     multiplierTmp *= tensorShape[i];
309   }
310   return index;
311 }
312 
313 template <typename OperationType>
calculateReducedValue(const mlir::ElementsAttr & oldTensorAttr,llvm::ArrayRef<int64_t> oldShape,int64_t reductionAxis,int64_t reductionIndex)314 llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
315                                   llvm::ArrayRef<int64_t> oldShape,
316                                   int64_t reductionAxis,
317                                   int64_t reductionIndex) {
318 
319   llvm::SmallVector<int64_t> newShape(oldShape);
320   newShape[reductionAxis] = 1;
321   /// Let's calculate the position of the index
322   llvm::SmallVector<int64_t> position =
323       getPositionFromIndex(reductionIndex, newShape);
324   auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
325   /// Starting from the first positon along the reduction axis
326   position[reductionAxis] = 0;
327   int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
328   llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
329 
330   for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
331        ++reductionAxisVal) {
332 
333     int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
334                                      oldShape.end(), 1, std::multiplies<int>());
335     int64_t index = indexAtOldTensor + stride * reductionAxisVal;
336     reducedValue =
337         OperationType::calcOneElement(reducedValue, oldTensor[index]);
338   }
339   return reducedValue;
340 }
341 
342 template <typename OperationType>
343 struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
344 
ReduceConstantOptimization__anon3882782b0111::ReduceConstantOptimization345   ReduceConstantOptimization(MLIRContext *context,
346                              bool aggressiveReduceConstant)
347       : OpRewritePattern<OperationType>(context),
348         aggressiveReduceConstant(aggressiveReduceConstant) {}
349 
350   using OpRewritePattern<OperationType>::OpRewritePattern;
351 
matchAndRewrite__anon3882782b0111::ReduceConstantOptimization352   LogicalResult matchAndRewrite(OperationType op,
353                                 PatternRewriter &rewriter) const override {
354     Value inputOp = op.getInput();
355     auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
356 
357     if (!constOp)
358       return rewriter.notifyMatchFailure(
359           op, "reduce input must be const operation");
360 
361     if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
362       return rewriter.notifyMatchFailure(
363           op, "input operation has more than one user");
364 
365     auto resultType = cast<ShapedType>(op.getOutput().getType());
366 
367     if (!resultType.hasStaticShape())
368       return rewriter.notifyMatchFailure(op, "result type shape is not static");
369 
370     auto reductionAxis = op.getAxis();
371     const auto denseElementsAttr = constOp.getValue();
372     const auto shapedOldElementsValues =
373         cast<ShapedType>(denseElementsAttr.getType());
374 
375     if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
376       return rewriter.notifyMatchFailure(
377           op, "reduce input currently supported with integer type");
378 
379     auto oldShape = shapedOldElementsValues.getShape();
380     auto newShape = resultType.getShape();
381 
382     auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
383                                             std::multiplies<int>());
384     llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
385 
386     for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
387          ++reductionIndex) {
388 
389       /// Let's reduce all the elements along this reduction axis
390       newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
391           denseElementsAttr, oldShape, reductionAxis, reductionIndex);
392     }
393 
394     auto rankedTensorType = cast<RankedTensorType>(resultType);
395     auto denseAttr =
396         mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
397     rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
398     return success();
399   }
400   const bool aggressiveReduceConstant;
401 };
402 
403 } // namespace
404 
populateTosaConstantReduction(MLIRContext * ctx,RewritePatternSet & patterns,bool aggressiveReduceConstant)405 void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
406                                                RewritePatternSet &patterns,
407                                                bool aggressiveReduceConstant) {
408   patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
409       ctx, aggressiveReduceConstant);
410   patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
411       ctx, aggressiveReduceConstant);
412   patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
413       ctx, aggressiveReduceConstant);
414   patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
415       ctx, aggressiveReduceConstant);
416   patterns.add<ReduceConstantOptimization<ReduceProdOp>>(
417       ctx, aggressiveReduceConstant);
418   patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
419       ctx, aggressiveReduceConstant);
420 }
421 
populateTosaFoldConstantTransposePatterns(MLIRContext * ctx,RewritePatternSet & patterns)422 void mlir::tosa::populateTosaFoldConstantTransposePatterns(
423     MLIRContext *ctx, RewritePatternSet &patterns) {
424   patterns.add<TosaFoldConstantTranspose>(ctx);
425 }
426 
populateTosaFoldConstantReciprocalPatterns(MLIRContext * ctx,RewritePatternSet & patterns)427 void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
428     MLIRContext *ctx, RewritePatternSet &patterns) {
429   patterns.add<TosaFoldConstantReciprocal>(ctx);
430 }
431