xref: /llvm-project/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (revision d4fd20258f63d30be638b04f10eaa469707759f0)
1 //===- ConversionUtils.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility functions for TOSA lowering
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
19 SmallVector<utils::IteratorType>
20 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
21   return SmallVector<utils::IteratorType>(nParallelLoops,
22                                           utils::IteratorType::parallel);
23 }
24 
25 SmallVector<Value>
26 mlir::tosa::condenseValues(const SmallVector<Value> &values) {
27   SmallVector<Value> condensedValues;
28   for (auto value : values)
29     if (value)
30       condensedValues.push_back(value);
31   return condensedValues;
32 }
33 
34 Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
35                                    Value max, OpBuilder &rewriter) {
36   Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max);
37   return rewriter.create<arith::MaximumFOp>(loc, minValue, min);
38 }
39 
40 Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
41                                  OpBuilder &rewriter) {
42   auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
43   return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
44 }
45 
46 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
47   uint64_t bitwidth = ty.getIntOrFloatBitWidth();
48   if (ty.getSignedness() == IntegerType::Unsigned) {
49     uint64_t uvalue = value;
50     APInt intMin = APInt::getMinValue(bitwidth);
51     APInt intMax = APInt::getMaxValue(bitwidth);
52     return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue();
53   }
54 
55   APInt intMin = APInt::getSignedMinValue(bitwidth);
56   APInt intMax = APInt::getSignedMaxValue(bitwidth);
57   return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
58 }
59 
60 namespace {
61 // Given two tensors of high and low ranks, derive the output shape
62 // to reshape the lower rank to.
63 // Examples:
64 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
65 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
66 // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
67 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
68 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
69 LogicalResult
70 computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
71                      ArrayRef<int64_t> lowerRankShape,
72                      SmallVectorImpl<int64_t> &reshapeOutputShape) {
73   // Initialize new shapes with [1] * higherRank.
74   int64_t higherRank = higherRankShape.size();
75   int64_t lowerRank = lowerRankShape.size();
76 
77   reshapeOutputShape.assign(higherRank, 1);
78 
79   int64_t higherRankDim;
80   int64_t lowerRankDim;
81 
82   for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
83        i--, j--) {
84     higherRankDim = higherRankShape[i];
85     lowerRankDim = lowerRankShape[j];
86 
87     if (lowerRankDim == 1 && higherRankDim > 1)
88       reshapeOutputShape[i] = 1;
89     else if ((lowerRankDim > 1 && higherRankDim == 1) ||
90              (lowerRankDim == higherRankDim))
91       reshapeOutputShape[i] = lowerRankDim;
92     else if (higherRankDim != lowerRankDim)
93       return failure();
94   }
95   return success();
96 }
97 } // namespace
98 
99 LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
100                                         Value &input1, Value &input2) {
101   auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
102   auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
103 
104   if (!input1Ty || !input2Ty) {
105     return failure();
106   }
107 
108   int64_t input1Rank = input1Ty.getRank();
109   int64_t input2Rank = input2Ty.getRank();
110 
111   if (input1Rank == input2Rank)
112     return success();
113 
114   Value higherTensorValue, lowerTensorValue;
115   if (input1Rank > input2Rank) {
116     higherTensorValue = input1;
117     lowerTensorValue = input2;
118   } else {
119     higherTensorValue = input2;
120     lowerTensorValue = input1;
121   }
122 
123   ArrayRef<int64_t> higherRankShape =
124       llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
125   ArrayRef<int64_t> lowerRankShape =
126       llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
127 
128   SmallVector<int64_t, 4> reshapeOutputShape;
129 
130   if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
131           .failed())
132     return failure();
133 
134   auto reshapeInputType =
135       llvm::cast<RankedTensorType>(lowerTensorValue.getType());
136   auto reshapeOutputType = RankedTensorType::get(
137       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
138 
139   auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
140       loc, reshapeOutputType, lowerTensorValue,
141       rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
142 
143   if (input1Rank > input2Rank) {
144     input1 = higherTensorValue;
145     input2 = reshapeLower.getResult();
146   } else {
147     input1 = reshapeLower.getResult();
148     input2 = higherTensorValue;
149   }
150 
151   return success();
152 }
153