1 //===- ConversionUtils.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Utility functions for TOSA lowering 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 14 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15 16 using namespace mlir; 17 using namespace mlir::tosa; 18 19 SmallVector<utils::IteratorType> 20 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) { 21 return SmallVector<utils::IteratorType>(nParallelLoops, 22 utils::IteratorType::parallel); 23 } 24 25 SmallVector<Value> 26 mlir::tosa::condenseValues(const SmallVector<Value> &values) { 27 SmallVector<Value> condensedValues; 28 for (auto value : values) 29 if (value) 30 condensedValues.push_back(value); 31 return condensedValues; 32 } 33 34 Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min, 35 Value max, OpBuilder &rewriter) { 36 Value minValue = rewriter.create<arith::MinFOp>(loc, arg, max); 37 return rewriter.create<arith::MaxFOp>(loc, minValue, min); 38 } 39 40 Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max, 41 OpBuilder &rewriter) { 42 auto smallerThanMin = 43 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min); 44 auto minOrArg = 45 rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg); 46 auto largerThanMax = 47 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg); 48 return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg); 49 } 50 51 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) { 52 uint64_t bitwidth = ty.getIntOrFloatBitWidth(); 53 if (ty.getSignedness() == IntegerType::Unsigned) { 54 uint64_t uvalue = value; 55 APInt intMin = APInt::getMinValue(bitwidth); 56 APInt intMax = APInt::getMaxValue(bitwidth); 57 return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue(); 58 } 59 60 APInt intMin = APInt::getSignedMinValue(bitwidth); 61 APInt intMax = APInt::getSignedMaxValue(bitwidth); 62 return value >= intMin.getSExtValue() && value <= intMax.getSExtValue(); 63 } 64 65 namespace { 66 // Given two tensors of high and low ranks, derive the output shape 67 // to reshape the lower rank to. 68 // Examples: 69 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. 70 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. 71 // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. 72 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. 73 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. 74 LogicalResult 75 computeReshapeOutput(ArrayRef<int64_t> higherRankShape, 76 ArrayRef<int64_t> lowerRankShape, 77 SmallVectorImpl<int64_t> &reshapeOutputShape) { 78 // Initialize new shapes with [1] * higherRank. 79 int64_t higherRank = higherRankShape.size(); 80 int64_t lowerRank = lowerRankShape.size(); 81 82 reshapeOutputShape.assign(higherRank, 1); 83 84 int64_t higherRankDim; 85 int64_t lowerRankDim; 86 87 for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; 88 i--, j--) { 89 higherRankDim = higherRankShape[i]; 90 lowerRankDim = lowerRankShape[j]; 91 92 if (lowerRankDim == 1 && higherRankDim > 1) 93 reshapeOutputShape[i] = 1; 94 else if ((lowerRankDim > 1 && higherRankDim == 1) || 95 (lowerRankDim == higherRankDim)) 96 reshapeOutputShape[i] = lowerRankDim; 97 else if (higherRankDim != lowerRankDim) 98 return failure(); 99 } 100 return success(); 101 } 102 } // namespace 103 104 LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc, 105 Value &input1, Value &input2) { 106 auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType()); 107 auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType()); 108 109 if (!input1Ty || !input2Ty) { 110 return failure(); 111 } 112 113 int64_t input1Rank = input1Ty.getRank(); 114 int64_t input2Rank = input2Ty.getRank(); 115 116 if (input1Rank == input2Rank) 117 return success(); 118 119 Value higherTensorValue, lowerTensorValue; 120 if (input1Rank > input2Rank) { 121 higherTensorValue = input1; 122 lowerTensorValue = input2; 123 } else { 124 higherTensorValue = input2; 125 lowerTensorValue = input1; 126 } 127 128 ArrayRef<int64_t> higherRankShape = 129 llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape(); 130 ArrayRef<int64_t> lowerRankShape = 131 llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape(); 132 133 SmallVector<int64_t, 4> reshapeOutputShape; 134 135 if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) 136 .failed()) 137 return failure(); 138 139 auto reshapeInputType = 140 llvm::cast<RankedTensorType>(lowerTensorValue.getType()); 141 auto reshapeOutputType = RankedTensorType::get( 142 ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType()); 143 144 auto reshapeLower = rewriter.create<tosa::ReshapeOp>( 145 loc, reshapeOutputType, lowerTensorValue, 146 rewriter.getDenseI64ArrayAttr(reshapeOutputShape)); 147 148 if (input1Rank > input2Rank) { 149 input1 = higherTensorValue; 150 input2 = reshapeLower.getResult(); 151 } else { 152 input1 = reshapeLower.getResult(); 153 input2 = higherTensorValue; 154 } 155 156 return success(); 157 } 158