xref: /llvm-project/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (revision f9a80062470daf94e07f65f9dd23df6a4f2946a2)
1 //===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert the Arith dialect to the EmitC
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/EmitC/IR/EmitC.h"
18 #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // Conversion Patterns
27 //===----------------------------------------------------------------------===//
28 
29 namespace {
30 class ArithConstantOpConversionPattern
31     : public OpConversionPattern<arith::ConstantOp> {
32 public:
33   using OpConversionPattern::OpConversionPattern;
34 
35   LogicalResult
36   matchAndRewrite(arith::ConstantOp arithConst,
37                   arith::ConstantOp::Adaptor adaptor,
38                   ConversionPatternRewriter &rewriter) const override {
39     Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
40     if (!newTy)
41       return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
42     rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
43                                                    adaptor.getValue());
44     return success();
45   }
46 };
47 
48 /// Get the signed or unsigned type corresponding to \p ty.
49 Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
50   if (isa<IntegerType>(ty)) {
51     if (ty.isUnsignedInteger() != needsUnsigned) {
52       auto signedness = needsUnsigned
53                             ? IntegerType::SignednessSemantics::Unsigned
54                             : IntegerType::SignednessSemantics::Signed;
55       return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
56                               signedness);
57     }
58   } else if (emitc::isPointerWideType(ty)) {
59     if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
60       if (needsUnsigned)
61         return emitc::SizeTType::get(ty.getContext());
62       return emitc::PtrDiffTType::get(ty.getContext());
63     }
64   }
65   return ty;
66 }
67 
68 /// Insert a cast operation to type \p ty if \p val does not have this type.
69 Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
70   return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
71 }
72 
73 class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
74 public:
75   using OpConversionPattern::OpConversionPattern;
76 
77   LogicalResult
78   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
79                   ConversionPatternRewriter &rewriter) const override {
80 
81     if (!isa<FloatType>(adaptor.getRhs().getType())) {
82       return rewriter.notifyMatchFailure(op.getLoc(),
83                                          "cmpf currently only supported on "
84                                          "floats, not tensors/vectors thereof");
85     }
86 
87     bool unordered = false;
88     emitc::CmpPredicate predicate;
89     switch (op.getPredicate()) {
90     case arith::CmpFPredicate::AlwaysFalse: {
91       auto constant = rewriter.create<emitc::ConstantOp>(
92           op.getLoc(), rewriter.getI1Type(),
93           rewriter.getBoolAttr(/*value=*/false));
94       rewriter.replaceOp(op, constant);
95       return success();
96     }
97     case arith::CmpFPredicate::OEQ:
98       unordered = false;
99       predicate = emitc::CmpPredicate::eq;
100       break;
101     case arith::CmpFPredicate::OGT:
102       unordered = false;
103       predicate = emitc::CmpPredicate::gt;
104       break;
105     case arith::CmpFPredicate::OGE:
106       unordered = false;
107       predicate = emitc::CmpPredicate::ge;
108       break;
109     case arith::CmpFPredicate::OLT:
110       unordered = false;
111       predicate = emitc::CmpPredicate::lt;
112       break;
113     case arith::CmpFPredicate::OLE:
114       unordered = false;
115       predicate = emitc::CmpPredicate::le;
116       break;
117     case arith::CmpFPredicate::ONE:
118       unordered = false;
119       predicate = emitc::CmpPredicate::ne;
120       break;
121     case arith::CmpFPredicate::ORD: {
122       // ordered, i.e. none of the operands is NaN
123       auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
124                                       adaptor.getRhs());
125       rewriter.replaceOp(op, cmp);
126       return success();
127     }
128     case arith::CmpFPredicate::UEQ:
129       unordered = true;
130       predicate = emitc::CmpPredicate::eq;
131       break;
132     case arith::CmpFPredicate::UGT:
133       unordered = true;
134       predicate = emitc::CmpPredicate::gt;
135       break;
136     case arith::CmpFPredicate::UGE:
137       unordered = true;
138       predicate = emitc::CmpPredicate::ge;
139       break;
140     case arith::CmpFPredicate::ULT:
141       unordered = true;
142       predicate = emitc::CmpPredicate::lt;
143       break;
144     case arith::CmpFPredicate::ULE:
145       unordered = true;
146       predicate = emitc::CmpPredicate::le;
147       break;
148     case arith::CmpFPredicate::UNE:
149       unordered = true;
150       predicate = emitc::CmpPredicate::ne;
151       break;
152     case arith::CmpFPredicate::UNO: {
153       // unordered, i.e. either operand is nan
154       auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
155                                         adaptor.getRhs());
156       rewriter.replaceOp(op, cmp);
157       return success();
158     }
159     case arith::CmpFPredicate::AlwaysTrue: {
160       auto constant = rewriter.create<emitc::ConstantOp>(
161           op.getLoc(), rewriter.getI1Type(),
162           rewriter.getBoolAttr(/*value=*/true));
163       rewriter.replaceOp(op, constant);
164       return success();
165     }
166     }
167 
168     // Compare the values naively
169     auto cmpResult =
170         rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
171                                       adaptor.getLhs(), adaptor.getRhs());
172 
173     // Adjust the results for unordered/ordered semantics
174     if (unordered) {
175       auto isUnordered = createCheckIsUnordered(
176           rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
177       rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
178                                                       isUnordered, cmpResult);
179       return success();
180     }
181 
182     auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
183                                           adaptor.getLhs(), adaptor.getRhs());
184     rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
185                                                      isOrdered, cmpResult);
186     return success();
187   }
188 
189 private:
190   /// Return a value that is true if \p operand is NaN.
191   Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
192               Value operand) const {
193     // A value is NaN exactly when it compares unequal to itself.
194     return rewriter.create<emitc::CmpOp>(
195         loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
196   }
197 
198   /// Return a value that is true if \p operand is not NaN.
199   Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
200                  Value operand) const {
201     // A value is not NaN exactly when it compares equal to itself.
202     return rewriter.create<emitc::CmpOp>(
203         loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
204   }
205 
206   /// Return a value that is true if the operands \p first and \p second are
207   /// unordered (i.e., at least one of them is NaN).
208   Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
209                                Location loc, Value first, Value second) const {
210     auto firstIsNaN = isNaN(rewriter, loc, first);
211     auto secondIsNaN = isNaN(rewriter, loc, second);
212     return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
213                                                firstIsNaN, secondIsNaN);
214   }
215 
216   /// Return a value that is true if the operands \p first and \p second are
217   /// both ordered (i.e., none one of them is NaN).
218   Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
219                              Value first, Value second) const {
220     auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
221     auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
222     return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
223                                                 firstIsNotNaN, secondIsNotNaN);
224   }
225 };
226 
227 class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
228 public:
229   using OpConversionPattern::OpConversionPattern;
230 
231   bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
232     switch (pred) {
233     case arith::CmpIPredicate::eq:
234     case arith::CmpIPredicate::ne:
235     case arith::CmpIPredicate::slt:
236     case arith::CmpIPredicate::sle:
237     case arith::CmpIPredicate::sgt:
238     case arith::CmpIPredicate::sge:
239       return false;
240     case arith::CmpIPredicate::ult:
241     case arith::CmpIPredicate::ule:
242     case arith::CmpIPredicate::ugt:
243     case arith::CmpIPredicate::uge:
244       return true;
245     }
246     llvm_unreachable("unknown cmpi predicate kind");
247   }
248 
249   emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
250     switch (pred) {
251     case arith::CmpIPredicate::eq:
252       return emitc::CmpPredicate::eq;
253     case arith::CmpIPredicate::ne:
254       return emitc::CmpPredicate::ne;
255     case arith::CmpIPredicate::slt:
256     case arith::CmpIPredicate::ult:
257       return emitc::CmpPredicate::lt;
258     case arith::CmpIPredicate::sle:
259     case arith::CmpIPredicate::ule:
260       return emitc::CmpPredicate::le;
261     case arith::CmpIPredicate::sgt:
262     case arith::CmpIPredicate::ugt:
263       return emitc::CmpPredicate::gt;
264     case arith::CmpIPredicate::sge:
265     case arith::CmpIPredicate::uge:
266       return emitc::CmpPredicate::ge;
267     }
268     llvm_unreachable("unknown cmpi predicate kind");
269   }
270 
271   LogicalResult
272   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
273                   ConversionPatternRewriter &rewriter) const override {
274 
275     Type type = adaptor.getLhs().getType();
276     if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
277       return rewriter.notifyMatchFailure(
278           op, "expected integer or size_t/ssize_t/ptrdiff_t type");
279     }
280 
281     bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
282     emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
283 
284     Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
285     Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
286     Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
287 
288     rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
289     return success();
290   }
291 };
292 
293 class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
294 public:
295   using OpConversionPattern::OpConversionPattern;
296 
297   LogicalResult
298   matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
299                   ConversionPatternRewriter &rewriter) const override {
300 
301     auto adaptedOp = adaptor.getOperand();
302     auto adaptedOpType = adaptedOp.getType();
303 
304     if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
305       return rewriter.notifyMatchFailure(
306           op.getLoc(),
307           "negf currently only supports scalar types, not vectors or tensors");
308     }
309 
310     if (!emitc::isSupportedFloatType(adaptedOpType)) {
311       return rewriter.notifyMatchFailure(
312           op.getLoc(), "floating-point type is not supported by EmitC");
313     }
314 
315     rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
316                                                      adaptedOp);
317     return success();
318   }
319 };
320 
321 template <typename ArithOp, bool castToUnsigned>
322 class CastConversion : public OpConversionPattern<ArithOp> {
323 public:
324   using OpConversionPattern<ArithOp>::OpConversionPattern;
325 
326   LogicalResult
327   matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
328                   ConversionPatternRewriter &rewriter) const override {
329 
330     Type opReturnType = this->getTypeConverter()->convertType(op.getType());
331     if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
332                            emitc::isPointerWideType(opReturnType)))
333       return rewriter.notifyMatchFailure(
334           op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
335 
336     if (adaptor.getOperands().size() != 1) {
337       return rewriter.notifyMatchFailure(
338           op, "CastConversion only supports unary ops");
339     }
340 
341     Type operandType = adaptor.getIn().getType();
342     if (!operandType || !(isa<IntegerType>(operandType) ||
343                           emitc::isPointerWideType(operandType)))
344       return rewriter.notifyMatchFailure(
345           op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
346 
347     // Signed (sign-extending) casts from i1 are not supported.
348     if (operandType.isInteger(1) && !castToUnsigned)
349       return rewriter.notifyMatchFailure(op,
350                                          "operation not supported on i1 type");
351 
352     // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
353     // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
354     // truncation.
355     if (opReturnType.isInteger(1)) {
356       Type attrType = (emitc::isPointerWideType(operandType))
357                           ? rewriter.getIndexType()
358                           : operandType;
359       auto constOne = rewriter.create<emitc::ConstantOp>(
360           op.getLoc(), operandType, rewriter.getOneAttr(attrType));
361       auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
362           op.getLoc(), operandType, adaptor.getIn(), constOne);
363       rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
364                                                  oneAndOperand);
365       return success();
366     }
367 
368     bool isTruncation =
369         (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
370          operandType.getIntOrFloatBitWidth() >
371              opReturnType.getIntOrFloatBitWidth());
372     bool doUnsigned = castToUnsigned || isTruncation;
373 
374     // Adapt the signedness of the result (bitwidth-preserving cast)
375     // This is needed e.g., if the return type is signless.
376     Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
377 
378     // Adapt the signedness of the operand (bitwidth-preserving cast)
379     Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
380     Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
381 
382     // Actual cast (may change bitwidth)
383     auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
384                                                         castDestType, actualOp);
385 
386     // Cast to the expected output type
387     auto result = adaptValueType(cast, rewriter, opReturnType);
388 
389     rewriter.replaceOp(op, result);
390     return success();
391   }
392 };
393 
394 template <typename ArithOp>
395 class UnsignedCastConversion : public CastConversion<ArithOp, true> {
396   using CastConversion<ArithOp, true>::CastConversion;
397 };
398 
399 template <typename ArithOp>
400 class SignedCastConversion : public CastConversion<ArithOp, false> {
401   using CastConversion<ArithOp, false>::CastConversion;
402 };
403 
404 template <typename ArithOp, typename EmitCOp>
405 class ArithOpConversion final : public OpConversionPattern<ArithOp> {
406 public:
407   using OpConversionPattern<ArithOp>::OpConversionPattern;
408 
409   LogicalResult
410   matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
411                   ConversionPatternRewriter &rewriter) const override {
412 
413     Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
414     if (!newTy)
415       return rewriter.notifyMatchFailure(arithOp,
416                                          "converting result type failed");
417     rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
418                                                   adaptor.getOperands());
419 
420     return success();
421   }
422 };
423 
424 template <class ArithOp, class EmitCOp>
425 class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
426 public:
427   using OpConversionPattern<ArithOp>::OpConversionPattern;
428 
429   LogicalResult
430   matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor,
431                   ConversionPatternRewriter &rewriter) const override {
432     Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
433     if (!newRetTy)
434       return rewriter.notifyMatchFailure(uiBinOp,
435                                          "converting result type failed");
436     if (!isa<IntegerType>(newRetTy)) {
437       return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
438     }
439     Type unsignedType =
440         adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
441     if (!unsignedType)
442       return rewriter.notifyMatchFailure(uiBinOp,
443                                          "converting result type failed");
444     Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
445     Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
446 
447     auto newDivOp =
448         rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
449                                  ArrayRef<Value>{lhsAdapted, rhsAdapted});
450     Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
451     rewriter.replaceOp(uiBinOp, resultAdapted);
452     return success();
453   }
454 };
455 
456 template <typename ArithOp, typename EmitCOp>
457 class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
458 public:
459   using OpConversionPattern<ArithOp>::OpConversionPattern;
460 
461   LogicalResult
462   matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
463                   ConversionPatternRewriter &rewriter) const override {
464 
465     Type type = this->getTypeConverter()->convertType(op.getType());
466     if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
467       return rewriter.notifyMatchFailure(
468           op, "expected integer or size_t/ssize_t/ptrdiff_t type");
469     }
470 
471     if (type.isInteger(1)) {
472       // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
473       return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
474     }
475 
476     Type arithmeticType = type;
477     if ((type.isSignlessInteger() || type.isSignedInteger()) &&
478         !bitEnumContainsAll(op.getOverflowFlags(),
479                             arith::IntegerOverflowFlags::nsw)) {
480       // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
481       // we compute in unsigned integers to avoid UB.
482       arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
483                                                /*isSigned=*/false);
484     }
485 
486     Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
487     Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
488 
489     Value arithmeticResult = rewriter.template create<EmitCOp>(
490         op.getLoc(), arithmeticType, lhs, rhs);
491 
492     Value result = adaptValueType(arithmeticResult, rewriter, type);
493 
494     rewriter.replaceOp(op, result);
495     return success();
496   }
497 };
498 
499 template <typename ArithOp, typename EmitCOp>
500 class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
501 public:
502   using OpConversionPattern<ArithOp>::OpConversionPattern;
503 
504   LogicalResult
505   matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
506                   ConversionPatternRewriter &rewriter) const override {
507 
508     Type type = this->getTypeConverter()->convertType(op.getType());
509     if (!isa_and_nonnull<IntegerType>(type)) {
510       return rewriter.notifyMatchFailure(
511           op,
512           "expected integer type, vector/tensor support not yet implemented");
513     }
514 
515     // Bitwise ops can be performed directly on booleans
516     if (type.isInteger(1)) {
517       rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
518                                            adaptor.getRhs());
519       return success();
520     }
521 
522     // Bitwise ops are defined by the C standard on unsigned operands.
523     Type arithmeticType =
524         adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
525 
526     Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
527     Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
528 
529     Value arithmeticResult = rewriter.template create<EmitCOp>(
530         op.getLoc(), arithmeticType, lhs, rhs);
531 
532     Value result = adaptValueType(arithmeticResult, rewriter, type);
533 
534     rewriter.replaceOp(op, result);
535     return success();
536   }
537 };
538 
539 template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
540 class ShiftOpConversion : public OpConversionPattern<ArithOp> {
541 public:
542   using OpConversionPattern<ArithOp>::OpConversionPattern;
543 
544   LogicalResult
545   matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
546                   ConversionPatternRewriter &rewriter) const override {
547 
548     Type type = this->getTypeConverter()->convertType(op.getType());
549     if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
550       return rewriter.notifyMatchFailure(
551           op, "expected integer or size_t/ssize_t/ptrdiff_t type");
552     }
553 
554     if (type.isInteger(1)) {
555       return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
556     }
557 
558     Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
559 
560     Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
561     // Shift amount interpreted as unsigned per Arith dialect spec.
562     Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
563                                                /*needsUnsigned=*/true);
564     Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
565 
566     // Add a runtime check for overflow
567     Value width;
568     if (emitc::isPointerWideType(type)) {
569       Value eight = rewriter.create<emitc::ConstantOp>(
570           op.getLoc(), rhsType, rewriter.getIndexAttr(8));
571       emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
572           op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
573       width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
574                                             sizeOfCall.getResult(0));
575     } else {
576       width = rewriter.create<emitc::ConstantOp>(
577           op.getLoc(), rhsType,
578           rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
579     }
580 
581     Value excessCheck = rewriter.create<emitc::CmpOp>(
582         op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
583 
584     // Any concrete value is a valid refinement of poison.
585     Value poison = rewriter.create<emitc::ConstantOp>(
586         op.getLoc(), arithmeticType,
587         (isa<IntegerType>(arithmeticType)
588              ? rewriter.getIntegerAttr(arithmeticType, 0)
589              : rewriter.getIndexAttr(0)));
590 
591     emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
592         op.getLoc(), arithmeticType, /*do_not_inline=*/false);
593     Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
594     auto currentPoint = rewriter.getInsertionPoint();
595     rewriter.setInsertionPointToStart(&bodyBlock);
596     Value arithmeticResult =
597         rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
598     Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
599         op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
600     rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
601     rewriter.setInsertionPoint(op->getBlock(), currentPoint);
602 
603     Value result = adaptValueType(ternary, rewriter, type);
604 
605     rewriter.replaceOp(op, result);
606     return success();
607   }
608 };
609 
610 template <typename ArithOp, typename EmitCOp>
611 class SignedShiftOpConversion final
612     : public ShiftOpConversion<ArithOp, EmitCOp, false> {
613   using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
614 };
615 
616 template <typename ArithOp, typename EmitCOp>
617 class UnsignedShiftOpConversion final
618     : public ShiftOpConversion<ArithOp, EmitCOp, true> {
619   using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
620 };
621 
622 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
623 public:
624   using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
625 
626   LogicalResult
627   matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
628                   ConversionPatternRewriter &rewriter) const override {
629 
630     Type dstType = getTypeConverter()->convertType(selectOp.getType());
631     if (!dstType)
632       return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
633 
634     if (!adaptor.getCondition().getType().isInteger(1))
635       return rewriter.notifyMatchFailure(
636           selectOp,
637           "can only be converted if condition is a scalar of type i1");
638 
639     rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
640                                                       adaptor.getOperands());
641 
642     return success();
643   }
644 };
645 
646 // Floating-point to integer conversions.
647 template <typename CastOp>
648 class FtoICastOpConversion : public OpConversionPattern<CastOp> {
649 public:
650   FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
651       : OpConversionPattern<CastOp>(typeConverter, context) {}
652 
653   LogicalResult
654   matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
655                   ConversionPatternRewriter &rewriter) const override {
656 
657     Type operandType = adaptor.getIn().getType();
658     if (!emitc::isSupportedFloatType(operandType))
659       return rewriter.notifyMatchFailure(castOp,
660                                          "unsupported cast source type");
661 
662     Type dstType = this->getTypeConverter()->convertType(castOp.getType());
663     if (!dstType)
664       return rewriter.notifyMatchFailure(castOp, "type conversion failed");
665 
666     // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
667     // truncated to 0, whereas a boolean conversion would return true.
668     if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
669       return rewriter.notifyMatchFailure(castOp,
670                                          "unsupported cast destination type");
671 
672     // Convert to unsigned if it's the "ui" variant
673     // Signless is interpreted as signed, so no need to cast for "si"
674     Type actualResultType = dstType;
675     if (isa<arith::FPToUIOp>(castOp)) {
676       actualResultType =
677           rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
678                                   /*isSigned=*/false);
679     }
680 
681     Value result = rewriter.create<emitc::CastOp>(
682         castOp.getLoc(), actualResultType, adaptor.getOperands());
683 
684     if (isa<arith::FPToUIOp>(castOp)) {
685       result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
686     }
687     rewriter.replaceOp(castOp, result);
688 
689     return success();
690   }
691 };
692 
693 // Integer to floating-point conversions.
694 template <typename CastOp>
695 class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
696 public:
697   ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
698       : OpConversionPattern<CastOp>(typeConverter, context) {}
699 
700   LogicalResult
701   matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
702                   ConversionPatternRewriter &rewriter) const override {
703     // Vectors in particular are not supported
704     Type operandType = adaptor.getIn().getType();
705     if (!emitc::isSupportedIntegerType(operandType))
706       return rewriter.notifyMatchFailure(castOp,
707                                          "unsupported cast source type");
708 
709     Type dstType = this->getTypeConverter()->convertType(castOp.getType());
710     if (!dstType)
711       return rewriter.notifyMatchFailure(castOp, "type conversion failed");
712 
713     if (!emitc::isSupportedFloatType(dstType))
714       return rewriter.notifyMatchFailure(castOp,
715                                          "unsupported cast destination type");
716 
717     // Convert to unsigned if it's the "ui" variant
718     // Signless is interpreted as signed, so no need to cast for "si"
719     Type actualOperandType = operandType;
720     if (isa<arith::UIToFPOp>(castOp)) {
721       actualOperandType =
722           rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
723                                   /*isSigned=*/false);
724     }
725     Value fpCastOperand = adaptor.getIn();
726     if (actualOperandType != operandType) {
727       fpCastOperand = rewriter.template create<emitc::CastOp>(
728           castOp.getLoc(), actualOperandType, fpCastOperand);
729     }
730     rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
731 
732     return success();
733   }
734 };
735 
736 // Floating-point to floating-point conversions.
737 template <typename CastOp>
738 class FpCastOpConversion : public OpConversionPattern<CastOp> {
739 public:
740   FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
741       : OpConversionPattern<CastOp>(typeConverter, context) {}
742 
743   LogicalResult
744   matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
745                   ConversionPatternRewriter &rewriter) const override {
746     // Vectors in particular are not supported.
747     Type operandType = adaptor.getIn().getType();
748     if (!emitc::isSupportedFloatType(operandType))
749       return rewriter.notifyMatchFailure(castOp,
750                                          "unsupported cast source type");
751     if (auto roundingModeOp =
752             dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
753       // Only supporting default rounding mode as of now.
754       if (roundingModeOp.getRoundingModeAttr())
755         return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
756     }
757 
758     Type dstType = this->getTypeConverter()->convertType(castOp.getType());
759     if (!dstType)
760       return rewriter.notifyMatchFailure(castOp, "type conversion failed");
761 
762     if (!emitc::isSupportedFloatType(dstType))
763       return rewriter.notifyMatchFailure(castOp,
764                                          "unsupported cast destination type");
765 
766     Value fpCastOperand = adaptor.getIn();
767     rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
768 
769     return success();
770   }
771 };
772 
773 } // namespace
774 
775 //===----------------------------------------------------------------------===//
776 // Pattern population
777 //===----------------------------------------------------------------------===//
778 
779 void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
780                                         RewritePatternSet &patterns) {
781   MLIRContext *ctx = patterns.getContext();
782 
783   mlir::populateEmitCSizeTTypeConversions(typeConverter);
784 
785   // clang-format off
786   patterns.add<
787     ArithConstantOpConversionPattern,
788     ArithOpConversion<arith::AddFOp, emitc::AddOp>,
789     ArithOpConversion<arith::DivFOp, emitc::DivOp>,
790     ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
791     ArithOpConversion<arith::MulFOp, emitc::MulOp>,
792     ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
793     ArithOpConversion<arith::SubFOp, emitc::SubOp>,
794     BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
795     BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
796     IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
797     IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
798     IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
799     BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
800     BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
801     BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
802     UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
803     SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
804     UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
805     CmpFOpConversion,
806     CmpIOpConversion,
807     NegFOpConversion,
808     SelectOpConversion,
809     // Truncation is guaranteed for unsigned types.
810     UnsignedCastConversion<arith::TruncIOp>,
811     SignedCastConversion<arith::ExtSIOp>,
812     UnsignedCastConversion<arith::ExtUIOp>,
813     SignedCastConversion<arith::IndexCastOp>,
814     UnsignedCastConversion<arith::IndexCastUIOp>,
815     ItoFCastOpConversion<arith::SIToFPOp>,
816     ItoFCastOpConversion<arith::UIToFPOp>,
817     FtoICastOpConversion<arith::FPToSIOp>,
818     FtoICastOpConversion<arith::FPToUIOp>,
819     FpCastOpConversion<arith::ExtFOp>,
820     FpCastOpConversion<arith::TruncFOp>
821   >(typeConverter, ctx);
822   // clang-format on
823 }
824