1 //===- ConversionUtils.h - Helper functions for tosa conversion -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Utility functions for TOSA lowering 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef DIALECT_TOSA_UTILS_COVERSION_UTILS_H_ 14 #define DIALECT_TOSA_UTILS_COVERSION_UTILS_H_ 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 20 #include "mlir/IR/ImplicitLocOpBuilder.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include <optional> 23 24 namespace mlir { 25 namespace tosa { 26 27 // Creates a SmallVector of Stringrefs for N parallel loops 28 SmallVector<utils::IteratorType> 29 getNParallelLoopsAttrs(unsigned nParallelLoops); 30 31 // Takes a vector of values and condenses them to a vector with no gaps. 32 SmallVector<Value> condenseValues(const SmallVector<Value> &values); 33 34 // Takes the parameters for a clamp and turns it into a series of ops for float 35 // inputs. 36 Value clampFloatHelper(Location loc, Value arg, Value min, Value max, 37 OpBuilder &rewriter); 38 39 // Takes the parameters for a clamp and turns it into a series of ops for 40 // integer inputs. 41 Value clampIntHelper(Location loc, Value arg, Value min, Value max, 42 OpBuilder &rewriter, bool isUnsigned); 43 44 // Determines whether the integer value falls witin the range of integer type. 45 bool validIntegerRange(IntegerType ty, int64_t value); 46 47 // Checks for a dynamic batch dim in any of the passed parameters of an op. 48 // The batch dimention must be #0 and the rest of the dimensions must be static. 49 template <typename Op> 50 std::optional<SmallVector<Value>> 51 checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, 52 ArrayRef<Value> params) { 53 SmallVector<ShapedType> dynTypes; 54 SmallVector<Value> dynamicDims; 55 for (const Value ¶m : params) { 56 auto paramTy = cast<ShapedType>(param.getType()); 57 if (!paramTy.hasStaticShape()) 58 dynTypes.push_back(paramTy); 59 } 60 61 if (dynTypes.empty()) 62 return dynamicDims; 63 64 for (const ShapedType &dynTy : dynTypes) { 65 if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) { 66 (void)rewriter.notifyMatchFailure( 67 op, "input can only be dynamic for batch size"); 68 return std::nullopt; 69 } 70 } 71 72 dynamicDims.push_back( 73 rewriter.create<tensor::DimOp>(op->getLoc(), params[0], 0)); 74 return dynamicDims; 75 } 76 77 /// Common code to create the reshape op where necessary to make the rank of two 78 /// values equal. input1 and input2 will be updated when the rank has 79 /// changed. The caller is expected to use these to rewrite the original 80 /// operator with the RESHAPE now in the graph. 81 LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, 82 Value &input1, Value &input2); 83 84 LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1, 85 Value &input2); 86 87 namespace { 88 89 // Creates a TOSA operation and performs shape inference on the individual 90 // op. This allows shape inference when lowering down to TOSA. 91 template <typename TosaOp, typename... Args> 92 TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy, 93 Args &&...args) { 94 auto op = builder.create<TosaOp>(resultTy, args...); 95 96 InferShapedTypeOpInterface shapeInterface = 97 dyn_cast<InferShapedTypeOpInterface>(op.getOperation()); 98 if (!shapeInterface) 99 return op; 100 101 SmallVector<ShapedTypeComponents> returnedShapes; 102 if (shapeInterface 103 .inferReturnTypeComponents(op.getContext(), builder.getLoc(), 104 op->getOperands(), op->getAttrDictionary(), 105 op->getPropertiesStorage(), 106 op->getRegions(), returnedShapes) 107 .failed()) 108 return op; 109 110 // We need to use the element type of the existing result type to generate 111 // the new result shaped type. This is because rescale can include a cast to 112 // different bit-width types and does not have a TypeAttr to define the 113 // target type. 114 auto result = op->getResult(0); 115 auto predictedShape = returnedShapes[0]; 116 auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(resultTy); 117 118 // Compute the knowledge based on the inferred type. 119 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); 120 inferredKnowledge.dtype = mlir::cast<ShapedType>(resultTy).getElementType(); 121 inferredKnowledge.hasRank = predictedShape.hasRank(); 122 if (predictedShape.hasRank()) { 123 for (auto dim : predictedShape.getDims()) { 124 inferredKnowledge.sizes.push_back(dim); 125 } 126 } 127 128 // Compute the new type based on the joined version. 129 auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge); 130 Type newTy = 131 newKnowledge.hasRank 132 ? Type{mlir::RankedTensorType::get(llvm::ArrayRef(newKnowledge.sizes), 133 newKnowledge.dtype)} 134 : Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)}; 135 result.setType(newTy); 136 return op; 137 } 138 139 } // namespace 140 141 // Creates a TOSA operation by: 142 // - first equalize ranks for ops with SameOperandsAndResultRank trait 143 // - create operator 144 // - performs shape inference on this operator 145 template <typename TosaOp, typename... Args> 146 TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy, 147 Args &&...args) { 148 if (TosaOp::template hasTrait<OpTrait::SameOperandsAndResultRank>()) { 149 // op requires same ranks for tensor operands 150 if constexpr (sizeof...(Args) == 2) { 151 auto argX = std::get<0>(std::tie(args...)); 152 auto argY = std::get<1>(std::tie(args...)); 153 using ArgX = decltype(argX); 154 using ArgY = decltype(argY); 155 if constexpr (std::is_same_v<ArgX, Value> && 156 std::is_same_v<ArgY, Value>) { 157 Value x = std::get<0>(std::tie(args...)); 158 Value y = std::get<1>(std::tie(args...)); 159 if (EqualizeRanks(builder, x, y).failed()) { 160 // incompatible broadcast shapes, no reshape is inserted 161 // ResultsBroadcastableShape verify will handle this 162 } 163 return createOpAndInferShape<TosaOp>(builder, resultTy, x, y); 164 } 165 } 166 if constexpr (sizeof...(Args) == 3) { 167 auto argX = std::get<0>(std::tie(args...)); 168 auto argY = std::get<1>(std::tie(args...)); 169 auto argZ = std::get<2>(std::tie(args...)); 170 using ArgX = decltype(argX); 171 using ArgY = decltype(argY); 172 using ArgZ = decltype(argZ); 173 if constexpr (std::is_same_v<ArgX, Value> && 174 std::is_same_v<ArgY, Value> && std::is_same_v<ArgZ, bool>) { 175 // special case for ArithmeticRightShiftOp 176 Value x = std::get<0>(std::tie(args...)); 177 Value y = std::get<1>(std::tie(args...)); 178 bool round = std::get<2>(std::tie(args...)); 179 if (EqualizeRanks(builder, x, y).failed()) { 180 // incompatible broadcast shapes, no reshape is inserted 181 // ResultsBroadcastableShape verify will handle this 182 } 183 return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, round); 184 } 185 if constexpr (std::is_same_v<ArgX, Value> && 186 std::is_same_v<ArgY, Value> && 187 std::is_same_v<ArgZ, Value>) { 188 // special case for Select 189 Value x = std::get<0>(std::tie(args...)); 190 Value y = std::get<1>(std::tie(args...)); 191 Value z = std::get<2>(std::tie(args...)); 192 193 if (EqualizeRanks(builder, x, y).failed() || 194 EqualizeRanks(builder, x, z).failed() || 195 EqualizeRanks(builder, y, z).failed()) { 196 // incompatible broadcast shapes, no reshape is inserted 197 // ResultsBroadcastableShape verify will handle this 198 } 199 200 return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, z); 201 } 202 } 203 } 204 205 return createOpAndInferShape<TosaOp>(builder, resultTy, args...); 206 } 207 208 // Creates a TOSA operation by: 209 // - first equalize ranks for ops with SameOperandsAndResultRank trait 210 // - create operator 211 // - performs shape inference on this operator 212 template <typename TosaOp, typename... Args> 213 TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc, 214 Type resultTy, Args &&...args) { 215 ImplicitLocOpBuilder builder(loc, rewriter); 216 return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...); 217 } 218 219 // Apply an int32_t permutation to some input, that should be of the same 220 // size as perms. Perms should contain some permutation of 0 - perms.size() - 1. 221 template <typename T> 222 SmallVector<T> applyTOSAPermutation(ArrayRef<T> input, 223 ArrayRef<int32_t> perms) { 224 SmallVector<T> permuted; 225 size_t N = input.size(); 226 permuted.resize_for_overwrite(N); 227 for (size_t i = 0; i < N; i++) 228 permuted[i] = input[perms[i]]; 229 return permuted; 230 } 231 232 // Computes shape value using tosa const_shape op. 233 Value getTosaConstShape(PatternRewriter &rewriter, Location loc, 234 llvm::ArrayRef<int64_t> shape); 235 SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape); 236 237 bool getConstShapeValue(Operation *op, 238 llvm::SmallVector<int64_t> &result_shape); 239 240 } // namespace tosa 241 } // namespace mlir 242 243 #endif // DIALECT_TOSA_UTILS_COVERSION_UTILS_H_ 244