1310e9636Snatashaknk //===- ConversionUtils.cpp ------------------------------------------------===// 2310e9636Snatashaknk // 3310e9636Snatashaknk // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4310e9636Snatashaknk // See https://llvm.org/LICENSE.txt for license information. 5310e9636Snatashaknk // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6310e9636Snatashaknk // 7310e9636Snatashaknk //===----------------------------------------------------------------------===// 8310e9636Snatashaknk // 9310e9636Snatashaknk // Utility functions for TOSA lowering 10310e9636Snatashaknk // 11310e9636Snatashaknk //===----------------------------------------------------------------------===// 12310e9636Snatashaknk 1378503e1aSRob Suderman #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 14e0537d1aSTai Ly #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15310e9636Snatashaknk 16310e9636Snatashaknk using namespace mlir; 17310e9636Snatashaknk using namespace mlir::tosa; 18310e9636Snatashaknk 19e6598b05SOleg Shyshkov SmallVector<utils::IteratorType> 20310e9636Snatashaknk mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) { 21e6598b05SOleg Shyshkov return SmallVector<utils::IteratorType>(nParallelLoops, 22e6598b05SOleg Shyshkov utils::IteratorType::parallel); 23310e9636Snatashaknk } 24310e9636Snatashaknk 25310e9636Snatashaknk SmallVector<Value> 26310e9636Snatashaknk mlir::tosa::condenseValues(const SmallVector<Value> &values) { 27310e9636Snatashaknk SmallVector<Value> condensedValues; 28310e9636Snatashaknk for (auto value : values) 29310e9636Snatashaknk if (value) 30310e9636Snatashaknk condensedValues.push_back(value); 31310e9636Snatashaknk return condensedValues; 32310e9636Snatashaknk } 332eb50ceeSThomas Raoux 3478503e1aSRob Suderman Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min, 3578503e1aSRob Suderman Value max, OpBuilder &rewriter) { 368a6e54c9SDaniil Dudkin Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max); 378a6e54c9SDaniil Dudkin return rewriter.create<arith::MaximumFOp>(loc, minValue, min); 382eb50ceeSThomas Raoux } 392eb50ceeSThomas Raoux 4078503e1aSRob Suderman Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max, 418d237190SMatthias Gehre OpBuilder &rewriter, bool isUnsigned) { 428d237190SMatthias Gehre if (isUnsigned) { 438d237190SMatthias Gehre auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg); 448d237190SMatthias Gehre return rewriter.create<arith::MinUIOp>(loc, max, minOrArg); 458d237190SMatthias Gehre } 46d4fd2025Smlevesquedion auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg); 47d4fd2025Smlevesquedion return rewriter.create<arith::MinSIOp>(loc, max, minOrArg); 482eb50ceeSThomas Raoux } 4969c984b6SRob Suderman 5069c984b6SRob Suderman bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) { 5169c984b6SRob Suderman uint64_t bitwidth = ty.getIntOrFloatBitWidth(); 5269c984b6SRob Suderman if (ty.getSignedness() == IntegerType::Unsigned) { 5369c984b6SRob Suderman uint64_t uvalue = value; 5469c984b6SRob Suderman APInt intMin = APInt::getMinValue(bitwidth); 5569c984b6SRob Suderman APInt intMax = APInt::getMaxValue(bitwidth); 5669c984b6SRob Suderman return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue(); 5769c984b6SRob Suderman } 5869c984b6SRob Suderman 5969c984b6SRob Suderman APInt intMin = APInt::getSignedMinValue(bitwidth); 6069c984b6SRob Suderman APInt intMax = APInt::getSignedMaxValue(bitwidth); 6169c984b6SRob Suderman return value >= intMin.getSExtValue() && value <= intMax.getSExtValue(); 6269c984b6SRob Suderman } 63e0537d1aSTai Ly 64e0537d1aSTai Ly namespace { 65e0537d1aSTai Ly // Given two tensors of high and low ranks, derive the output shape 66e0537d1aSTai Ly // to reshape the lower rank to. 67e0537d1aSTai Ly // Examples: 68e0537d1aSTai Ly // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. 69e0537d1aSTai Ly // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. 70e0537d1aSTai Ly // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. 71e0537d1aSTai Ly // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. 72e0537d1aSTai Ly // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. 73e0537d1aSTai Ly LogicalResult 74e0537d1aSTai Ly computeReshapeOutput(ArrayRef<int64_t> higherRankShape, 75e0537d1aSTai Ly ArrayRef<int64_t> lowerRankShape, 76e0537d1aSTai Ly SmallVectorImpl<int64_t> &reshapeOutputShape) { 77e0537d1aSTai Ly // Initialize new shapes with [1] * higherRank. 78e0537d1aSTai Ly int64_t higherRank = higherRankShape.size(); 79e0537d1aSTai Ly int64_t lowerRank = lowerRankShape.size(); 80e0537d1aSTai Ly 81e0537d1aSTai Ly reshapeOutputShape.assign(higherRank, 1); 82e0537d1aSTai Ly 83e0537d1aSTai Ly int64_t higherRankDim; 84e0537d1aSTai Ly int64_t lowerRankDim; 85e0537d1aSTai Ly 86e0537d1aSTai Ly for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; 87e0537d1aSTai Ly i--, j--) { 88e0537d1aSTai Ly higherRankDim = higherRankShape[i]; 89e0537d1aSTai Ly lowerRankDim = lowerRankShape[j]; 90e0537d1aSTai Ly 91e0537d1aSTai Ly if (lowerRankDim == 1 && higherRankDim > 1) 92e0537d1aSTai Ly reshapeOutputShape[i] = 1; 93e0537d1aSTai Ly else if ((lowerRankDim > 1 && higherRankDim == 1) || 94e0537d1aSTai Ly (lowerRankDim == higherRankDim)) 95e0537d1aSTai Ly reshapeOutputShape[i] = lowerRankDim; 96e0537d1aSTai Ly else if (higherRankDim != lowerRankDim) 97e0537d1aSTai Ly return failure(); 98e0537d1aSTai Ly } 99e0537d1aSTai Ly return success(); 100e0537d1aSTai Ly } 101e0537d1aSTai Ly } // namespace 102e0537d1aSTai Ly 103e0537d1aSTai Ly LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc, 104e0537d1aSTai Ly Value &input1, Value &input2) { 105c8834527STai Ly ImplicitLocOpBuilder builder(loc, rewriter); 106c8834527STai Ly return EqualizeRanks(builder, input1, input2); 107c8834527STai Ly } 108c8834527STai Ly 109c8834527STai Ly LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder, 110c8834527STai Ly Value &input1, Value &input2) { 11168f58812STres Popp auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType()); 11268f58812STres Popp auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType()); 113e0537d1aSTai Ly 114e0537d1aSTai Ly if (!input1Ty || !input2Ty) { 115e0537d1aSTai Ly return failure(); 116e0537d1aSTai Ly } 117e0537d1aSTai Ly 118e0537d1aSTai Ly int64_t input1Rank = input1Ty.getRank(); 119e0537d1aSTai Ly int64_t input2Rank = input2Ty.getRank(); 120e0537d1aSTai Ly 121e0537d1aSTai Ly if (input1Rank == input2Rank) 122e0537d1aSTai Ly return success(); 123e0537d1aSTai Ly 124e0537d1aSTai Ly Value higherTensorValue, lowerTensorValue; 125e0537d1aSTai Ly if (input1Rank > input2Rank) { 126e0537d1aSTai Ly higherTensorValue = input1; 127e0537d1aSTai Ly lowerTensorValue = input2; 128e0537d1aSTai Ly } else { 129e0537d1aSTai Ly higherTensorValue = input2; 130e0537d1aSTai Ly lowerTensorValue = input1; 131e0537d1aSTai Ly } 132e0537d1aSTai Ly 133e0537d1aSTai Ly ArrayRef<int64_t> higherRankShape = 13468f58812STres Popp llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape(); 135e0537d1aSTai Ly ArrayRef<int64_t> lowerRankShape = 13668f58812STres Popp llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape(); 137e0537d1aSTai Ly 138e0537d1aSTai Ly SmallVector<int64_t, 4> reshapeOutputShape; 139e0537d1aSTai Ly 140e0537d1aSTai Ly if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) 141e0537d1aSTai Ly .failed()) 142e0537d1aSTai Ly return failure(); 143e0537d1aSTai Ly 14468f58812STres Popp auto reshapeInputType = 14568f58812STres Popp llvm::cast<RankedTensorType>(lowerTensorValue.getType()); 146e0537d1aSTai Ly auto reshapeOutputType = RankedTensorType::get( 147e0537d1aSTai Ly ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType()); 148e0537d1aSTai Ly 149c8834527STai Ly auto reshapeLower = builder.create<tosa::ReshapeOp>( 150c8834527STai Ly reshapeOutputType, lowerTensorValue, 151c8834527STai Ly builder.getDenseI64ArrayAttr(reshapeOutputShape)); 152e0537d1aSTai Ly 153e0537d1aSTai Ly if (input1Rank > input2Rank) { 154e0537d1aSTai Ly input1 = higherTensorValue; 155e0537d1aSTai Ly input2 = reshapeLower.getResult(); 156e0537d1aSTai Ly } else { 157e0537d1aSTai Ly input1 = reshapeLower.getResult(); 158e0537d1aSTai Ly input2 = higherTensorValue; 159e0537d1aSTai Ly } 160e0537d1aSTai Ly 161e0537d1aSTai Ly return success(); 162e0537d1aSTai Ly } 163*7e622b61SJerry-Ge 164*7e622b61SJerry-Ge Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc, 165*7e622b61SJerry-Ge llvm::ArrayRef<int64_t> shape) { 166*7e622b61SJerry-Ge auto attr = rewriter.getIndexTensorAttr(shape); 167*7e622b61SJerry-Ge auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size()); 168*7e622b61SJerry-Ge mlir::Operation *mlir_op = 169*7e622b61SJerry-Ge rewriter.create<tosa::ConstShapeOp>(loc, type, attr); 170*7e622b61SJerry-Ge return mlir_op->getResult(0); 171*7e622b61SJerry-Ge } 172*7e622b61SJerry-Ge 173*7e622b61SJerry-Ge SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) { 174*7e622b61SJerry-Ge return to_vector(llvm::map_range(shape, [](int64_t dim) { 175*7e622b61SJerry-Ge return ShapedType::isDynamic(dim) ? -1 : dim; 176*7e622b61SJerry-Ge })); 177*7e622b61SJerry-Ge } 178*7e622b61SJerry-Ge 179*7e622b61SJerry-Ge bool mlir::tosa::getConstShapeValue(Operation *op, 180*7e622b61SJerry-Ge llvm::SmallVector<int64_t> &result_shape) { 181*7e622b61SJerry-Ge if (!op) { 182*7e622b61SJerry-Ge return false; 183*7e622b61SJerry-Ge } 184*7e622b61SJerry-Ge if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) { 185*7e622b61SJerry-Ge Attribute constOpAttr = constOp->getAttr("value"); 186*7e622b61SJerry-Ge DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr); 187*7e622b61SJerry-Ge for (int i = 0; i < elementsAttr.size(); i++) { 188*7e622b61SJerry-Ge int64_t val = elementsAttr.getValues<int64_t>()[i]; 189*7e622b61SJerry-Ge result_shape.push_back(val); 190*7e622b61SJerry-Ge } 191*7e622b61SJerry-Ge return true; 192*7e622b61SJerry-Ge } 193*7e622b61SJerry-Ge // for undefined op, return false. 194*7e622b61SJerry-Ge return false; 195*7e622b61SJerry-Ge } 196