//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // \file // TOSA canonicalization patterns and folders. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" #include using namespace mlir; using namespace mlir::tosa; //===----------------------------------------------------------------------===// // Operator Canonicalizers. //===----------------------------------------------------------------------===// struct ConcatOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override { if (op.getInput1().size() != 1) return failure(); if (op.getInput1().front().getType() != op.getType()) { rewriter .replaceOpWithNewOp(op, op.getType(), op.getInput1().front()) .getResult(); return success(); } rewriter.replaceOp(op, op.getInput1().front()); return success(); } }; void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { auto notOp = op.getPred().getDefiningOp(); if (!notOp) return failure(); rewriter.modifyOpInPlace(op, [&]() { op.getOperation()->setOperands( {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()}); }); return success(); } struct ConsolidateTransposeOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override { // Input is also TransposeOp - transpose(transpose(A)). auto innerTranspose = transposeOp.getInput1().getDefiningOp(); if (!innerTranspose) return rewriter.notifyMatchFailure(transposeOp, "input must be transpose operation"); SmallVector transposePerms, innerTransposePerms; if (transposeOp.getConstantPerms(transposePerms).failed()) return rewriter.notifyMatchFailure(transposeOp, "transpose perms must be constant"); if (innerTranspose.getConstantPerms(innerTransposePerms).failed()) return rewriter.notifyMatchFailure( transposeOp, "inner transpose perms must be constant"); if (transposePerms.size() != innerTransposePerms.size()) return rewriter.notifyMatchFailure( transposeOp, "transpose and inner transpose perms sizes must be equal"); if (transposePerms.empty()) return rewriter.notifyMatchFailure( transposeOp, "transpose perms sizes must be positive"); // Consolidate transposes into one transpose. SmallVector perms(transposePerms.size()); for (int i = 0, s = transposePerms.size(); i < s; ++i) perms[i] = innerTransposePerms[transposePerms[i]]; auto permsTy = RankedTensorType::get(transposePerms.size(), rewriter.getI32Type()); auto permsAttr = DenseIntElementsAttr::get(permsTy, perms); Value permsValue = rewriter.create(transposeOp.getLoc(), permsAttr); rewriter.replaceOpWithNewOp( transposeOp, transposeOp.getResult().getType(), innerTranspose.getInput1(), permsValue); return success(); } }; // Determines the case when tosa.transpose is a tosa.reshape operation. struct TransposeIsReshape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override { DenseIntElementsAttr permAttr; if (!matchPattern(op.getPerms(), m_Constant(&permAttr))) return rewriter.notifyMatchFailure(op, "Non-constant permutation"); if (op.getInput1().getDefiningOp()) return rewriter.notifyMatchFailure( op, "Src is from transpose, can compose transposes"); Value result = op.getResult(); for (Operation *subop : result.getUsers()) { if (dyn_cast_or_null(subop)) return rewriter.notifyMatchFailure( op, "Dest is used by transpose, can compose transposes"); } auto input = op.getInput1(); auto inputTy = llvm::cast(input.getType()); if (!inputTy.hasRank()) return rewriter.notifyMatchFailure(op, "Unranked input."); int64_t numDynDims = 0; for (int i = 0; i < inputTy.getRank(); ++i) if (inputTy.isDynamicDim(i)) numDynDims++; if (numDynDims > 1) return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim."); SmallVector permValues = llvm::to_vector<6>( llvm::map_range(permAttr.getValues(), [](const APInt &val) { return val.getSExtValue(); })); SmallVector nonZeroPerms; nonZeroPerms.reserve(permValues.size()); for (auto idx : permValues) { auto sz = inputTy.getDimSize(idx); if (sz != 1) nonZeroPerms.push_back(idx); } for (int i = 1, s = nonZeroPerms.size(); i < s; ++i) if (nonZeroPerms[i - 1] > nonZeroPerms[i]) return rewriter.notifyMatchFailure(op, "Transpose changes memory layout."); SmallVector newShape; newShape.reserve(inputTy.getRank()); for (int i = 0, s = inputTy.getRank(); i < s; ++i) newShape.push_back(inputTy.getDimSize(permValues[i])); rewriter.replaceOpWithNewOp( op, op.getType(), op.getInput1(), rewriter.getDenseI64ArrayAttr(newShape)); return success(); } }; void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } struct MaterializePadValue : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override { if (op.getPadConst()) return failure(); auto input = op.getInput1(); auto padding = op.getPadding(); ShapedType inputTy = llvm::cast(input.getType()); Type elementTy = inputTy.getElementType(); Attribute constantAttr; if (llvm::isa(elementTy)) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); } else if (llvm::isa(elementTy) && !op.getQuantizationInfo()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); } else if (llvm::isa(elementTy) && op.getQuantizationInfo()) { auto value = op.getQuantizationInfo()->getInputZp(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } if (!constantAttr) { return rewriter.notifyMatchFailure( op, "tosa.pad to linalg lowering encountered an unknown element type"); } auto denseAttr = DenseElementsAttr::get( RankedTensorType::get({}, elementTy), constantAttr); auto constantVal = rewriter.create( op.getLoc(), denseAttr.getType(), denseAttr); rewriter.replaceOpWithNewOp( op, op.getType(), ValueRange{input, padding, constantVal}, op->getAttrs()); return success(); } }; void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } struct MaxPool2dIsNoOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override { Value input = op.getInput(); Value output = op.getOutput(); ShapedType inputType = llvm::cast(input.getType()); ShapedType outputType = llvm::cast(output.getType()); if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { return failure(); } // If the output and input shapes are 1x1, then this is a no op. ArrayRef outputShape = outputType.getShape(); if (outputShape[1] != 1 || outputShape[2] != 1) { return failure(); } ArrayRef inputShape = inputType.getShape(); if (inputShape[1] != 1 || inputShape[2] != 1) { return failure(); } rewriter.replaceOp(op, input); return success(); } }; void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } struct ClampIsNoOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override { Value input = op.getInput(); auto inputType = llvm::dyn_cast(op.getInput().getType()); auto inputElementType = inputType.getElementType(); if (!inputType.hasStaticShape()) { return failure(); } if (isa(inputElementType)) { // Unlike integer types, floating point types can represent infinity. auto minClamp = op.getMinFp(); auto maxClamp = op.getMaxFp(); bool isMin = minClamp.isInfinity() && minClamp.isNegative(); bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative(); if (isMin && isMax) { rewriter.replaceOp(op, input); return success(); } return failure(); } if (inputElementType.isUnsignedInteger()) { int64_t minClamp = op.getMinInt(); int64_t maxClamp = op.getMaxInt(); int64_t intMin = APInt::getMinValue(inputElementType.getIntOrFloatBitWidth()) .getZExtValue(); int64_t intMax = APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth()) .getZExtValue(); if (minClamp <= intMin && maxClamp >= intMax) { rewriter.replaceOp(op, input); return success(); } return failure(); } if (llvm::isa(inputElementType)) { int64_t minClamp = op.getMinInt(); int64_t maxClamp = op.getMaxInt(); int64_t intMin = APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth()) .getSExtValue(); int64_t intMax = APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth()) .getSExtValue(); if (minClamp <= intMin && maxClamp >= intMax) { rewriter.replaceOp(op, input); return success(); } return failure(); } return failure(); } }; // Attempts the following transformation: // // For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input // tensor X the following identity holds: // // CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b')) // // subject to the following valid NaN propagation semantics: // -------------------------------------------- // | OUTER CLAMP | INNER CLAMP | RESULT MODE | // |-------------|--------------|-------------| // | PROPAGATE | PROPAGATE | PROPAGATE | // | PROPAGATE | IGNORE | IGNORE | // | IGNORE | PROPAGATE | INVALID | // | IGNORE | IGNORE | IGNORE | // |------------------------------------------| struct ClampClampOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; // Helper structure to describe the range of a clamp operation. template struct ClampRange { ClampRange(const T &start, const T &end) : start(start), end(end) {} T start; T end; // Helper function to determine if two Clamp ranges intersect. bool intersects(const ClampRange &otherRange) { return start < otherRange.end && otherRange.start < end; } }; LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override { // Check the input to the CLAMP op is itself a CLAMP. auto clampOp = dyn_cast_if_present(op.getInput().getDefiningOp()); if (!clampOp) return failure(); // Check we have a valid NaN propagation combination. const auto opNanMode = op.getNanMode(); const auto clampNanMode = clampOp.getNanMode(); if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE") return failure(); // Check we have intersecting ranges. const auto opMinInt = op.getMinInt(); const auto opMaxInt = op.getMaxInt(); const auto clampOpMinInt = clampOp.getMinInt(); const auto clampOpMaxInt = clampOp.getMaxInt(); ClampRange opRangeIntRange(opMinInt, opMaxInt); ClampRange clampRangeIntRange(clampOpMinInt, clampOpMaxInt); if (!opRangeIntRange.intersects(clampRangeIntRange)) return failure(); const auto opMinFloat = op.getMinFp(); const auto opMaxFloat = op.getMaxFp(); const auto clampOpMinFloat = clampOp.getMinFp(); const auto clampOpMaxFloat = clampOp.getMaxFp(); ClampRange opRangeFloatRange(opMinFloat, opMaxFloat); ClampRange clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat); if (!opRangeFloatRange.intersects(clampRangeFloatRange)) return failure(); // Run the transformation. const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat(); const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat(); const auto minInt = std::max(opMinInt, clampOpMinInt); const auto maxInt = std::min(opMaxInt, clampOpMaxInt); rewriter.replaceOpWithNewOp( op, op.getType(), clampOp.getInput(), rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp), rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE" : opNanMode)); return success(); } }; void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); results.add(context); } struct ConcatSliceOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override { Value sliceInput = sliceOp.getInput1(); auto concatOp = sliceInput.getDefiningOp(); if (!concatOp) return rewriter.notifyMatchFailure( sliceOp, "slice input must be concat operation"); OperandRange inputs = concatOp.getInput1(); auto concatType = dyn_cast(concatOp.getType()); if (!concatType || !concatType.hasStaticShape()) return rewriter.notifyMatchFailure( sliceOp, "slice input must be a static ranked tensor"); int32_t axis = concatOp.getAxis(); DenseElementsAttr startElems; DenseElementsAttr sizeElems; if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems))) return rewriter.notifyMatchFailure( sliceOp, "start of slice must be a static ranked shape"); if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) return rewriter.notifyMatchFailure( sliceOp, "size of slice must be a static ranked shape"); llvm::SmallVector sliceStarts = llvm::to_vector(startElems.getValues()); llvm::SmallVector sliceSizes = llvm::to_vector(sizeElems.getValues()); // Validate slice on the concatenated axis. Slicing along this // axis should span only one of the inputs to the concatenate // operation. std::optional replaceWithSlice; for (auto input : inputs) { auto inputType = dyn_cast(input.getType()); if (!inputType || !inputType.hasStaticShape()) return rewriter.notifyMatchFailure( sliceOp, "concat input must be a static ranked tensor"); if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <= inputType.getDimSize(axis)) { auto start_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts); auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes); replaceWithSlice = rewriter .create(sliceOp.getLoc(), sliceOp.getType(), input, start_op, size_op) .getResult(); break; } sliceStarts[axis] -= inputType.getDimSize(axis); } if (!replaceWithSlice) return rewriter.notifyMatchFailure( sliceOp, "corresponding concat input not found for slice"); rewriter.replaceOp(sliceOp, replaceWithSlice.value()); return success(); } }; void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// template DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy) { if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { auto lETy = llvm::cast(lhs.getType()).getElementType(); auto rETy = llvm::cast(rhs.getType()).getElementType(); if (lETy != rETy) return {}; if (llvm::isa(lETy)) { APInt l = lhs.getSplatValue(); APInt r = rhs.getSplatValue(); auto result = IntFolder()(l, r); return DenseElementsAttr::get(returnTy, result); } if (llvm::isa(lETy)) { APFloat l = lhs.getSplatValue(); APFloat r = rhs.getSplatValue(); auto result = FloatFolder()(l, r); return DenseElementsAttr::get(returnTy, result); } } return {}; } static bool isSplatZero(Type elemType, DenseElementsAttr val) { if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isZero(); if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isZero(); return false; } static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) { if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isExactlyValue(1.0); if (llvm::isa(elemType)) { const int64_t shifted = 1LL << shift; return val && val.isSplat() && val.getSplatValue().getSExtValue() == shifted; } return false; } OpFoldResult AddOp::fold(FoldAdaptor adaptor) { auto lhsTy = llvm::dyn_cast(getInput1().getType()); auto rhsTy = llvm::dyn_cast(getInput2().getType()); auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; // Cannot create an ElementsAttr from non-int/float/index types if (!lhsTy.getElementType().isIntOrIndexOrFloat() || !rhsTy.getElementType().isIntOrIndexOrFloat()) return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) return getInput1(); if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr)) return getInput2(); if (!lhsAttr || !rhsAttr) return {}; return binaryFolder, std::plus>(lhsAttr, rhsAttr, resultTy); } OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { auto inputTy = llvm::dyn_cast(getInput().getType()); auto outputTy = llvm::dyn_cast(getType()); if (!inputTy || !outputTy || !inputTy.hasStaticShape() || !outputTy.hasStaticShape()) return {}; if (inputTy.getDimSize(getAxis()) == 1) return DenseElementsAttr::get(outputTy, 0); return {}; } OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) { auto lhsTy = llvm::dyn_cast(getInput1().getType()); auto rhsTy = llvm::dyn_cast(getInput2().getType()); auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; if (lhsTy != rhsTy) return {}; // IntDivOp inputs must be integer type, no need to check for quantized type auto resultETy = resultTy.getElementType(); auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); if (lhsAttr && lhsAttr.isSplat()) { if (llvm::isa(resultETy) && lhsAttr.getSplatValue().isZero()) return lhsAttr; } if (rhsAttr && rhsAttr.isSplat()) { if (llvm::isa(resultETy) && rhsAttr.getSplatValue().isOne()) return getInput1(); } if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) { if (llvm::isa(resultETy)) { APInt l = lhsAttr.getSplatValue(); APInt r = rhsAttr.getSplatValue(); APInt result = l.sdiv(r); return DenseElementsAttr::get(resultTy, result); } } return {}; } namespace { DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType ty, int32_t shift) { if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { if (llvm::isa(ty.getElementType())) { APInt l = lhs.getSplatValue(); APInt r = rhs.getSplatValue(); if (shift == 0) { return DenseElementsAttr::get(ty, l * r); } auto bitwidth = ty.getElementType().getIntOrFloatBitWidth(); l = l.sext(bitwidth * 2); r = r.sext(bitwidth * 2); auto result = l * r; result.lshrInPlace(shift); result = result.trunc(bitwidth); return DenseElementsAttr::get(ty, result); } if (llvm::isa(ty.getElementType())) { APFloat l = lhs.getSplatValue(); APFloat r = rhs.getSplatValue(); APFloat result = l * r; return DenseElementsAttr::get(ty, result); } } return {}; } } // namespace OpFoldResult MulOp::fold(FoldAdaptor adaptor) { auto lhs = getInput1(); auto rhs = getInput2(); auto lhsTy = llvm::dyn_cast(lhs.getType()); auto rhsTy = llvm::dyn_cast(rhs.getType()); auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); // Result right shift on i32_t data type only. For simplification, synthesize // a zero shift for other data type. int32_t shift = 0; if (resultETy.isInteger(32)) { ElementsAttr shift_elem; if (getShift().getImpl()) { if (!matchPattern(getShift(), m_Constant(&shift_elem))) // cannot be folded when the shift value is unknown. return {}; shift = shift_elem.getValues()[0].getInt(); } } if (rhsTy == resultTy) { if (isSplatZero(resultETy, lhsAttr)) return lhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, lhsAttr, shift)) return rhs; } if (lhsTy == resultTy) { if (isSplatZero(resultETy, rhsAttr)) return rhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, rhsAttr, shift)) return lhs; } return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift); } OpFoldResult SubOp::fold(FoldAdaptor adaptor) { auto lhsTy = llvm::dyn_cast(getInput1().getType()); auto rhsTy = llvm::dyn_cast(getInput2().getType()); auto resultTy = llvm::dyn_cast(getType()); if (!lhsTy || !rhsTy || !resultTy) return {}; // Cannot create an ElementsAttr from non-int/float/index types if (!lhsTy.getElementType().isIntOrIndexOrFloat() || !rhsTy.getElementType().isIntOrIndexOrFloat()) return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) return getInput1(); if (!lhsAttr || !rhsAttr) return {}; return binaryFolder, std::minus>(lhsAttr, rhsAttr, resultTy); } namespace { template struct ComparisonFold { ComparisonFold() = default; APInt operator()(const APInt &l, const APInt &r) { return APInt(1, Cmp()(l, r)); } APInt operator()(const APFloat &l, const APFloat &r) { return APInt(1, Cmp()(l, r)); } }; struct APIntFoldGreater { APIntFoldGreater() = default; APInt operator()(const APInt &l, const APInt &r) { return APInt(1, l.sgt(r)); } }; struct APIntFoldGreaterEqual { APIntFoldGreaterEqual() = default; APInt operator()(const APInt &l, const APInt &r) { return APInt(1, l.sge(r)); } }; } // namespace OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast(getType()); auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); if (!lhsAttr || !rhsAttr) return {}; return binaryFolder>>( lhsAttr, rhsAttr, resultTy); } OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast(getType()); auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); if (!lhsAttr || !rhsAttr) return {}; return binaryFolder>>( lhsAttr, rhsAttr, resultTy); } OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast(getType()); auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); Value lhs = getInput1(); Value rhs = getInput2(); auto lhsTy = llvm::cast(lhs.getType()); // If we are comparing an integer value to itself it is always true. We can // not do this with float due to float values. if (llvm::isa(lhsTy.getElementType()) && resultTy && resultTy.hasStaticShape() && lhs == rhs) { return DenseElementsAttr::get(resultTy, true); } if (!lhsAttr || !rhsAttr) return {}; return binaryFolder>, ComparisonFold>>(lhsAttr, rhsAttr, resultTy); } OpFoldResult CastOp::fold(FoldAdaptor adaptor) { if (getInput().getType() == getType()) return getInput(); auto operand = llvm::dyn_cast_if_present(adaptor.getInput()); if (!operand) return {}; auto inTy = llvm::cast(getInput().getType()); auto outTy = llvm::cast(getType()); auto inETy = inTy.getElementType(); auto outETy = outTy.getElementType(); if (operand.isSplat()) { if (llvm::isa(inETy) && llvm::isa(outETy)) { bool overflow; auto splatVal = operand.getSplatValue(); auto &semantics = llvm::cast(outETy).getFloatSemantics(); splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven, &overflow); return SplatElementsAttr::get(outTy, splatVal); } if (llvm::isa(inETy) && llvm::isa(outETy)) { auto unsign = llvm::cast(inETy).isUnsignedInteger(); APFloat splatVal(llvm::cast(outETy).getFloatSemantics()); splatVal.convertFromAPInt(operand.getSplatValue(), !unsign, llvm::RoundingMode::NearestTiesToEven); return SplatElementsAttr::get(outTy, splatVal); } if (llvm::isa(inETy) && llvm::isa(outETy)) { auto unsign = llvm::cast(outETy).isUnsignedInteger(); auto intVal = APSInt( llvm::cast(outETy).getIntOrFloatBitWidth(), unsign); auto floatVal = operand.getSplatValue(); bool exact; floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven, &exact); return SplatElementsAttr::get(outTy, intVal); } if (llvm::isa(inETy) && llvm::isa(outETy)) { auto unsignIn = llvm::cast(inETy).isUnsignedInteger(); bool trunc = inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth(); auto intVal = operand.getSplatValue(); auto bitwidth = outETy.getIntOrFloatBitWidth(); if (trunc) { intVal = intVal.trunc(bitwidth); } else if (unsignIn) { intVal = intVal.zext(bitwidth); } else { intVal = intVal.sext(bitwidth); } return SplatElementsAttr::get(outTy, intVal); } } return {}; } OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } #define REDUCE_FOLDER(OP) \ OpFoldResult OP::fold(FoldAdaptor adaptor) { \ ShapedType inputTy = llvm::cast(getInput().getType()); \ if (!inputTy.hasRank()) \ return {}; \ if (inputTy != getType()) \ return {}; \ if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \ return getInput(); \ return {}; \ } REDUCE_FOLDER(ReduceAllOp) REDUCE_FOLDER(ReduceAnyOp) REDUCE_FOLDER(ReduceMaxOp) REDUCE_FOLDER(ReduceMinOp) REDUCE_FOLDER(ReduceProdOp) REDUCE_FOLDER(ReduceSumOp) #undef REDUCE_FOLDER OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { auto inputTy = llvm::dyn_cast(getInput1().getType()); auto outputTy = llvm::dyn_cast(getType()); if (!inputTy || !outputTy) return {}; // Fold when the input and output types are the same. This is only safe when // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions, // there may still be a productive reshape. if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2) return getInput1(); // reshape(reshape(x)) -> reshape(x) if (auto reshapeOp = llvm::dyn_cast_if_present( getInput1().getDefiningOp())) { getInput1Mutable().assign(reshapeOp.getInput1()); return getResult(); } // Cannot create an ElementsAttr from non-int/float/index types if (!inputTy.getElementType().isIntOrIndexOrFloat()) return {}; // reshape(const(x)) -> const(reshape-attr(x)) if (auto operand = llvm::dyn_cast_if_present(adaptor.getInput1())) { // Constants must have static shape. if (!outputTy.hasStaticShape()) return {}; // Okay to duplicate splat constants. if (operand.isSplat()) return SplatElementsAttr::get(outputTy, operand.getSplatValue()); // Don't duplicate other constants. if (!getInput1().hasOneUse()) return {}; return operand.reshape( llvm::cast(operand.getType()).clone(getNewShape())); } return {}; } OpFoldResult PadOp::fold(FoldAdaptor adaptor) { // If the pad is all zeros we can fold this operation away. if (adaptor.getPadding() && getInput1().getType() == getType()) { auto densePad = llvm::dyn_cast(adaptor.getPadding()); if (densePad && densePad.isSplat() && densePad.getSplatValue().isZero()) { return getInput1(); } } return {}; } // Fold away cases where a tosa.resize operation returns a copy // of the input image. OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) { ArrayRef offset = getOffset(); ArrayRef border = getBorder(); ArrayRef scale = getScale(); // Check unit scaling. if (scale[0] != scale[1] || scale[2] != scale[3]) { return {}; } // There should be no offset. if (offset[0] != 0 || offset[1] != 0) { return {}; } // There should be no border. if (border[0] != 0 || border[1] != 0) { return {}; } auto input = getInput(); auto inputTy = llvm::cast(input.getType()); auto resultTy = llvm::cast(getType()); if (inputTy != resultTy) return {}; return input; } OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { auto operand = getInput1(); auto operandTy = llvm::cast(operand.getType()); auto axis = getAxis(); auto operandAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); if (operandAttr) return operandAttr; // If the dim-length is 1, tosa.reverse is a no-op. if (operandTy.hasRank() && (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1)) return operand; return {}; } OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { auto inputTy = llvm::dyn_cast(getInput1().getType()); auto outputTy = llvm::dyn_cast(getType()); if (!inputTy || !outputTy) return {}; if (inputTy == outputTy && inputTy.hasStaticShape()) return getInput1(); if (!adaptor.getInput1()) return {}; // Cannot create an ElementsAttr from non-int/float/index types if (!inputTy.getElementType().isIntOrIndexOrFloat() || !outputTy.getElementType().isIntOrIndexOrFloat()) return {}; auto operand = llvm::cast(adaptor.getInput1()); if (operand.isSplat() && outputTy.hasStaticShape()) { return SplatElementsAttr::get(outputTy, operand.getSplatValue()); } if (inputTy.hasStaticShape() && outputTy.hasStaticShape() && outputTy.getNumElements() == 1) { DenseElementsAttr startElems; if (!matchPattern(getStart(), m_Constant(&startElems))) return {}; llvm::SmallVector indices = llvm::to_vector(startElems.getValues()); auto value = operand.getValues()[indices]; return SplatElementsAttr::get(outputTy, value); } return {}; } OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { if (getOnTrue() == getOnFalse()) return getOnTrue(); auto predicate = llvm::dyn_cast_if_present(adaptor.getPred()); if (!predicate) return {}; if (!predicate.isSplat()) return {}; return predicate.getSplatValue().getBoolValue() ? getOnTrue() : getOnFalse(); } OpFoldResult TileOp::fold(FoldAdaptor adaptor) { if (getInput1().getType() == getType()) { if (auto multiples = llvm::dyn_cast_if_present( adaptor.getMultiples())) { if (multiples.isSplat() && multiples.getSplatValue().getSExtValue() == 1) return getInput1(); if (auto int_array_attr = llvm::dyn_cast(multiples)) { if (llvm::all_of(int_array_attr.getValues(), [](APInt v) { return v.getSExtValue() == 1; })) return getInput1(); } } } return {}; } OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::cast(getType()); // Transposing splat values just means reshaping. if (auto input = llvm::dyn_cast_if_present(adaptor.getInput1())) { if (input.isSplat() && resultTy.hasStaticShape() && input.getType().getElementType() == resultTy.getElementType()) return input.reshape(resultTy); } // Transpose is not the identity transpose. SmallVector perms; if (getConstantPerms(perms).failed()) return {}; if (!llvm::equal(llvm::seq(0, perms.size()), perms)) return {}; return getInput1(); } OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) { auto input = getInput1(); // Element-wise log(exp(x)) = x if (auto op = input.getDefiningOp()) { return op.getInput1(); } return {}; } OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) { auto input = getInput1(); // Element-wise exp(log(x)) = x if (auto op = input.getDefiningOp()) { return op.getInput1(); } return {}; } OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) { auto input = getInput1(); // Element-wise negate(negate(x)) = x if (auto op = input.getDefiningOp()) { return op.getInput1(); } return {}; } OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) { auto input = getInput1(); // Element-wise abs(abs(x)) = abs(x) if (auto op = input.getDefiningOp()) { return input; } return {}; } OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { // Fold consecutive concats on the same axis into a single op. // Keep track of the operands so we are able to construct a new concat // later. Conservatively assume that we double the number of operands when // folding SmallVector concatOperands; concatOperands.reserve(2 * getNumOperands()); // Find all operands that are foldable concats bool foundFoldableConcat = false; for (Value operand : getOperands()) { concatOperands.emplace_back(operand); auto producer = dyn_cast_or_null(operand.getDefiningOp()); if (!producer) continue; // Not foldable if axes are not the same if (getAxis() != producer.getAxis()) continue; // Replace the original operand with all incoming operands foundFoldableConcat = true; concatOperands.pop_back(); llvm::append_range(concatOperands, producer->getOperands()); } if (!foundFoldableConcat) return {}; getOperation()->setOperands(concatOperands); return getResult(); } OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) { auto input = adaptor.getInput1(); auto inputAttr = llvm::dyn_cast_if_present(input); // Fold splat inputs only. if (!inputAttr || !inputAttr.isSplat()) return {}; auto shapeType = llvm::cast(getType()); if (auto floatType = llvm::dyn_cast(inputAttr.getElementType())) { auto floatVal = inputAttr.getSplatValue(); return DenseElementsAttr::get(shapeType, ReciprocalOp::calcOneElement(floatVal)); } return {}; }