1 //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // \file 10 // TOSA canonicalization patterns and folders. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Quant/IR/Quant.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 17 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 19 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 20 #include "mlir/IR/BuiltinTypeInterfaces.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/DialectImplementation.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Transforms/FoldUtils.h" 26 #include "mlir/Transforms/InliningUtils.h" 27 #include "mlir/Transforms/RegionUtils.h" 28 #include "llvm/ADT/APFloat.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/TypeSwitch.h" 32 33 #include <functional> 34 35 using namespace mlir; 36 using namespace mlir::tosa; 37 38 //===----------------------------------------------------------------------===// 39 // Operator Canonicalizers. 40 //===----------------------------------------------------------------------===// 41 42 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> { 43 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern; 44 45 LogicalResult matchAndRewrite(tosa::ConcatOp op, 46 PatternRewriter &rewriter) const override { 47 if (op.getInput1().size() != 1) 48 return failure(); 49 if (op.getInput1().front().getType() != op.getType()) { 50 rewriter 51 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 52 op.getInput1().front()) 53 .getResult(); 54 return success(); 55 } 56 57 rewriter.replaceOp(op, op.getInput1().front()); 58 return success(); 59 } 60 }; 61 62 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, 63 MLIRContext *context) { 64 results.add<ConcatOptimization>(context); 65 } 66 67 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { 68 auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>(); 69 if (!notOp) 70 return failure(); 71 rewriter.modifyOpInPlace(op, [&]() { 72 op.getOperation()->setOperands( 73 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()}); 74 }); 75 return success(); 76 } 77 78 struct ConsolidateTransposeOptimization 79 : public OpRewritePattern<tosa::TransposeOp> { 80 using OpRewritePattern::OpRewritePattern; 81 82 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, 83 PatternRewriter &rewriter) const override { 84 // Input is also TransposeOp - transpose(transpose(A)). 85 auto innerTranspose = 86 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>(); 87 if (!innerTranspose) 88 return rewriter.notifyMatchFailure(transposeOp, 89 "input must be transpose operation"); 90 91 SmallVector<int32_t> transposePerms, innerTransposePerms; 92 if (transposeOp.getConstantPerms(transposePerms).failed()) 93 return rewriter.notifyMatchFailure(transposeOp, 94 "transpose perms must be constant"); 95 if (innerTranspose.getConstantPerms(innerTransposePerms).failed()) 96 return rewriter.notifyMatchFailure( 97 transposeOp, "inner transpose perms must be constant"); 98 if (transposePerms.size() != innerTransposePerms.size()) 99 return rewriter.notifyMatchFailure( 100 transposeOp, 101 "transpose and inner transpose perms sizes must be equal"); 102 if (transposePerms.empty()) 103 return rewriter.notifyMatchFailure( 104 transposeOp, "transpose perms sizes must be positive"); 105 106 // Consolidate transposes into one transpose. 107 SmallVector<int32_t> perms(transposePerms.size()); 108 for (int i = 0, s = transposePerms.size(); i < s; ++i) 109 perms[i] = innerTransposePerms[transposePerms[i]]; 110 111 auto permsTy = 112 RankedTensorType::get(transposePerms.size(), rewriter.getI32Type()); 113 auto permsAttr = DenseIntElementsAttr::get(permsTy, perms); 114 Value permsValue = 115 rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr); 116 117 rewriter.replaceOpWithNewOp<tosa::TransposeOp>( 118 transposeOp, transposeOp.getResult().getType(), 119 innerTranspose.getInput1(), permsValue); 120 121 return success(); 122 } 123 }; 124 125 // Determines the case when tosa.transpose is a tosa.reshape operation. 126 struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> { 127 using OpRewritePattern::OpRewritePattern; 128 129 LogicalResult matchAndRewrite(tosa::TransposeOp op, 130 PatternRewriter &rewriter) const override { 131 DenseIntElementsAttr permAttr; 132 if (!matchPattern(op.getPerms(), m_Constant(&permAttr))) 133 return rewriter.notifyMatchFailure(op, "Non-constant permutation"); 134 135 if (op.getInput1().getDefiningOp<tosa::TransposeOp>()) 136 return rewriter.notifyMatchFailure( 137 op, "Src is from transpose, can compose transposes"); 138 139 Value result = op.getResult(); 140 for (Operation *subop : result.getUsers()) { 141 if (dyn_cast_or_null<tosa::TransposeOp>(subop)) 142 return rewriter.notifyMatchFailure( 143 op, "Dest is used by transpose, can compose transposes"); 144 } 145 146 auto input = op.getInput1(); 147 auto inputTy = llvm::cast<ShapedType>(input.getType()); 148 if (!inputTy.hasRank()) 149 return rewriter.notifyMatchFailure(op, "Unranked input."); 150 151 int64_t numDynDims = 0; 152 for (int i = 0; i < inputTy.getRank(); ++i) 153 if (inputTy.isDynamicDim(i)) 154 numDynDims++; 155 156 if (numDynDims > 1) 157 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim."); 158 159 SmallVector<int64_t> permValues = llvm::to_vector<6>( 160 llvm::map_range(permAttr.getValues<APInt>(), 161 [](const APInt &val) { return val.getSExtValue(); })); 162 163 SmallVector<int64_t> nonZeroPerms; 164 nonZeroPerms.reserve(permValues.size()); 165 for (auto idx : permValues) { 166 auto sz = inputTy.getDimSize(idx); 167 if (sz != 1) 168 nonZeroPerms.push_back(idx); 169 } 170 171 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i) 172 if (nonZeroPerms[i - 1] > nonZeroPerms[i]) 173 return rewriter.notifyMatchFailure(op, 174 "Transpose changes memory layout."); 175 176 SmallVector<int64_t> newShape; 177 newShape.reserve(inputTy.getRank()); 178 for (int i = 0, s = inputTy.getRank(); i < s; ++i) 179 newShape.push_back(inputTy.getDimSize(permValues[i])); 180 181 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( 182 op, op.getType(), op.getInput1(), 183 rewriter.getDenseI64ArrayAttr(newShape)); 184 return success(); 185 } 186 }; 187 188 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, 189 MLIRContext *context) { 190 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context); 191 } 192 193 struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> { 194 using OpRewritePattern::OpRewritePattern; 195 196 LogicalResult matchAndRewrite(tosa::PadOp op, 197 PatternRewriter &rewriter) const override { 198 if (op.getPadConst()) 199 return failure(); 200 201 auto input = op.getInput1(); 202 auto padding = op.getPadding(); 203 204 ShapedType inputTy = llvm::cast<ShapedType>(input.getType()); 205 Type elementTy = inputTy.getElementType(); 206 207 Attribute constantAttr; 208 if (llvm::isa<FloatType>(elementTy)) { 209 constantAttr = rewriter.getFloatAttr(elementTy, 0.0); 210 } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) { 211 constantAttr = rewriter.getIntegerAttr(elementTy, 0); 212 } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) { 213 auto value = op.getQuantizationInfo()->getInputZp(); 214 constantAttr = rewriter.getIntegerAttr(elementTy, value); 215 } 216 217 if (!constantAttr) { 218 return rewriter.notifyMatchFailure( 219 op, 220 "tosa.pad to linalg lowering encountered an unknown element type"); 221 } 222 223 auto denseAttr = DenseElementsAttr::get( 224 RankedTensorType::get({}, elementTy), constantAttr); 225 auto constantVal = rewriter.create<tosa::ConstOp>( 226 op.getLoc(), denseAttr.getType(), denseAttr); 227 228 rewriter.replaceOpWithNewOp<tosa::PadOp>( 229 op, op.getType(), ValueRange{input, padding, constantVal}, 230 op->getAttrs()); 231 return success(); 232 } 233 }; 234 235 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, 236 MLIRContext *context) { 237 results.add<MaterializePadValue>(context); 238 } 239 240 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> { 241 using OpRewritePattern::OpRewritePattern; 242 243 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, 244 PatternRewriter &rewriter) const override { 245 Value input = op.getInput(); 246 Value output = op.getOutput(); 247 ShapedType inputType = llvm::cast<ShapedType>(input.getType()); 248 ShapedType outputType = llvm::cast<ShapedType>(output.getType()); 249 250 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { 251 return failure(); 252 } 253 254 // If the output and input shapes are 1x1, then this is a no op. 255 ArrayRef<int64_t> outputShape = outputType.getShape(); 256 if (outputShape[1] != 1 || outputShape[2] != 1) { 257 return failure(); 258 } 259 260 ArrayRef<int64_t> inputShape = inputType.getShape(); 261 if (inputShape[1] != 1 || inputShape[2] != 1) { 262 return failure(); 263 } 264 265 rewriter.replaceOp(op, input); 266 return success(); 267 } 268 }; 269 270 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, 271 MLIRContext *context) { 272 results.add<MaxPool2dIsNoOp>(context); 273 } 274 275 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> { 276 using OpRewritePattern::OpRewritePattern; 277 278 LogicalResult matchAndRewrite(tosa::ClampOp op, 279 PatternRewriter &rewriter) const override { 280 Value input = op.getInput(); 281 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType()); 282 auto inputElementType = inputType.getElementType(); 283 284 if (!inputType.hasStaticShape()) { 285 return failure(); 286 } 287 288 if (isa<FloatType>(inputElementType)) { 289 // Unlike integer types, floating point types can represent infinity. 290 auto minClamp = op.getMinFp(); 291 auto maxClamp = op.getMaxFp(); 292 bool isMin = minClamp.isInfinity() && minClamp.isNegative(); 293 bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative(); 294 295 if (isMin && isMax) { 296 rewriter.replaceOp(op, input); 297 return success(); 298 } 299 return failure(); 300 } 301 302 if (inputElementType.isUnsignedInteger()) { 303 int64_t minClamp = op.getMinInt(); 304 int64_t maxClamp = op.getMaxInt(); 305 306 int64_t intMin = 307 APInt::getMinValue(inputElementType.getIntOrFloatBitWidth()) 308 .getZExtValue(); 309 int64_t intMax = 310 APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth()) 311 .getZExtValue(); 312 313 if (minClamp <= intMin && maxClamp >= intMax) { 314 rewriter.replaceOp(op, input); 315 return success(); 316 } 317 return failure(); 318 } 319 320 if (llvm::isa<IntegerType>(inputElementType)) { 321 int64_t minClamp = op.getMinInt(); 322 int64_t maxClamp = op.getMaxInt(); 323 324 int64_t intMin = 325 APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth()) 326 .getSExtValue(); 327 int64_t intMax = 328 APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth()) 329 .getSExtValue(); 330 331 if (minClamp <= intMin && maxClamp >= intMax) { 332 rewriter.replaceOp(op, input); 333 return success(); 334 } 335 return failure(); 336 } 337 338 return failure(); 339 } 340 }; 341 342 // Attempts the following transformation: 343 // 344 // For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input 345 // tensor X the following identity holds: 346 // 347 // CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b')) 348 // 349 // subject to the following valid NaN propagation semantics: 350 // -------------------------------------------- 351 // | OUTER CLAMP | INNER CLAMP | RESULT MODE | 352 // |-------------|--------------|-------------| 353 // | PROPAGATE | PROPAGATE | PROPAGATE | 354 // | PROPAGATE | IGNORE | IGNORE | 355 // | IGNORE | PROPAGATE | INVALID | 356 // | IGNORE | IGNORE | IGNORE | 357 // |------------------------------------------| 358 359 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> { 360 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern; 361 362 // Helper structure to describe the range of a clamp operation. 363 template <typename T> 364 struct ClampRange { 365 ClampRange(const T &start, const T &end) : start(start), end(end) {} 366 T start; 367 T end; 368 369 // Helper function to determine if two Clamp ranges intersect. 370 bool intersects(const ClampRange<T> &otherRange) { 371 return start < otherRange.end && otherRange.start < end; 372 } 373 }; 374 375 LogicalResult matchAndRewrite(tosa::ClampOp op, 376 PatternRewriter &rewriter) const override { 377 // Check the input to the CLAMP op is itself a CLAMP. 378 auto clampOp = 379 dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp()); 380 if (!clampOp) 381 return failure(); 382 383 // Check we have a valid NaN propagation combination. 384 const auto opNanMode = op.getNanMode(); 385 const auto clampNanMode = clampOp.getNanMode(); 386 if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE") 387 return failure(); 388 389 // Check we have intersecting ranges. 390 const auto opMinInt = op.getMinInt(); 391 const auto opMaxInt = op.getMaxInt(); 392 const auto clampOpMinInt = clampOp.getMinInt(); 393 const auto clampOpMaxInt = clampOp.getMaxInt(); 394 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt); 395 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt); 396 if (!opRangeIntRange.intersects(clampRangeIntRange)) 397 return failure(); 398 399 const auto opMinFloat = op.getMinFp(); 400 const auto opMaxFloat = op.getMaxFp(); 401 const auto clampOpMinFloat = clampOp.getMinFp(); 402 const auto clampOpMaxFloat = clampOp.getMaxFp(); 403 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat); 404 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat); 405 if (!opRangeFloatRange.intersects(clampRangeFloatRange)) 406 return failure(); 407 408 // Run the transformation. 409 const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat(); 410 const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat(); 411 const auto minInt = std::max(opMinInt, clampOpMinInt); 412 const auto maxInt = std::min(opMaxInt, clampOpMaxInt); 413 rewriter.replaceOpWithNewOp<tosa::ClampOp>( 414 op, op.getType(), clampOp.getInput(), 415 rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt), 416 rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp), 417 rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE" 418 : opNanMode)); 419 return success(); 420 } 421 }; 422 423 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results, 424 MLIRContext *context) { 425 results.add<ClampIsNoOp>(context); 426 results.add<ClampClampOptimization>(context); 427 } 428 429 struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> { 430 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern; 431 432 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, 433 PatternRewriter &rewriter) const override { 434 Value sliceInput = sliceOp.getInput1(); 435 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>(); 436 if (!concatOp) 437 return rewriter.notifyMatchFailure( 438 sliceOp, "slice input must be concat operation"); 439 440 OperandRange inputs = concatOp.getInput1(); 441 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType()); 442 if (!concatType || !concatType.hasStaticShape()) 443 return rewriter.notifyMatchFailure( 444 sliceOp, "slice input must be a static ranked tensor"); 445 int32_t axis = concatOp.getAxis(); 446 447 DenseElementsAttr startElems; 448 DenseElementsAttr sizeElems; 449 450 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems))) 451 return rewriter.notifyMatchFailure( 452 sliceOp, "start of slice must be a static ranked shape"); 453 454 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) 455 return rewriter.notifyMatchFailure( 456 sliceOp, "size of slice must be a static ranked shape"); 457 458 llvm::SmallVector<int64_t> sliceStarts = 459 llvm::to_vector(startElems.getValues<int64_t>()); 460 llvm::SmallVector<int64_t> sliceSizes = 461 llvm::to_vector(sizeElems.getValues<int64_t>()); 462 463 // Validate slice on the concatenated axis. Slicing along this 464 // axis should span only one of the inputs to the concatenate 465 // operation. 466 std::optional<Value> replaceWithSlice; 467 for (auto input : inputs) { 468 auto inputType = dyn_cast<RankedTensorType>(input.getType()); 469 if (!inputType || !inputType.hasStaticShape()) 470 return rewriter.notifyMatchFailure( 471 sliceOp, "concat input must be a static ranked tensor"); 472 473 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <= 474 inputType.getDimSize(axis)) { 475 auto start_op = 476 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts); 477 auto size_op = 478 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes); 479 replaceWithSlice = 480 rewriter 481 .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(), 482 input, start_op, size_op) 483 .getResult(); 484 break; 485 } 486 sliceStarts[axis] -= inputType.getDimSize(axis); 487 } 488 489 if (!replaceWithSlice) 490 return rewriter.notifyMatchFailure( 491 sliceOp, "corresponding concat input not found for slice"); 492 493 rewriter.replaceOp(sliceOp, replaceWithSlice.value()); 494 return success(); 495 } 496 }; 497 498 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, 499 MLIRContext *context) { 500 results.add<ConcatSliceOptimization>(context); 501 } 502 503 //===----------------------------------------------------------------------===// 504 // Operator Folders. 505 //===----------------------------------------------------------------------===// 506 507 template <typename IntFolder, typename FloatFolder> 508 DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, 509 RankedTensorType returnTy) { 510 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { 511 auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType(); 512 auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType(); 513 if (lETy != rETy) 514 return {}; 515 516 if (llvm::isa<IntegerType>(lETy)) { 517 APInt l = lhs.getSplatValue<APInt>(); 518 APInt r = rhs.getSplatValue<APInt>(); 519 auto result = IntFolder()(l, r); 520 return DenseElementsAttr::get(returnTy, result); 521 } 522 523 if (llvm::isa<FloatType>(lETy)) { 524 APFloat l = lhs.getSplatValue<APFloat>(); 525 APFloat r = rhs.getSplatValue<APFloat>(); 526 auto result = FloatFolder()(l, r); 527 return DenseElementsAttr::get(returnTy, result); 528 } 529 } 530 531 return {}; 532 } 533 534 static bool isSplatZero(Type elemType, DenseElementsAttr val) { 535 if (llvm::isa<FloatType>(elemType)) 536 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero(); 537 if (llvm::isa<IntegerType>(elemType)) 538 return val && val.isSplat() && val.getSplatValue<APInt>().isZero(); 539 return false; 540 } 541 542 static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) { 543 if (llvm::isa<FloatType>(elemType)) 544 return val && val.isSplat() && 545 val.getSplatValue<APFloat>().isExactlyValue(1.0); 546 if (llvm::isa<IntegerType>(elemType)) { 547 const int64_t shifted = 1LL << shift; 548 return val && val.isSplat() && 549 val.getSplatValue<APInt>().getSExtValue() == shifted; 550 } 551 return false; 552 } 553 554 OpFoldResult AddOp::fold(FoldAdaptor adaptor) { 555 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType()); 556 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType()); 557 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); 558 if (!lhsTy || !rhsTy || !resultTy) 559 return {}; 560 561 // Cannot create an ElementsAttr from non-int/float/index types 562 if (!lhsTy.getElementType().isIntOrIndexOrFloat() || 563 !rhsTy.getElementType().isIntOrIndexOrFloat()) 564 return {}; 565 566 auto resultETy = resultTy.getElementType(); 567 auto lhsAttr = 568 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); 569 auto rhsAttr = 570 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); 571 572 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) 573 return getInput1(); 574 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr)) 575 return getInput2(); 576 577 if (!lhsAttr || !rhsAttr) 578 return {}; 579 580 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr, 581 resultTy); 582 } 583 584 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { 585 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType()); 586 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType()); 587 if (!inputTy || !outputTy || !inputTy.hasStaticShape() || 588 !outputTy.hasStaticShape()) 589 return {}; 590 591 if (inputTy.getDimSize(getAxis()) == 1) 592 return DenseElementsAttr::get(outputTy, 0); 593 594 return {}; 595 } 596 597 OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) { 598 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType()); 599 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType()); 600 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); 601 if (!lhsTy || !rhsTy || !resultTy) 602 return {}; 603 if (lhsTy != rhsTy) 604 return {}; 605 606 // IntDivOp inputs must be integer type, no need to check for quantized type 607 auto resultETy = resultTy.getElementType(); 608 auto lhsAttr = 609 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); 610 auto rhsAttr = 611 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); 612 if (lhsAttr && lhsAttr.isSplat()) { 613 if (llvm::isa<IntegerType>(resultETy) && 614 lhsAttr.getSplatValue<APInt>().isZero()) 615 return lhsAttr; 616 } 617 618 if (rhsAttr && rhsAttr.isSplat()) { 619 if (llvm::isa<IntegerType>(resultETy) && 620 rhsAttr.getSplatValue<APInt>().isOne()) 621 return getInput1(); 622 } 623 624 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) { 625 if (llvm::isa<IntegerType>(resultETy)) { 626 APInt l = lhsAttr.getSplatValue<APInt>(); 627 APInt r = rhsAttr.getSplatValue<APInt>(); 628 APInt result = l.sdiv(r); 629 return DenseElementsAttr::get(resultTy, result); 630 } 631 } 632 633 return {}; 634 } 635 636 namespace { 637 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, 638 RankedTensorType ty, int32_t shift) { 639 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { 640 if (llvm::isa<IntegerType>(ty.getElementType())) { 641 APInt l = lhs.getSplatValue<APInt>(); 642 APInt r = rhs.getSplatValue<APInt>(); 643 644 if (shift == 0) { 645 return DenseElementsAttr::get(ty, l * r); 646 } 647 648 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth(); 649 l = l.sext(bitwidth * 2); 650 r = r.sext(bitwidth * 2); 651 auto result = l * r; 652 result.lshrInPlace(shift); 653 result = result.trunc(bitwidth); 654 return DenseElementsAttr::get(ty, result); 655 } 656 657 if (llvm::isa<FloatType>(ty.getElementType())) { 658 APFloat l = lhs.getSplatValue<APFloat>(); 659 APFloat r = rhs.getSplatValue<APFloat>(); 660 APFloat result = l * r; 661 return DenseElementsAttr::get(ty, result); 662 } 663 } 664 665 return {}; 666 } 667 } // namespace 668 669 OpFoldResult MulOp::fold(FoldAdaptor adaptor) { 670 auto lhs = getInput1(); 671 auto rhs = getInput2(); 672 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType()); 673 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType()); 674 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); 675 if (!lhsTy || !rhsTy || !resultTy) 676 return {}; 677 678 auto resultETy = resultTy.getElementType(); 679 auto lhsAttr = 680 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); 681 auto rhsAttr = 682 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); 683 684 // Result right shift on i32_t data type only. For simplification, synthesize 685 // a zero shift for other data type. 686 int32_t shift = 0; 687 if (resultETy.isInteger(32)) { 688 ElementsAttr shift_elem; 689 if (getShift().getImpl()) { 690 if (!matchPattern(getShift(), m_Constant(&shift_elem))) 691 // cannot be folded when the shift value is unknown. 692 return {}; 693 shift = shift_elem.getValues<IntegerAttr>()[0].getInt(); 694 } 695 } 696 697 if (rhsTy == resultTy) { 698 if (isSplatZero(resultETy, lhsAttr)) 699 return lhsAttr.resizeSplat(resultTy); 700 if (isSplatOne(resultETy, lhsAttr, shift)) 701 return rhs; 702 } 703 if (lhsTy == resultTy) { 704 if (isSplatZero(resultETy, rhsAttr)) 705 return rhsAttr.resizeSplat(resultTy); 706 if (isSplatOne(resultETy, rhsAttr, shift)) 707 return lhs; 708 } 709 710 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift); 711 } 712 713 OpFoldResult SubOp::fold(FoldAdaptor adaptor) { 714 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType()); 715 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType()); 716 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); 717 if (!lhsTy || !rhsTy || !resultTy) 718 return {}; 719 720 // Cannot create an ElementsAttr from non-int/float/index types 721 if (!lhsTy.getElementType().isIntOrIndexOrFloat() || 722 !rhsTy.getElementType().isIntOrIndexOrFloat()) 723 return {}; 724 725 auto resultETy = resultTy.getElementType(); 726 auto lhsAttr = 727 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); 728 auto rhsAttr = 729 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); 730 731 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) 732 return getInput1(); 733 734 if (!lhsAttr || !rhsAttr) 735 return {}; 736 737 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr, 738 resultTy); 739 } 740 741 namespace { 742 template <typename Cmp> 743 struct ComparisonFold { 744 ComparisonFold() = default; 745 APInt operator()(const APInt &l, const APInt &r) { 746 return APInt(1, Cmp()(l, r)); 747 } 748 749 APInt operator()(const APFloat &l, const APFloat &r) { 750 return APInt(1, Cmp()(l, r)); 751 } 752 }; 753 754 struct APIntFoldGreater { 755 APIntFoldGreater() = default; 756 APInt operator()(const APInt &l, const APInt &r) { 757 return APInt(1, l.sgt(r)); 758 } 759 }; 760 761 struct APIntFoldGreaterEqual { 762 APIntFoldGreaterEqual() = default; 763 APInt operator()(const APInt &l, const APInt &r) { 764 return APInt(1, l.sge(r)); 765 } 766 }; 767 } // namespace 768 769 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { 770 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); 771 auto lhsAttr = 772 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); 773 auto rhsAttr = 774 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); 775 776 if (!lhsAttr || !rhsAttr) 777 return {}; 778 779 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>( 780 lhsAttr, rhsAttr, resultTy); 781 } 782 783 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { 784 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); 785 auto lhsAttr = 786 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); 787 auto rhsAttr = 788 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); 789 790 if (!lhsAttr || !rhsAttr) 791 return {}; 792 793 return binaryFolder<APIntFoldGreaterEqual, 794 ComparisonFold<std::greater_equal<APFloat>>>( 795 lhsAttr, rhsAttr, resultTy); 796 } 797 798 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { 799 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); 800 auto lhsAttr = 801 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); 802 auto rhsAttr = 803 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); 804 Value lhs = getInput1(); 805 Value rhs = getInput2(); 806 auto lhsTy = llvm::cast<ShapedType>(lhs.getType()); 807 808 // If we are comparing an integer value to itself it is always true. We can 809 // not do this with float due to float values. 810 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy && 811 resultTy.hasStaticShape() && lhs == rhs) { 812 return DenseElementsAttr::get(resultTy, true); 813 } 814 815 if (!lhsAttr || !rhsAttr) 816 return {}; 817 818 return binaryFolder<ComparisonFold<std::equal_to<APInt>>, 819 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr, 820 resultTy); 821 } 822 823 OpFoldResult CastOp::fold(FoldAdaptor adaptor) { 824 if (getInput().getType() == getType()) 825 return getInput(); 826 827 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput()); 828 if (!operand) 829 return {}; 830 831 auto inTy = llvm::cast<ShapedType>(getInput().getType()); 832 auto outTy = llvm::cast<ShapedType>(getType()); 833 auto inETy = inTy.getElementType(); 834 auto outETy = outTy.getElementType(); 835 836 if (operand.isSplat()) { 837 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) { 838 bool overflow; 839 auto splatVal = operand.getSplatValue<APFloat>(); 840 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics(); 841 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven, 842 &overflow); 843 return SplatElementsAttr::get(outTy, splatVal); 844 } 845 846 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) { 847 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger(); 848 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics()); 849 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign, 850 llvm::RoundingMode::NearestTiesToEven); 851 return SplatElementsAttr::get(outTy, splatVal); 852 } 853 854 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) { 855 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger(); 856 auto intVal = APSInt( 857 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign); 858 auto floatVal = operand.getSplatValue<APFloat>(); 859 bool exact; 860 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven, 861 &exact); 862 return SplatElementsAttr::get(outTy, intVal); 863 } 864 865 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) { 866 auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger(); 867 bool trunc = 868 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth(); 869 auto intVal = operand.getSplatValue<APInt>(); 870 auto bitwidth = outETy.getIntOrFloatBitWidth(); 871 872 if (trunc) { 873 intVal = intVal.trunc(bitwidth); 874 } else if (unsignIn) { 875 intVal = intVal.zext(bitwidth); 876 } else { 877 intVal = intVal.sext(bitwidth); 878 } 879 880 return SplatElementsAttr::get(outTy, intVal); 881 } 882 } 883 884 return {}; 885 } 886 887 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } 888 889 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } 890 891 #define REDUCE_FOLDER(OP) \ 892 OpFoldResult OP::fold(FoldAdaptor adaptor) { \ 893 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \ 894 if (!inputTy.hasRank()) \ 895 return {}; \ 896 if (inputTy != getType()) \ 897 return {}; \ 898 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \ 899 return getInput(); \ 900 return {}; \ 901 } 902 903 REDUCE_FOLDER(ReduceAllOp) 904 REDUCE_FOLDER(ReduceAnyOp) 905 REDUCE_FOLDER(ReduceMaxOp) 906 REDUCE_FOLDER(ReduceMinOp) 907 REDUCE_FOLDER(ReduceProdOp) 908 REDUCE_FOLDER(ReduceSumOp) 909 #undef REDUCE_FOLDER 910 911 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { 912 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType()); 913 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType()); 914 915 if (!inputTy || !outputTy) 916 return {}; 917 918 // Fold when the input and output types are the same. This is only safe when 919 // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions, 920 // there may still be a productive reshape. 921 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2) 922 return getInput1(); 923 924 // reshape(reshape(x)) -> reshape(x) 925 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>( 926 getInput1().getDefiningOp())) { 927 getInput1Mutable().assign(reshapeOp.getInput1()); 928 return getResult(); 929 } 930 931 // Cannot create an ElementsAttr from non-int/float/index types 932 if (!inputTy.getElementType().isIntOrIndexOrFloat()) 933 return {}; 934 935 // reshape(const(x)) -> const(reshape-attr(x)) 936 if (auto operand = 937 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) { 938 // Constants must have static shape. 939 if (!outputTy.hasStaticShape()) 940 return {}; 941 942 // Okay to duplicate splat constants. 943 if (operand.isSplat()) 944 return SplatElementsAttr::get(outputTy, 945 operand.getSplatValue<Attribute>()); 946 947 // Don't duplicate other constants. 948 if (!getInput1().hasOneUse()) 949 return {}; 950 951 return operand.reshape( 952 llvm::cast<ShapedType>(operand.getType()).clone(getNewShape())); 953 } 954 955 return {}; 956 } 957 958 OpFoldResult PadOp::fold(FoldAdaptor adaptor) { 959 // If the pad is all zeros we can fold this operation away. 960 if (adaptor.getPadding() && getInput1().getType() == getType()) { 961 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding()); 962 if (densePad && densePad.isSplat() && 963 densePad.getSplatValue<APInt>().isZero()) { 964 return getInput1(); 965 } 966 } 967 968 return {}; 969 } 970 971 // Fold away cases where a tosa.resize operation returns a copy 972 // of the input image. 973 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) { 974 ArrayRef<int64_t> offset = getOffset(); 975 ArrayRef<int64_t> border = getBorder(); 976 ArrayRef<int64_t> scale = getScale(); 977 978 // Check unit scaling. 979 if (scale[0] != scale[1] || scale[2] != scale[3]) { 980 return {}; 981 } 982 983 // There should be no offset. 984 if (offset[0] != 0 || offset[1] != 0) { 985 return {}; 986 } 987 988 // There should be no border. 989 if (border[0] != 0 || border[1] != 0) { 990 return {}; 991 } 992 993 auto input = getInput(); 994 auto inputTy = llvm::cast<RankedTensorType>(input.getType()); 995 auto resultTy = llvm::cast<RankedTensorType>(getType()); 996 if (inputTy != resultTy) 997 return {}; 998 999 return input; 1000 } 1001 1002 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { 1003 auto operand = getInput1(); 1004 auto operandTy = llvm::cast<ShapedType>(operand.getType()); 1005 auto axis = getAxis(); 1006 auto operandAttr = 1007 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1()); 1008 if (operandAttr) 1009 return operandAttr; 1010 1011 // If the dim-length is 1, tosa.reverse is a no-op. 1012 if (operandTy.hasRank() && 1013 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1)) 1014 return operand; 1015 1016 return {}; 1017 } 1018 1019 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { 1020 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType()); 1021 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType()); 1022 1023 if (!inputTy || !outputTy) 1024 return {}; 1025 1026 if (inputTy == outputTy && inputTy.hasStaticShape()) 1027 return getInput1(); 1028 1029 if (!adaptor.getInput1()) 1030 return {}; 1031 1032 // Cannot create an ElementsAttr from non-int/float/index types 1033 if (!inputTy.getElementType().isIntOrIndexOrFloat() || 1034 !outputTy.getElementType().isIntOrIndexOrFloat()) 1035 return {}; 1036 1037 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1()); 1038 if (operand.isSplat() && outputTy.hasStaticShape()) { 1039 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>()); 1040 } 1041 1042 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() && 1043 outputTy.getNumElements() == 1) { 1044 DenseElementsAttr startElems; 1045 if (!matchPattern(getStart(), m_Constant(&startElems))) 1046 return {}; 1047 1048 llvm::SmallVector<uint64_t> indices = 1049 llvm::to_vector(startElems.getValues<uint64_t>()); 1050 auto value = operand.getValues<Attribute>()[indices]; 1051 return SplatElementsAttr::get(outputTy, value); 1052 } 1053 1054 return {}; 1055 } 1056 1057 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { 1058 if (getOnTrue() == getOnFalse()) 1059 return getOnTrue(); 1060 1061 auto predicate = 1062 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred()); 1063 if (!predicate) 1064 return {}; 1065 1066 if (!predicate.isSplat()) 1067 return {}; 1068 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue() 1069 : getOnFalse(); 1070 } 1071 1072 OpFoldResult TileOp::fold(FoldAdaptor adaptor) { 1073 if (getInput1().getType() == getType()) { 1074 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>( 1075 adaptor.getMultiples())) { 1076 if (multiples.isSplat() && 1077 multiples.getSplatValue<APInt>().getSExtValue() == 1) 1078 return getInput1(); 1079 if (auto int_array_attr = 1080 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) { 1081 if (llvm::all_of(int_array_attr.getValues<APInt>(), 1082 [](APInt v) { return v.getSExtValue() == 1; })) 1083 return getInput1(); 1084 } 1085 } 1086 } 1087 return {}; 1088 } 1089 1090 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { 1091 auto resultTy = llvm::cast<ShapedType>(getType()); 1092 1093 // Transposing splat values just means reshaping. 1094 if (auto input = 1095 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) { 1096 if (input.isSplat() && resultTy.hasStaticShape() && 1097 input.getType().getElementType() == resultTy.getElementType()) 1098 return input.reshape(resultTy); 1099 } 1100 1101 // Transpose is not the identity transpose. 1102 SmallVector<int32_t> perms; 1103 if (getConstantPerms(perms).failed()) 1104 return {}; 1105 1106 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms)) 1107 return {}; 1108 1109 return getInput1(); 1110 } 1111 1112 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) { 1113 auto input = getInput1(); 1114 // Element-wise log(exp(x)) = x 1115 if (auto op = input.getDefiningOp<tosa::ExpOp>()) { 1116 return op.getInput1(); 1117 } 1118 1119 return {}; 1120 } 1121 1122 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) { 1123 auto input = getInput1(); 1124 // Element-wise exp(log(x)) = x 1125 if (auto op = input.getDefiningOp<tosa::LogOp>()) { 1126 return op.getInput1(); 1127 } 1128 1129 return {}; 1130 } 1131 1132 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) { 1133 auto input = getInput1(); 1134 // Element-wise negate(negate(x)) = x 1135 if (auto op = input.getDefiningOp<tosa::NegateOp>()) { 1136 return op.getInput1(); 1137 } 1138 1139 return {}; 1140 } 1141 1142 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) { 1143 auto input = getInput1(); 1144 // Element-wise abs(abs(x)) = abs(x) 1145 if (auto op = input.getDefiningOp<tosa::AbsOp>()) { 1146 return input; 1147 } 1148 1149 return {}; 1150 } 1151 1152 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { 1153 // Fold consecutive concats on the same axis into a single op. 1154 // Keep track of the operands so we are able to construct a new concat 1155 // later. Conservatively assume that we double the number of operands when 1156 // folding 1157 SmallVector<Value, 8> concatOperands; 1158 concatOperands.reserve(2 * getNumOperands()); 1159 1160 // Find all operands that are foldable concats 1161 bool foundFoldableConcat = false; 1162 for (Value operand : getOperands()) { 1163 concatOperands.emplace_back(operand); 1164 1165 auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp()); 1166 if (!producer) 1167 continue; 1168 1169 // Not foldable if axes are not the same 1170 if (getAxis() != producer.getAxis()) 1171 continue; 1172 1173 // Replace the original operand with all incoming operands 1174 foundFoldableConcat = true; 1175 concatOperands.pop_back(); 1176 llvm::append_range(concatOperands, producer->getOperands()); 1177 } 1178 1179 if (!foundFoldableConcat) 1180 return {}; 1181 1182 getOperation()->setOperands(concatOperands); 1183 return getResult(); 1184 } 1185 1186 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) { 1187 auto input = adaptor.getInput1(); 1188 1189 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input); 1190 // Fold splat inputs only. 1191 if (!inputAttr || !inputAttr.isSplat()) 1192 return {}; 1193 1194 auto shapeType = llvm::cast<ShapedType>(getType()); 1195 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) { 1196 auto floatVal = inputAttr.getSplatValue<APFloat>(); 1197 return DenseElementsAttr::get(shapeType, 1198 ReciprocalOp::calcOneElement(floatVal)); 1199 } 1200 1201 return {}; 1202 } 1203