1 //===- ConversionUtils.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Utility functions for TOSA lowering 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 14 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15 16 using namespace mlir; 17 using namespace mlir::tosa; 18 19 SmallVector<utils::IteratorType> 20 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) { 21 return SmallVector<utils::IteratorType>(nParallelLoops, 22 utils::IteratorType::parallel); 23 } 24 25 SmallVector<Value> 26 mlir::tosa::condenseValues(const SmallVector<Value> &values) { 27 SmallVector<Value> condensedValues; 28 for (auto value : values) 29 if (value) 30 condensedValues.push_back(value); 31 return condensedValues; 32 } 33 34 Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min, 35 Value max, OpBuilder &rewriter) { 36 Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max); 37 return rewriter.create<arith::MaximumFOp>(loc, minValue, min); 38 } 39 40 Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max, 41 OpBuilder &rewriter) { 42 auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg); 43 return rewriter.create<arith::MinSIOp>(loc, max, minOrArg); 44 } 45 46 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) { 47 uint64_t bitwidth = ty.getIntOrFloatBitWidth(); 48 if (ty.getSignedness() == IntegerType::Unsigned) { 49 uint64_t uvalue = value; 50 APInt intMin = APInt::getMinValue(bitwidth); 51 APInt intMax = APInt::getMaxValue(bitwidth); 52 return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue(); 53 } 54 55 APInt intMin = APInt::getSignedMinValue(bitwidth); 56 APInt intMax = APInt::getSignedMaxValue(bitwidth); 57 return value >= intMin.getSExtValue() && value <= intMax.getSExtValue(); 58 } 59 60 namespace { 61 // Given two tensors of high and low ranks, derive the output shape 62 // to reshape the lower rank to. 63 // Examples: 64 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. 65 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. 66 // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. 67 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. 68 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. 69 LogicalResult 70 computeReshapeOutput(ArrayRef<int64_t> higherRankShape, 71 ArrayRef<int64_t> lowerRankShape, 72 SmallVectorImpl<int64_t> &reshapeOutputShape) { 73 // Initialize new shapes with [1] * higherRank. 74 int64_t higherRank = higherRankShape.size(); 75 int64_t lowerRank = lowerRankShape.size(); 76 77 reshapeOutputShape.assign(higherRank, 1); 78 79 int64_t higherRankDim; 80 int64_t lowerRankDim; 81 82 for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; 83 i--, j--) { 84 higherRankDim = higherRankShape[i]; 85 lowerRankDim = lowerRankShape[j]; 86 87 if (lowerRankDim == 1 && higherRankDim > 1) 88 reshapeOutputShape[i] = 1; 89 else if ((lowerRankDim > 1 && higherRankDim == 1) || 90 (lowerRankDim == higherRankDim)) 91 reshapeOutputShape[i] = lowerRankDim; 92 else if (higherRankDim != lowerRankDim) 93 return failure(); 94 } 95 return success(); 96 } 97 } // namespace 98 99 LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc, 100 Value &input1, Value &input2) { 101 auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType()); 102 auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType()); 103 104 if (!input1Ty || !input2Ty) { 105 return failure(); 106 } 107 108 int64_t input1Rank = input1Ty.getRank(); 109 int64_t input2Rank = input2Ty.getRank(); 110 111 if (input1Rank == input2Rank) 112 return success(); 113 114 Value higherTensorValue, lowerTensorValue; 115 if (input1Rank > input2Rank) { 116 higherTensorValue = input1; 117 lowerTensorValue = input2; 118 } else { 119 higherTensorValue = input2; 120 lowerTensorValue = input1; 121 } 122 123 ArrayRef<int64_t> higherRankShape = 124 llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape(); 125 ArrayRef<int64_t> lowerRankShape = 126 llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape(); 127 128 SmallVector<int64_t, 4> reshapeOutputShape; 129 130 if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) 131 .failed()) 132 return failure(); 133 134 auto reshapeInputType = 135 llvm::cast<RankedTensorType>(lowerTensorValue.getType()); 136 auto reshapeOutputType = RankedTensorType::get( 137 ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType()); 138 139 auto reshapeLower = rewriter.create<tosa::ReshapeOp>( 140 loc, reshapeOutputType, lowerTensorValue, 141 rewriter.getDenseI64ArrayAttr(reshapeOutputShape)); 142 143 if (input1Rank > input2Rank) { 144 input1 = higherTensorValue; 145 input2 = reshapeLower.getResult(); 146 } else { 147 input1 = reshapeLower.getResult(); 148 input2 = higherTensorValue; 149 } 150 151 return success(); 152 } 153