xref: /llvm-project/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (revision 7e622b61320543b3706711609f1f32fd9ea3788d)
1 //===- ConversionUtils.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility functions for TOSA lowering
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
19 SmallVector<utils::IteratorType>
20 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
21   return SmallVector<utils::IteratorType>(nParallelLoops,
22                                           utils::IteratorType::parallel);
23 }
24 
25 SmallVector<Value>
26 mlir::tosa::condenseValues(const SmallVector<Value> &values) {
27   SmallVector<Value> condensedValues;
28   for (auto value : values)
29     if (value)
30       condensedValues.push_back(value);
31   return condensedValues;
32 }
33 
34 Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
35                                    Value max, OpBuilder &rewriter) {
36   Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max);
37   return rewriter.create<arith::MaximumFOp>(loc, minValue, min);
38 }
39 
40 Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
41                                  OpBuilder &rewriter, bool isUnsigned) {
42   if (isUnsigned) {
43     auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg);
44     return rewriter.create<arith::MinUIOp>(loc, max, minOrArg);
45   }
46   auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
47   return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
48 }
49 
50 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
51   uint64_t bitwidth = ty.getIntOrFloatBitWidth();
52   if (ty.getSignedness() == IntegerType::Unsigned) {
53     uint64_t uvalue = value;
54     APInt intMin = APInt::getMinValue(bitwidth);
55     APInt intMax = APInt::getMaxValue(bitwidth);
56     return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue();
57   }
58 
59   APInt intMin = APInt::getSignedMinValue(bitwidth);
60   APInt intMax = APInt::getSignedMaxValue(bitwidth);
61   return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
62 }
63 
64 namespace {
65 // Given two tensors of high and low ranks, derive the output shape
66 // to reshape the lower rank to.
67 // Examples:
68 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
69 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
70 // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
71 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
72 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
73 LogicalResult
74 computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
75                      ArrayRef<int64_t> lowerRankShape,
76                      SmallVectorImpl<int64_t> &reshapeOutputShape) {
77   // Initialize new shapes with [1] * higherRank.
78   int64_t higherRank = higherRankShape.size();
79   int64_t lowerRank = lowerRankShape.size();
80 
81   reshapeOutputShape.assign(higherRank, 1);
82 
83   int64_t higherRankDim;
84   int64_t lowerRankDim;
85 
86   for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
87        i--, j--) {
88     higherRankDim = higherRankShape[i];
89     lowerRankDim = lowerRankShape[j];
90 
91     if (lowerRankDim == 1 && higherRankDim > 1)
92       reshapeOutputShape[i] = 1;
93     else if ((lowerRankDim > 1 && higherRankDim == 1) ||
94              (lowerRankDim == higherRankDim))
95       reshapeOutputShape[i] = lowerRankDim;
96     else if (higherRankDim != lowerRankDim)
97       return failure();
98   }
99   return success();
100 }
101 } // namespace
102 
103 LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
104                                         Value &input1, Value &input2) {
105   ImplicitLocOpBuilder builder(loc, rewriter);
106   return EqualizeRanks(builder, input1, input2);
107 }
108 
109 LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
110                                         Value &input1, Value &input2) {
111   auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
112   auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
113 
114   if (!input1Ty || !input2Ty) {
115     return failure();
116   }
117 
118   int64_t input1Rank = input1Ty.getRank();
119   int64_t input2Rank = input2Ty.getRank();
120 
121   if (input1Rank == input2Rank)
122     return success();
123 
124   Value higherTensorValue, lowerTensorValue;
125   if (input1Rank > input2Rank) {
126     higherTensorValue = input1;
127     lowerTensorValue = input2;
128   } else {
129     higherTensorValue = input2;
130     lowerTensorValue = input1;
131   }
132 
133   ArrayRef<int64_t> higherRankShape =
134       llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
135   ArrayRef<int64_t> lowerRankShape =
136       llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
137 
138   SmallVector<int64_t, 4> reshapeOutputShape;
139 
140   if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
141           .failed())
142     return failure();
143 
144   auto reshapeInputType =
145       llvm::cast<RankedTensorType>(lowerTensorValue.getType());
146   auto reshapeOutputType = RankedTensorType::get(
147       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
148 
149   auto reshapeLower = builder.create<tosa::ReshapeOp>(
150       reshapeOutputType, lowerTensorValue,
151       builder.getDenseI64ArrayAttr(reshapeOutputShape));
152 
153   if (input1Rank > input2Rank) {
154     input1 = higherTensorValue;
155     input2 = reshapeLower.getResult();
156   } else {
157     input1 = reshapeLower.getResult();
158     input2 = higherTensorValue;
159   }
160 
161   return success();
162 }
163 
164 Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
165                                     llvm::ArrayRef<int64_t> shape) {
166   auto attr = rewriter.getIndexTensorAttr(shape);
167   auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
168   mlir::Operation *mlir_op =
169       rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
170   return mlir_op->getResult(0);
171 }
172 
173 SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
174   return to_vector(llvm::map_range(shape, [](int64_t dim) {
175     return ShapedType::isDynamic(dim) ? -1 : dim;
176   }));
177 }
178 
179 bool mlir::tosa::getConstShapeValue(Operation *op,
180                                     llvm::SmallVector<int64_t> &result_shape) {
181   if (!op) {
182     return false;
183   }
184   if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
185     Attribute constOpAttr = constOp->getAttr("value");
186     DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
187     for (int i = 0; i < elementsAttr.size(); i++) {
188       int64_t val = elementsAttr.getValues<int64_t>()[i];
189       result_shape.push_back(val);
190     }
191     return true;
192   }
193   // for undefined op, return false.
194   return false;
195 }
196