xref: /llvm-project/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (revision a58e774fba42e13aa00667d644e96b783fc914b4)
1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Arith/Utils/Utils.h"
16 #include "mlir/Dialect/Index/IR/IndexOps.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Math/IR/Math.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
22 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
23 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24 #include "mlir/Dialect/Utils/StaticValueUtils.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33 
34 #include <numeric>
35 
36 using namespace mlir;
37 using namespace mlir::tosa;
38 
39 template <typename T>
40 static arith::ConstantOp
41 createConstFromIntAttribute(Operation *op, const std::string &attrName,
42                             Type requiredAttrType, OpBuilder &rewriter) {
43   auto castedN = static_cast<T>(
44       cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
45   return rewriter.create<arith::ConstantOp>(
46       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
47 }
48 
49 static Value createLinalgBodyCalculationForElementwiseOp(
50     Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
51     ConversionPatternRewriter &rewriter) {
52   Location loc = op->getLoc();
53   auto elementTy =
54       cast<ShapedType>(op->getOperand(0).getType()).getElementType();
55 
56   // tosa::AbsOp
57   if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
58     return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
59 
60   if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
61     auto zero = rewriter.create<arith::ConstantOp>(
62         loc, rewriter.getZeroAttr(elementTy));
63     auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
64     return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
65   }
66 
67   // tosa::AddOp
68   if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
69     return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
70 
71   if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
72     return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
73 
74   // tosa::SubOp
75   if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
76     return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
77 
78   if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
79     return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
80 
81   // tosa::IntDivOp
82   if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
83     return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
84 
85   // tosa::ReciprocalOp
86   if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
87     auto one =
88         rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
89     return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
90   }
91 
92   // tosa::MulOp
93   if (isa<tosa::MulOp>(op)) {
94     auto shift_val = cast<tosa::MulOp>(op).getShift();
95 
96     if (isa<FloatType>(elementTy)) {
97       return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
98     }
99 
100     if (isa<IntegerType>(elementTy)) {
101       int32_t shift = 0;
102       ElementsAttr shift_elem;
103       if (shift_val.getImpl() &&
104           matchPattern(shift_val, m_Constant(&shift_elem))) {
105         // Explicit shift is set.
106         shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
107       }
108 
109       Value a = args[0];
110       Value b = args[1];
111       if (shift > 0) {
112         auto shiftConst =
113             rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
114         if (!a.getType().isInteger(32))
115           a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
116 
117         if (!b.getType().isInteger(32))
118           b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
119 
120         auto result = rewriter.create<tosa::ApplyScaleOp>(
121             loc, rewriter.getI32Type(), a, b, shiftConst,
122             rewriter.getBoolAttr(false));
123 
124         if (elementTy.isInteger(32))
125           return result;
126 
127         return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
128       }
129 
130       int aWidth = a.getType().getIntOrFloatBitWidth();
131       int bWidth = b.getType().getIntOrFloatBitWidth();
132       int cWidth = resultTypes[0].getIntOrFloatBitWidth();
133 
134       if (aWidth < cWidth)
135         a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
136       if (bWidth < cWidth)
137         b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
138 
139       return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
140     }
141   }
142 
143   // tosa::NegateOp
144   if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
145     return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
146 
147   if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
148     int64_t inZp = 0, outZp = 0;
149 
150     if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
151       auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
152       inZp = quantizationInfo.value().getInputZp();
153       outZp = quantizationInfo.value().getOutputZp();
154     }
155 
156     int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
157     if (!inZp && !outZp) {
158       auto constant = rewriter.create<arith::ConstantOp>(
159           loc, IntegerAttr::get(elementTy, 0));
160       return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
161                                             args[0]);
162     }
163 
164     // Compute the maximum value that can occur in the intermediate buffer.
165     int64_t zpAdd = inZp + outZp;
166     int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
167                        std::abs(zpAdd) + 1;
168 
169     // Convert that maximum value into the maximum bitwidth needed to represent
170     // it. We assume 48-bit numbers may be supported further in the pipeline.
171     int intermediateBitWidth = 64;
172     if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
173       intermediateBitWidth = 16;
174     } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
175       intermediateBitWidth = 32;
176     } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
177       intermediateBitWidth = 48;
178     }
179 
180     Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
181     Value zpAddValue = rewriter.create<arith::ConstantOp>(
182         loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
183 
184     // The negation can be applied by doing:
185     //  outputValue = inZp + outZp - inputValue
186     auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
187     auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
188 
189     // Clamp to the negation range.
190     Value min = rewriter.create<arith::ConstantIntOp>(
191         loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
192         intermediateType);
193     Value max = rewriter.create<arith::ConstantIntOp>(
194         loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
195         intermediateType);
196     auto clamp =
197         clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
198 
199     // Truncate to the final value.
200     return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
201   }
202 
203   // tosa::BitwiseAndOp
204   if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
205     return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
206 
207   // tosa::BitwiseOrOp
208   if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
209     return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
210 
211   // tosa::BitwiseNotOp
212   if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
213     auto allOnesAttr = rewriter.getIntegerAttr(
214         elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
215     auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
216     return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
217   }
218 
219   // tosa::BitwiseXOrOp
220   if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
221     return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
222 
223   // tosa::LogicalLeftShiftOp
224   if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
225     return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
226 
227   // tosa::LogicalRightShiftOp
228   if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
229     return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
230 
231   // tosa::ArithmeticRightShiftOp
232   if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
233     auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
234     auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
235     if (!round) {
236       return result;
237     }
238 
239     Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
240     auto one =
241         rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
242     auto zero =
243         rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
244     auto i1one =
245         rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
246 
247     // Checking that input2 != 0
248     auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
249         loc, arith::CmpIPredicate::sgt, args[1], zero);
250 
251     // Checking for the last bit of input1 to be 1
252     auto subtract =
253         rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
254     auto shifted =
255         rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
256             ->getResults();
257     auto truncated =
258         rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
259     auto isInputOdd =
260         rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
261 
262     auto shouldRound = rewriter.create<arith::AndIOp>(
263         loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
264     auto extended =
265         rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
266     return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
267   }
268 
269   // tosa::ClzOp
270   if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
271     return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
272   }
273 
274   // tosa::LogicalAnd
275   if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
276     return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
277 
278   // tosa::LogicalNot
279   if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
280     auto one = rewriter.create<arith::ConstantOp>(
281         loc, rewriter.getIntegerAttr(elementTy, 1));
282     return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
283   }
284 
285   // tosa::LogicalOr
286   if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
287     return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
288 
289   // tosa::LogicalXor
290   if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
291     return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
292 
293   // tosa::PowOp
294   if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
295     return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
296 
297   // tosa::RsqrtOp
298   if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
299     return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
300 
301   // tosa::LogOp
302   if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
303     return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
304 
305   // tosa::ExpOp
306   if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
307     return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
308 
309   // tosa::SinOp
310   if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
311     return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
312 
313   // tosa::CosOp
314   if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
315     return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
316 
317   // tosa::TanhOp
318   if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
319     return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
320 
321   // tosa::ErfOp
322   if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
323     return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
324 
325   // tosa::GreaterOp
326   if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
327     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
328                                           args[0], args[1]);
329 
330   if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
331     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
332                                           args[0], args[1]);
333 
334   // tosa::GreaterEqualOp
335   if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
336     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
337                                           args[0], args[1]);
338 
339   if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
340     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
341                                           args[0], args[1]);
342 
343   // tosa::EqualOp
344   if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
345     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
346                                           args[0], args[1]);
347 
348   if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
349     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
350                                           args[0], args[1]);
351 
352   // tosa::SelectOp
353   if (isa<tosa::SelectOp>(op)) {
354     elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
355     if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
356       return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
357   }
358 
359   // tosa::MaximumOp
360   if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
361     return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
362   }
363 
364   if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
365     return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
366   }
367 
368   // tosa::MinimumOp
369   if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
370     return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
371   }
372 
373   if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
374     return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
375   }
376 
377   // tosa::CeilOp
378   if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
379     return rewriter.create<math::CeilOp>(loc, resultTypes, args);
380 
381   // tosa::FloorOp
382   if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
383     return rewriter.create<math::FloorOp>(loc, resultTypes, args);
384 
385   // tosa::ClampOp
386   if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
387     bool losesInfo = false;
388     APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
389     APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
390     minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
391                    APFloat::rmNearestTiesToEven, &losesInfo);
392     maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
393                    APFloat::rmNearestTiesToEven, &losesInfo);
394     auto min = rewriter.create<arith::ConstantOp>(
395         loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
396     auto max = rewriter.create<arith::ConstantOp>(
397         loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
398     return clampFloatHelper(loc, args[0], min, max, rewriter);
399   }
400 
401   if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
402     auto intTy = cast<IntegerType>(elementTy);
403     int64_t min =
404         cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
405     int64_t max =
406         cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
407 
408     int64_t minRepresentable = std::numeric_limits<int64_t>::min();
409     int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
410     if (intTy.isUnsignedInteger()) {
411       minRepresentable = 0;
412       if (intTy.getIntOrFloatBitWidth() <= 63) {
413         maxRepresentable =
414             (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
415                 .getZExtValue();
416       }
417     } else if (intTy.getIntOrFloatBitWidth() <= 64) {
418       // Ensure that min & max fit into signed n-bit constants.
419       minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
420                              .getSExtValue();
421       maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
422                              .getSExtValue();
423     }
424     // Ensure that the bounds are representable as n-bit signed/unsigned
425     // integers.
426     min = std::max(min, minRepresentable);
427     max = std::max(max, minRepresentable);
428     min = std::min(min, maxRepresentable);
429     max = std::min(max, maxRepresentable);
430 
431     auto minVal = rewriter.create<arith::ConstantIntOp>(
432         loc, min, intTy.getIntOrFloatBitWidth());
433     auto maxVal = rewriter.create<arith::ConstantIntOp>(
434         loc, max, intTy.getIntOrFloatBitWidth());
435     return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
436                           intTy.isUnsignedInteger());
437   }
438 
439   // tosa::SigmoidOp
440   if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
441     auto one =
442         rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
443     auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
444     auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
445     auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
446     return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
447   }
448 
449   // tosa::CastOp
450   if (isa<tosa::CastOp>(op)) {
451     Type srcTy = elementTy;
452     Type dstTy = resultTypes.front();
453     bool bitExtend =
454         srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
455 
456     if (srcTy == dstTy)
457       return args.front();
458 
459     if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
460       return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
461                                             std::nullopt);
462 
463     if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
464       return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
465                                               std::nullopt);
466 
467     // 1-bit integers need to be treated as signless.
468     if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
469       return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
470                                               std::nullopt);
471 
472     if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
473       return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
474                                              std::nullopt);
475 
476     // Unsigned integers need an unrealized cast so that they can be passed
477     // to UIToFP.
478     if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
479       auto unrealizedCast =
480           rewriter
481               .create<UnrealizedConversionCastOp>(
482                   loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
483                   args[0])
484               .getResult(0);
485       return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
486                                               unrealizedCast);
487     }
488 
489     // All other si-to-fp conversions should be handled by SIToFP.
490     if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
491       return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
492                                               std::nullopt);
493 
494     // Casting to boolean, floats need to only be checked as not-equal to zero.
495     if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
496       Value zero = rewriter.create<arith::ConstantOp>(
497           loc, rewriter.getFloatAttr(srcTy, 0.0));
498       return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
499                                             args.front(), zero);
500     }
501 
502     if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
503       auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
504 
505       const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
506       // Check whether neither int min nor int max can be represented in the
507       // input floating-point type due to too short exponent range.
508       if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
509           APFloat::semanticsMaxExponent(fltSemantics)) {
510         // Use cmp + select to replace infinites by int min / int max. Other
511         // integral values can be represented in the integer space.
512         auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
513         auto posInf = rewriter.create<arith::ConstantOp>(
514             loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
515                                        APFloat::getInf(fltSemantics)));
516         auto negInf = rewriter.create<arith::ConstantOp>(
517             loc, rewriter.getFloatAttr(
518                      getElementTypeOrSelf(srcTy),
519                      APFloat::getInf(fltSemantics, /*Negative=*/true)));
520         auto overflow = rewriter.create<arith::CmpFOp>(
521             loc, arith::CmpFPredicate::UEQ, rounded, posInf);
522         auto underflow = rewriter.create<arith::CmpFOp>(
523             loc, arith::CmpFPredicate::UEQ, rounded, negInf);
524         auto intMin = rewriter.create<arith::ConstantOp>(
525             loc, rewriter.getIntegerAttr(
526                      getElementTypeOrSelf(dstTy),
527                      APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
528         auto intMax = rewriter.create<arith::ConstantOp>(
529             loc, rewriter.getIntegerAttr(
530                      getElementTypeOrSelf(dstTy),
531                      APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
532         auto maxClamped =
533             rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
534         return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
535                                                 maxClamped);
536       }
537 
538       auto intMinFP = rewriter.create<arith::ConstantOp>(
539           loc, rewriter.getFloatAttr(
540                    getElementTypeOrSelf(srcTy),
541                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
542                        .getSExtValue()));
543 
544       // Check whether the mantissa has enough bits to represent int max.
545       if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
546           dstTy.getIntOrFloatBitWidth() - 1) {
547         // Int min can also be represented since it is a power of two and thus
548         // consists of a single leading bit. Therefore we can clamp the input
549         // in the floating-point domain.
550 
551         auto intMaxFP = rewriter.create<arith::ConstantOp>(
552             loc, rewriter.getFloatAttr(
553                      getElementTypeOrSelf(srcTy),
554                      APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
555                          .getSExtValue()));
556 
557         Value clamped =
558             clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
559         return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
560       }
561 
562       // Due to earlier check we know exponant range is big enough to represent
563       // int min. We can therefore rely on int max + 1 being representable as
564       // well because it's just int min with a positive sign. So clamp the min
565       // value and compare against that to select the max int value if needed.
566       auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
567           loc, rewriter.getFloatAttr(
568                    getElementTypeOrSelf(srcTy),
569                    static_cast<double>(
570                        APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
571                            .getSExtValue()) +
572                        1.0f));
573 
574       auto intMax = rewriter.create<arith::ConstantOp>(
575           loc, rewriter.getIntegerAttr(
576                    getElementTypeOrSelf(dstTy),
577                    APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
578       auto minClampedFP =
579           rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
580       auto minClamped =
581           rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
582       auto overflow = rewriter.create<arith::CmpFOp>(
583           loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
584       return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
585                                               minClamped);
586     }
587 
588     // Casting to boolean, integers need to only be checked as not-equal to
589     // zero.
590     if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
591       Value zero = rewriter.create<arith::ConstantIntOp>(
592           loc, 0, srcTy.getIntOrFloatBitWidth());
593       return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
594                                             args.front(), zero);
595     }
596 
597     if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
598       return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
599                                              std::nullopt);
600 
601     if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
602       return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
603     }
604   }
605 
606   (void)rewriter.notifyMatchFailure(
607       op, "unhandled op for linalg body calculation for elementwise op");
608   return nullptr;
609 }
610 
611 static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
612                         int64_t rank) {
613   // No need to expand if we are already at the desired rank
614   auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
615   assert(tensorType && "expected a ranked tensor type");
616   int64_t tensorRank = tensorType.getRank();
617   int64_t numExtraDims = rank - tensorRank;
618   assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
619   if (!numExtraDims)
620     return tensor;
621 
622   // Compute reassociation indices
623   SmallVector<ReassociationIndices> reassociationIndices(tensorRank);
624   int64_t index = 0;
625   if (tensorRank != 0) {
626     for (index = 0; index <= numExtraDims; index++)
627       reassociationIndices[0].push_back(index);
628     for (size_t position = 1; position < reassociationIndices.size();
629          position++)
630       reassociationIndices[position].push_back(index++);
631   }
632 
633   // Compute result type
634   SmallVector<int64_t> resultShape;
635   for (index = 0; index < numExtraDims; index++)
636     resultShape.push_back(1);
637   for (auto size : tensorType.getShape())
638     resultShape.push_back(size);
639   auto resultType =
640       RankedTensorType::get(resultShape, tensorType.getElementType());
641 
642   // Emit 'tensor.expand_shape' op
643   return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
644                                                 reassociationIndices);
645 }
646 
647 static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
648                                            Location loc, ValueRange operands,
649                                            int64_t rank) {
650   return llvm::map_to_vector(operands, [&](Value operand) {
651     return expandRank(rewriter, loc, operand, rank);
652   });
653 }
654 
655 using IndexPool = DenseMap<int64_t, Value>;
656 
657 // Emit an 'arith.constant' op for the given index if it has not been created
658 // yet, or return an existing constant. This will prevent an excessive creation
659 // of redundant constants, easing readability of emitted code for unit tests.
660 static Value createIndex(PatternRewriter &rewriter, Location loc,
661                          IndexPool &indexPool, int64_t index) {
662   auto [it, inserted] = indexPool.try_emplace(index);
663   if (inserted)
664     it->second =
665         rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
666   return it->second;
667 }
668 
669 static Value getTensorDim(PatternRewriter &rewriter, Location loc,
670                           IndexPool &indexPool, Value tensor, int64_t index) {
671   auto indexValue = createIndex(rewriter, loc, indexPool, index);
672   return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
673 }
674 
675 static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
676                                        IndexPool &indexPool, Value tensor,
677                                        int64_t index) {
678   auto shapedType = dyn_cast<ShapedType>(tensor.getType());
679   assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
680   assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
681   if (shapedType.isDynamicDim(index))
682     return getTensorDim(rewriter, loc, indexPool, tensor, index);
683   return rewriter.getIndexAttr(shapedType.getDimSize(index));
684 }
685 
686 static bool operandsAndResultsRanked(Operation *operation) {
687   auto isRanked = [](Value value) {
688     return isa<RankedTensorType>(value.getType());
689   };
690   return llvm::all_of(operation->getOperands(), isRanked) &&
691          llvm::all_of(operation->getResults(), isRanked);
692 }
693 
694 // Compute the runtime dimension size for dimension 'dim' of the output by
695 // inspecting input 'operands', all of which are expected to have the same rank.
696 // This function returns a pair {targetSize, masterOperand}.
697 //
698 // The runtime size of the output dimension is returned either as a statically
699 // computed attribute or as a runtime SSA value.
700 //
701 // If the target size was inferred directly from one dominating operand, that
702 // operand is returned in 'masterOperand'. If the target size is inferred from
703 // multiple operands, 'masterOperand' is set to nullptr.
704 static std::pair<OpFoldResult, Value>
705 computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
706                   ValueRange operands, int64_t dim) {
707   // If any input operand contains a static size greater than 1 for this
708   // dimension, that is the target size. An occurrence of an additional static
709   // dimension greater than 1 with a different value is undefined behavior.
710   for (auto operand : operands) {
711     auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
712     if (!ShapedType::isDynamic(size) && size > 1)
713       return {rewriter.getIndexAttr(size), operand};
714   }
715 
716   // Filter operands with dynamic dimension
717   auto operandsWithDynamicDim =
718       llvm::filter_to_vector(operands, [&](Value operand) {
719         return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
720       });
721 
722   // If no operand has a dynamic dimension, it means all sizes were 1
723   if (operandsWithDynamicDim.empty())
724     return {rewriter.getIndexAttr(1), operands.front()};
725 
726   // Emit code that computes the runtime size for this dimension. If there is
727   // only one operand with a dynamic dimension, it is considered the master
728   // operand that determines the runtime size of the output dimension.
729   auto targetSize =
730       getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
731   if (operandsWithDynamicDim.size() == 1)
732     return {targetSize, operandsWithDynamicDim[0]};
733 
734   // Calculate maximum size among all dynamic dimensions
735   for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
736     auto nextSize =
737         getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
738     targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
739   }
740   return {targetSize, nullptr};
741 }
742 
743 // Compute the runtime output size for all dimensions. This function returns
744 // a pair {targetShape, masterOperands}.
745 static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
746 computeTargetShape(PatternRewriter &rewriter, Location loc,
747                    IndexPool &indexPool, ValueRange operands) {
748   assert(!operands.empty());
749   auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
750   SmallVector<OpFoldResult> targetShape;
751   SmallVector<Value> masterOperands;
752   for (auto dim : llvm::seq<int64_t>(0, rank)) {
753     auto [targetSize, masterOperand] =
754         computeTargetSize(rewriter, loc, indexPool, operands, dim);
755     targetShape.push_back(targetSize);
756     masterOperands.push_back(masterOperand);
757   }
758   return {targetShape, masterOperands};
759 }
760 
761 static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
762                                        IndexPool &indexPool, Value operand,
763                                        int64_t dim, OpFoldResult targetSize,
764                                        Value masterOperand) {
765   // Nothing to do if this is a static dimension
766   auto rankedTensorType = cast<RankedTensorType>(operand.getType());
767   if (!rankedTensorType.isDynamicDim(dim))
768     return operand;
769 
770   // If the target size for this dimension was directly inferred by only taking
771   // this operand into account, there is no need to broadcast. This is an
772   // optimization that will prevent redundant control flow, and constitutes the
773   // main motivation for tracking "master operands".
774   if (operand == masterOperand)
775     return operand;
776 
777   // Affine maps for 'linalg.generic' op
778   auto rank = rankedTensorType.getRank();
779   SmallVector<AffineExpr> affineExprs;
780   for (auto index : llvm::seq<int64_t>(0, rank)) {
781     auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
782                                    : rewriter.getAffineDimExpr(index);
783     affineExprs.push_back(affineExpr);
784   }
785   auto broadcastAffineMap =
786       AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
787   auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
788   SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
789 
790   // Check if broadcast is necessary
791   auto one = createIndex(rewriter, loc, indexPool, 1);
792   auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
793   auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
794       loc, arith::CmpIPredicate::eq, runtimeSize, one);
795 
796   // Emit 'then' region of 'scf.if'
797   auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
798     // It is not safe to cache constants across regions.
799     // New constants could potentially violate dominance requirements.
800     IndexPool localPool;
801 
802     // Emit 'tensor.empty' op
803     SmallVector<OpFoldResult> outputTensorShape;
804     for (auto index : llvm::seq<int64_t>(0, rank)) {
805       auto size = index == dim ? targetSize
806                                : getOrFoldTensorDim(rewriter, loc, localPool,
807                                                     operand, index);
808       outputTensorShape.push_back(size);
809     }
810     Value outputTensor = opBuilder.create<tensor::EmptyOp>(
811         loc, outputTensorShape, rankedTensorType.getElementType());
812 
813     // Emit 'linalg.generic' op
814     auto resultTensor =
815         opBuilder
816             .create<linalg::GenericOp>(
817                 loc, outputTensor.getType(), operand, outputTensor, affineMaps,
818                 getNParallelLoopsAttrs(rank),
819                 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
820                   // Emit 'linalg.yield' op
821                   opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
822                 })
823             .getResult(0);
824 
825     // Cast to original operand type if necessary
826     auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
827         loc, operand.getType(), resultTensor);
828 
829     // Emit 'scf.yield' op
830     opBuilder.create<scf::YieldOp>(loc, castResultTensor);
831   };
832 
833   // Emit 'else' region of 'scf.if'
834   auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
835     opBuilder.create<scf::YieldOp>(loc, operand);
836   };
837 
838   // Emit 'scf.if' op
839   auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
840                                          emitThenRegion, emitElseRegion);
841   return ifOp.getResult(0);
842 }
843 
844 static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
845                                         IndexPool &indexPool, Value operand,
846                                         ArrayRef<OpFoldResult> targetShape,
847                                         ArrayRef<Value> masterOperands) {
848   int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
849   assert((int64_t)targetShape.size() == rank);
850   assert((int64_t)masterOperands.size() == rank);
851   for (auto index : llvm::seq<int64_t>(0, rank))
852     operand =
853         broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
854                                   targetShape[index], masterOperands[index]);
855   return operand;
856 }
857 
858 static SmallVector<Value>
859 broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
860                            IndexPool &indexPool, ValueRange operands,
861                            ArrayRef<OpFoldResult> targetShape,
862                            ArrayRef<Value> masterOperands) {
863   // No need to broadcast for unary operations
864   if (operands.size() == 1)
865     return operands;
866 
867   // Broadcast dynamic dimensions operand by operand
868   return llvm::map_to_vector(operands, [&](Value operand) {
869     return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
870                                       targetShape, masterOperands);
871   });
872 }
873 
874 static LogicalResult
875 emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
876                            Operation *operation, ValueRange operands,
877                            ArrayRef<OpFoldResult> targetShape,
878                            const TypeConverter &converter) {
879   // Generate output tensor
880   auto resultType = cast_or_null<RankedTensorType>(
881       converter.convertType(operation->getResultTypes().front()));
882   if (!resultType) {
883     return rewriter.notifyMatchFailure(operation, "failed to convert type");
884   }
885   Value outputTensor = rewriter.create<tensor::EmptyOp>(
886       loc, targetShape, resultType.getElementType());
887 
888   // Create affine maps. Input affine maps broadcast static dimensions of size
889   // 1. The output affine map is an identity map.
890   //
891   auto rank = resultType.getRank();
892   auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
893     auto shape = cast<ShapedType>(operand.getType()).getShape();
894     SmallVector<AffineExpr> affineExprs;
895     for (auto it : llvm::enumerate(shape)) {
896       // Prefer producting identity maps whenever possible (i.e. no broadcasting
897       // needed) because some transforms (like reshape folding)
898       // do not support affine constant exprs.
899       bool requiresBroadcast =
900           (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
901       auto affineExpr = requiresBroadcast
902                             ? rewriter.getAffineConstantExpr(0)
903                             : rewriter.getAffineDimExpr(it.index());
904       affineExprs.push_back(affineExpr);
905     }
906     return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
907   });
908   affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
909 
910   // Emit 'linalg.generic' op
911   bool encounteredError = false;
912   auto linalgOp = rewriter.create<linalg::GenericOp>(
913       loc, outputTensor.getType(), operands, outputTensor, affineMaps,
914       getNParallelLoopsAttrs(rank),
915       [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
916         Value opResult = createLinalgBodyCalculationForElementwiseOp(
917             operation, blockArgs.take_front(operation->getNumOperands()),
918             {resultType.getElementType()}, rewriter);
919         if (!opResult) {
920           encounteredError = true;
921           return;
922         }
923         opBuilder.create<linalg::YieldOp>(loc, opResult);
924       });
925   if (encounteredError)
926     return rewriter.notifyMatchFailure(
927         operation, "unable to create linalg.generic body for elementwise op");
928 
929   // Cast 'linalg.generic' result into original result type if needed
930   auto castResult = rewriter.createOrFold<tensor::CastOp>(
931       loc, resultType, linalgOp->getResult(0));
932   rewriter.replaceOp(operation, castResult);
933   return success();
934 }
935 
936 static LogicalResult
937 elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
938                                  ConversionPatternRewriter &rewriter,
939                                  const TypeConverter &converter) {
940 
941   // Collect op properties
942   assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
943   assert(operation->getNumOperands() >= 1 &&
944          "elementwise op expects at least 1 operand");
945   if (!operandsAndResultsRanked(operation))
946     return rewriter.notifyMatchFailure(operation,
947                                        "Unranked tensors not supported");
948 
949   // Lower operation
950   IndexPool indexPool;
951   auto loc = operation->getLoc();
952   auto rank =
953       cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
954   // For the mul op we need to avoid expanding the rank of the optional shift
955   // input.
956   auto operandsToExpand =
957       isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
958 
959   auto expandedOperands =
960       expandInputRanks(rewriter, loc, operandsToExpand, rank);
961   auto [targetShape, masterOperands] =
962       computeTargetShape(rewriter, loc, indexPool, expandedOperands);
963   auto broadcastOperands = broadcastDynamicDimensions(
964       rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
965   return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
966                                     targetShape, converter);
967 }
968 
969 // Returns the constant initial value for a given reduction operation. The
970 // attribute type varies depending on the element type required.
971 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
972                                                PatternRewriter &rewriter) {
973   if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
974     return rewriter.getFloatAttr(elementTy, 0.0);
975 
976   if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
977     return rewriter.getIntegerAttr(elementTy, 0);
978 
979   if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
980     return rewriter.getFloatAttr(elementTy, 1.0);
981 
982   if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
983     return rewriter.getIntegerAttr(elementTy, 1);
984 
985   if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
986     return rewriter.getFloatAttr(
987         elementTy, APFloat::getLargest(
988                        cast<FloatType>(elementTy).getFloatSemantics(), false));
989 
990   if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
991     return rewriter.getIntegerAttr(
992         elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
993 
994   if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
995     return rewriter.getFloatAttr(
996         elementTy, APFloat::getLargest(
997                        cast<FloatType>(elementTy).getFloatSemantics(), true));
998 
999   if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1000     return rewriter.getIntegerAttr(
1001         elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1002 
1003   if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1004     return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1005 
1006   if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1007     return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
1008 
1009   if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1010     return rewriter.getFloatAttr(
1011         elementTy, APFloat::getLargest(
1012                        cast<FloatType>(elementTy).getFloatSemantics(), true));
1013 
1014   if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1015     return rewriter.getIntegerAttr(
1016         elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1017 
1018   return {};
1019 }
1020 
1021 // Creates the body calculation for a reduction. The operations vary depending
1022 // on the input type.
1023 static Value createLinalgBodyCalculationForReduceOp(Operation *op,
1024                                                     ValueRange args,
1025                                                     Type elementTy,
1026                                                     PatternRewriter &rewriter) {
1027   Location loc = op->getLoc();
1028   if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1029     return rewriter.create<arith::AddFOp>(loc, args);
1030   }
1031 
1032   if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1033     return rewriter.create<arith::AddIOp>(loc, args);
1034   }
1035 
1036   if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
1037     return rewriter.create<arith::MulFOp>(loc, args);
1038   }
1039 
1040   if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
1041     return rewriter.create<arith::MulIOp>(loc, args);
1042   }
1043 
1044   if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1045     return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1046   }
1047 
1048   if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1049     return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1050   }
1051 
1052   if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1053     return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1054   }
1055 
1056   if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1057     return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1058   }
1059 
1060   if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1061     return rewriter.create<arith::AndIOp>(loc, args);
1062 
1063   if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1064     return rewriter.create<arith::OrIOp>(loc, args);
1065 
1066   return {};
1067 }
1068 
1069 // Performs the match and rewrite for reduction operations. This includes
1070 // declaring a correctly sized initial value, and the linalg.generic operation
1071 // that reduces across the specified axis.
1072 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1073                                                  PatternRewriter &rewriter) {
1074   auto loc = op->getLoc();
1075   auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
1076   auto resultTy = cast<ShapedType>(op->getResult(0).getType());
1077   auto elementTy = resultTy.getElementType();
1078   Value input = op->getOperand(0);
1079 
1080   SmallVector<int64_t> reduceShape;
1081   SmallVector<Value> dynDims;
1082   for (unsigned i = 0; i < inputTy.getRank(); i++) {
1083     if (axis != i) {
1084       reduceShape.push_back(inputTy.getDimSize(i));
1085       if (inputTy.isDynamicDim(i))
1086         dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1087     }
1088   }
1089 
1090   // First fill the output buffer with the init value.
1091   auto emptyTensor =
1092       rewriter
1093           .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1094                                    dynDims)
1095           .getResult();
1096 
1097   auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1098   if (!fillValueAttr)
1099     return rewriter.notifyMatchFailure(
1100         op, "No initial value found for reduction operation");
1101 
1102   auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1103   auto filledTensor = rewriter
1104                           .create<linalg::FillOp>(loc, ValueRange{fillValue},
1105                                                   ValueRange{emptyTensor})
1106                           .result();
1107 
1108   bool didEncounterError = false;
1109   auto linalgOp = rewriter.create<linalg::ReduceOp>(
1110       loc, input, filledTensor, axis,
1111       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1112         auto result = createLinalgBodyCalculationForReduceOp(
1113             op, blockArgs, elementTy, rewriter);
1114         if (result)
1115           didEncounterError = true;
1116 
1117         nestedBuilder.create<linalg::YieldOp>(loc, result);
1118       });
1119 
1120   if (!didEncounterError)
1121     return rewriter.notifyMatchFailure(
1122         op, "unable to create linalg.generic body for reduce op");
1123 
1124   SmallVector<ReassociationExprs, 4> reassociationMap;
1125   uint64_t expandInputRank =
1126       cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1127   reassociationMap.resize(expandInputRank);
1128 
1129   for (uint64_t i = 0; i < expandInputRank; i++) {
1130     int32_t dimToPush = i > axis ? i + 1 : i;
1131     reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1132   }
1133 
1134   if (expandInputRank != 0) {
1135     int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1136     reassociationMap[expandedDim].push_back(
1137         rewriter.getAffineDimExpr(expandedDim + 1));
1138   }
1139 
1140   // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1141   // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1142   // not have access to such information. This matters when handling dynamically
1143   // sized tensors.
1144   rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1145       op, resultTy, linalgOp.getResults()[0], reassociationMap);
1146   return success();
1147 }
1148 
1149 namespace {
1150 
1151 template <typename SrcOp>
1152 class PointwiseConverter : public OpConversionPattern<SrcOp> {
1153 public:
1154   using OpConversionPattern<SrcOp>::OpConversionPattern;
1155   using typename OpConversionPattern<SrcOp>::OpAdaptor;
1156 
1157   LogicalResult
1158   matchAndRewrite(SrcOp op, OpAdaptor operands,
1159                   ConversionPatternRewriter &rewriter) const final {
1160     return elementwiseMatchAndRewriteHelper(
1161         op, operands.getOperands(), rewriter, *this->getTypeConverter());
1162   }
1163 };
1164 
1165 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1166 public:
1167   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1168 
1169   LogicalResult matchAndRewrite(tosa::RescaleOp op,
1170                                 PatternRewriter &rewriter) const final {
1171     auto loc = op.getLoc();
1172     auto input = op.getInput();
1173     auto inputTy = cast<ShapedType>(op.getInput().getType());
1174     auto outputTy = cast<ShapedType>(op.getOutput().getType());
1175     unsigned rank = inputTy.getRank();
1176 
1177     // This is an illegal configuration. terminate and log an error
1178     if (op.getDoubleRound() && !op.getScale32())
1179       return rewriter.notifyMatchFailure(
1180           op, "tosa.rescale requires scale32 for double_round to be true");
1181 
1182     if (!isa<IntegerType>(inputTy.getElementType()))
1183       return rewriter.notifyMatchFailure(op, "only support integer type");
1184 
1185     SmallVector<Value> dynDims;
1186     for (int i = 0; i < outputTy.getRank(); i++) {
1187       if (outputTy.isDynamicDim(i)) {
1188         dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1189       }
1190     }
1191 
1192     // The shift and multiplier values.
1193     SmallVector<int32_t> multiplierValues(op.getMultiplier());
1194     SmallVector<int8_t> shiftValues(op.getShift());
1195 
1196     // If we shift by more than the bitwidth, this just sets to 0.
1197     for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1198       if (shiftValues[i] > 63) {
1199         shiftValues[i] = 0;
1200         multiplierValues[i] = 0;
1201       }
1202     }
1203 
1204     // Double round only occurs if shift is greater than 31, check that this
1205     // is ever true.
1206     bool doubleRound =
1207         op.getDoubleRound() &&
1208         llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1209 
1210     SmallVector<AffineMap> indexingMaps = {
1211         rewriter.getMultiDimIdentityMap(rank)};
1212     SmallVector<Value, 4> genericInputs = {input};
1213 
1214     // If we are rescaling per-channel then we need to store the multiplier
1215     // values in a buffer.
1216     Value multiplierConstant;
1217     int64_t multiplierArg = 0;
1218     if (multiplierValues.size() == 1) {
1219       multiplierConstant = rewriter.create<arith::ConstantOp>(
1220           loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1221     } else {
1222       SmallVector<AffineExpr, 2> multiplierExprs{
1223           rewriter.getAffineDimExpr(rank - 1)};
1224       auto multiplierType =
1225           RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1226                                 rewriter.getI32Type());
1227       genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1228           loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1229 
1230       indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1231                                             /*symbolCount=*/0, multiplierExprs,
1232                                             rewriter.getContext()));
1233 
1234       multiplierArg = indexingMaps.size() - 1;
1235     }
1236 
1237     // If we are rescaling per-channel then we need to store the shift
1238     // values in a buffer.
1239     Value shiftConstant;
1240     int64_t shiftArg = 0;
1241     if (shiftValues.size() == 1) {
1242       shiftConstant = rewriter.create<arith::ConstantOp>(
1243           loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1244     } else {
1245       SmallVector<AffineExpr, 2> shiftExprs = {
1246           rewriter.getAffineDimExpr(rank - 1)};
1247       auto shiftType =
1248           RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1249                                 rewriter.getIntegerType(8));
1250       genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1251           loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1252       indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1253                                             /*symbolCount=*/0, shiftExprs,
1254                                             rewriter.getContext()));
1255       shiftArg = indexingMaps.size() - 1;
1256     }
1257 
1258     // Indexing maps for output values.
1259     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1260 
1261     // Construct the indexing maps needed for linalg.generic ops.
1262     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1263         loc, outputTy.getShape(), outputTy.getElementType(),
1264         ArrayRef<Value>({dynDims}));
1265 
1266     auto linalgOp = rewriter.create<linalg::GenericOp>(
1267         loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1268         getNParallelLoopsAttrs(rank),
1269         [&](OpBuilder &nestedBuilder, Location nestedLoc,
1270             ValueRange blockArgs) {
1271           Value value = blockArgs[0];
1272           Type valueTy = value.getType();
1273 
1274           // For now we do all of our math in 64-bit. This is not optimal but
1275           // should be correct for now, consider computing correct bit depth
1276           // later.
1277           int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1278 
1279           auto inputZp = createConstFromIntAttribute<int32_t>(
1280               op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1281               nestedBuilder);
1282           auto outputZp = createConstFromIntAttribute<int32_t>(
1283               op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1284 
1285           Value multiplier = multiplierConstant ? multiplierConstant
1286                                                 : blockArgs[multiplierArg];
1287           Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1288 
1289           if (valueTy.getIntOrFloatBitWidth() < 32) {
1290             if (op.getInputUnsigned()) {
1291               value = nestedBuilder.create<arith::ExtUIOp>(
1292                   nestedLoc, nestedBuilder.getI32Type(), value);
1293             } else {
1294               value = nestedBuilder.create<arith::ExtSIOp>(
1295                   nestedLoc, nestedBuilder.getI32Type(), value);
1296             }
1297           }
1298 
1299           value =
1300               nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1301 
1302           value = nestedBuilder.create<tosa::ApplyScaleOp>(
1303               loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1304               nestedBuilder.getBoolAttr(doubleRound));
1305 
1306           // Move to the new zero-point.
1307           value =
1308               nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1309 
1310           // Saturate to the output size.
1311           IntegerType outIntType =
1312               cast<IntegerType>(blockArgs.back().getType());
1313           unsigned outBitWidth = outIntType.getWidth();
1314 
1315           int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1316           int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1317 
1318           // Unsigned integers have a difference output value.
1319           if (op.getOutputUnsigned()) {
1320             intMin = 0;
1321             intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1322           }
1323 
1324           auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1325               loc, nestedBuilder.getI32IntegerAttr(intMin));
1326           auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1327               loc, nestedBuilder.getI32IntegerAttr(intMax));
1328 
1329           value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1330                                  nestedBuilder, /*isUnsigned=*/false);
1331 
1332           if (outIntType.getWidth() < 32) {
1333             value = nestedBuilder.create<arith::TruncIOp>(
1334                 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1335                 value);
1336           }
1337 
1338           nestedBuilder.create<linalg::YieldOp>(loc, value);
1339         });
1340 
1341     rewriter.replaceOp(op, linalgOp->getResults());
1342     return success();
1343   }
1344 };
1345 
1346 // Handle the resize case where the input is a 1x1 image. This case
1347 // can entirely avoiding having extract operations which target much
1348 // more difficult to optimize away.
1349 class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1350 public:
1351   using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1352 
1353   LogicalResult matchAndRewrite(tosa::ResizeOp op,
1354                                 PatternRewriter &rewriter) const final {
1355     Location loc = op.getLoc();
1356     ImplicitLocOpBuilder builder(loc, rewriter);
1357     auto input = op.getInput();
1358     auto inputTy = cast<RankedTensorType>(input.getType());
1359     auto resultTy = cast<RankedTensorType>(op.getType());
1360     const bool isBilinear = op.getMode() == "BILINEAR";
1361 
1362     auto inputH = inputTy.getDimSize(1);
1363     auto inputW = inputTy.getDimSize(2);
1364     auto outputH = resultTy.getDimSize(1);
1365     auto outputW = resultTy.getDimSize(2);
1366 
1367     if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1368       return rewriter.notifyMatchFailure(
1369           op, "tosa.resize is not a pure 1x1->1x1 image operation");
1370 
1371     // TODO(suderman): These string values should be declared the TOSA dialect.
1372     if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1373       return rewriter.notifyMatchFailure(
1374           op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1375 
1376     if (inputTy == resultTy) {
1377       rewriter.replaceOp(op, input);
1378       return success();
1379     }
1380 
1381     ArrayRef<int64_t> scale = op.getScale();
1382 
1383     // Collapse the unit width and height away.
1384     SmallVector<ReassociationExprs, 4> reassociationMap(2);
1385     reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1386     reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1387     reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1388     reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1389 
1390     auto collapseTy =
1391         RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1392                               inputTy.getElementType());
1393     Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1394                                                              reassociationMap);
1395 
1396     // Get any dynamic shapes that appear in the input format.
1397     llvm::SmallVector<Value> outputDynSize;
1398     if (inputTy.isDynamicDim(0))
1399       outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1400     if (inputTy.isDynamicDim(3))
1401       outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1402 
1403     // Generate the elementwise operation for casting scaling the input value.
1404     auto genericTy = collapseTy.clone(resultTy.getElementType());
1405     Value empty = builder.create<tensor::EmptyOp>(
1406         genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1407     auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1408     SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1409                                                utils::IteratorType::parallel);
1410 
1411     auto generic = builder.create<linalg::GenericOp>(
1412         genericTy, ValueRange{collapse}, ValueRange{empty},
1413         ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1414         [=](OpBuilder &b, Location loc, ValueRange args) {
1415           Value value = args[0];
1416           // This is the quantized case.
1417           if (inputTy.getElementType() != resultTy.getElementType()) {
1418             value =
1419                 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1420 
1421             if (isBilinear && scale[0] != 0) {
1422               Value scaleY = b.create<arith::ConstantOp>(
1423                   loc, b.getI32IntegerAttr(scale[0]));
1424               value = b.create<arith::MulIOp>(loc, value, scaleY);
1425             }
1426 
1427             if (isBilinear && scale[2] != 0) {
1428               Value scaleX = b.create<arith::ConstantOp>(
1429                   loc, b.getI32IntegerAttr(scale[2]));
1430               value = b.create<arith::MulIOp>(loc, value, scaleX);
1431             }
1432           }
1433 
1434           b.create<linalg::YieldOp>(loc, value);
1435         });
1436 
1437     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1438         op, resultTy, generic.getResults()[0], reassociationMap);
1439     return success();
1440   }
1441 };
1442 
1443 // TOSA resize with width or height of 1 may be broadcasted to a wider
1444 // dimension. This is done by materializing a new tosa.resize without
1445 // the broadcasting behavior, and an explicit broadcast afterwards.
1446 class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1447 public:
1448   using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1449 
1450   LogicalResult matchAndRewrite(tosa::ResizeOp op,
1451                                 PatternRewriter &rewriter) const final {
1452     Location loc = op.getLoc();
1453     ImplicitLocOpBuilder builder(loc, rewriter);
1454     auto input = op.getInput();
1455     auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1456     auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1457 
1458     if (!inputTy || !resultTy)
1459       return rewriter.notifyMatchFailure(op,
1460                                          "requires ranked input/output types");
1461 
1462     auto batch = inputTy.getDimSize(0);
1463     auto channels = inputTy.getDimSize(3);
1464     auto inputH = inputTy.getDimSize(1);
1465     auto inputW = inputTy.getDimSize(2);
1466     auto outputH = resultTy.getDimSize(1);
1467     auto outputW = resultTy.getDimSize(2);
1468 
1469     if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1470       return rewriter.notifyMatchFailure(
1471           op, "tosa.resize has no broadcasting behavior");
1472 
1473     // For any dimension that is broadcastable we generate a width of 1
1474     // on the output.
1475     llvm::SmallVector<int64_t> resizeShape;
1476     resizeShape.push_back(batch);
1477     resizeShape.push_back(inputH == 1 ? 1 : outputH);
1478     resizeShape.push_back(inputW == 1 ? 1 : outputW);
1479     resizeShape.push_back(channels);
1480 
1481     auto resizeTy = resultTy.clone(resizeShape);
1482     auto resize =
1483         builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1484 
1485     // Collapse an unit result dims.
1486     SmallVector<ReassociationExprs, 4> reassociationMap(2);
1487     reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1488     reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1489     if (inputH != 1)
1490       reassociationMap.push_back({});
1491     reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1492     if (inputW != 1)
1493       reassociationMap.push_back({});
1494     reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1495 
1496     llvm::SmallVector<int64_t> collapseShape = {batch};
1497     if (inputH != 1)
1498       collapseShape.push_back(outputH);
1499     if (inputW != 1)
1500       collapseShape.push_back(outputW);
1501     collapseShape.push_back(channels);
1502 
1503     auto collapseTy = resultTy.clone(collapseShape);
1504     Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1505                                                              reassociationMap);
1506 
1507     // Broadcast the collapsed shape to the output result.
1508     llvm::SmallVector<Value> outputDynSize;
1509     if (inputTy.isDynamicDim(0))
1510       outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1511     if (inputTy.isDynamicDim(3))
1512       outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1513 
1514     SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1515                                                utils::IteratorType::parallel);
1516     Value empty = builder.create<tensor::EmptyOp>(
1517         resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1518 
1519     SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1520     if (inputH != 1)
1521       inputExprs.push_back(rewriter.getAffineDimExpr(1));
1522     if (inputW != 1)
1523       inputExprs.push_back(rewriter.getAffineDimExpr(2));
1524     inputExprs.push_back(rewriter.getAffineDimExpr(3));
1525 
1526     auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1527                                    inputExprs, rewriter.getContext());
1528 
1529     auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1530     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1531         op, resultTy, ValueRange{collapse}, ValueRange{empty},
1532         ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1533         [=](OpBuilder &b, Location loc, ValueRange args) {
1534           Value value = args[0];
1535           b.create<linalg::YieldOp>(loc, value);
1536         });
1537 
1538     return success();
1539   }
1540 };
1541 
1542 class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1543 public:
1544   using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1545 
1546   LogicalResult matchAndRewrite(tosa::ResizeOp op,
1547                                 PatternRewriter &rewriter) const final {
1548     Location loc = op.getLoc();
1549     ImplicitLocOpBuilder b(loc, rewriter);
1550     auto input = op.getInput();
1551     auto inputTy = cast<ShapedType>(input.getType());
1552     auto resultTy = cast<ShapedType>(op.getType());
1553     auto resultETy = resultTy.getElementType();
1554 
1555     bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1556     auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1557 
1558     auto imageH = inputTy.getShape()[1];
1559     auto imageW = inputTy.getShape()[2];
1560 
1561     auto dynamicDimsOr =
1562         checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1563     if (!dynamicDimsOr.has_value())
1564       return rewriter.notifyMatchFailure(
1565           op, "unable to get dynamic dimensions of tosa.resize");
1566 
1567     if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1568       return rewriter.notifyMatchFailure(
1569           op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1570 
1571     SmallVector<AffineMap, 2> affineMaps = {
1572         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1573     auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1574                                                  *dynamicDimsOr);
1575     auto genericOp = b.create<linalg::GenericOp>(
1576         resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1577         getNParallelLoopsAttrs(resultTy.getRank()));
1578     Value resize = genericOp.getResult(0);
1579 
1580     {
1581       OpBuilder::InsertionGuard regionGuard(b);
1582       b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1583                     TypeRange({resultETy}), loc);
1584       Value batch = b.create<linalg::IndexOp>(0);
1585       Value y = b.create<linalg::IndexOp>(1);
1586       Value x = b.create<linalg::IndexOp>(2);
1587       Value channel = b.create<linalg::IndexOp>(3);
1588 
1589       Value zeroI32 =
1590           b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1591       Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1592       Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1593       Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1594 
1595       Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1596       Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1597 
1598       ArrayRef<int64_t> offset = op.getOffset();
1599       ArrayRef<int64_t> border = op.getBorder();
1600       ArrayRef<int64_t> scale = op.getScale();
1601 
1602       Value yScaleN, yScaleD, xScaleN, xScaleD;
1603       yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1604       yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1605       xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1606       xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1607 
1608       Value yOffset, xOffset, yBorder, xBorder;
1609       yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1610       xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1611       yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1612       xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1613 
1614       // Compute the ix and dx values for both the X and Y dimensions.
1615       auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1616                                     Value scaleN, Value scaleD, Value offset,
1617                                     int size, ImplicitLocOpBuilder &b) {
1618         if (size == 1) {
1619           index = zeroI32;
1620           delta = zeroFp;
1621           return;
1622         }
1623         // x = x * scale_d + offset;
1624         // ix = floor(x / scale_n)
1625         Value val = b.create<arith::MulIOp>(in, scaleD);
1626         val = b.create<arith::AddIOp>(val, offset);
1627         index = b.create<arith::FloorDivSIOp>(val, scaleN);
1628 
1629         // rx = x % scale_n
1630         // dx = rx / scale_n
1631         Value r = b.create<arith::RemSIOp>(val, scaleN);
1632         Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1633         Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1634         delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1635       };
1636 
1637       // Compute the ix and dx values for the X and Y dimensions - int case.
1638       auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1639                                      Value scaleN, Value scaleD, Value offset,
1640                                      int size, ImplicitLocOpBuilder &b) {
1641         if (size == 1) {
1642           index = zeroI32;
1643           delta = zeroI32;
1644           return;
1645         }
1646         // x = x * scale_d + offset;
1647         // ix = floor(x / scale_n)
1648         //  dx = x - ix * scale_n;
1649         Value val = b.create<arith::MulIOp>(in, scaleD);
1650         val = b.create<arith::AddIOp>(val, offset);
1651         index = b.create<arith::DivSIOp>(val, scaleN);
1652         delta = b.create<arith::MulIOp>(index, scaleN);
1653         delta = b.create<arith::SubIOp>(val, delta);
1654       };
1655 
1656       Value ix, iy, dx, dy;
1657       if (floatingPointMode) {
1658         getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1659         getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1660       } else {
1661         getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1662         getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1663       }
1664 
1665       if (op.getMode() == "NEAREST_NEIGHBOR") {
1666         auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1667 
1668         auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1669                                            Value max, int size,
1670                                            ImplicitLocOpBuilder &b) -> Value {
1671           if (size == 1) {
1672             return b.create<arith::ConstantIndexOp>(0);
1673           }
1674 
1675           Value pred;
1676           if (floatingPointMode) {
1677             auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1678             pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1679           } else {
1680             Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1681             pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1682                                            dvalDouble, scale);
1683           }
1684 
1685           auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1686           val = b.create<arith::AddIOp>(val, offset);
1687           val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
1688           return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1689         };
1690 
1691         iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1692         ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1693 
1694         Value result = b.create<tensor::ExtractOp>(
1695             input, ValueRange{batch, iy, ix, channel});
1696 
1697         b.create<linalg::YieldOp>(result);
1698       } else {
1699         // The mode here must be BILINEAR.
1700         assert(op.getMode() == "BILINEAR");
1701 
1702         auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1703 
1704         auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1705                                   Value max, ImplicitLocOpBuilder &b) {
1706           val0 = in;
1707           val1 = b.create<arith::AddIOp>(val0, oneVal);
1708           val0 =
1709               clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1710           val1 =
1711               clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
1712           val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1713           val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1714         };
1715 
1716         // Linalg equivalent to the section below:
1717         //    int16_t iy0 = apply_max(iy, 0);
1718         //    int16_t iy1 = apply_min(iy + 1, IH - 1);
1719         //    int16_t ix0 = apply_max(ix, 0);
1720         //    int16_t ix1 = apply_min(ix + 1, IW - 1);
1721         Value x0, x1, y0, y1;
1722         getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1723         getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1724 
1725         Value y0x0 = b.create<tensor::ExtractOp>(
1726             input, ValueRange{batch, y0, x0, channel});
1727         Value y0x1 = b.create<tensor::ExtractOp>(
1728             input, ValueRange{batch, y0, x1, channel});
1729         Value y1x0 = b.create<tensor::ExtractOp>(
1730             input, ValueRange{batch, y1, x0, channel});
1731         Value y1x1 = b.create<tensor::ExtractOp>(
1732             input, ValueRange{batch, y1, x1, channel});
1733 
1734         if (floatingPointMode) {
1735           auto oneVal =
1736               b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1737           auto interpolate = [&](Value val0, Value val1, Value delta,
1738                                  int inputSize,
1739                                  ImplicitLocOpBuilder &b) -> Value {
1740             if (inputSize == 1)
1741               return val0;
1742             Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1743             Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1744             Value mul1 = b.create<arith::MulFOp>(val1, delta);
1745             return b.create<arith::AddFOp>(mul0, mul1);
1746           };
1747 
1748           // Linalg equivalent to the section below:
1749           //   topAcc = v00 * (unit_x - dx);
1750           //   topAcc += v01 * dx;
1751           Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1752 
1753           // Linalg equivalent to the section below:
1754           //   bottomAcc = v10 * (unit_x - dx);
1755           //   bottomAcc += v11 * dx;
1756           Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1757 
1758           // Linalg equivalent to the section below:
1759           //   result = topAcc * (unit_y - dy) + bottomAcc * dy
1760           Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1761           b.create<linalg::YieldOp>(result);
1762         } else {
1763           // Perform in quantized space.
1764           y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1765           y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1766           y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1767           y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1768 
1769           const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1770           if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1771             dx = b.create<arith::ExtSIOp>(resultETy, dx);
1772             dy = b.create<arith::ExtSIOp>(resultETy, dy);
1773           }
1774 
1775           Value yScaleNExt = yScaleN;
1776           Value xScaleNExt = xScaleN;
1777 
1778           const int64_t scaleBitwidth =
1779               xScaleN.getType().getIntOrFloatBitWidth();
1780           if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1781             yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1782             xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1783           }
1784 
1785           auto interpolate = [](Value val0, Value val1, Value weight1,
1786                                 Value scale, int inputSize,
1787                                 ImplicitLocOpBuilder &b) -> Value {
1788             if (inputSize == 1)
1789               return b.create<arith::MulIOp>(val0, scale);
1790             Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1791             Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1792             Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1793             return b.create<arith::AddIOp>(mul0, mul1);
1794           };
1795 
1796           Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1797           Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1798           Value result =
1799               interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1800           b.create<linalg::YieldOp>(result);
1801         }
1802       }
1803     }
1804 
1805     rewriter.replaceOp(op, resize);
1806     return success();
1807   }
1808 };
1809 
1810 // At the codegen level any identity operations should be removed. Any cases
1811 // where identity is load-bearing (e.g. cross device computation) should be
1812 // handled before lowering to codegen.
1813 template <typename SrcOp>
1814 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1815 public:
1816   using OpRewritePattern<SrcOp>::OpRewritePattern;
1817 
1818   LogicalResult matchAndRewrite(SrcOp op,
1819                                 PatternRewriter &rewriter) const final {
1820     rewriter.replaceOp(op, op.getOperation()->getOperands());
1821     return success();
1822   }
1823 };
1824 
1825 template <typename SrcOp>
1826 class ReduceConverter : public OpRewritePattern<SrcOp> {
1827 public:
1828   using OpRewritePattern<SrcOp>::OpRewritePattern;
1829 
1830   LogicalResult matchAndRewrite(SrcOp reduceOp,
1831                                 PatternRewriter &rewriter) const final {
1832     return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1833   }
1834 };
1835 
1836 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1837 public:
1838   using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
1839 
1840   LogicalResult matchAndRewrite(tosa::ReverseOp op,
1841                                 PatternRewriter &rewriter) const final {
1842     auto loc = op.getLoc();
1843     Value input = op.getInput1();
1844     auto inputTy = cast<ShapedType>(input.getType());
1845     auto resultTy = cast<ShapedType>(op.getType());
1846     auto axis = op.getAxis();
1847 
1848     SmallVector<Value> dynDims;
1849     for (int i = 0; i < inputTy.getRank(); i++) {
1850       if (inputTy.isDynamicDim(i)) {
1851         dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1852       }
1853     }
1854 
1855     Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1856 
1857     // First fill the output buffer with the init value.
1858     auto emptyTensor = rewriter
1859                            .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1860                                                     inputTy.getElementType(),
1861                                                     ArrayRef<Value>({dynDims}))
1862                            .getResult();
1863     SmallVector<AffineMap, 2> affineMaps = {
1864         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1865 
1866     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1867         op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
1868         getNParallelLoopsAttrs(resultTy.getRank()),
1869         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1870           llvm::SmallVector<Value> indices;
1871           for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1872             Value index =
1873                 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1874             if (i == axis) {
1875               auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1876               auto sizeMinusOne =
1877                   rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1878               index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1879                                                      index);
1880             }
1881 
1882             indices.push_back(index);
1883           }
1884 
1885           auto extract = nestedBuilder.create<tensor::ExtractOp>(
1886               nestedLoc, input, indices);
1887           nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1888                                                 extract.getResult());
1889         });
1890     return success();
1891   }
1892 };
1893 
1894 // This converter translate a tile operation to a reshape, broadcast, reshape.
1895 // The first reshape minimally expands each tiled dimension to include a
1896 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1897 // multiple.
1898 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1899   using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
1900 
1901   LogicalResult
1902   matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1903                   ConversionPatternRewriter &rewriter) const override {
1904     auto loc = op.getLoc();
1905     auto input = op.getInput1();
1906     auto inputTy = cast<ShapedType>(input.getType());
1907     auto inputShape = inputTy.getShape();
1908     auto resultTy = cast<ShapedType>(op.getType());
1909     auto elementTy = inputTy.getElementType();
1910     int64_t rank = inputTy.getRank();
1911 
1912     SmallVector<int64_t> multiples;
1913     if (failed(op.getConstantMultiples(multiples)))
1914       return failure();
1915 
1916     // Broadcast the newly added dimensions to their appropriate multiple.
1917     SmallVector<int64_t, 2> genericShape;
1918     for (int i = 0; i < rank; i++) {
1919       int64_t dim = multiples[i];
1920       genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1921       genericShape.push_back(inputShape[i]);
1922     }
1923 
1924     SmallVector<Value> dynDims;
1925     for (int i = 0; i < inputTy.getRank(); i++) {
1926       if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1927         dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1928       }
1929     }
1930 
1931     auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1932         op.getLoc(), genericShape, elementTy, dynDims);
1933 
1934     // We needs to map the input shape to the non-broadcasted dimensions.
1935     SmallVector<AffineExpr, 4> dimExprs;
1936     dimExprs.reserve(rank);
1937     for (unsigned i = 0; i < rank; ++i)
1938       dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1939 
1940     auto readAffineMap =
1941         AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1942                        rewriter.getContext());
1943 
1944     SmallVector<AffineMap, 2> affineMaps = {
1945         readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1946 
1947     auto genericOp = rewriter.create<linalg::GenericOp>(
1948         loc, RankedTensorType::get(genericShape, elementTy), input,
1949         ValueRange{emptyTensor}, affineMaps,
1950         getNParallelLoopsAttrs(genericShape.size()),
1951         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1952           nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1953         });
1954 
1955     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1956         op, resultTy, genericOp.getResult(0),
1957         rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
1958     return success();
1959   }
1960 };
1961 
1962 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1963 // op, producing two output buffers.
1964 //
1965 // The first output buffer contains the index of the found maximum value. It is
1966 // initialized to 0 and is resulting integer type.
1967 //
1968 // The second output buffer contains the maximum value found. It is initialized
1969 // to the minimum representable value of the input element type. After being
1970 // populated by indexed_generic, this buffer is disgarded as only the index is
1971 // requested.
1972 //
1973 // The indexed_generic op updates both the maximum value and index if the
1974 // current value exceeds the running max.
1975 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1976 public:
1977   using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
1978 
1979   LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1980                                 PatternRewriter &rewriter) const final {
1981     auto loc = argmaxOp.getLoc();
1982     Value input = argmaxOp.getInput();
1983     auto inputTy = cast<ShapedType>(input.getType());
1984     auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1985     auto inElementTy = inputTy.getElementType();
1986     auto outElementTy = resultTy.getElementType();
1987     int axis = argmaxOp.getAxis();
1988     auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1989 
1990     if (!isa<IntegerType>(outElementTy))
1991       return rewriter.notifyMatchFailure(
1992           argmaxOp,
1993           "tosa.arg_max to linalg.* requires integer-like result type");
1994 
1995     SmallVector<Value> dynDims;
1996     for (int i = 0; i < inputTy.getRank(); i++) {
1997       if (inputTy.isDynamicDim(i) && i != axis) {
1998         dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1999       }
2000     }
2001 
2002     // First fill the output buffer for the index.
2003     auto emptyTensorIdx = rewriter
2004                               .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2005                                                        outElementTy, dynDims)
2006                               .getResult();
2007     auto fillValueIdx = rewriter.create<arith::ConstantOp>(
2008         loc, rewriter.getIntegerAttr(outElementTy, 0));
2009     auto filledTensorIdx =
2010         rewriter
2011             .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
2012                                     ValueRange{emptyTensorIdx})
2013             .result();
2014 
2015     // Second fill the output buffer for the running max.
2016     auto emptyTensorMax = rewriter
2017                               .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2018                                                        inElementTy, dynDims)
2019                               .getResult();
2020     auto fillValueMaxAttr =
2021         createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2022 
2023     if (!fillValueMaxAttr)
2024       return rewriter.notifyMatchFailure(
2025           argmaxOp, "unsupported tosa.argmax element type");
2026 
2027     auto fillValueMax =
2028         rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2029     auto filledTensorMax =
2030         rewriter
2031             .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
2032                                     ValueRange{emptyTensorMax})
2033             .result();
2034 
2035     // We need to reduce along the arg-max axis, with parallel operations along
2036     // the rest.
2037     SmallVector<utils::IteratorType, 4> iteratorTypes;
2038     iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2039     iteratorTypes[axis] = utils::IteratorType::reduction;
2040 
2041     SmallVector<AffineExpr, 2> srcExprs;
2042     SmallVector<AffineExpr, 2> dstExprs;
2043     for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2044       srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2045       if (axis != i)
2046         dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2047     }
2048 
2049     bool didEncounterError = false;
2050     auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2051                                              rewriter.getContext());
2052     auto linalgOp = rewriter.create<linalg::GenericOp>(
2053         loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2054         ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2055         [&](OpBuilder &nestedBuilder, Location nestedLoc,
2056             ValueRange blockArgs) {
2057           auto newValue = blockArgs[0];
2058           auto oldIndex = blockArgs[1];
2059           auto oldValue = blockArgs[2];
2060 
2061           Value newIndex = rewriter.create<arith::IndexCastOp>(
2062               nestedLoc, oldIndex.getType(),
2063               rewriter.create<linalg::IndexOp>(loc, axis));
2064 
2065           Value predicate;
2066           if (isa<FloatType>(inElementTy)) {
2067             predicate = rewriter.create<arith::CmpFOp>(
2068                 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2069           } else if (isa<IntegerType>(inElementTy)) {
2070             predicate = rewriter.create<arith::CmpIOp>(
2071                 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2072           } else {
2073             didEncounterError = true;
2074             return;
2075           }
2076 
2077           auto resultMax = rewriter.create<arith::SelectOp>(
2078               nestedLoc, predicate, newValue, oldValue);
2079           auto resultIndex = rewriter.create<arith::SelectOp>(
2080               nestedLoc, predicate, newIndex, oldIndex);
2081           nestedBuilder.create<linalg::YieldOp>(
2082               nestedLoc, ValueRange({resultIndex, resultMax}));
2083         });
2084 
2085     if (didEncounterError)
2086       return rewriter.notifyMatchFailure(
2087           argmaxOp, "unsupported tosa.argmax element type");
2088 
2089     rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2090     return success();
2091   }
2092 };
2093 
2094 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2095 public:
2096   using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2097   LogicalResult
2098   matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2099                   ConversionPatternRewriter &rewriter) const final {
2100     auto input = adaptor.getOperands()[0];
2101     auto indices = adaptor.getOperands()[1];
2102 
2103     auto valuesTy =
2104         dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2105     auto resultTy = cast<ShapedType>(op.getType());
2106 
2107     if (!valuesTy)
2108       return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2109 
2110     auto dynamicDims = inferDynamicDimsForGather(
2111         rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2112 
2113     auto resultElementTy = resultTy.getElementType();
2114 
2115     auto loc = op.getLoc();
2116     auto emptyTensor =
2117         rewriter
2118             .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2119                                      dynamicDims)
2120             .getResult();
2121 
2122     SmallVector<AffineMap, 2> affineMaps = {
2123         AffineMap::get(
2124             /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2125             {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2126             rewriter.getContext()),
2127         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2128 
2129     auto genericOp = rewriter.create<linalg::GenericOp>(
2130         loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2131         ValueRange{emptyTensor}, affineMaps,
2132         getNParallelLoopsAttrs(resultTy.getRank()),
2133         [&](OpBuilder &b, Location loc, ValueRange args) {
2134           auto indexValue = args[0];
2135           auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2136           Value index1 = rewriter.create<arith::IndexCastOp>(
2137               loc, rewriter.getIndexType(), indexValue);
2138           auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2139           Value extract = rewriter.create<tensor::ExtractOp>(
2140               loc, input, ValueRange{index0, index1, index2});
2141           rewriter.create<linalg::YieldOp>(loc, extract);
2142         });
2143     rewriter.replaceOp(op, genericOp.getResult(0));
2144     return success();
2145   }
2146 
2147   static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2148                                                             Location loc,
2149                                                             Value values,
2150                                                             Value indices) {
2151     llvm::SmallVector<Value> results;
2152 
2153     auto addDynamicDimension = [&](Value source, int64_t dim) {
2154       auto sz = tensor::getMixedSize(builder, loc, source, dim);
2155       if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2156         results.push_back(dimValue);
2157     };
2158 
2159     addDynamicDimension(values, 0);
2160     addDynamicDimension(indices, 1);
2161     addDynamicDimension(values, 2);
2162     return results;
2163   }
2164 };
2165 
2166 // Lowerings the TableOp to a series of gathers and numerica operations. This
2167 // includes interpolation between the high/low values. For the I8 varient, this
2168 // simplifies to a single gather operation.
2169 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2170 public:
2171   using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2172 
2173   LogicalResult matchAndRewrite(tosa::TableOp op,
2174                                 PatternRewriter &rewriter) const final {
2175     auto loc = op.getLoc();
2176     Value input = op.getInput1();
2177     Value table = op.getTable();
2178     auto inputTy = cast<ShapedType>(input.getType());
2179     auto tableTy = cast<ShapedType>(table.getType());
2180     auto resultTy = cast<ShapedType>(op.getType());
2181 
2182     auto inputElementTy = inputTy.getElementType();
2183     auto tableElementTy = tableTy.getElementType();
2184     auto resultElementTy = resultTy.getElementType();
2185 
2186     SmallVector<Value> dynDims;
2187     for (int i = 0; i < resultTy.getRank(); ++i) {
2188       if (inputTy.isDynamicDim(i)) {
2189         dynDims.push_back(
2190             rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2191       }
2192     }
2193 
2194     auto emptyTensor = rewriter
2195                            .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2196                                                     resultElementTy, dynDims)
2197                            .getResult();
2198 
2199     SmallVector<AffineMap, 2> affineMaps = {
2200         rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2201         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2202 
2203     auto genericOp = rewriter.create<linalg::GenericOp>(
2204         loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2205         getNParallelLoopsAttrs(resultTy.getRank()));
2206     rewriter.replaceOp(op, genericOp.getResult(0));
2207 
2208     {
2209       OpBuilder::InsertionGuard regionGuard(rewriter);
2210       Block *block = rewriter.createBlock(
2211           &genericOp.getRegion(), genericOp.getRegion().end(),
2212           TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2213 
2214       auto inputValue = block->getArgument(0);
2215       rewriter.setInsertionPointToStart(block);
2216       if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2217           resultElementTy.isInteger(8)) {
2218         Value index = rewriter.create<arith::IndexCastOp>(
2219             loc, rewriter.getIndexType(), inputValue);
2220         Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2221         index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2222                                                index, offset);
2223         Value extract =
2224             rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2225         rewriter.create<linalg::YieldOp>(loc, extract);
2226         return success();
2227       }
2228 
2229       if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2230           resultElementTy.isInteger(32)) {
2231         Value extend = rewriter.create<arith::ExtSIOp>(
2232             loc, rewriter.getI32Type(), inputValue);
2233 
2234         auto offset = rewriter.create<arith::ConstantOp>(
2235             loc, rewriter.getI32IntegerAttr(32768));
2236         auto seven = rewriter.create<arith::ConstantOp>(
2237             loc, rewriter.getI32IntegerAttr(7));
2238         auto one = rewriter.create<arith::ConstantOp>(
2239             loc, rewriter.getI32IntegerAttr(1));
2240         auto b1111111 = rewriter.create<arith::ConstantOp>(
2241             loc, rewriter.getI32IntegerAttr(127));
2242 
2243         // Compute the index and fractional part from the input value:
2244         // value = value + 32768
2245         // index = value >> 7;
2246         // fraction = 0x01111111 & value
2247         auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2248         Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2249         Value fraction =
2250             rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2251 
2252         // Extract the base and next values from the table.
2253         // base = (int32_t) table[index];
2254         // next = (int32_t) table[index + 1];
2255         Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2256 
2257         index = rewriter.create<arith::IndexCastOp>(
2258             loc, rewriter.getIndexType(), index);
2259         indexPlusOne = rewriter.create<arith::IndexCastOp>(
2260             loc, rewriter.getIndexType(), indexPlusOne);
2261 
2262         Value base =
2263             rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2264         Value next = rewriter.create<tensor::ExtractOp>(
2265             loc, table, ValueRange{indexPlusOne});
2266 
2267         base =
2268             rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2269         next =
2270             rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2271 
2272         // Use the fractional part to interpolate between the input values:
2273         // result = (base << 7) + (next - base) * fraction
2274         Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2275         Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2276         Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2277         Value result =
2278             rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2279 
2280         rewriter.create<linalg::YieldOp>(loc, result);
2281 
2282         return success();
2283       }
2284     }
2285 
2286     return rewriter.notifyMatchFailure(
2287         op, "unable to create body for tosa.table op");
2288   }
2289 };
2290 
2291 struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2292   using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2293 
2294   static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2295 
2296   static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2297                                   OpFoldResult ofr) {
2298     auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
2299     auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
2300 
2301     auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2302     auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2303     auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2304     return getAsOpFoldResult(plusOne);
2305   }
2306 
2307   static RankedTensorType
2308   computeOutputShape(OpBuilder &builder, Location loc, Value input,
2309                      llvm::SmallVectorImpl<Value> &dynamicSizes) {
2310     // Get [N, H, W]
2311     auto dims = tensor::getMixedSizes(builder, loc, input);
2312 
2313     // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2314     // output tensors.
2315     dims[2] = halfPlusOne(builder, loc, dims[2]);
2316 
2317     llvm::SmallVector<int64_t, 3> staticSizes;
2318     dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2319 
2320     auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2321     return RankedTensorType::get(staticSizes, elementType);
2322   }
2323 
2324   static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2325                                 RankedTensorType type,
2326                                 llvm::ArrayRef<Value> dynamicSizes) {
2327     auto emptyTensor =
2328         rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2329     auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2330     auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2331     auto filledTensor = rewriter
2332                             .create<linalg::FillOp>(loc, ValueRange{fillValue},
2333                                                     ValueRange{emptyTensor})
2334                             .result();
2335     return filledTensor;
2336   }
2337 
2338   static Value castIndexToFloat(OpBuilder &builder, Location loc,
2339                                 FloatType type, Value value) {
2340     auto integerVal = builder.create<arith::IndexCastUIOp>(
2341         loc,
2342         type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2343                                           : builder.getI32Type(),
2344         value);
2345 
2346     return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2347   }
2348 
2349   static Value createLinalgIndex(OpBuilder &builder, Location loc,
2350                                  FloatType type, int64_t index) {
2351     auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2352     return castIndexToFloat(builder, loc, type, indexVal);
2353   }
2354 
2355   template <typename... Args>
2356   static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2357                                                          Args... args) {
2358     return {builder.getAffineDimExpr(args)...};
2359   }
2360 
2361   LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2362                                 PatternRewriter &rewriter) const override {
2363     if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2364         !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2365       return rewriter.notifyMatchFailure(rfft2d,
2366                                          "only supports ranked tensors");
2367     }
2368 
2369     auto loc = rfft2d.getLoc();
2370     auto input = rfft2d.getInput();
2371     auto elementType =
2372         dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2373     if (!elementType)
2374       return rewriter.notifyMatchFailure(rfft2d,
2375                                          "only supports float element types");
2376 
2377     // Compute the output type and set of dynamic sizes
2378     llvm::SmallVector<Value> dynamicSizes;
2379     auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2380 
2381     // Iterator types for the linalg.generic implementation
2382     llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2383         utils::IteratorType::parallel, utils::IteratorType::parallel,
2384         utils::IteratorType::parallel, utils::IteratorType::reduction,
2385         utils::IteratorType::reduction};
2386 
2387     // Inputs/outputs to the linalg.generic implementation
2388     llvm::SmallVector<Value> genericOpInputs = {input};
2389     llvm::SmallVector<Value> genericOpOutputs = {
2390         createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2391         createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2392 
2393     // Indexing maps for input and output tensors
2394     auto indexingMaps = AffineMap::inferFromExprList(
2395         llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2396                        affineDimsExpr(rewriter, 0, 1, 2),
2397                        affineDimsExpr(rewriter, 0, 1, 2)},
2398         rewriter.getContext());
2399 
2400     // Width and height dimensions of the original input.
2401     auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2402     auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2403 
2404     // Constants and dimension sizes
2405     auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2406     auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2407     auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2408     auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2409 
2410     auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2411       Value valReal = args[0];
2412       Value sumReal = args[1];
2413       Value sumImag = args[2];
2414 
2415       // Indices for angle computation
2416       Value oy = builder.create<linalg::IndexOp>(loc, 1);
2417       Value ox = builder.create<linalg::IndexOp>(loc, 2);
2418       Value iy = builder.create<linalg::IndexOp>(loc, 3);
2419       Value ix = builder.create<linalg::IndexOp>(loc, 4);
2420 
2421       // Calculating angle without integer parts of components as sin/cos are
2422       // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2423       // / W);
2424       auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2425       auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2426 
2427       auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2428       auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2429 
2430       auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2431       auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2432 
2433       auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2434       auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2435       auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2436       auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2437 
2438       // realComponent = valReal * cos(angle)
2439       // imagComponent = valReal * sin(angle)
2440       auto cosAngle = builder.create<math::CosOp>(loc, angle);
2441       auto sinAngle = builder.create<math::SinOp>(loc, angle);
2442       auto realComponent =
2443           builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2444       auto imagComponent =
2445           builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2446 
2447       // outReal = sumReal + realComponent
2448       // outImag = sumImag - imagComponent
2449       auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2450       auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2451 
2452       builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2453     };
2454 
2455     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2456         rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2457         indexingMaps, iteratorTypes, buildBody);
2458 
2459     return success();
2460   }
2461 };
2462 
2463 struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2464   using OpRewritePattern::OpRewritePattern;
2465 
2466   LogicalResult matchAndRewrite(FFT2dOp fft2d,
2467                                 PatternRewriter &rewriter) const override {
2468     if (!llvm::all_of(fft2d->getOperandTypes(),
2469                       RFFT2dConverter::isRankedTensor) ||
2470         !llvm::all_of(fft2d->getResultTypes(),
2471                       RFFT2dConverter::isRankedTensor)) {
2472       return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2473     }
2474 
2475     Location loc = fft2d.getLoc();
2476     Value input_real = fft2d.getInputReal();
2477     Value input_imag = fft2d.getInputImag();
2478     BoolAttr inverse = fft2d.getInverseAttr();
2479 
2480     auto real_el_ty = cast<FloatType>(
2481         cast<ShapedType>(input_real.getType()).getElementType());
2482     [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2483         cast<ShapedType>(input_imag.getType()).getElementType());
2484 
2485     assert(real_el_ty == imag_el_ty);
2486 
2487     // Compute the output type and set of dynamic sizes
2488     SmallVector<Value> dynamicSizes;
2489 
2490     // Get [N, H, W]
2491     auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2492 
2493     SmallVector<int64_t, 3> staticSizes;
2494     dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2495 
2496     auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2497 
2498     // Iterator types for the linalg.generic implementation
2499     SmallVector<utils::IteratorType, 5> iteratorTypes = {
2500         utils::IteratorType::parallel, utils::IteratorType::parallel,
2501         utils::IteratorType::parallel, utils::IteratorType::reduction,
2502         utils::IteratorType::reduction};
2503 
2504     // Inputs/outputs to the linalg.generic implementation
2505     SmallVector<Value> genericOpInputs = {input_real, input_imag};
2506     SmallVector<Value> genericOpOutputs = {
2507         RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2508                                           dynamicSizes),
2509         RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2510                                           dynamicSizes)};
2511 
2512     // Indexing maps for input and output tensors
2513     auto indexingMaps = AffineMap::inferFromExprList(
2514         ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2515                  RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2516                  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2517                  RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2518         rewriter.getContext());
2519 
2520     // Width and height dimensions of the original input.
2521     auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2522     auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2523 
2524     // Constants and dimension sizes
2525     auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2526     auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2527     Value constH =
2528         RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2529     Value constW =
2530         RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2531 
2532     auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2533       Value valReal = args[0];
2534       Value valImag = args[1];
2535       Value sumReal = args[2];
2536       Value sumImag = args[3];
2537 
2538       // Indices for angle computation
2539       Value oy = builder.create<linalg::IndexOp>(loc, 1);
2540       Value ox = builder.create<linalg::IndexOp>(loc, 2);
2541       Value iy = builder.create<linalg::IndexOp>(loc, 3);
2542       Value ix = builder.create<linalg::IndexOp>(loc, 4);
2543 
2544       // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2545       // ox) % W ) / W);
2546       auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2547       auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2548 
2549       auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2550       auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2551 
2552       auto iyRemFloat =
2553           RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2554       auto ixRemFloat =
2555           RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2556 
2557       auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2558       auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2559 
2560       auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2561       auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2562 
2563       if (inverse.getValue()) {
2564         angle = builder.create<arith::MulFOp>(
2565             loc, angle,
2566             rewriter.create<arith::ConstantOp>(
2567                 loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2568       }
2569 
2570       // realComponent = val_real * cos(a) + val_imag * sin(a);
2571       // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2572       auto cosAngle = builder.create<math::CosOp>(loc, angle);
2573       auto sinAngle = builder.create<math::SinOp>(loc, angle);
2574 
2575       auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2576       auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2577       auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2578 
2579       auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2580       auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2581 
2582       auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2583 
2584       // outReal = sumReal + realComponent
2585       // outImag = sumImag - imagComponent
2586       auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2587       auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2588 
2589       builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2590     };
2591 
2592     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2593         fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2594         indexingMaps, iteratorTypes, buildBody);
2595 
2596     return success();
2597   }
2598 };
2599 
2600 } // namespace
2601 
2602 void mlir::tosa::populateTosaToLinalgConversionPatterns(
2603     const TypeConverter &converter, RewritePatternSet *patterns) {
2604 
2605   // We have multiple resize coverters to handle degenerate cases.
2606   patterns->add<GenericResizeConverter>(patterns->getContext(),
2607                                         /*benefit=*/100);
2608   patterns->add<ResizeUnaryConverter>(patterns->getContext(),
2609                                       /*benefit=*/200);
2610   patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
2611                                             /*benefit=*/300);
2612 
2613   patterns->add<
2614       // clang-format off
2615       PointwiseConverter<tosa::AddOp>,
2616       PointwiseConverter<tosa::SubOp>,
2617       PointwiseConverter<tosa::MulOp>,
2618       PointwiseConverter<tosa::IntDivOp>,
2619       PointwiseConverter<tosa::NegateOp>,
2620       PointwiseConverter<tosa::PowOp>,
2621       PointwiseConverter<tosa::ReciprocalOp>,
2622       PointwiseConverter<tosa::RsqrtOp>,
2623       PointwiseConverter<tosa::LogOp>,
2624       PointwiseConverter<tosa::ExpOp>,
2625       PointwiseConverter<tosa::AbsOp>,
2626       PointwiseConverter<tosa::SinOp>,
2627       PointwiseConverter<tosa::CosOp>,
2628       PointwiseConverter<tosa::TanhOp>,
2629       PointwiseConverter<tosa::ErfOp>,
2630       PointwiseConverter<tosa::BitwiseAndOp>,
2631       PointwiseConverter<tosa::BitwiseOrOp>,
2632       PointwiseConverter<tosa::BitwiseNotOp>,
2633       PointwiseConverter<tosa::BitwiseXorOp>,
2634       PointwiseConverter<tosa::LogicalAndOp>,
2635       PointwiseConverter<tosa::LogicalNotOp>,
2636       PointwiseConverter<tosa::LogicalOrOp>,
2637       PointwiseConverter<tosa::LogicalXorOp>,
2638       PointwiseConverter<tosa::CastOp>,
2639       PointwiseConverter<tosa::LogicalLeftShiftOp>,
2640       PointwiseConverter<tosa::LogicalRightShiftOp>,
2641       PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2642       PointwiseConverter<tosa::ClzOp>,
2643       PointwiseConverter<tosa::SelectOp>,
2644       PointwiseConverter<tosa::GreaterOp>,
2645       PointwiseConverter<tosa::GreaterEqualOp>,
2646       PointwiseConverter<tosa::EqualOp>,
2647       PointwiseConverter<tosa::MaximumOp>,
2648       PointwiseConverter<tosa::MinimumOp>,
2649       PointwiseConverter<tosa::CeilOp>,
2650       PointwiseConverter<tosa::FloorOp>,
2651       PointwiseConverter<tosa::ClampOp>,
2652       PointwiseConverter<tosa::SigmoidOp>
2653         >(converter, patterns->getContext());
2654 
2655   patterns->add<
2656       IdentityNConverter<tosa::IdentityOp>,
2657       ReduceConverter<tosa::ReduceAllOp>,
2658       ReduceConverter<tosa::ReduceAnyOp>,
2659       ReduceConverter<tosa::ReduceMinOp>,
2660       ReduceConverter<tosa::ReduceMaxOp>,
2661       ReduceConverter<tosa::ReduceSumOp>,
2662       ReduceConverter<tosa::ReduceProdOp>,
2663       ArgMaxConverter,
2664       GatherConverter,
2665       RescaleConverter,
2666       ReverseConverter,
2667       RFFT2dConverter,
2668       FFT2dConverter,
2669       TableConverter,
2670       TileConverter>(patterns->getContext());
2671   // clang-format on
2672 }
2673