xref: /llvm-project/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (revision 956c0707d9098499a2682297b71f46b0a562eed9)
1 //===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // TOSA canonicalization patterns and folders.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Quant/IR/Quant.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
17 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20 #include "mlir/IR/BuiltinTypeInterfaces.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/FoldUtils.h"
26 #include "mlir/Transforms/InliningUtils.h"
27 #include "mlir/Transforms/RegionUtils.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <functional>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 //===----------------------------------------------------------------------===//
39 // Operator Canonicalizers.
40 //===----------------------------------------------------------------------===//
41 
42 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
43   using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
44 
45   LogicalResult matchAndRewrite(tosa::ConcatOp op,
46                                 PatternRewriter &rewriter) const override {
47     if (op.getInput1().size() != 1)
48       return failure();
49     if (op.getInput1().front().getType() != op.getType()) {
50       rewriter
51           .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
52                                               op.getInput1().front())
53           .getResult();
54       return success();
55     }
56 
57     rewriter.replaceOp(op, op.getInput1().front());
58     return success();
59   }
60 };
61 
62 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
63                                            MLIRContext *context) {
64   results.add<ConcatOptimization>(context);
65 }
66 
67 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
68   auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
69   if (!notOp)
70     return failure();
71   rewriter.modifyOpInPlace(op, [&]() {
72     op.getOperation()->setOperands(
73         {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
74   });
75   return success();
76 }
77 
78 struct ConsolidateTransposeOptimization
79     : public OpRewritePattern<tosa::TransposeOp> {
80   using OpRewritePattern::OpRewritePattern;
81 
82   LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
83                                 PatternRewriter &rewriter) const override {
84     // Input is also TransposeOp - transpose(transpose(A)).
85     auto innerTranspose =
86         transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
87     if (!innerTranspose)
88       return rewriter.notifyMatchFailure(transposeOp,
89                                          "input must be transpose operation");
90 
91     SmallVector<int32_t> transposePerms, innerTransposePerms;
92     if (transposeOp.getConstantPerms(transposePerms).failed())
93       return rewriter.notifyMatchFailure(transposeOp,
94                                          "transpose perms must be constant");
95     if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
96       return rewriter.notifyMatchFailure(
97           transposeOp, "inner transpose perms must be constant");
98     if (transposePerms.size() != innerTransposePerms.size())
99       return rewriter.notifyMatchFailure(
100           transposeOp,
101           "transpose and inner transpose perms sizes must be equal");
102     if (transposePerms.empty())
103       return rewriter.notifyMatchFailure(
104           transposeOp, "transpose perms sizes must be positive");
105 
106     // Consolidate transposes into one transpose.
107     SmallVector<int32_t> perms(transposePerms.size());
108     for (int i = 0, s = transposePerms.size(); i < s; ++i)
109       perms[i] = innerTransposePerms[transposePerms[i]];
110 
111     auto permsTy =
112         RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
113     auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
114     Value permsValue =
115         rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
116 
117     rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
118         transposeOp, transposeOp.getResult().getType(),
119         innerTranspose.getInput1(), permsValue);
120 
121     return success();
122   }
123 };
124 
125 // Determines the case when tosa.transpose is a tosa.reshape operation.
126 struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
127   using OpRewritePattern::OpRewritePattern;
128 
129   LogicalResult matchAndRewrite(tosa::TransposeOp op,
130                                 PatternRewriter &rewriter) const override {
131     DenseIntElementsAttr permAttr;
132     if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
133       return rewriter.notifyMatchFailure(op, "Non-constant permutation");
134 
135     if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
136       return rewriter.notifyMatchFailure(
137           op, "Src is from transpose, can compose transposes");
138 
139     Value result = op.getResult();
140     for (Operation *subop : result.getUsers()) {
141       if (dyn_cast_or_null<tosa::TransposeOp>(subop))
142         return rewriter.notifyMatchFailure(
143             op, "Dest is used by transpose, can compose transposes");
144     }
145 
146     auto input = op.getInput1();
147     auto inputTy = llvm::cast<ShapedType>(input.getType());
148     if (!inputTy.hasRank())
149       return rewriter.notifyMatchFailure(op, "Unranked input.");
150 
151     int64_t numDynDims = 0;
152     for (int i = 0; i < inputTy.getRank(); ++i)
153       if (inputTy.isDynamicDim(i))
154         numDynDims++;
155 
156     if (numDynDims > 1)
157       return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
158 
159     SmallVector<int64_t> permValues = llvm::to_vector<6>(
160         llvm::map_range(permAttr.getValues<APInt>(),
161                         [](const APInt &val) { return val.getSExtValue(); }));
162 
163     SmallVector<int64_t> nonZeroPerms;
164     nonZeroPerms.reserve(permValues.size());
165     for (auto idx : permValues) {
166       auto sz = inputTy.getDimSize(idx);
167       if (sz != 1)
168         nonZeroPerms.push_back(idx);
169     }
170 
171     for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
172       if (nonZeroPerms[i - 1] > nonZeroPerms[i])
173         return rewriter.notifyMatchFailure(op,
174                                            "Transpose changes memory layout.");
175 
176     SmallVector<int64_t> newShape;
177     newShape.reserve(inputTy.getRank());
178     for (int i = 0, s = inputTy.getRank(); i < s; ++i)
179       newShape.push_back(inputTy.getDimSize(permValues[i]));
180 
181     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
182         op, op.getType(), op.getInput1(),
183         rewriter.getDenseI64ArrayAttr(newShape));
184     return success();
185   }
186 };
187 
188 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
189                                               MLIRContext *context) {
190   results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
191 }
192 
193 struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
194   using OpRewritePattern::OpRewritePattern;
195 
196   LogicalResult matchAndRewrite(tosa::PadOp op,
197                                 PatternRewriter &rewriter) const override {
198     if (op.getPadConst())
199       return failure();
200 
201     auto input = op.getInput1();
202     auto padding = op.getPadding();
203 
204     ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
205     Type elementTy = inputTy.getElementType();
206 
207     Attribute constantAttr;
208     if (llvm::isa<FloatType>(elementTy)) {
209       constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
210     } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
211       constantAttr = rewriter.getIntegerAttr(elementTy, 0);
212     } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
213       auto value = op.getQuantizationInfo()->getInputZp();
214       constantAttr = rewriter.getIntegerAttr(elementTy, value);
215     }
216 
217     if (!constantAttr) {
218       return rewriter.notifyMatchFailure(
219           op,
220           "tosa.pad to linalg lowering encountered an unknown element type");
221     }
222 
223     auto denseAttr = DenseElementsAttr::get(
224         RankedTensorType::get({}, elementTy), constantAttr);
225     auto constantVal = rewriter.create<tosa::ConstOp>(
226         op.getLoc(), denseAttr.getType(), denseAttr);
227 
228     rewriter.replaceOpWithNewOp<tosa::PadOp>(
229         op, op.getType(), ValueRange{input, padding, constantVal},
230         op->getAttrs());
231     return success();
232   }
233 };
234 
235 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
236                                         MLIRContext *context) {
237   results.add<MaterializePadValue>(context);
238 }
239 
240 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
241   using OpRewritePattern::OpRewritePattern;
242 
243   LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
244                                 PatternRewriter &rewriter) const override {
245     Value input = op.getInput();
246     Value output = op.getOutput();
247     ShapedType inputType = llvm::cast<ShapedType>(input.getType());
248     ShapedType outputType = llvm::cast<ShapedType>(output.getType());
249 
250     if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
251       return failure();
252     }
253 
254     // If the output and input shapes are 1x1, then this is a no op.
255     ArrayRef<int64_t> outputShape = outputType.getShape();
256     if (outputShape[1] != 1 || outputShape[2] != 1) {
257       return failure();
258     }
259 
260     ArrayRef<int64_t> inputShape = inputType.getShape();
261     if (inputShape[1] != 1 || inputShape[2] != 1) {
262       return failure();
263     }
264 
265     rewriter.replaceOp(op, input);
266     return success();
267   }
268 };
269 
270 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
271                                               MLIRContext *context) {
272   results.add<MaxPool2dIsNoOp>(context);
273 }
274 
275 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
276   using OpRewritePattern::OpRewritePattern;
277 
278   LogicalResult matchAndRewrite(tosa::ClampOp op,
279                                 PatternRewriter &rewriter) const override {
280     Value input = op.getInput();
281     auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
282     auto inputElementType = inputType.getElementType();
283 
284     if (!inputType.hasStaticShape()) {
285       return failure();
286     }
287 
288     if (isa<FloatType>(inputElementType)) {
289       // Unlike integer types, floating point types can represent infinity.
290       auto minClamp = op.getMinFp();
291       auto maxClamp = op.getMaxFp();
292       bool isMin = minClamp.isInfinity() && minClamp.isNegative();
293       bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
294 
295       if (isMin && isMax) {
296         rewriter.replaceOp(op, input);
297         return success();
298       }
299       return failure();
300     }
301 
302     if (inputElementType.isUnsignedInteger()) {
303       int64_t minClamp = op.getMinInt();
304       int64_t maxClamp = op.getMaxInt();
305 
306       int64_t intMin =
307           APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
308               .getZExtValue();
309       int64_t intMax =
310           APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
311               .getZExtValue();
312 
313       if (minClamp <= intMin && maxClamp >= intMax) {
314         rewriter.replaceOp(op, input);
315         return success();
316       }
317       return failure();
318     }
319 
320     if (llvm::isa<IntegerType>(inputElementType)) {
321       int64_t minClamp = op.getMinInt();
322       int64_t maxClamp = op.getMaxInt();
323 
324       int64_t intMin =
325           APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
326               .getSExtValue();
327       int64_t intMax =
328           APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
329               .getSExtValue();
330 
331       if (minClamp <= intMin && maxClamp >= intMax) {
332         rewriter.replaceOp(op, input);
333         return success();
334       }
335       return failure();
336     }
337 
338     return failure();
339   }
340 };
341 
342 // Attempts the following transformation:
343 //
344 // For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
345 // tensor X the following identity holds:
346 //
347 // CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'),  min(b, b'))
348 //
349 // subject to the following valid NaN propagation semantics:
350 // --------------------------------------------
351 // | OUTER CLAMP | INNER CLAMP  | RESULT MODE |
352 // |-------------|--------------|-------------|
353 // | PROPAGATE   | PROPAGATE    | PROPAGATE   |
354 // | PROPAGATE   | IGNORE       | IGNORE      |
355 // | IGNORE      | PROPAGATE    | INVALID     |
356 // | IGNORE      | IGNORE       | IGNORE      |
357 // |------------------------------------------|
358 
359 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
360   using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
361 
362   // Helper structure to describe the range of a clamp operation.
363   template <typename T>
364   struct ClampRange {
365     ClampRange(const T &start, const T &end) : start(start), end(end) {}
366     T start;
367     T end;
368 
369     // Helper function to determine if two Clamp ranges intersect.
370     bool intersects(const ClampRange<T> &otherRange) {
371       return start < otherRange.end && otherRange.start < end;
372     }
373   };
374 
375   LogicalResult matchAndRewrite(tosa::ClampOp op,
376                                 PatternRewriter &rewriter) const override {
377     // Check the input to the CLAMP op is itself a CLAMP.
378     auto clampOp =
379         dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
380     if (!clampOp)
381       return failure();
382 
383     // Check we have a valid NaN propagation combination.
384     const auto opNanMode = op.getNanMode();
385     const auto clampNanMode = clampOp.getNanMode();
386     if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
387       return failure();
388 
389     // Check we have intersecting ranges.
390     const auto opMinInt = op.getMinInt();
391     const auto opMaxInt = op.getMaxInt();
392     const auto clampOpMinInt = clampOp.getMinInt();
393     const auto clampOpMaxInt = clampOp.getMaxInt();
394     ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
395     ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
396     if (!opRangeIntRange.intersects(clampRangeIntRange))
397       return failure();
398 
399     const auto opMinFloat = op.getMinFp();
400     const auto opMaxFloat = op.getMaxFp();
401     const auto clampOpMinFloat = clampOp.getMinFp();
402     const auto clampOpMaxFloat = clampOp.getMaxFp();
403     ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
404     ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
405     if (!opRangeFloatRange.intersects(clampRangeFloatRange))
406       return failure();
407 
408     // Run the transformation.
409     const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
410     const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
411     const auto minInt = std::max(opMinInt, clampOpMinInt);
412     const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
413     rewriter.replaceOpWithNewOp<tosa::ClampOp>(
414         op, op.getType(), clampOp.getInput(),
415         rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
416         rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
417         rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
418                                                            : opNanMode));
419     return success();
420   }
421 };
422 
423 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
424                                           MLIRContext *context) {
425   results.add<ClampIsNoOp>(context);
426   results.add<ClampClampOptimization>(context);
427 }
428 
429 struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
430   using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
431 
432   LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
433                                 PatternRewriter &rewriter) const override {
434     Value sliceInput = sliceOp.getInput1();
435     auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
436     if (!concatOp)
437       return rewriter.notifyMatchFailure(
438           sliceOp, "slice input must be concat operation");
439 
440     OperandRange inputs = concatOp.getInput1();
441     auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
442     if (!concatType || !concatType.hasStaticShape())
443       return rewriter.notifyMatchFailure(
444           sliceOp, "slice input must be a static ranked tensor");
445     int32_t axis = concatOp.getAxis();
446 
447     DenseElementsAttr startElems;
448     DenseElementsAttr sizeElems;
449 
450     if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
451       return rewriter.notifyMatchFailure(
452           sliceOp, "start of slice must be a static ranked shape");
453 
454     if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
455       return rewriter.notifyMatchFailure(
456           sliceOp, "size of slice must be a static ranked shape");
457 
458     llvm::SmallVector<int64_t> sliceStarts =
459         llvm::to_vector(startElems.getValues<int64_t>());
460     llvm::SmallVector<int64_t> sliceSizes =
461         llvm::to_vector(sizeElems.getValues<int64_t>());
462 
463     // Validate slice on the concatenated axis. Slicing along this
464     // axis should span only one of the inputs to the concatenate
465     // operation.
466     std::optional<Value> replaceWithSlice;
467     for (auto input : inputs) {
468       auto inputType = dyn_cast<RankedTensorType>(input.getType());
469       if (!inputType || !inputType.hasStaticShape())
470         return rewriter.notifyMatchFailure(
471             sliceOp, "concat input must be a static ranked tensor");
472 
473       if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
474                                         inputType.getDimSize(axis)) {
475         auto start_op =
476             getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
477         auto size_op =
478             getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
479         replaceWithSlice =
480             rewriter
481                 .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
482                                        input, start_op, size_op)
483                 .getResult();
484         break;
485       }
486       sliceStarts[axis] -= inputType.getDimSize(axis);
487     }
488 
489     if (!replaceWithSlice)
490       return rewriter.notifyMatchFailure(
491           sliceOp, "corresponding concat input not found for slice");
492 
493     rewriter.replaceOp(sliceOp, replaceWithSlice.value());
494     return success();
495   }
496 };
497 
498 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
499                                           MLIRContext *context) {
500   results.add<ConcatSliceOptimization>(context);
501 }
502 
503 //===----------------------------------------------------------------------===//
504 // Operator Folders.
505 //===----------------------------------------------------------------------===//
506 
507 template <typename IntFolder, typename FloatFolder>
508 DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
509                                RankedTensorType returnTy) {
510   if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
511     auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
512     auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
513     if (lETy != rETy)
514       return {};
515 
516     if (llvm::isa<IntegerType>(lETy)) {
517       APInt l = lhs.getSplatValue<APInt>();
518       APInt r = rhs.getSplatValue<APInt>();
519       auto result = IntFolder()(l, r);
520       return DenseElementsAttr::get(returnTy, result);
521     }
522 
523     if (llvm::isa<FloatType>(lETy)) {
524       APFloat l = lhs.getSplatValue<APFloat>();
525       APFloat r = rhs.getSplatValue<APFloat>();
526       auto result = FloatFolder()(l, r);
527       return DenseElementsAttr::get(returnTy, result);
528     }
529   }
530 
531   return {};
532 }
533 
534 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
535   if (llvm::isa<FloatType>(elemType))
536     return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
537   if (llvm::isa<IntegerType>(elemType))
538     return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
539   return false;
540 }
541 
542 static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
543   if (llvm::isa<FloatType>(elemType))
544     return val && val.isSplat() &&
545            val.getSplatValue<APFloat>().isExactlyValue(1.0);
546   if (llvm::isa<IntegerType>(elemType)) {
547     const int64_t shifted = 1LL << shift;
548     return val && val.isSplat() &&
549            val.getSplatValue<APInt>().getSExtValue() == shifted;
550   }
551   return false;
552 }
553 
554 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
555   auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
556   auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
557   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
558   if (!lhsTy || !rhsTy || !resultTy)
559     return {};
560 
561   // Cannot create an ElementsAttr from non-int/float/index types
562   if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
563       !rhsTy.getElementType().isIntOrIndexOrFloat())
564     return {};
565 
566   auto resultETy = resultTy.getElementType();
567   auto lhsAttr =
568       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
569   auto rhsAttr =
570       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
571 
572   if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
573     return getInput1();
574   if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
575     return getInput2();
576 
577   if (!lhsAttr || !rhsAttr)
578     return {};
579 
580   return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
581                                                             resultTy);
582 }
583 
584 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
585   auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
586   auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
587   if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
588       !outputTy.hasStaticShape())
589     return {};
590 
591   if (inputTy.getDimSize(getAxis()) == 1)
592     return DenseElementsAttr::get(outputTy, 0);
593 
594   return {};
595 }
596 
597 OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
598   auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
599   auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
600   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
601   if (!lhsTy || !rhsTy || !resultTy)
602     return {};
603   if (lhsTy != rhsTy)
604     return {};
605 
606   // IntDivOp inputs must be integer type, no need to check for quantized type
607   auto resultETy = resultTy.getElementType();
608   auto lhsAttr =
609       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
610   auto rhsAttr =
611       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
612   if (lhsAttr && lhsAttr.isSplat()) {
613     if (llvm::isa<IntegerType>(resultETy) &&
614         lhsAttr.getSplatValue<APInt>().isZero())
615       return lhsAttr;
616   }
617 
618   if (rhsAttr && rhsAttr.isSplat()) {
619     if (llvm::isa<IntegerType>(resultETy) &&
620         rhsAttr.getSplatValue<APInt>().isOne())
621       return getInput1();
622   }
623 
624   if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
625     if (llvm::isa<IntegerType>(resultETy)) {
626       APInt l = lhsAttr.getSplatValue<APInt>();
627       APInt r = rhsAttr.getSplatValue<APInt>();
628       APInt result = l.sdiv(r);
629       return DenseElementsAttr::get(resultTy, result);
630     }
631   }
632 
633   return {};
634 }
635 
636 namespace {
637 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
638                                   RankedTensorType ty, int32_t shift) {
639   if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
640     if (llvm::isa<IntegerType>(ty.getElementType())) {
641       APInt l = lhs.getSplatValue<APInt>();
642       APInt r = rhs.getSplatValue<APInt>();
643 
644       if (shift == 0) {
645         return DenseElementsAttr::get(ty, l * r);
646       }
647 
648       auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
649       l = l.sext(bitwidth * 2);
650       r = r.sext(bitwidth * 2);
651       auto result = l * r;
652       result.lshrInPlace(shift);
653       result = result.trunc(bitwidth);
654       return DenseElementsAttr::get(ty, result);
655     }
656 
657     if (llvm::isa<FloatType>(ty.getElementType())) {
658       APFloat l = lhs.getSplatValue<APFloat>();
659       APFloat r = rhs.getSplatValue<APFloat>();
660       APFloat result = l * r;
661       return DenseElementsAttr::get(ty, result);
662     }
663   }
664 
665   return {};
666 }
667 } // namespace
668 
669 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
670   auto lhs = getInput1();
671   auto rhs = getInput2();
672   auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
673   auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
674   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
675   if (!lhsTy || !rhsTy || !resultTy)
676     return {};
677 
678   auto resultETy = resultTy.getElementType();
679   auto lhsAttr =
680       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
681   auto rhsAttr =
682       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
683 
684   // Result right shift on i32_t data type only. For simplification, synthesize
685   // a zero shift for other data type.
686   int32_t shift = 0;
687   if (resultETy.isInteger(32)) {
688     ElementsAttr shift_elem;
689     if (getShift().getImpl()) {
690       if (!matchPattern(getShift(), m_Constant(&shift_elem)))
691         // cannot be folded when the shift value is unknown.
692         return {};
693       shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
694     }
695   }
696 
697   if (rhsTy == resultTy) {
698     if (isSplatZero(resultETy, lhsAttr))
699       return lhsAttr.resizeSplat(resultTy);
700     if (isSplatOne(resultETy, lhsAttr, shift))
701       return rhs;
702   }
703   if (lhsTy == resultTy) {
704     if (isSplatZero(resultETy, rhsAttr))
705       return rhsAttr.resizeSplat(resultTy);
706     if (isSplatOne(resultETy, rhsAttr, shift))
707       return lhs;
708   }
709 
710   return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
711 }
712 
713 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
714   auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
715   auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
716   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
717   if (!lhsTy || !rhsTy || !resultTy)
718     return {};
719 
720   // Cannot create an ElementsAttr from non-int/float/index types
721   if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
722       !rhsTy.getElementType().isIntOrIndexOrFloat())
723     return {};
724 
725   auto resultETy = resultTy.getElementType();
726   auto lhsAttr =
727       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
728   auto rhsAttr =
729       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
730 
731   if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
732     return getInput1();
733 
734   if (!lhsAttr || !rhsAttr)
735     return {};
736 
737   return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
738                                                               resultTy);
739 }
740 
741 namespace {
742 template <typename Cmp>
743 struct ComparisonFold {
744   ComparisonFold() = default;
745   APInt operator()(const APInt &l, const APInt &r) {
746     return APInt(1, Cmp()(l, r));
747   }
748 
749   APInt operator()(const APFloat &l, const APFloat &r) {
750     return APInt(1, Cmp()(l, r));
751   }
752 };
753 
754 struct APIntFoldGreater {
755   APIntFoldGreater() = default;
756   APInt operator()(const APInt &l, const APInt &r) {
757     return APInt(1, l.sgt(r));
758   }
759 };
760 
761 struct APIntFoldGreaterEqual {
762   APIntFoldGreaterEqual() = default;
763   APInt operator()(const APInt &l, const APInt &r) {
764     return APInt(1, l.sge(r));
765   }
766 };
767 } // namespace
768 
769 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
770   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
771   auto lhsAttr =
772       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
773   auto rhsAttr =
774       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
775 
776   if (!lhsAttr || !rhsAttr)
777     return {};
778 
779   return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
780       lhsAttr, rhsAttr, resultTy);
781 }
782 
783 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
784   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
785   auto lhsAttr =
786       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
787   auto rhsAttr =
788       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
789 
790   if (!lhsAttr || !rhsAttr)
791     return {};
792 
793   return binaryFolder<APIntFoldGreaterEqual,
794                       ComparisonFold<std::greater_equal<APFloat>>>(
795       lhsAttr, rhsAttr, resultTy);
796 }
797 
798 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
799   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
800   auto lhsAttr =
801       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
802   auto rhsAttr =
803       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
804   Value lhs = getInput1();
805   Value rhs = getInput2();
806   auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
807 
808   // If we are comparing an integer value to itself it is always true. We can
809   // not do this with float due to float values.
810   if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
811       resultTy.hasStaticShape() && lhs == rhs) {
812     return DenseElementsAttr::get(resultTy, true);
813   }
814 
815   if (!lhsAttr || !rhsAttr)
816     return {};
817 
818   return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
819                       ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
820                                                               resultTy);
821 }
822 
823 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
824   if (getInput().getType() == getType())
825     return getInput();
826 
827   auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
828   if (!operand)
829     return {};
830 
831   auto inTy = llvm::cast<ShapedType>(getInput().getType());
832   auto outTy = llvm::cast<ShapedType>(getType());
833   auto inETy = inTy.getElementType();
834   auto outETy = outTy.getElementType();
835 
836   if (operand.isSplat()) {
837     if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
838       bool overflow;
839       auto splatVal = operand.getSplatValue<APFloat>();
840       auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
841       splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
842                        &overflow);
843       return SplatElementsAttr::get(outTy, splatVal);
844     }
845 
846     if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
847       auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
848       APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
849       splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
850                                 llvm::RoundingMode::NearestTiesToEven);
851       return SplatElementsAttr::get(outTy, splatVal);
852     }
853 
854     if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
855       auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
856       auto intVal = APSInt(
857           llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
858       auto floatVal = operand.getSplatValue<APFloat>();
859       bool exact;
860       floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
861                                 &exact);
862       return SplatElementsAttr::get(outTy, intVal);
863     }
864 
865     if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
866       auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
867       bool trunc =
868           inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
869       auto intVal = operand.getSplatValue<APInt>();
870       auto bitwidth = outETy.getIntOrFloatBitWidth();
871 
872       if (trunc) {
873         intVal = intVal.trunc(bitwidth);
874       } else if (unsignIn) {
875         intVal = intVal.zext(bitwidth);
876       } else {
877         intVal = intVal.sext(bitwidth);
878       }
879 
880       return SplatElementsAttr::get(outTy, intVal);
881     }
882   }
883 
884   return {};
885 }
886 
887 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
888 
889 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
890 
891 #define REDUCE_FOLDER(OP)                                                      \
892   OpFoldResult OP::fold(FoldAdaptor adaptor) {                                 \
893     ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());         \
894     if (!inputTy.hasRank())                                                    \
895       return {};                                                               \
896     if (inputTy != getType())                                                  \
897       return {};                                                               \
898     if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1)          \
899       return getInput();                                                       \
900     return {};                                                                 \
901   }
902 
903 REDUCE_FOLDER(ReduceAllOp)
904 REDUCE_FOLDER(ReduceAnyOp)
905 REDUCE_FOLDER(ReduceMaxOp)
906 REDUCE_FOLDER(ReduceMinOp)
907 REDUCE_FOLDER(ReduceProdOp)
908 REDUCE_FOLDER(ReduceSumOp)
909 #undef REDUCE_FOLDER
910 
911 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
912   auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
913   auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
914 
915   if (!inputTy || !outputTy)
916     return {};
917 
918   // Fold when the input and output types are the same. This is only safe when
919   // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
920   // there may still be a productive reshape.
921   if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
922     return getInput1();
923 
924   // reshape(reshape(x)) -> reshape(x)
925   if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
926           getInput1().getDefiningOp())) {
927     getInput1Mutable().assign(reshapeOp.getInput1());
928     return getResult();
929   }
930 
931   // Cannot create an ElementsAttr from non-int/float/index types
932   if (!inputTy.getElementType().isIntOrIndexOrFloat())
933     return {};
934 
935   // reshape(const(x)) -> const(reshape-attr(x))
936   if (auto operand =
937           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
938     // Constants must have static shape.
939     if (!outputTy.hasStaticShape())
940       return {};
941 
942     // Okay to duplicate splat constants.
943     if (operand.isSplat())
944       return SplatElementsAttr::get(outputTy,
945                                     operand.getSplatValue<Attribute>());
946 
947     // Don't duplicate other constants.
948     if (!getInput1().hasOneUse())
949       return {};
950 
951     return operand.reshape(
952         llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
953   }
954 
955   return {};
956 }
957 
958 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
959   // If the pad is all zeros we can fold this operation away.
960   if (adaptor.getPadding() && getInput1().getType() == getType()) {
961     auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
962     if (densePad && densePad.isSplat() &&
963         densePad.getSplatValue<APInt>().isZero()) {
964       return getInput1();
965     }
966   }
967 
968   return {};
969 }
970 
971 // Fold away cases where a tosa.resize operation returns a copy
972 // of the input image.
973 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
974   ArrayRef<int64_t> offset = getOffset();
975   ArrayRef<int64_t> border = getBorder();
976   ArrayRef<int64_t> scale = getScale();
977 
978   // Check unit scaling.
979   if (scale[0] != scale[1] || scale[2] != scale[3]) {
980     return {};
981   }
982 
983   // There should be no offset.
984   if (offset[0] != 0 || offset[1] != 0) {
985     return {};
986   }
987 
988   // There should be no border.
989   if (border[0] != 0 || border[1] != 0) {
990     return {};
991   }
992 
993   auto input = getInput();
994   auto inputTy = llvm::cast<RankedTensorType>(input.getType());
995   auto resultTy = llvm::cast<RankedTensorType>(getType());
996   if (inputTy != resultTy)
997     return {};
998 
999   return input;
1000 }
1001 
1002 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1003   auto operand = getInput1();
1004   auto operandTy = llvm::cast<ShapedType>(operand.getType());
1005   auto axis = getAxis();
1006   auto operandAttr =
1007       llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1008   if (operandAttr)
1009     return operandAttr;
1010 
1011   // If the dim-length is 1, tosa.reverse is a no-op.
1012   if (operandTy.hasRank() &&
1013       (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1014     return operand;
1015 
1016   return {};
1017 }
1018 
1019 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1020   auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1021   auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1022 
1023   if (!inputTy || !outputTy)
1024     return {};
1025 
1026   if (inputTy == outputTy && inputTy.hasStaticShape())
1027     return getInput1();
1028 
1029   if (!adaptor.getInput1())
1030     return {};
1031 
1032   // Cannot create an ElementsAttr from non-int/float/index types
1033   if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1034       !outputTy.getElementType().isIntOrIndexOrFloat())
1035     return {};
1036 
1037   auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1038   if (operand.isSplat() && outputTy.hasStaticShape()) {
1039     return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1040   }
1041 
1042   if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1043       outputTy.getNumElements() == 1) {
1044     DenseElementsAttr startElems;
1045     if (!matchPattern(getStart(), m_Constant(&startElems)))
1046       return {};
1047 
1048     llvm::SmallVector<uint64_t> indices =
1049         llvm::to_vector(startElems.getValues<uint64_t>());
1050     auto value = operand.getValues<Attribute>()[indices];
1051     return SplatElementsAttr::get(outputTy, value);
1052   }
1053 
1054   return {};
1055 }
1056 
1057 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1058   if (getOnTrue() == getOnFalse())
1059     return getOnTrue();
1060 
1061   auto predicate =
1062       llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
1063   if (!predicate)
1064     return {};
1065 
1066   if (!predicate.isSplat())
1067     return {};
1068   return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1069                                                          : getOnFalse();
1070 }
1071 
1072 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1073   if (getInput1().getType() == getType()) {
1074     if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1075             adaptor.getMultiples())) {
1076       if (multiples.isSplat() &&
1077           multiples.getSplatValue<APInt>().getSExtValue() == 1)
1078         return getInput1();
1079       if (auto int_array_attr =
1080               llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1081         if (llvm::all_of(int_array_attr.getValues<APInt>(),
1082                          [](APInt v) { return v.getSExtValue() == 1; }))
1083           return getInput1();
1084       }
1085     }
1086   }
1087   return {};
1088 }
1089 
1090 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1091   auto resultTy = llvm::cast<ShapedType>(getType());
1092 
1093   // Transposing splat values just means reshaping.
1094   if (auto input =
1095           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1096     if (input.isSplat() && resultTy.hasStaticShape() &&
1097         input.getType().getElementType() == resultTy.getElementType())
1098       return input.reshape(resultTy);
1099   }
1100 
1101   // Transpose is not the identity transpose.
1102   SmallVector<int32_t> perms;
1103   if (getConstantPerms(perms).failed())
1104     return {};
1105 
1106   if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1107     return {};
1108 
1109   return getInput1();
1110 }
1111 
1112 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
1113   auto input = getInput1();
1114   // Element-wise log(exp(x)) = x
1115   if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
1116     return op.getInput1();
1117   }
1118 
1119   return {};
1120 }
1121 
1122 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
1123   auto input = getInput1();
1124   // Element-wise exp(log(x)) = x
1125   if (auto op = input.getDefiningOp<tosa::LogOp>()) {
1126     return op.getInput1();
1127   }
1128 
1129   return {};
1130 }
1131 
1132 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1133   auto input = getInput1();
1134   // Element-wise negate(negate(x)) = x
1135   if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
1136     return op.getInput1();
1137   }
1138 
1139   return {};
1140 }
1141 
1142 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1143   auto input = getInput1();
1144   // Element-wise abs(abs(x)) = abs(x)
1145   if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1146     return input;
1147   }
1148 
1149   return {};
1150 }
1151 
1152 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1153   // Fold consecutive concats on the same axis into a single op.
1154   // Keep track of the operands so we are able to construct a new concat
1155   // later. Conservatively assume that we double the number of operands when
1156   // folding
1157   SmallVector<Value, 8> concatOperands;
1158   concatOperands.reserve(2 * getNumOperands());
1159 
1160   // Find all operands that are foldable concats
1161   bool foundFoldableConcat = false;
1162   for (Value operand : getOperands()) {
1163     concatOperands.emplace_back(operand);
1164 
1165     auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1166     if (!producer)
1167       continue;
1168 
1169     // Not foldable if axes are not the same
1170     if (getAxis() != producer.getAxis())
1171       continue;
1172 
1173     // Replace the original operand with all incoming operands
1174     foundFoldableConcat = true;
1175     concatOperands.pop_back();
1176     llvm::append_range(concatOperands, producer->getOperands());
1177   }
1178 
1179   if (!foundFoldableConcat)
1180     return {};
1181 
1182   getOperation()->setOperands(concatOperands);
1183   return getResult();
1184 }
1185 
1186 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1187   auto input = adaptor.getInput1();
1188 
1189   auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1190   // Fold splat inputs only.
1191   if (!inputAttr || !inputAttr.isSplat())
1192     return {};
1193 
1194   auto shapeType = llvm::cast<ShapedType>(getType());
1195   if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1196     auto floatVal = inputAttr.getSplatValue<APFloat>();
1197     return DenseElementsAttr::get(shapeType,
1198                                   ReciprocalOp::calcOneElement(floatVal));
1199   }
1200 
1201   return {};
1202 }
1203