xref: /llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (revision 7e622b61320543b3706711609f1f32fd9ea3788d)
1 //===- ConversionUtils.h - Helper functions for tosa conversion -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility functions for TOSA lowering
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
14 #define DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include <optional>
23 
24 namespace mlir {
25 namespace tosa {
26 
27 // Creates a SmallVector of Stringrefs for N parallel loops
28 SmallVector<utils::IteratorType>
29 getNParallelLoopsAttrs(unsigned nParallelLoops);
30 
31 // Takes a vector of values and condenses them to a vector with no gaps.
32 SmallVector<Value> condenseValues(const SmallVector<Value> &values);
33 
34 // Takes the parameters for a clamp and turns it into a series of ops for float
35 // inputs.
36 Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
37                        OpBuilder &rewriter);
38 
39 // Takes the parameters for a clamp and turns it into a series of ops for
40 // integer inputs.
41 Value clampIntHelper(Location loc, Value arg, Value min, Value max,
42                      OpBuilder &rewriter, bool isUnsigned);
43 
44 // Determines whether the integer value falls witin the range of integer type.
45 bool validIntegerRange(IntegerType ty, int64_t value);
46 
47 // Checks for a dynamic batch dim in any of the passed parameters of an op.
48 // The batch dimention must be #0 and the rest of the dimensions must be static.
49 template <typename Op>
50 std::optional<SmallVector<Value>>
51 checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
52                          ArrayRef<Value> params) {
53   SmallVector<ShapedType> dynTypes;
54   SmallVector<Value> dynamicDims;
55   for (const Value &param : params) {
56     auto paramTy = cast<ShapedType>(param.getType());
57     if (!paramTy.hasStaticShape())
58       dynTypes.push_back(paramTy);
59   }
60 
61   if (dynTypes.empty())
62     return dynamicDims;
63 
64   for (const ShapedType &dynTy : dynTypes) {
65     if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) {
66       (void)rewriter.notifyMatchFailure(
67           op, "input can only be dynamic for batch size");
68       return std::nullopt;
69     }
70   }
71 
72   dynamicDims.push_back(
73       rewriter.create<tensor::DimOp>(op->getLoc(), params[0], 0));
74   return dynamicDims;
75 }
76 
77 /// Common code to create the reshape op where necessary to make the rank of two
78 /// values equal. input1 and input2 will be updated when the rank has
79 /// changed. The caller is expected to use these to rewrite the original
80 /// operator with the RESHAPE now in the graph.
81 LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
82                             Value &input1, Value &input2);
83 
84 LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
85                             Value &input2);
86 
87 namespace {
88 
89 // Creates a TOSA operation and performs shape inference on the individual
90 // op. This allows shape inference when lowering down to TOSA.
91 template <typename TosaOp, typename... Args>
92 TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy,
93                              Args &&...args) {
94   auto op = builder.create<TosaOp>(resultTy, args...);
95 
96   InferShapedTypeOpInterface shapeInterface =
97       dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
98   if (!shapeInterface)
99     return op;
100 
101   SmallVector<ShapedTypeComponents> returnedShapes;
102   if (shapeInterface
103           .inferReturnTypeComponents(op.getContext(), builder.getLoc(),
104                                      op->getOperands(), op->getAttrDictionary(),
105                                      op->getPropertiesStorage(),
106                                      op->getRegions(), returnedShapes)
107           .failed())
108     return op;
109 
110   // We need to use the element type of the existing result type to generate
111   // the new result shaped type. This is because rescale can include a cast to
112   // different bit-width types and does not have a TypeAttr to define the
113   // target type.
114   auto result = op->getResult(0);
115   auto predictedShape = returnedShapes[0];
116   auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(resultTy);
117 
118   // Compute the knowledge based on the inferred type.
119   auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
120   inferredKnowledge.dtype = mlir::cast<ShapedType>(resultTy).getElementType();
121   inferredKnowledge.hasRank = predictedShape.hasRank();
122   if (predictedShape.hasRank()) {
123     for (auto dim : predictedShape.getDims()) {
124       inferredKnowledge.sizes.push_back(dim);
125     }
126   }
127 
128   // Compute the new type based on the joined version.
129   auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
130   Type newTy =
131       newKnowledge.hasRank
132           ? Type{mlir::RankedTensorType::get(llvm::ArrayRef(newKnowledge.sizes),
133                                              newKnowledge.dtype)}
134           : Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)};
135   result.setType(newTy);
136   return op;
137 }
138 
139 } // namespace
140 
141 // Creates a TOSA operation by:
142 //   - first equalize ranks for ops with SameOperandsAndResultRank trait
143 //   - create operator
144 //   - performs shape inference on this operator
145 template <typename TosaOp, typename... Args>
146 TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy,
147                              Args &&...args) {
148   if (TosaOp::template hasTrait<OpTrait::SameOperandsAndResultRank>()) {
149     // op requires same ranks for tensor operands
150     if constexpr (sizeof...(Args) == 2) {
151       auto argX = std::get<0>(std::tie(args...));
152       auto argY = std::get<1>(std::tie(args...));
153       using ArgX = decltype(argX);
154       using ArgY = decltype(argY);
155       if constexpr (std::is_same_v<ArgX, Value> &&
156                     std::is_same_v<ArgY, Value>) {
157         Value x = std::get<0>(std::tie(args...));
158         Value y = std::get<1>(std::tie(args...));
159         if (EqualizeRanks(builder, x, y).failed()) {
160           // incompatible broadcast shapes, no reshape is inserted
161           // ResultsBroadcastableShape verify will handle this
162         }
163         return createOpAndInferShape<TosaOp>(builder, resultTy, x, y);
164       }
165     }
166     if constexpr (sizeof...(Args) == 3) {
167       auto argX = std::get<0>(std::tie(args...));
168       auto argY = std::get<1>(std::tie(args...));
169       auto argZ = std::get<2>(std::tie(args...));
170       using ArgX = decltype(argX);
171       using ArgY = decltype(argY);
172       using ArgZ = decltype(argZ);
173       if constexpr (std::is_same_v<ArgX, Value> &&
174                     std::is_same_v<ArgY, Value> && std::is_same_v<ArgZ, bool>) {
175         // special case for ArithmeticRightShiftOp
176         Value x = std::get<0>(std::tie(args...));
177         Value y = std::get<1>(std::tie(args...));
178         bool round = std::get<2>(std::tie(args...));
179         if (EqualizeRanks(builder, x, y).failed()) {
180           // incompatible broadcast shapes, no reshape is inserted
181           // ResultsBroadcastableShape verify will handle this
182         }
183         return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, round);
184       }
185       if constexpr (std::is_same_v<ArgX, Value> &&
186                     std::is_same_v<ArgY, Value> &&
187                     std::is_same_v<ArgZ, Value>) {
188         // special case for Select
189         Value x = std::get<0>(std::tie(args...));
190         Value y = std::get<1>(std::tie(args...));
191         Value z = std::get<2>(std::tie(args...));
192 
193         if (EqualizeRanks(builder, x, y).failed() ||
194             EqualizeRanks(builder, x, z).failed() ||
195             EqualizeRanks(builder, y, z).failed()) {
196           // incompatible broadcast shapes, no reshape is inserted
197           // ResultsBroadcastableShape verify will handle this
198         }
199 
200         return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, z);
201       }
202     }
203   }
204 
205   return createOpAndInferShape<TosaOp>(builder, resultTy, args...);
206 }
207 
208 // Creates a TOSA operation by:
209 //   - first equalize ranks for ops with SameOperandsAndResultRank trait
210 //   - create operator
211 //   - performs shape inference on this operator
212 template <typename TosaOp, typename... Args>
213 TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
214                              Type resultTy, Args &&...args) {
215   ImplicitLocOpBuilder builder(loc, rewriter);
216   return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
217 }
218 
219 // Apply an int32_t permutation to some input, that should be of the same
220 // size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
221 template <typename T>
222 SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
223                                     ArrayRef<int32_t> perms) {
224   SmallVector<T> permuted;
225   size_t N = input.size();
226   permuted.resize_for_overwrite(N);
227   for (size_t i = 0; i < N; i++)
228     permuted[i] = input[perms[i]];
229   return permuted;
230 }
231 
232 // Computes shape value using tosa const_shape op.
233 Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
234                         llvm::ArrayRef<int64_t> shape);
235 SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
236 
237 bool getConstShapeValue(Operation *op,
238                         llvm::SmallVector<int64_t> &result_shape);
239 
240 } // namespace tosa
241 } // namespace mlir
242 
243 #endif // DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
244