1 //===- TosaFolders.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold TOSA operations
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include <functional>
14 #include <numeric>
15
16 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
17 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/FloatingPointMode.h"
25 #include "llvm/ADT/SmallVector.h"
26
27 using namespace mlir;
28 using namespace mlir::tosa;
29
30 namespace {
31
32 /// Apply the given transformation \p toApply to every element of the tensor to
33 /// be transformed \p toTransform.
34 ///
35 /// Elements of \p toTransform are extracted as \p SrcValueType.
36 ///
37 /// \returns A tensor with the same size as \p toTransform, containing
38 /// \p TargetValueType values of type \p TargetType.
39 template <class SrcValType, class TargetValType, class TargetType>
applyElementWise(const DenseElementsAttr & toTransform,const std::function<TargetValType (const SrcValType &)> & toApply,TargetType targetType)40 DenseElementsAttr applyElementWise(
41 const DenseElementsAttr &toTransform,
42 const std::function<TargetValType(const SrcValType &)> &toApply,
43 TargetType targetType) {
44 SmallVector<TargetValType> transformedValues;
45 // We already know the amount of values we will insert, reserve space for
46 // all of them to avoid dynamic resizing
47 transformedValues.reserve(toTransform.getNumElements());
48 for (auto val : toTransform.getValues<SrcValType>()) {
49 auto transformedVal = toApply(val);
50 transformedValues.push_back(transformedVal);
51 }
52
53 // Make sure that the output tensor has the expected output type
54 auto inShape = toTransform.getType();
55 auto outTy = inShape.cloneWith({}, targetType);
56
57 return DenseElementsAttr::get(outTy, transformedValues);
58 }
59
60 template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
61 const DenseElementsAttr &toTransform,
62 const std::function<APFloat(const APFloat &)> &toApply,
63 FloatType targetType);
64
65 /// Function that checks if the type contained in \p toCheck is float.
notifyIfNotFloat(TypedValue<TensorType> toCheck,TosaOp location,PatternRewriter & rewriter)66 LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
67 PatternRewriter &rewriter) {
68 if (isa<FloatType>(toCheck.getType().getElementType())) {
69 return success();
70 }
71 return rewriter.notifyMatchFailure(location,
72 "Unexpected input tensor type: the "
73 "TOSA spec only allows floats");
74 }
75
76 /// Function that checks if \p toCheck is a dense TOSA constant tensor.
notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,TosaOp location,PatternRewriter & rewriter)77 LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
78 TosaOp location,
79 PatternRewriter &rewriter) {
80 // Check whether the tensor is constant and dense
81 // TODO We currently ensure the tensor is dense by using the correct type for
82 // the bind_value, however we do not actually need this value. It would be
83 // nicer to only have a check here.
84 DenseElementsAttr tmp;
85 if (!matchPattern(toCheck, m_Constant(&tmp))) {
86 return rewriter.notifyMatchFailure(location,
87 "Non-const or non-dense input tensor");
88 }
89
90 // Make sure it actually is a TOSA constant (the match allows for other
91 // constants as well)
92 if (isa<ConstOp>(toCheck.getDefiningOp())) {
93 return success();
94 }
95
96 return rewriter.notifyMatchFailure(location,
97 "The reciprocal can only be folded if "
98 "it operates on a TOSA constant");
99 }
100
101 /// Function that checks if \p toCheck is a dense TOSA constant float tensor.
notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,TosaOp location,PatternRewriter & rewriter)102 LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
103 TosaOp location,
104 PatternRewriter &rewriter) {
105 auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
106 if (failed(floatCheck)) {
107 return floatCheck;
108 }
109 return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
110 }
111
112 /// Heuristic to decide when to replace a unary operation on a constant with the
113 /// folded value.
114 /// Folding operations on constants can lead to an increased memory usage
115 /// whenever the input cannot be replaced but a new constant is inserted. Hence,
116 /// this will currently only suggest folding when the memory impact is
117 /// negligible.
118 /// Takes the \p unaryOp and the constant input \p values.
119 /// \returns Whether folding should be applied.
constantUnaryOpShouldBeFolded(TosaOp unaryOp,DenseElementsAttr values)120 bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
121 assert(unaryOp->getNumOperands() == 1);
122 auto inputOp = unaryOp->getOperand(0);
123
124 // If the input is a splat, we don't care for the number of users
125 if (isa<SplatElementsAttr>(values)) {
126 return true;
127 }
128
129 // If this is the only use of the tensor it should be replaced as no
130 // additional memory is required
131 return inputOp.hasOneUse();
132 }
133
134 template <typename RangeType>
transposeType(const RangeType & data,ShapedType inputType,ShapedType outputType,llvm::ArrayRef<int64_t> permValues)135 DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
136 ShapedType outputType,
137 llvm::ArrayRef<int64_t> permValues) {
138 using ElementType = std::decay_t<decltype(*std::begin(data))>;
139
140 assert(inputType.getElementType() == outputType.getElementType());
141
142 if (inputType.getNumElements() == 0)
143 return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});
144
145 auto inputShape = inputType.getShape();
146
147 // The inverted permutation map and strides of the output are used to compute
148 // the contribution of a given dimension to the destination linear index in
149 // an order-independent way.
150 auto outputStrides = computeStrides(outputType.getShape());
151 auto invertedPermValues = invertPermutationVector(permValues);
152
153 auto initialValue = *std::begin(data);
154 SmallVector<ElementType> outputValues(inputType.getNumElements(),
155 initialValue);
156
157 for (const auto &it : llvm::enumerate(data)) {
158 auto srcLinearIndex = it.index();
159
160 uint64_t dstLinearIndex = 0;
161 for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
162 // Compute the index into the current dimension of the source vector.
163 auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
164 srcLinearIndex /= inputShape[dim];
165
166 // Add the contribution of the current dimension to the output using the
167 // permutation map.
168 dstLinearIndex +=
169 outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
170 }
171
172 outputValues[dstLinearIndex] = it.value();
173 }
174
175 return DenseElementsAttr::get(outputType,
176 llvm::ArrayRef<ElementType>(outputValues));
177 }
178
179 // A type specialized transposition of an ElementsAttr.
180 // This implementation tries to operate on the underlying data in its raw
181 // representation when possible to avoid allocating a large number of Attribute
182 // objects.
transpose(ElementsAttr attr,ShapedType inputType,ShapedType outputType,llvm::ArrayRef<int64_t> permValues)183 DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
184 ShapedType outputType,
185 llvm::ArrayRef<int64_t> permValues) {
186 if (auto data = attr.tryGetValues<bool>())
187 return transposeType(*data, inputType, outputType, permValues);
188
189 if (auto data = attr.tryGetValues<int8_t>())
190 return transposeType(*data, inputType, outputType, permValues);
191
192 if (auto data = attr.tryGetValues<int16_t>())
193 return transposeType(*data, inputType, outputType, permValues);
194
195 if (auto data = attr.tryGetValues<int32_t>())
196 return transposeType(*data, inputType, outputType, permValues);
197
198 if (auto data = attr.tryGetValues<int64_t>())
199 return transposeType(*data, inputType, outputType, permValues);
200
201 if (auto data = attr.tryGetValues<float>())
202 return transposeType(*data, inputType, outputType, permValues);
203
204 if (auto data = attr.tryGetValues<APFloat>())
205 return transposeType(*data, inputType, outputType, permValues);
206
207 return nullptr;
208 }
209
210 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
211 using OpRewritePattern::OpRewritePattern;
212
matchAndRewrite__anon3882782b0111::TosaFoldConstantTranspose213 LogicalResult matchAndRewrite(tosa::TransposeOp op,
214 PatternRewriter &rewriter) const override {
215 auto outputType = cast<ShapedType>(op.getType());
216 // TOSA supports quantized types.
217 if (!outputType.getElementType().isIntOrIndexOrFloat())
218 return failure();
219
220 ElementsAttr inputValues;
221 if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
222 return failure();
223 // Make sure the input is a constant that has a single user.
224 if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
225 return failure();
226
227 DenseIntElementsAttr permAttr;
228 if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
229 return failure();
230 auto permValues = llvm::map_to_vector(
231 // TOSA allows both 32- and 64-bit integer tensors here.
232 permAttr.getValues<APInt>(),
233 [](const APInt &val) { return val.getSExtValue(); });
234
235 auto inputType = cast<ShapedType>(op.getInput1().getType());
236
237 auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
238 if (!resultAttr) {
239 return rewriter.notifyMatchFailure(
240 op, "unsupported attribute or element type");
241 }
242
243 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
244 return success();
245 }
246 };
247
248 struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
249
250 using OpRewritePattern::OpRewritePattern;
251
matchAndRewrite__anon3882782b0111::TosaFoldConstantReciprocal252 LogicalResult matchAndRewrite(ReciprocalOp recip,
253 PatternRewriter &rewriter) const override {
254 auto inputTensor = recip.getInput1();
255
256 // Check that we can apply folding
257 auto preCondCheck =
258 notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
259 if (failed(preCondCheck)) {
260 return preCondCheck;
261 }
262
263 // Extract the tensor values
264 DenseElementsAttr inputValues;
265 matchPattern(inputTensor, m_Constant(&inputValues));
266
267 // Check whether this should be folded.
268 if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
269 return rewriter.notifyMatchFailure(
270 recip, "Currently, reciprocals will only be folded if the input "
271 "tensor has a single user");
272 }
273
274 // Create a new tensor with the updated values
275 auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
276 inputValues, &ReciprocalOp::calcOneElement,
277 cast<FloatType>(inputValues.getElementType()));
278
279 // Replace the use of the reciprocal with the transformed tensor
280 rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
281 return success();
282 }
283 };
284
285 /// Getting the axes position of the element which is located
286 /// in the tensor at the counter index
287
288 llvm::SmallVector<int64_t>
getPositionFromIndex(int64_t index,llvm::ArrayRef<int64_t> tensorShape)289 getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
290 int64_t remaining = index;
291 llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
292 for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
293 position[i] = remaining % tensorShape[i];
294 remaining /= tensorShape[i];
295 }
296 return position;
297 }
298
299 /// Getting the index of the element which is located at the
300 /// axes position in the tensor
301
getIndexFromPosition(llvm::ArrayRef<int64_t> position,llvm::ArrayRef<int64_t> tensorShape)302 int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
303 llvm::ArrayRef<int64_t> tensorShape) {
304 int64_t index = 0;
305 int64_t multiplierTmp = 1;
306 for (int64_t i = position.size() - 1; i >= 0; --i) {
307 index += position[i] * multiplierTmp;
308 multiplierTmp *= tensorShape[i];
309 }
310 return index;
311 }
312
313 template <typename OperationType>
calculateReducedValue(const mlir::ElementsAttr & oldTensorAttr,llvm::ArrayRef<int64_t> oldShape,int64_t reductionAxis,int64_t reductionIndex)314 llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
315 llvm::ArrayRef<int64_t> oldShape,
316 int64_t reductionAxis,
317 int64_t reductionIndex) {
318
319 llvm::SmallVector<int64_t> newShape(oldShape);
320 newShape[reductionAxis] = 1;
321 /// Let's calculate the position of the index
322 llvm::SmallVector<int64_t> position =
323 getPositionFromIndex(reductionIndex, newShape);
324 auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
325 /// Starting from the first positon along the reduction axis
326 position[reductionAxis] = 0;
327 int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
328 llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
329
330 for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
331 ++reductionAxisVal) {
332
333 int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
334 oldShape.end(), 1, std::multiplies<int>());
335 int64_t index = indexAtOldTensor + stride * reductionAxisVal;
336 reducedValue =
337 OperationType::calcOneElement(reducedValue, oldTensor[index]);
338 }
339 return reducedValue;
340 }
341
342 template <typename OperationType>
343 struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
344
ReduceConstantOptimization__anon3882782b0111::ReduceConstantOptimization345 ReduceConstantOptimization(MLIRContext *context,
346 bool aggressiveReduceConstant)
347 : OpRewritePattern<OperationType>(context),
348 aggressiveReduceConstant(aggressiveReduceConstant) {}
349
350 using OpRewritePattern<OperationType>::OpRewritePattern;
351
matchAndRewrite__anon3882782b0111::ReduceConstantOptimization352 LogicalResult matchAndRewrite(OperationType op,
353 PatternRewriter &rewriter) const override {
354 Value inputOp = op.getInput();
355 auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
356
357 if (!constOp)
358 return rewriter.notifyMatchFailure(
359 op, "reduce input must be const operation");
360
361 if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
362 return rewriter.notifyMatchFailure(
363 op, "input operation has more than one user");
364
365 auto resultType = cast<ShapedType>(op.getOutput().getType());
366
367 if (!resultType.hasStaticShape())
368 return rewriter.notifyMatchFailure(op, "result type shape is not static");
369
370 auto reductionAxis = op.getAxis();
371 const auto denseElementsAttr = constOp.getValue();
372 const auto shapedOldElementsValues =
373 cast<ShapedType>(denseElementsAttr.getType());
374
375 if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
376 return rewriter.notifyMatchFailure(
377 op, "reduce input currently supported with integer type");
378
379 auto oldShape = shapedOldElementsValues.getShape();
380 auto newShape = resultType.getShape();
381
382 auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
383 std::multiplies<int>());
384 llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
385
386 for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
387 ++reductionIndex) {
388
389 /// Let's reduce all the elements along this reduction axis
390 newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
391 denseElementsAttr, oldShape, reductionAxis, reductionIndex);
392 }
393
394 auto rankedTensorType = cast<RankedTensorType>(resultType);
395 auto denseAttr =
396 mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
397 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
398 return success();
399 }
400 const bool aggressiveReduceConstant;
401 };
402
403 } // namespace
404
populateTosaConstantReduction(MLIRContext * ctx,RewritePatternSet & patterns,bool aggressiveReduceConstant)405 void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
406 RewritePatternSet &patterns,
407 bool aggressiveReduceConstant) {
408 patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
409 ctx, aggressiveReduceConstant);
410 patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
411 ctx, aggressiveReduceConstant);
412 patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
413 ctx, aggressiveReduceConstant);
414 patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
415 ctx, aggressiveReduceConstant);
416 patterns.add<ReduceConstantOptimization<ReduceProdOp>>(
417 ctx, aggressiveReduceConstant);
418 patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
419 ctx, aggressiveReduceConstant);
420 }
421
populateTosaFoldConstantTransposePatterns(MLIRContext * ctx,RewritePatternSet & patterns)422 void mlir::tosa::populateTosaFoldConstantTransposePatterns(
423 MLIRContext *ctx, RewritePatternSet &patterns) {
424 patterns.add<TosaFoldConstantTranspose>(ctx);
425 }
426
populateTosaFoldConstantReciprocalPatterns(MLIRContext * ctx,RewritePatternSet & patterns)427 void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
428 MLIRContext *ctx, RewritePatternSet &patterns) {
429 patterns.add<TosaFoldConstantReciprocal>(ctx);
430 }
431