xref: /llvm-project/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (revision 935810c4de56235ddfaffbe02b2ae3cd0b72d2e0)
1 //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
13 #include "mlir/Dialect/Arith/Utils/Utils.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/APInt.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/MathExtras.h"
23 #include <cassert>
24 
25 namespace mlir::arith {
26 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT
27 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
28 } // namespace mlir::arith
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Common Helper Functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
37 /// Treats `value` as a 2*N bits-wide integer.
38 /// The bottom bits are returned in the first pair element, while the top bits
39 /// in the second one.
40 static std::pair<APInt, APInt> getHalves(const APInt &value,
41                                          unsigned newBitWidth) {
42   APInt low = value.extractBits(newBitWidth, 0);
43   APInt high = value.extractBits(newBitWidth, newBitWidth);
44   return {std::move(low), std::move(high)};
45 }
46 
47 /// Returns the type with the last (innermost) dimension reduced to x1.
48 /// Scalarizes 1D vector inputs to match how we extract/insert vector values,
49 /// e.g.:
50 ///   - vector<3x2xi16> --> vector<3x1xi16>
51 ///   - vector<2xi16>   --> i16
52 static Type reduceInnermostDim(VectorType type) {
53   if (type.getShape().size() == 1)
54     return type.getElementType();
55 
56   auto newShape = to_vector(type.getShape());
57   newShape.back() = 1;
58   return VectorType::get(newShape, type.getElementType());
59 }
60 
61 /// Extracts the `input` vector slice with elements at the last dimension offset
62 /// by `lastOffset`. Returns a value of vector type with the last dimension
63 /// reduced to x1 or fully scalarized, e.g.:
64 ///   - vector<3x2xi16> --> vector<3x1xi16>
65 ///   - vector<2xi16>   --> i16
66 static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
67                                  Location loc, Value input,
68                                  int64_t lastOffset) {
69   ArrayRef<int64_t> shape = cast<VectorType>(input.getType()).getShape();
70   assert(lastOffset < shape.back() && "Offset out of bounds");
71 
72   // Scalarize the result in case of 1D vectors.
73   if (shape.size() == 1)
74     return rewriter.create<vector::ExtractOp>(loc, input, lastOffset);
75 
76   SmallVector<int64_t> offsets(shape.size(), 0);
77   offsets.back() = lastOffset;
78   auto sizes = llvm::to_vector(shape);
79   sizes.back() = 1;
80   SmallVector<int64_t> strides(shape.size(), 1);
81 
82   return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets,
83                                                         sizes, strides);
84 }
85 
86 /// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
87 /// with the first element at offset 0 and the second element at offset 1.
88 static std::pair<Value, Value>
89 extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc,
90                      Value input) {
91   return {extractLastDimSlice(rewriter, loc, input, 0),
92           extractLastDimSlice(rewriter, loc, input, 1)};
93 }
94 
95 // Performs a vector shape cast to drop the trailing x1 dimension. If the
96 // `input` is a scalar, this is a noop.
97 static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
98                                Location loc, Value input) {
99   auto vecTy = dyn_cast<VectorType>(input.getType());
100   if (!vecTy)
101     return input;
102 
103   // Shape cast to drop the last x1 dimension.
104   ArrayRef<int64_t> shape = vecTy.getShape();
105   assert(shape.size() >= 2 && "Expected vector with at list two dims");
106   assert(shape.back() == 1 && "Expected the last vector dim to be x1");
107 
108   auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
109   return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
110 }
111 
112 /// Performs a vector shape cast to append an x1 dimension. If the
113 /// `input` is a scalar, this is a noop.
114 static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
115                          Value input) {
116   auto vecTy = dyn_cast<VectorType>(input.getType());
117   if (!vecTy)
118     return input;
119 
120   // Add a trailing x1 dim.
121   auto newShape = llvm::to_vector(vecTy.getShape());
122   newShape.push_back(1);
123   auto newTy = VectorType::get(newShape, vecTy.getElementType());
124   return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
125 }
126 
127 /// Inserts the `source` vector slice into the `dest` vector at offset
128 /// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is
129 /// a 1D vector.
130 static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
131                                 Location loc, Value source, Value dest,
132                                 int64_t lastOffset) {
133   ArrayRef<int64_t> shape = cast<VectorType>(dest.getType()).getShape();
134   assert(lastOffset < shape.back() && "Offset out of bounds");
135 
136   // Handle scalar source.
137   if (isa<IntegerType>(source.getType()))
138     return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
139 
140   SmallVector<int64_t> offsets(shape.size(), 0);
141   offsets.back() = lastOffset;
142   SmallVector<int64_t> strides(shape.size(), 1);
143   return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest,
144                                                        offsets, strides);
145 }
146 
147 /// Constructs a new vector of type `resultType` by creating a series of
148 /// insertions of `resultComponents`, each at the next offset of the last vector
149 /// dimension.
150 /// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
151 /// when `resultComponents` are `vector<...x1xT>`s, the result type is
152 /// `vector<...xNxT>`, where `N` is the number of `resultComponents`.
153 static Value constructResultVector(ConversionPatternRewriter &rewriter,
154                                    Location loc, VectorType resultType,
155                                    ValueRange resultComponents) {
156   llvm::ArrayRef<int64_t> resultShape = resultType.getShape();
157   (void)resultShape;
158   assert(!resultShape.empty() && "Result expected to have dimensions");
159   assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
160          "Wrong number of result components");
161 
162   Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0);
163   for (auto [i, component] : llvm::enumerate(resultComponents))
164     resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);
165 
166   return resultVec;
167 }
168 
169 namespace {
170 //===----------------------------------------------------------------------===//
171 // ConvertConstant
172 //===----------------------------------------------------------------------===//
173 
174 struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
175   using OpConversionPattern::OpConversionPattern;
176 
177   LogicalResult
178   matchAndRewrite(arith::ConstantOp op, OpAdaptor,
179                   ConversionPatternRewriter &rewriter) const override {
180     Type oldType = op.getType();
181     auto newType = getTypeConverter()->convertType<VectorType>(oldType);
182     if (!newType)
183       return rewriter.notifyMatchFailure(
184           op, llvm::formatv("unsupported type: {0}", op.getType()));
185 
186     unsigned newBitWidth = newType.getElementTypeBitWidth();
187     Attribute oldValue = op.getValueAttr();
188 
189     if (auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
190       auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
191       auto newAttr = DenseElementsAttr::get(newType, {low, high});
192       rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
193       return success();
194     }
195 
196     if (auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
197       auto [low, high] =
198           getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
199       int64_t numSplatElems = splatAttr.getNumElements();
200       SmallVector<APInt> values;
201       values.reserve(numSplatElems * 2);
202       for (int64_t i = 0; i < numSplatElems; ++i) {
203         values.push_back(low);
204         values.push_back(high);
205       }
206 
207       auto attr = DenseElementsAttr::get(newType, values);
208       rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
209       return success();
210     }
211 
212     if (auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
213       int64_t numElems = elemsAttr.getNumElements();
214       SmallVector<APInt> values;
215       values.reserve(numElems * 2);
216       for (const APInt &origVal : elemsAttr.getValues<APInt>()) {
217         auto [low, high] = getHalves(origVal, newBitWidth);
218         values.push_back(std::move(low));
219         values.push_back(std::move(high));
220       }
221 
222       auto attr = DenseElementsAttr::get(newType, values);
223       rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
224       return success();
225     }
226 
227     return rewriter.notifyMatchFailure(op.getLoc(),
228                                        "unhandled constant attribute");
229   }
230 };
231 
232 //===----------------------------------------------------------------------===//
233 // ConvertAddI
234 //===----------------------------------------------------------------------===//
235 
236 struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
237   using OpConversionPattern::OpConversionPattern;
238 
239   LogicalResult
240   matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
241                   ConversionPatternRewriter &rewriter) const override {
242     Location loc = op->getLoc();
243     auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
244     if (!newTy)
245       return rewriter.notifyMatchFailure(
246           loc, llvm::formatv("unsupported type: {0}", op.getType()));
247 
248     Type newElemTy = reduceInnermostDim(newTy);
249 
250     auto [lhsElem0, lhsElem1] =
251         extractLastDimHalves(rewriter, loc, adaptor.getLhs());
252     auto [rhsElem0, rhsElem1] =
253         extractLastDimHalves(rewriter, loc, adaptor.getRhs());
254 
255     auto lowSum =
256         rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
257     Value overflowVal =
258         rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
259 
260     Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1);
261     Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
262 
263     Value resultVec =
264         constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
265     rewriter.replaceOp(op, resultVec);
266     return success();
267   }
268 };
269 
270 //===----------------------------------------------------------------------===//
271 // ConvertBitwiseBinary
272 //===----------------------------------------------------------------------===//
273 
274 /// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`.
275 template <typename BinaryOp>
276 struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
277   using OpConversionPattern<BinaryOp>::OpConversionPattern;
278   using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
279 
280   LogicalResult
281   matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
282                   ConversionPatternRewriter &rewriter) const override {
283     Location loc = op->getLoc();
284     auto newTy = this->getTypeConverter()->template convertType<VectorType>(
285         op.getType());
286     if (!newTy)
287       return rewriter.notifyMatchFailure(
288           loc, llvm::formatv("unsupported type: {0}", op.getType()));
289 
290     auto [lhsElem0, lhsElem1] =
291         extractLastDimHalves(rewriter, loc, adaptor.getLhs());
292     auto [rhsElem0, rhsElem1] =
293         extractLastDimHalves(rewriter, loc, adaptor.getRhs());
294 
295     Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0);
296     Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1);
297     Value resultVec =
298         constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
299     rewriter.replaceOp(op, resultVec);
300     return success();
301   }
302 };
303 
304 //===----------------------------------------------------------------------===//
305 // ConvertCmpI
306 //===----------------------------------------------------------------------===//
307 
308 /// Returns the matching unsigned version of the given predicate `pred`, or the
309 /// same predicate if `pred` is not a signed.
310 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
311   using P = arith::CmpIPredicate;
312   switch (pred) {
313   case P::sge:
314     return P::uge;
315   case P::sgt:
316     return P::ugt;
317   case P::sle:
318     return P::ule;
319   case P::slt:
320     return P::ult;
321   default:
322     return pred;
323   }
324 }
325 
326 struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
327   using OpConversionPattern::OpConversionPattern;
328 
329   LogicalResult
330   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
331                   ConversionPatternRewriter &rewriter) const override {
332     Location loc = op->getLoc();
333     auto inputTy =
334         getTypeConverter()->convertType<VectorType>(op.getLhs().getType());
335     if (!inputTy)
336       return rewriter.notifyMatchFailure(
337           loc, llvm::formatv("unsupported type: {0}", op.getType()));
338 
339     arith::CmpIPredicate highPred = adaptor.getPredicate();
340     arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
341 
342     auto [lhsElem0, lhsElem1] =
343         extractLastDimHalves(rewriter, loc, adaptor.getLhs());
344     auto [rhsElem0, rhsElem1] =
345         extractLastDimHalves(rewriter, loc, adaptor.getRhs());
346 
347     Value lowCmp =
348         rewriter.create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0);
349     Value highCmp =
350         rewriter.create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1);
351 
352     Value cmpResult{};
353     switch (highPred) {
354     case arith::CmpIPredicate::eq: {
355       cmpResult = rewriter.create<arith::AndIOp>(loc, lowCmp, highCmp);
356       break;
357     }
358     case arith::CmpIPredicate::ne: {
359       cmpResult = rewriter.create<arith::OrIOp>(loc, lowCmp, highCmp);
360       break;
361     }
362     default: {
363       // Handle inequality checks.
364       Value highEq = rewriter.create<arith::CmpIOp>(
365           loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
366       cmpResult =
367           rewriter.create<arith::SelectOp>(loc, highEq, lowCmp, highCmp);
368       break;
369     }
370     }
371 
372     assert(cmpResult && "Unhandled case");
373     rewriter.replaceOp(op, dropTrailingX1Dim(rewriter, loc, cmpResult));
374     return success();
375   }
376 };
377 
378 //===----------------------------------------------------------------------===//
379 // ConvertMulI
380 //===----------------------------------------------------------------------===//
381 
382 struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
383   using OpConversionPattern::OpConversionPattern;
384 
385   LogicalResult
386   matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
387                   ConversionPatternRewriter &rewriter) const override {
388     Location loc = op->getLoc();
389     auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
390     if (!newTy)
391       return rewriter.notifyMatchFailure(
392           loc, llvm::formatv("unsupported type: {0}", op.getType()));
393 
394     auto [lhsElem0, lhsElem1] =
395         extractLastDimHalves(rewriter, loc, adaptor.getLhs());
396     auto [rhsElem0, rhsElem1] =
397         extractLastDimHalves(rewriter, loc, adaptor.getRhs());
398 
399     // The multiplication algorithm used is the standard (long) multiplication.
400     // Multiplying two i2N integers produces (at most) an i4N result, but
401     // because the calculation of top i2N is not necessary, we omit it.
402     auto mulLowLow =
403         rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
404     Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
405     Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
406 
407     Value resLow = mulLowLow.getLow();
408     Value resHi =
409         rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
410     resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
411 
412     Value resultVec =
413         constructResultVector(rewriter, loc, newTy, {resLow, resHi});
414     rewriter.replaceOp(op, resultVec);
415     return success();
416   }
417 };
418 
419 //===----------------------------------------------------------------------===//
420 // ConvertExtSI
421 //===----------------------------------------------------------------------===//
422 
423 struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
424   using OpConversionPattern::OpConversionPattern;
425 
426   LogicalResult
427   matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
428                   ConversionPatternRewriter &rewriter) const override {
429     Location loc = op->getLoc();
430     auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
431     if (!newTy)
432       return rewriter.notifyMatchFailure(
433           loc, llvm::formatv("unsupported type: {0}", op.getType()));
434 
435     Type newResultComponentTy = reduceInnermostDim(newTy);
436 
437     // Sign-extend the input value to determine the low half of the result.
438     // Then, check if the low half is negative, and sign-extend the comparison
439     // result to get the high half.
440     Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
441     Value extended = rewriter.createOrFold<arith::ExtSIOp>(
442         loc, newResultComponentTy, newOperand);
443     Value operandZeroCst =
444         createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0);
445     Value signBit = rewriter.create<arith::CmpIOp>(
446         loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
447     Value signValue =
448         rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
449 
450     Value resultVec =
451         constructResultVector(rewriter, loc, newTy, {extended, signValue});
452     rewriter.replaceOp(op, resultVec);
453     return success();
454   }
455 };
456 
457 //===----------------------------------------------------------------------===//
458 // ConvertExtUI
459 //===----------------------------------------------------------------------===//
460 
461 struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
462   using OpConversionPattern::OpConversionPattern;
463 
464   LogicalResult
465   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
466                   ConversionPatternRewriter &rewriter) const override {
467     Location loc = op->getLoc();
468     auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
469     if (!newTy)
470       return rewriter.notifyMatchFailure(
471           loc, llvm::formatv("unsupported type: {0}", op.getType()));
472 
473     Type newResultComponentTy = reduceInnermostDim(newTy);
474 
475     // Zero-extend the input value to determine the low half of the result.
476     // The high half is always zero.
477     Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
478     Value extended = rewriter.createOrFold<arith::ExtUIOp>(
479         loc, newResultComponentTy, newOperand);
480     Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0);
481     Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0);
482     rewriter.replaceOp(op, newRes);
483     return success();
484   }
485 };
486 
487 //===----------------------------------------------------------------------===//
488 // ConvertMaxMin
489 //===----------------------------------------------------------------------===//
490 
491 template <typename SourceOp, arith::CmpIPredicate CmpPred>
492 struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
493   using OpConversionPattern<SourceOp>::OpConversionPattern;
494 
495   LogicalResult
496   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
497                   ConversionPatternRewriter &rewriter) const override {
498     Location loc = op->getLoc();
499 
500     Type oldTy = op.getType();
501     auto newTy = dyn_cast_or_null<VectorType>(
502         this->getTypeConverter()->convertType(oldTy));
503     if (!newTy)
504       return rewriter.notifyMatchFailure(
505           loc, llvm::formatv("unsupported type: {0}", op.getType()));
506 
507     // Rewrite Max*I/Min*I as compare and select over original operands. Let
508     // the CmpI and Select emulation patterns handle the final legalization.
509     Value cmp =
510         rewriter.create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs());
511     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
512                                                  op.getRhs());
513     return success();
514   }
515 };
516 
517 // Convert IndexCast ops
518 //===----------------------------------------------------------------------===//
519 
520 /// Returns true iff the type is `index` or `vector<...index>`.
521 static bool isIndexOrIndexVector(Type type) {
522   if (isa<IndexType>(type))
523     return true;
524 
525   if (auto vectorTy = dyn_cast<VectorType>(type))
526     if (isa<IndexType>(vectorTy.getElementType()))
527       return true;
528 
529   return false;
530 }
531 
532 template <typename CastOp>
533 struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
534   using OpConversionPattern<CastOp>::OpConversionPattern;
535 
536   LogicalResult
537   matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
538                   ConversionPatternRewriter &rewriter) const override {
539     Type resultType = op.getType();
540     if (!isIndexOrIndexVector(resultType))
541       return failure();
542 
543     Location loc = op.getLoc();
544     Type inType = op.getIn().getType();
545     auto newInTy =
546         this->getTypeConverter()->template convertType<VectorType>(inType);
547     if (!newInTy)
548       return rewriter.notifyMatchFailure(
549           loc, llvm::formatv("unsupported type: {0}", inType));
550 
551     // Discard the high half of the input truncating the original value.
552     Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
553     extracted = dropTrailingX1Dim(rewriter, loc, extracted);
554     rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
555     return success();
556   }
557 };
558 
559 template <typename CastOp, typename ExtensionOp>
560 struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
561   using OpConversionPattern<CastOp>::OpConversionPattern;
562 
563   LogicalResult
564   matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
565                   ConversionPatternRewriter &rewriter) const override {
566     Type inType = op.getIn().getType();
567     if (!isIndexOrIndexVector(inType))
568       return failure();
569 
570     Location loc = op.getLoc();
571     auto *typeConverter =
572         this->template getTypeConverter<arith::WideIntEmulationConverter>();
573 
574     Type resultType = op.getType();
575     auto newTy = typeConverter->template convertType<VectorType>(resultType);
576     if (!newTy)
577       return rewriter.notifyMatchFailure(
578           loc, llvm::formatv("unsupported type: {0}", resultType));
579 
580     // Emit an index cast over the matching narrow type.
581     Type narrowTy =
582         rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
583     if (auto vecTy = dyn_cast<VectorType>(resultType))
584       narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
585 
586     // Sign or zero-extend the result. Let the matching conversion pattern
587     // legalize the extension op.
588     Value underlyingVal =
589         rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
590     rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
591     return success();
592   }
593 };
594 
595 //===----------------------------------------------------------------------===//
596 // ConvertSelect
597 //===----------------------------------------------------------------------===//
598 
599 struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
600   using OpConversionPattern::OpConversionPattern;
601 
602   LogicalResult
603   matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
604                   ConversionPatternRewriter &rewriter) const override {
605     Location loc = op->getLoc();
606     auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
607     if (!newTy)
608       return rewriter.notifyMatchFailure(
609           loc, llvm::formatv("unsupported type: {0}", op.getType()));
610 
611     auto [trueElem0, trueElem1] =
612         extractLastDimHalves(rewriter, loc, adaptor.getTrueValue());
613     auto [falseElem0, falseElem1] =
614         extractLastDimHalves(rewriter, loc, adaptor.getFalseValue());
615     Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
616 
617     Value resElem0 =
618         rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
619     Value resElem1 =
620         rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
621     Value resultVec =
622         constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
623     rewriter.replaceOp(op, resultVec);
624     return success();
625   }
626 };
627 
628 //===----------------------------------------------------------------------===//
629 // ConvertShLI
630 //===----------------------------------------------------------------------===//
631 
632 struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
633   using OpConversionPattern::OpConversionPattern;
634 
635   LogicalResult
636   matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
637                   ConversionPatternRewriter &rewriter) const override {
638     Location loc = op->getLoc();
639 
640     Type oldTy = op.getType();
641     auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
642     if (!newTy)
643       return rewriter.notifyMatchFailure(
644           loc, llvm::formatv("unsupported type: {0}", op.getType()));
645 
646     Type newOperandTy = reduceInnermostDim(newTy);
647     // `oldBitWidth` == `2 * newBitWidth`
648     unsigned newBitWidth = newTy.getElementTypeBitWidth();
649 
650     auto [lhsElem0, lhsElem1] =
651         extractLastDimHalves(rewriter, loc, adaptor.getLhs());
652     Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
653 
654     // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
655     // high halves of the results separately:
656     //   1. low := LHS.low shli RHS
657     //
658     //   2. high := a or b or c, where:
659     //     a) Bits from LHS.high, shifted by the RHS.
660     //     b) Bits from LHS.low, shifted right. These come into play when
661     //        RHS < newBitWidth, e.g.:
662     //         [0000][llll] shli 3 --> [0lll][l000]
663     //                                    ^
664     //                                    |
665     //                           [llll] shrui (4 - 3)
666     //     c) Bits from LHS.low, shifted left. These matter when
667     //        RHS > newBitWidth, e.g.:
668     //         [0000][llll] shli 7 --> [l000][0000]
669     //                                   ^
670     //                                   |
671     //                          [llll] shli (7 - 4)
672     //
673     // Because shifts by values >= newBitWidth are undefined, we ignore the high
674     // half of RHS, and introduce 'bounds checks' to account for
675     // RHS.low > newBitWidth.
676     //
677     // TODO: Explore possible optimizations.
678     Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
679     Value elemBitWidth =
680         createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
681 
682     Value illegalElemShift = rewriter.create<arith::CmpIOp>(
683         loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
684 
685     Value shiftedElem0 =
686         rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
687     Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
688                                                       zeroCst, shiftedElem0);
689 
690     Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
691         loc, illegalElemShift, elemBitWidth, rhsElem0);
692     Value rightShiftAmount =
693         rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
694     Value shiftedRight =
695         rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
696     Value overshotShiftAmount =
697         rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
698     Value shiftedLeft =
699         rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
700 
701     Value shiftedElem1 =
702         rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
703     Value resElem1High = rewriter.create<arith::SelectOp>(
704         loc, illegalElemShift, zeroCst, shiftedElem1);
705     Value resElem1Low = rewriter.create<arith::SelectOp>(
706         loc, illegalElemShift, shiftedLeft, shiftedRight);
707     Value resElem1 =
708         rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High);
709 
710     Value resultVec =
711         constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
712     rewriter.replaceOp(op, resultVec);
713     return success();
714   }
715 };
716 
717 //===----------------------------------------------------------------------===//
718 // ConvertShRUI
719 //===----------------------------------------------------------------------===//
720 
721 struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
722   using OpConversionPattern::OpConversionPattern;
723 
724   LogicalResult
725   matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
726                   ConversionPatternRewriter &rewriter) const override {
727     Location loc = op->getLoc();
728 
729     Type oldTy = op.getType();
730     auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
731     if (!newTy)
732       return rewriter.notifyMatchFailure(
733           loc, llvm::formatv("unsupported type: {0}", op.getType()));
734 
735     Type newOperandTy = reduceInnermostDim(newTy);
736     // `oldBitWidth` == `2 * newBitWidth`
737     unsigned newBitWidth = newTy.getElementTypeBitWidth();
738 
739     auto [lhsElem0, lhsElem1] =
740         extractLastDimHalves(rewriter, loc, adaptor.getLhs());
741     Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
742 
743     // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
744     // high halves of the results separately:
745     //   1. low := a or b or c, where:
746     //     a) Bits from LHS.low, shifted by the RHS.
747     //     b) Bits from LHS.high, shifted left. These matter when
748     //        RHS < newBitWidth, e.g.:
749     //         [hhhh][0000] shrui 3 --> [000h][hhh0]
750     //                                          ^
751     //                                          |
752     //                                 [hhhh] shli (4 - 1)
753     //     c) Bits from LHS.high, shifted right. These come into play when
754     //        RHS > newBitWidth, e.g.:
755     //         [hhhh][0000] shrui 7 --> [0000][000h]
756     //                                          ^
757     //                                          |
758     //                                 [hhhh] shrui (7 - 4)
759     //
760     //   2. high := LHS.high shrui RHS
761     //
762     // Because shifts by values >= newBitWidth are undefined, we ignore the high
763     // half of RHS, and introduce 'bounds checks' to account for
764     // RHS.low > newBitWidth.
765     //
766     // TODO: Explore possible optimizations.
767     Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
768     Value elemBitWidth =
769         createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
770 
771     Value illegalElemShift = rewriter.create<arith::CmpIOp>(
772         loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
773 
774     Value shiftedElem0 =
775         rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
776     Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
777                                                          zeroCst, shiftedElem0);
778     Value shiftedElem1 =
779         rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
780     Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
781                                                       zeroCst, shiftedElem1);
782 
783     Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
784         loc, illegalElemShift, elemBitWidth, rhsElem0);
785     Value leftShiftAmount =
786         rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
787     Value shiftedLeft =
788         rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
789     Value overshotShiftAmount =
790         rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
791     Value shiftedRight =
792         rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
793 
794     Value resElem0High = rewriter.create<arith::SelectOp>(
795         loc, illegalElemShift, shiftedRight, shiftedLeft);
796     Value resElem0 =
797         rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High);
798 
799     Value resultVec =
800         constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
801     rewriter.replaceOp(op, resultVec);
802     return success();
803   }
804 };
805 
806 //===----------------------------------------------------------------------===//
807 // ConvertShRSI
808 //===----------------------------------------------------------------------===//
809 
810 struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
811   using OpConversionPattern::OpConversionPattern;
812 
813   LogicalResult
814   matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
815                   ConversionPatternRewriter &rewriter) const override {
816     Location loc = op->getLoc();
817 
818     Type oldTy = op.getType();
819     auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
820     if (!newTy)
821       return rewriter.notifyMatchFailure(
822           loc, llvm::formatv("unsupported type: {0}", op.getType()));
823 
824     Value lhsElem1 = extractLastDimSlice(rewriter, loc, adaptor.getLhs(), 1);
825     Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
826 
827     Type narrowTy = rhsElem0.getType();
828     int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
829 
830     // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
831     // Perform as many ops over the narrow integer type as possible and let the
832     // other emulation patterns convert the rest.
833     Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
834     Value signBit = rewriter.create<arith::CmpIOp>(
835         loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
836     signBit = dropTrailingX1Dim(rewriter, loc, signBit);
837 
838     // Create a bit pattern of either all ones or all zeros. Then shift it left
839     // to calculate the sign extension bits created by shifting the original
840     // sign bit right.
841     Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit);
842     Value maxShift =
843         createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth);
844     Value numNonSignExtBits =
845         rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0);
846     numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);
847     numNonSignExtBits =
848         rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
849     Value signBits =
850         rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
851 
852     // Use original arguments to create the right shift.
853     Value shrui =
854         rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
855     Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
856 
857     // Handle shifting by zero. This is necessary when the `signBits` shift is
858     // invalid.
859     Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
860                                                   rhsElem0, elemZero);
861     isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
862     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
863                                                  shrsi);
864 
865     return success();
866   }
867 };
868 
869 //===----------------------------------------------------------------------===//
870 // ConvertSIToFP
871 //===----------------------------------------------------------------------===//
872 
873 struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
874   using OpConversionPattern::OpConversionPattern;
875 
876   LogicalResult
877   matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
878                   ConversionPatternRewriter &rewriter) const override {
879     Location loc = op.getLoc();
880 
881     Value in = op.getIn();
882     Type oldTy = in.getType();
883     auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
884     if (!newTy)
885       return rewriter.notifyMatchFailure(
886           loc, llvm::formatv("unsupported type: {0}", oldTy));
887 
888     unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth();
889     Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
890     Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1);
891     Value allOnesCst = createScalarOrSplatConstant(
892         rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth));
893 
894     // To avoid operating on very large unsigned numbers, perform the
895     // conversion on the absolute value. Then, decide whether to negate the
896     // result or not based on that sign bit. We assume two's complement and
897     // implement negation by flipping all bits and adding 1.
898     // Note that this relies on the the other conversion patterns to legalize
899     // created ops and narrow the bit widths.
900     Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
901                                                  in, zeroCst);
902     Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst);
903     Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
904     Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
905 
906     Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
907     Value negResult = rewriter.create<arith::NegFOp>(loc, absResult);
908     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
909                                                  absResult);
910     return success();
911   }
912 };
913 
914 //===----------------------------------------------------------------------===//
915 // ConvertUIToFP
916 //===----------------------------------------------------------------------===//
917 
918 struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
919   using OpConversionPattern::OpConversionPattern;
920 
921   LogicalResult
922   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
923                   ConversionPatternRewriter &rewriter) const override {
924     Location loc = op.getLoc();
925 
926     Type oldTy = op.getIn().getType();
927     auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
928     if (!newTy)
929       return rewriter.notifyMatchFailure(
930           loc, llvm::formatv("unsupported type: {0}", oldTy));
931     unsigned newBitWidth = newTy.getElementTypeBitWidth();
932 
933     auto [low, hi] = extractLastDimHalves(rewriter, loc, adaptor.getIn());
934     Value lowInt = dropTrailingX1Dim(rewriter, loc, low);
935     Value hiInt = dropTrailingX1Dim(rewriter, loc, hi);
936     Value zeroCst =
937         createScalarOrSplatConstant(rewriter, loc, hiInt.getType(), 0);
938 
939     // The final result has the following form:
940     //   if (hi == 0) return uitofp(low)
941     //   else         return uitofp(low) + uitofp(hi) * 2^BW
942     //
943     // where `BW` is the bitwidth of the narrowed integer type. We emit a
944     // select to make it easier to fold-away the `hi` part calculation when it
945     // is known to be zero.
946     //
947     // Note 1: The emulation is precise only for input values that have exact
948     // integer representation in the result floating point type, and may lead
949     // loss of precision otherwise.
950     //
951     // Note 2: We do not strictly need the `hi == 0`, case, but it makes
952     // constant folding easier.
953     Value hiEqZero = rewriter.create<arith::CmpIOp>(
954         loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
955 
956     Type resultTy = op.getType();
957     Type resultElemTy = getElementTypeOrSelf(resultTy);
958     Value lowFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, lowInt);
959     Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);
960 
961     int64_t pow2Int = int64_t(1) << newBitWidth;
962     TypedAttr pow2Attr =
963         rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
964     if (auto vecTy = dyn_cast<VectorType>(resultTy))
965       pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
966 
967     Value pow2Val = rewriter.create<arith::ConstantOp>(loc, resultTy, pow2Attr);
968 
969     Value hiVal = rewriter.create<arith::MulFOp>(loc, hiFp, pow2Val);
970     Value result = rewriter.create<arith::AddFOp>(loc, lowFp, hiVal);
971 
972     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result);
973     return success();
974   }
975 };
976 
977 //===----------------------------------------------------------------------===//
978 // ConvertTruncI
979 //===----------------------------------------------------------------------===//
980 
981 struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
982   using OpConversionPattern::OpConversionPattern;
983 
984   LogicalResult
985   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
986                   ConversionPatternRewriter &rewriter) const override {
987     Location loc = op.getLoc();
988     // Check if the result type is legal for this target. Currently, we do not
989     // support truncation to types wider than supported by the target.
990     if (!getTypeConverter()->isLegal(op.getType()))
991       return rewriter.notifyMatchFailure(
992           loc, llvm::formatv("unsupported truncation result type: {0}",
993                              op.getType()));
994 
995     // Discard the high half of the input. Truncate the low half, if
996     // necessary.
997     Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
998     extracted = dropTrailingX1Dim(rewriter, loc, extracted);
999     Value truncated =
1000         rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
1001     rewriter.replaceOp(op, truncated);
1002     return success();
1003   }
1004 };
1005 
1006 //===----------------------------------------------------------------------===//
1007 // ConvertVectorPrint
1008 //===----------------------------------------------------------------------===//
1009 
1010 struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> {
1011   using OpConversionPattern::OpConversionPattern;
1012 
1013   LogicalResult
1014   matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1015                   ConversionPatternRewriter &rewriter) const override {
1016     rewriter.replaceOpWithNewOp<vector::PrintOp>(op, adaptor.getSource());
1017     return success();
1018   }
1019 };
1020 
1021 //===----------------------------------------------------------------------===//
1022 // Pass Definition
1023 //===----------------------------------------------------------------------===//
1024 
1025 struct EmulateWideIntPass final
1026     : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
1027   using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1028 
1029   void runOnOperation() override {
1030     if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1031       signalPassFailure();
1032       return;
1033     }
1034 
1035     Operation *op = getOperation();
1036     MLIRContext *ctx = op->getContext();
1037 
1038     arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1039     ConversionTarget target(*ctx);
1040     target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
1041       return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
1042     });
1043     auto opLegalCallback = [&typeConverter](Operation *op) {
1044       return typeConverter.isLegal(op);
1045     };
1046     target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1047     target
1048         .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
1049             opLegalCallback);
1050 
1051     RewritePatternSet patterns(ctx);
1052     arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
1053 
1054     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1055       signalPassFailure();
1056   }
1057 };
1058 } // end anonymous namespace
1059 
1060 //===----------------------------------------------------------------------===//
1061 // Public Interface Definition
1062 //===----------------------------------------------------------------------===//
1063 
1064 arith::WideIntEmulationConverter::WideIntEmulationConverter(
1065     unsigned widestIntSupportedByTarget)
1066     : maxIntWidth(widestIntSupportedByTarget) {
1067   assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1068          "Only power-of-two integers with are supported");
1069   assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");
1070 
1071   // Allow unknown types.
1072   addConversion([](Type ty) -> std::optional<Type> { return ty; });
1073 
1074   // Scalar case.
1075   addConversion([this](IntegerType ty) -> std::optional<Type> {
1076     unsigned width = ty.getWidth();
1077     if (width <= maxIntWidth)
1078       return ty;
1079 
1080     // i2N --> vector<2xiN>
1081     if (width == 2 * maxIntWidth)
1082       return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
1083 
1084     return nullptr;
1085   });
1086 
1087   // Vector case.
1088   addConversion([this](VectorType ty) -> std::optional<Type> {
1089     auto intTy = dyn_cast<IntegerType>(ty.getElementType());
1090     if (!intTy)
1091       return ty;
1092 
1093     unsigned width = intTy.getWidth();
1094     if (width <= maxIntWidth)
1095       return ty;
1096 
1097     // vector<...xi2N> --> vector<...x2xiN>
1098     if (width == 2 * maxIntWidth) {
1099       auto newShape = to_vector(ty.getShape());
1100       newShape.push_back(2);
1101       return VectorType::get(newShape,
1102                              IntegerType::get(ty.getContext(), maxIntWidth));
1103     }
1104 
1105     return nullptr;
1106   });
1107 
1108   // Function case.
1109   addConversion([this](FunctionType ty) -> std::optional<Type> {
1110     // Convert inputs and results, e.g.:
1111     //   (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
1112     SmallVector<Type> inputs;
1113     if (failed(convertTypes(ty.getInputs(), inputs)))
1114       return nullptr;
1115 
1116     SmallVector<Type> results;
1117     if (failed(convertTypes(ty.getResults(), results)))
1118       return nullptr;
1119 
1120     return FunctionType::get(ty.getContext(), inputs, results);
1121   });
1122 }
1123 
1124 void arith::populateArithWideIntEmulationPatterns(
1125     const WideIntEmulationConverter &typeConverter,
1126     RewritePatternSet &patterns) {
1127   // Populate `func.*` conversion patterns.
1128   populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1129                                                                  typeConverter);
1130   populateCallOpTypeConversionPattern(patterns, typeConverter);
1131   populateReturnOpTypeConversionPattern(patterns, typeConverter);
1132 
1133   // Populate `arith.*` conversion patterns.
1134   patterns.add<
1135       // Misc ops.
1136       ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1137       // Binary ops.
1138       ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1139       ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1140       ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1141       ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1142       ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
1143       // Bitwise binary ops.
1144       ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1145       ConvertBitwiseBinary<arith::XOrIOp>,
1146       // Extension and truncation ops.
1147       ConvertExtSI, ConvertExtUI, ConvertTruncI,
1148       // Cast ops.
1149       ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1150       ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1151       ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1152       ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1153       ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext());
1154 }
1155