xref: /llvm-project/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (revision 68f58812e3e99e31d77c0c23b6298489444dc0be)
1 //===- ConversionUtils.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility functions for TOSA lowering
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
19 SmallVector<utils::IteratorType>
20 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
21   return SmallVector<utils::IteratorType>(nParallelLoops,
22                                           utils::IteratorType::parallel);
23 }
24 
25 SmallVector<Value>
26 mlir::tosa::condenseValues(const SmallVector<Value> &values) {
27   SmallVector<Value> condensedValues;
28   for (auto value : values)
29     if (value)
30       condensedValues.push_back(value);
31   return condensedValues;
32 }
33 
34 Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
35                                    Value max, OpBuilder &rewriter) {
36   Value minValue = rewriter.create<arith::MinFOp>(loc, arg, max);
37   return rewriter.create<arith::MaxFOp>(loc, minValue, min);
38 }
39 
40 Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
41                                  OpBuilder &rewriter) {
42   auto smallerThanMin =
43       rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
44   auto minOrArg =
45       rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
46   auto largerThanMax =
47       rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
48   return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
49 }
50 
51 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
52   uint64_t bitwidth = ty.getIntOrFloatBitWidth();
53   if (ty.getSignedness() == IntegerType::Unsigned) {
54     uint64_t uvalue = value;
55     APInt intMin = APInt::getMinValue(bitwidth);
56     APInt intMax = APInt::getMaxValue(bitwidth);
57     return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue();
58   }
59 
60   APInt intMin = APInt::getSignedMinValue(bitwidth);
61   APInt intMax = APInt::getSignedMaxValue(bitwidth);
62   return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
63 }
64 
65 namespace {
66 // Given two tensors of high and low ranks, derive the output shape
67 // to reshape the lower rank to.
68 // Examples:
69 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
70 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
71 // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
72 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
73 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
74 LogicalResult
75 computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
76                      ArrayRef<int64_t> lowerRankShape,
77                      SmallVectorImpl<int64_t> &reshapeOutputShape) {
78   // Initialize new shapes with [1] * higherRank.
79   int64_t higherRank = higherRankShape.size();
80   int64_t lowerRank = lowerRankShape.size();
81 
82   reshapeOutputShape.assign(higherRank, 1);
83 
84   int64_t higherRankDim;
85   int64_t lowerRankDim;
86 
87   for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
88        i--, j--) {
89     higherRankDim = higherRankShape[i];
90     lowerRankDim = lowerRankShape[j];
91 
92     if (lowerRankDim == 1 && higherRankDim > 1)
93       reshapeOutputShape[i] = 1;
94     else if ((lowerRankDim > 1 && higherRankDim == 1) ||
95              (lowerRankDim == higherRankDim))
96       reshapeOutputShape[i] = lowerRankDim;
97     else if (higherRankDim != lowerRankDim)
98       return failure();
99   }
100   return success();
101 }
102 } // namespace
103 
104 LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
105                                         Value &input1, Value &input2) {
106   auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
107   auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
108 
109   if (!input1Ty || !input2Ty) {
110     return failure();
111   }
112 
113   int64_t input1Rank = input1Ty.getRank();
114   int64_t input2Rank = input2Ty.getRank();
115 
116   if (input1Rank == input2Rank)
117     return success();
118 
119   Value higherTensorValue, lowerTensorValue;
120   if (input1Rank > input2Rank) {
121     higherTensorValue = input1;
122     lowerTensorValue = input2;
123   } else {
124     higherTensorValue = input2;
125     lowerTensorValue = input1;
126   }
127 
128   ArrayRef<int64_t> higherRankShape =
129       llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
130   ArrayRef<int64_t> lowerRankShape =
131       llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
132 
133   SmallVector<int64_t, 4> reshapeOutputShape;
134 
135   if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
136           .failed())
137     return failure();
138 
139   auto reshapeInputType =
140       llvm::cast<RankedTensorType>(lowerTensorValue.getType());
141   auto reshapeOutputType = RankedTensorType::get(
142       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
143 
144   auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
145       loc, reshapeOutputType, lowerTensorValue,
146       rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
147 
148   if (input1Rank > input2Rank) {
149     input1 = higherTensorValue;
150     input2 = reshapeLower.getResult();
151   } else {
152     input1 = reshapeLower.getResult();
153     input2 = higherTensorValue;
154   }
155 
156   return success();
157 }
158