1 //===- ConversionUtils.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Utility functions for TOSA lowering 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 14 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15 16 using namespace mlir; 17 using namespace mlir::tosa; 18 19 SmallVector<utils::IteratorType> 20 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) { 21 return SmallVector<utils::IteratorType>(nParallelLoops, 22 utils::IteratorType::parallel); 23 } 24 25 SmallVector<Value> 26 mlir::tosa::condenseValues(const SmallVector<Value> &values) { 27 SmallVector<Value> condensedValues; 28 for (auto value : values) 29 if (value) 30 condensedValues.push_back(value); 31 return condensedValues; 32 } 33 34 Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min, 35 Value max, OpBuilder &rewriter) { 36 Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max); 37 return rewriter.create<arith::MaximumFOp>(loc, minValue, min); 38 } 39 40 Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max, 41 OpBuilder &rewriter, bool isUnsigned) { 42 if (isUnsigned) { 43 auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg); 44 return rewriter.create<arith::MinUIOp>(loc, max, minOrArg); 45 } 46 auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg); 47 return rewriter.create<arith::MinSIOp>(loc, max, minOrArg); 48 } 49 50 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) { 51 uint64_t bitwidth = ty.getIntOrFloatBitWidth(); 52 if (ty.getSignedness() == IntegerType::Unsigned) { 53 uint64_t uvalue = value; 54 APInt intMin = APInt::getMinValue(bitwidth); 55 APInt intMax = APInt::getMaxValue(bitwidth); 56 return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue(); 57 } 58 59 APInt intMin = APInt::getSignedMinValue(bitwidth); 60 APInt intMax = APInt::getSignedMaxValue(bitwidth); 61 return value >= intMin.getSExtValue() && value <= intMax.getSExtValue(); 62 } 63 64 namespace { 65 // Given two tensors of high and low ranks, derive the output shape 66 // to reshape the lower rank to. 67 // Examples: 68 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. 69 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. 70 // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. 71 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. 72 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. 73 LogicalResult 74 computeReshapeOutput(ArrayRef<int64_t> higherRankShape, 75 ArrayRef<int64_t> lowerRankShape, 76 SmallVectorImpl<int64_t> &reshapeOutputShape) { 77 // Initialize new shapes with [1] * higherRank. 78 int64_t higherRank = higherRankShape.size(); 79 int64_t lowerRank = lowerRankShape.size(); 80 81 reshapeOutputShape.assign(higherRank, 1); 82 83 int64_t higherRankDim; 84 int64_t lowerRankDim; 85 86 for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; 87 i--, j--) { 88 higherRankDim = higherRankShape[i]; 89 lowerRankDim = lowerRankShape[j]; 90 91 if (lowerRankDim == 1 && higherRankDim > 1) 92 reshapeOutputShape[i] = 1; 93 else if ((lowerRankDim > 1 && higherRankDim == 1) || 94 (lowerRankDim == higherRankDim)) 95 reshapeOutputShape[i] = lowerRankDim; 96 else if (higherRankDim != lowerRankDim) 97 return failure(); 98 } 99 return success(); 100 } 101 } // namespace 102 103 LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc, 104 Value &input1, Value &input2) { 105 auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType()); 106 auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType()); 107 108 if (!input1Ty || !input2Ty) { 109 return failure(); 110 } 111 112 int64_t input1Rank = input1Ty.getRank(); 113 int64_t input2Rank = input2Ty.getRank(); 114 115 if (input1Rank == input2Rank) 116 return success(); 117 118 Value higherTensorValue, lowerTensorValue; 119 if (input1Rank > input2Rank) { 120 higherTensorValue = input1; 121 lowerTensorValue = input2; 122 } else { 123 higherTensorValue = input2; 124 lowerTensorValue = input1; 125 } 126 127 ArrayRef<int64_t> higherRankShape = 128 llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape(); 129 ArrayRef<int64_t> lowerRankShape = 130 llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape(); 131 132 SmallVector<int64_t, 4> reshapeOutputShape; 133 134 if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) 135 .failed()) 136 return failure(); 137 138 auto reshapeInputType = 139 llvm::cast<RankedTensorType>(lowerTensorValue.getType()); 140 auto reshapeOutputType = RankedTensorType::get( 141 ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType()); 142 143 auto reshapeLower = rewriter.create<tosa::ReshapeOp>( 144 loc, reshapeOutputType, lowerTensorValue, 145 rewriter.getDenseI64ArrayAttr(reshapeOutputShape)); 146 147 if (input1Rank > input2Rank) { 148 input1 = higherTensorValue; 149 input2 = reshapeLower.getResult(); 150 } else { 151 input1 = reshapeLower.getResult(); 152 input2 = higherTensorValue; 153 } 154 155 return success(); 156 } 157