xref: /llvm-project/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (revision 7e622b61320543b3706711609f1f32fd9ea3788d)
1310e9636Snatashaknk //===- ConversionUtils.cpp ------------------------------------------------===//
2310e9636Snatashaknk //
3310e9636Snatashaknk // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4310e9636Snatashaknk // See https://llvm.org/LICENSE.txt for license information.
5310e9636Snatashaknk // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6310e9636Snatashaknk //
7310e9636Snatashaknk //===----------------------------------------------------------------------===//
8310e9636Snatashaknk //
9310e9636Snatashaknk // Utility functions for TOSA lowering
10310e9636Snatashaknk //
11310e9636Snatashaknk //===----------------------------------------------------------------------===//
12310e9636Snatashaknk 
1378503e1aSRob Suderman #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
14e0537d1aSTai Ly #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15310e9636Snatashaknk 
16310e9636Snatashaknk using namespace mlir;
17310e9636Snatashaknk using namespace mlir::tosa;
18310e9636Snatashaknk 
19e6598b05SOleg Shyshkov SmallVector<utils::IteratorType>
20310e9636Snatashaknk mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
21e6598b05SOleg Shyshkov   return SmallVector<utils::IteratorType>(nParallelLoops,
22e6598b05SOleg Shyshkov                                           utils::IteratorType::parallel);
23310e9636Snatashaknk }
24310e9636Snatashaknk 
25310e9636Snatashaknk SmallVector<Value>
26310e9636Snatashaknk mlir::tosa::condenseValues(const SmallVector<Value> &values) {
27310e9636Snatashaknk   SmallVector<Value> condensedValues;
28310e9636Snatashaknk   for (auto value : values)
29310e9636Snatashaknk     if (value)
30310e9636Snatashaknk       condensedValues.push_back(value);
31310e9636Snatashaknk   return condensedValues;
32310e9636Snatashaknk }
332eb50ceeSThomas Raoux 
3478503e1aSRob Suderman Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
3578503e1aSRob Suderman                                    Value max, OpBuilder &rewriter) {
368a6e54c9SDaniil Dudkin   Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max);
378a6e54c9SDaniil Dudkin   return rewriter.create<arith::MaximumFOp>(loc, minValue, min);
382eb50ceeSThomas Raoux }
392eb50ceeSThomas Raoux 
4078503e1aSRob Suderman Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
418d237190SMatthias Gehre                                  OpBuilder &rewriter, bool isUnsigned) {
428d237190SMatthias Gehre   if (isUnsigned) {
438d237190SMatthias Gehre     auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg);
448d237190SMatthias Gehre     return rewriter.create<arith::MinUIOp>(loc, max, minOrArg);
458d237190SMatthias Gehre   }
46d4fd2025Smlevesquedion   auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
47d4fd2025Smlevesquedion   return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
482eb50ceeSThomas Raoux }
4969c984b6SRob Suderman 
5069c984b6SRob Suderman bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
5169c984b6SRob Suderman   uint64_t bitwidth = ty.getIntOrFloatBitWidth();
5269c984b6SRob Suderman   if (ty.getSignedness() == IntegerType::Unsigned) {
5369c984b6SRob Suderman     uint64_t uvalue = value;
5469c984b6SRob Suderman     APInt intMin = APInt::getMinValue(bitwidth);
5569c984b6SRob Suderman     APInt intMax = APInt::getMaxValue(bitwidth);
5669c984b6SRob Suderman     return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue();
5769c984b6SRob Suderman   }
5869c984b6SRob Suderman 
5969c984b6SRob Suderman   APInt intMin = APInt::getSignedMinValue(bitwidth);
6069c984b6SRob Suderman   APInt intMax = APInt::getSignedMaxValue(bitwidth);
6169c984b6SRob Suderman   return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
6269c984b6SRob Suderman }
63e0537d1aSTai Ly 
64e0537d1aSTai Ly namespace {
65e0537d1aSTai Ly // Given two tensors of high and low ranks, derive the output shape
66e0537d1aSTai Ly // to reshape the lower rank to.
67e0537d1aSTai Ly // Examples:
68e0537d1aSTai Ly // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
69e0537d1aSTai Ly // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
70e0537d1aSTai Ly // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
71e0537d1aSTai Ly // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
72e0537d1aSTai Ly // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
73e0537d1aSTai Ly LogicalResult
74e0537d1aSTai Ly computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
75e0537d1aSTai Ly                      ArrayRef<int64_t> lowerRankShape,
76e0537d1aSTai Ly                      SmallVectorImpl<int64_t> &reshapeOutputShape) {
77e0537d1aSTai Ly   // Initialize new shapes with [1] * higherRank.
78e0537d1aSTai Ly   int64_t higherRank = higherRankShape.size();
79e0537d1aSTai Ly   int64_t lowerRank = lowerRankShape.size();
80e0537d1aSTai Ly 
81e0537d1aSTai Ly   reshapeOutputShape.assign(higherRank, 1);
82e0537d1aSTai Ly 
83e0537d1aSTai Ly   int64_t higherRankDim;
84e0537d1aSTai Ly   int64_t lowerRankDim;
85e0537d1aSTai Ly 
86e0537d1aSTai Ly   for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
87e0537d1aSTai Ly        i--, j--) {
88e0537d1aSTai Ly     higherRankDim = higherRankShape[i];
89e0537d1aSTai Ly     lowerRankDim = lowerRankShape[j];
90e0537d1aSTai Ly 
91e0537d1aSTai Ly     if (lowerRankDim == 1 && higherRankDim > 1)
92e0537d1aSTai Ly       reshapeOutputShape[i] = 1;
93e0537d1aSTai Ly     else if ((lowerRankDim > 1 && higherRankDim == 1) ||
94e0537d1aSTai Ly              (lowerRankDim == higherRankDim))
95e0537d1aSTai Ly       reshapeOutputShape[i] = lowerRankDim;
96e0537d1aSTai Ly     else if (higherRankDim != lowerRankDim)
97e0537d1aSTai Ly       return failure();
98e0537d1aSTai Ly   }
99e0537d1aSTai Ly   return success();
100e0537d1aSTai Ly }
101e0537d1aSTai Ly } // namespace
102e0537d1aSTai Ly 
103e0537d1aSTai Ly LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
104e0537d1aSTai Ly                                         Value &input1, Value &input2) {
105c8834527STai Ly   ImplicitLocOpBuilder builder(loc, rewriter);
106c8834527STai Ly   return EqualizeRanks(builder, input1, input2);
107c8834527STai Ly }
108c8834527STai Ly 
109c8834527STai Ly LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
110c8834527STai Ly                                         Value &input1, Value &input2) {
11168f58812STres Popp   auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
11268f58812STres Popp   auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
113e0537d1aSTai Ly 
114e0537d1aSTai Ly   if (!input1Ty || !input2Ty) {
115e0537d1aSTai Ly     return failure();
116e0537d1aSTai Ly   }
117e0537d1aSTai Ly 
118e0537d1aSTai Ly   int64_t input1Rank = input1Ty.getRank();
119e0537d1aSTai Ly   int64_t input2Rank = input2Ty.getRank();
120e0537d1aSTai Ly 
121e0537d1aSTai Ly   if (input1Rank == input2Rank)
122e0537d1aSTai Ly     return success();
123e0537d1aSTai Ly 
124e0537d1aSTai Ly   Value higherTensorValue, lowerTensorValue;
125e0537d1aSTai Ly   if (input1Rank > input2Rank) {
126e0537d1aSTai Ly     higherTensorValue = input1;
127e0537d1aSTai Ly     lowerTensorValue = input2;
128e0537d1aSTai Ly   } else {
129e0537d1aSTai Ly     higherTensorValue = input2;
130e0537d1aSTai Ly     lowerTensorValue = input1;
131e0537d1aSTai Ly   }
132e0537d1aSTai Ly 
133e0537d1aSTai Ly   ArrayRef<int64_t> higherRankShape =
13468f58812STres Popp       llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
135e0537d1aSTai Ly   ArrayRef<int64_t> lowerRankShape =
13668f58812STres Popp       llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
137e0537d1aSTai Ly 
138e0537d1aSTai Ly   SmallVector<int64_t, 4> reshapeOutputShape;
139e0537d1aSTai Ly 
140e0537d1aSTai Ly   if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
141e0537d1aSTai Ly           .failed())
142e0537d1aSTai Ly     return failure();
143e0537d1aSTai Ly 
14468f58812STres Popp   auto reshapeInputType =
14568f58812STres Popp       llvm::cast<RankedTensorType>(lowerTensorValue.getType());
146e0537d1aSTai Ly   auto reshapeOutputType = RankedTensorType::get(
147e0537d1aSTai Ly       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
148e0537d1aSTai Ly 
149c8834527STai Ly   auto reshapeLower = builder.create<tosa::ReshapeOp>(
150c8834527STai Ly       reshapeOutputType, lowerTensorValue,
151c8834527STai Ly       builder.getDenseI64ArrayAttr(reshapeOutputShape));
152e0537d1aSTai Ly 
153e0537d1aSTai Ly   if (input1Rank > input2Rank) {
154e0537d1aSTai Ly     input1 = higherTensorValue;
155e0537d1aSTai Ly     input2 = reshapeLower.getResult();
156e0537d1aSTai Ly   } else {
157e0537d1aSTai Ly     input1 = reshapeLower.getResult();
158e0537d1aSTai Ly     input2 = higherTensorValue;
159e0537d1aSTai Ly   }
160e0537d1aSTai Ly 
161e0537d1aSTai Ly   return success();
162e0537d1aSTai Ly }
163*7e622b61SJerry-Ge 
164*7e622b61SJerry-Ge Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
165*7e622b61SJerry-Ge                                     llvm::ArrayRef<int64_t> shape) {
166*7e622b61SJerry-Ge   auto attr = rewriter.getIndexTensorAttr(shape);
167*7e622b61SJerry-Ge   auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
168*7e622b61SJerry-Ge   mlir::Operation *mlir_op =
169*7e622b61SJerry-Ge       rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
170*7e622b61SJerry-Ge   return mlir_op->getResult(0);
171*7e622b61SJerry-Ge }
172*7e622b61SJerry-Ge 
173*7e622b61SJerry-Ge SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
174*7e622b61SJerry-Ge   return to_vector(llvm::map_range(shape, [](int64_t dim) {
175*7e622b61SJerry-Ge     return ShapedType::isDynamic(dim) ? -1 : dim;
176*7e622b61SJerry-Ge   }));
177*7e622b61SJerry-Ge }
178*7e622b61SJerry-Ge 
179*7e622b61SJerry-Ge bool mlir::tosa::getConstShapeValue(Operation *op,
180*7e622b61SJerry-Ge                                     llvm::SmallVector<int64_t> &result_shape) {
181*7e622b61SJerry-Ge   if (!op) {
182*7e622b61SJerry-Ge     return false;
183*7e622b61SJerry-Ge   }
184*7e622b61SJerry-Ge   if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
185*7e622b61SJerry-Ge     Attribute constOpAttr = constOp->getAttr("value");
186*7e622b61SJerry-Ge     DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
187*7e622b61SJerry-Ge     for (int i = 0; i < elementsAttr.size(); i++) {
188*7e622b61SJerry-Ge       int64_t val = elementsAttr.getValues<int64_t>()[i];
189*7e622b61SJerry-Ge       result_shape.push_back(val);
190*7e622b61SJerry-Ge     }
191*7e622b61SJerry-Ge     return true;
192*7e622b61SJerry-Ge   }
193*7e622b61SJerry-Ge   // for undefined op, return false.
194*7e622b61SJerry-Ge   return false;
195*7e622b61SJerry-Ge }
196