1 //===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert the Arith dialect to the EmitC 10 // dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/EmitC/IR/EmitC.h" 18 #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" 19 #include "mlir/IR/BuiltinAttributes.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 25 //===----------------------------------------------------------------------===// 26 // Conversion Patterns 27 //===----------------------------------------------------------------------===// 28 29 namespace { 30 class ArithConstantOpConversionPattern 31 : public OpConversionPattern<arith::ConstantOp> { 32 public: 33 using OpConversionPattern::OpConversionPattern; 34 35 LogicalResult 36 matchAndRewrite(arith::ConstantOp arithConst, 37 arith::ConstantOp::Adaptor adaptor, 38 ConversionPatternRewriter &rewriter) const override { 39 Type newTy = this->getTypeConverter()->convertType(arithConst.getType()); 40 if (!newTy) 41 return rewriter.notifyMatchFailure(arithConst, "type conversion failed"); 42 rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy, 43 adaptor.getValue()); 44 return success(); 45 } 46 }; 47 48 /// Get the signed or unsigned type corresponding to \p ty. 49 Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) { 50 if (isa<IntegerType>(ty)) { 51 if (ty.isUnsignedInteger() != needsUnsigned) { 52 auto signedness = needsUnsigned 53 ? IntegerType::SignednessSemantics::Unsigned 54 : IntegerType::SignednessSemantics::Signed; 55 return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(), 56 signedness); 57 } 58 } else if (emitc::isPointerWideType(ty)) { 59 if (isa<emitc::SizeTType>(ty) != needsUnsigned) { 60 if (needsUnsigned) 61 return emitc::SizeTType::get(ty.getContext()); 62 return emitc::PtrDiffTType::get(ty.getContext()); 63 } 64 } 65 return ty; 66 } 67 68 /// Insert a cast operation to type \p ty if \p val does not have this type. 69 Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) { 70 return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val); 71 } 72 73 class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> { 74 public: 75 using OpConversionPattern::OpConversionPattern; 76 77 LogicalResult 78 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 79 ConversionPatternRewriter &rewriter) const override { 80 81 if (!isa<FloatType>(adaptor.getRhs().getType())) { 82 return rewriter.notifyMatchFailure(op.getLoc(), 83 "cmpf currently only supported on " 84 "floats, not tensors/vectors thereof"); 85 } 86 87 bool unordered = false; 88 emitc::CmpPredicate predicate; 89 switch (op.getPredicate()) { 90 case arith::CmpFPredicate::AlwaysFalse: { 91 auto constant = rewriter.create<emitc::ConstantOp>( 92 op.getLoc(), rewriter.getI1Type(), 93 rewriter.getBoolAttr(/*value=*/false)); 94 rewriter.replaceOp(op, constant); 95 return success(); 96 } 97 case arith::CmpFPredicate::OEQ: 98 unordered = false; 99 predicate = emitc::CmpPredicate::eq; 100 break; 101 case arith::CmpFPredicate::OGT: 102 unordered = false; 103 predicate = emitc::CmpPredicate::gt; 104 break; 105 case arith::CmpFPredicate::OGE: 106 unordered = false; 107 predicate = emitc::CmpPredicate::ge; 108 break; 109 case arith::CmpFPredicate::OLT: 110 unordered = false; 111 predicate = emitc::CmpPredicate::lt; 112 break; 113 case arith::CmpFPredicate::OLE: 114 unordered = false; 115 predicate = emitc::CmpPredicate::le; 116 break; 117 case arith::CmpFPredicate::ONE: 118 unordered = false; 119 predicate = emitc::CmpPredicate::ne; 120 break; 121 case arith::CmpFPredicate::ORD: { 122 // ordered, i.e. none of the operands is NaN 123 auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(), 124 adaptor.getRhs()); 125 rewriter.replaceOp(op, cmp); 126 return success(); 127 } 128 case arith::CmpFPredicate::UEQ: 129 unordered = true; 130 predicate = emitc::CmpPredicate::eq; 131 break; 132 case arith::CmpFPredicate::UGT: 133 unordered = true; 134 predicate = emitc::CmpPredicate::gt; 135 break; 136 case arith::CmpFPredicate::UGE: 137 unordered = true; 138 predicate = emitc::CmpPredicate::ge; 139 break; 140 case arith::CmpFPredicate::ULT: 141 unordered = true; 142 predicate = emitc::CmpPredicate::lt; 143 break; 144 case arith::CmpFPredicate::ULE: 145 unordered = true; 146 predicate = emitc::CmpPredicate::le; 147 break; 148 case arith::CmpFPredicate::UNE: 149 unordered = true; 150 predicate = emitc::CmpPredicate::ne; 151 break; 152 case arith::CmpFPredicate::UNO: { 153 // unordered, i.e. either operand is nan 154 auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(), 155 adaptor.getRhs()); 156 rewriter.replaceOp(op, cmp); 157 return success(); 158 } 159 case arith::CmpFPredicate::AlwaysTrue: { 160 auto constant = rewriter.create<emitc::ConstantOp>( 161 op.getLoc(), rewriter.getI1Type(), 162 rewriter.getBoolAttr(/*value=*/true)); 163 rewriter.replaceOp(op, constant); 164 return success(); 165 } 166 } 167 168 // Compare the values naively 169 auto cmpResult = 170 rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate, 171 adaptor.getLhs(), adaptor.getRhs()); 172 173 // Adjust the results for unordered/ordered semantics 174 if (unordered) { 175 auto isUnordered = createCheckIsUnordered( 176 rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); 177 rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(), 178 isUnordered, cmpResult); 179 return success(); 180 } 181 182 auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(), 183 adaptor.getLhs(), adaptor.getRhs()); 184 rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(), 185 isOrdered, cmpResult); 186 return success(); 187 } 188 189 private: 190 /// Return a value that is true if \p operand is NaN. 191 Value isNaN(ConversionPatternRewriter &rewriter, Location loc, 192 Value operand) const { 193 // A value is NaN exactly when it compares unequal to itself. 194 return rewriter.create<emitc::CmpOp>( 195 loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand); 196 } 197 198 /// Return a value that is true if \p operand is not NaN. 199 Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, 200 Value operand) const { 201 // A value is not NaN exactly when it compares equal to itself. 202 return rewriter.create<emitc::CmpOp>( 203 loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand); 204 } 205 206 /// Return a value that is true if the operands \p first and \p second are 207 /// unordered (i.e., at least one of them is NaN). 208 Value createCheckIsUnordered(ConversionPatternRewriter &rewriter, 209 Location loc, Value first, Value second) const { 210 auto firstIsNaN = isNaN(rewriter, loc, first); 211 auto secondIsNaN = isNaN(rewriter, loc, second); 212 return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(), 213 firstIsNaN, secondIsNaN); 214 } 215 216 /// Return a value that is true if the operands \p first and \p second are 217 /// both ordered (i.e., none one of them is NaN). 218 Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc, 219 Value first, Value second) const { 220 auto firstIsNotNaN = isNotNaN(rewriter, loc, first); 221 auto secondIsNotNaN = isNotNaN(rewriter, loc, second); 222 return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(), 223 firstIsNotNaN, secondIsNotNaN); 224 } 225 }; 226 227 class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> { 228 public: 229 using OpConversionPattern::OpConversionPattern; 230 231 bool needsUnsignedCmp(arith::CmpIPredicate pred) const { 232 switch (pred) { 233 case arith::CmpIPredicate::eq: 234 case arith::CmpIPredicate::ne: 235 case arith::CmpIPredicate::slt: 236 case arith::CmpIPredicate::sle: 237 case arith::CmpIPredicate::sgt: 238 case arith::CmpIPredicate::sge: 239 return false; 240 case arith::CmpIPredicate::ult: 241 case arith::CmpIPredicate::ule: 242 case arith::CmpIPredicate::ugt: 243 case arith::CmpIPredicate::uge: 244 return true; 245 } 246 llvm_unreachable("unknown cmpi predicate kind"); 247 } 248 249 emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const { 250 switch (pred) { 251 case arith::CmpIPredicate::eq: 252 return emitc::CmpPredicate::eq; 253 case arith::CmpIPredicate::ne: 254 return emitc::CmpPredicate::ne; 255 case arith::CmpIPredicate::slt: 256 case arith::CmpIPredicate::ult: 257 return emitc::CmpPredicate::lt; 258 case arith::CmpIPredicate::sle: 259 case arith::CmpIPredicate::ule: 260 return emitc::CmpPredicate::le; 261 case arith::CmpIPredicate::sgt: 262 case arith::CmpIPredicate::ugt: 263 return emitc::CmpPredicate::gt; 264 case arith::CmpIPredicate::sge: 265 case arith::CmpIPredicate::uge: 266 return emitc::CmpPredicate::ge; 267 } 268 llvm_unreachable("unknown cmpi predicate kind"); 269 } 270 271 LogicalResult 272 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 273 ConversionPatternRewriter &rewriter) const override { 274 275 Type type = adaptor.getLhs().getType(); 276 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) { 277 return rewriter.notifyMatchFailure( 278 op, "expected integer or size_t/ssize_t/ptrdiff_t type"); 279 } 280 281 bool needsUnsigned = needsUnsignedCmp(op.getPredicate()); 282 emitc::CmpPredicate pred = toEmitCPred(op.getPredicate()); 283 284 Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned); 285 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); 286 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); 287 288 rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs); 289 return success(); 290 } 291 }; 292 293 class NegFOpConversion : public OpConversionPattern<arith::NegFOp> { 294 public: 295 using OpConversionPattern::OpConversionPattern; 296 297 LogicalResult 298 matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, 299 ConversionPatternRewriter &rewriter) const override { 300 301 auto adaptedOp = adaptor.getOperand(); 302 auto adaptedOpType = adaptedOp.getType(); 303 304 if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) { 305 return rewriter.notifyMatchFailure( 306 op.getLoc(), 307 "negf currently only supports scalar types, not vectors or tensors"); 308 } 309 310 if (!emitc::isSupportedFloatType(adaptedOpType)) { 311 return rewriter.notifyMatchFailure( 312 op.getLoc(), "floating-point type is not supported by EmitC"); 313 } 314 315 rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType, 316 adaptedOp); 317 return success(); 318 } 319 }; 320 321 template <typename ArithOp, bool castToUnsigned> 322 class CastConversion : public OpConversionPattern<ArithOp> { 323 public: 324 using OpConversionPattern<ArithOp>::OpConversionPattern; 325 326 LogicalResult 327 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, 328 ConversionPatternRewriter &rewriter) const override { 329 330 Type opReturnType = this->getTypeConverter()->convertType(op.getType()); 331 if (!opReturnType || !(isa<IntegerType>(opReturnType) || 332 emitc::isPointerWideType(opReturnType))) 333 return rewriter.notifyMatchFailure( 334 op, "expected integer or size_t/ssize_t/ptrdiff_t result type"); 335 336 if (adaptor.getOperands().size() != 1) { 337 return rewriter.notifyMatchFailure( 338 op, "CastConversion only supports unary ops"); 339 } 340 341 Type operandType = adaptor.getIn().getType(); 342 if (!operandType || !(isa<IntegerType>(operandType) || 343 emitc::isPointerWideType(operandType))) 344 return rewriter.notifyMatchFailure( 345 op, "expected integer or size_t/ssize_t/ptrdiff_t operand type"); 346 347 // Signed (sign-extending) casts from i1 are not supported. 348 if (operandType.isInteger(1) && !castToUnsigned) 349 return rewriter.notifyMatchFailure(op, 350 "operation not supported on i1 type"); 351 352 // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is 353 // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives 354 // truncation. 355 if (opReturnType.isInteger(1)) { 356 Type attrType = (emitc::isPointerWideType(operandType)) 357 ? rewriter.getIndexType() 358 : operandType; 359 auto constOne = rewriter.create<emitc::ConstantOp>( 360 op.getLoc(), operandType, rewriter.getOneAttr(attrType)); 361 auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>( 362 op.getLoc(), operandType, adaptor.getIn(), constOne); 363 rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType, 364 oneAndOperand); 365 return success(); 366 } 367 368 bool isTruncation = 369 (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) && 370 operandType.getIntOrFloatBitWidth() > 371 opReturnType.getIntOrFloatBitWidth()); 372 bool doUnsigned = castToUnsigned || isTruncation; 373 374 // Adapt the signedness of the result (bitwidth-preserving cast) 375 // This is needed e.g., if the return type is signless. 376 Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned); 377 378 // Adapt the signedness of the operand (bitwidth-preserving cast) 379 Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned); 380 Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType); 381 382 // Actual cast (may change bitwidth) 383 auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(), 384 castDestType, actualOp); 385 386 // Cast to the expected output type 387 auto result = adaptValueType(cast, rewriter, opReturnType); 388 389 rewriter.replaceOp(op, result); 390 return success(); 391 } 392 }; 393 394 template <typename ArithOp> 395 class UnsignedCastConversion : public CastConversion<ArithOp, true> { 396 using CastConversion<ArithOp, true>::CastConversion; 397 }; 398 399 template <typename ArithOp> 400 class SignedCastConversion : public CastConversion<ArithOp, false> { 401 using CastConversion<ArithOp, false>::CastConversion; 402 }; 403 404 template <typename ArithOp, typename EmitCOp> 405 class ArithOpConversion final : public OpConversionPattern<ArithOp> { 406 public: 407 using OpConversionPattern<ArithOp>::OpConversionPattern; 408 409 LogicalResult 410 matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor, 411 ConversionPatternRewriter &rewriter) const override { 412 413 Type newTy = this->getTypeConverter()->convertType(arithOp.getType()); 414 if (!newTy) 415 return rewriter.notifyMatchFailure(arithOp, 416 "converting result type failed"); 417 rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy, 418 adaptor.getOperands()); 419 420 return success(); 421 } 422 }; 423 424 template <class ArithOp, class EmitCOp> 425 class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> { 426 public: 427 using OpConversionPattern<ArithOp>::OpConversionPattern; 428 429 LogicalResult 430 matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor, 431 ConversionPatternRewriter &rewriter) const override { 432 Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType()); 433 if (!newRetTy) 434 return rewriter.notifyMatchFailure(uiBinOp, 435 "converting result type failed"); 436 if (!isa<IntegerType>(newRetTy)) { 437 return rewriter.notifyMatchFailure(uiBinOp, "expected integer type"); 438 } 439 Type unsignedType = 440 adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true); 441 if (!unsignedType) 442 return rewriter.notifyMatchFailure(uiBinOp, 443 "converting result type failed"); 444 Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType); 445 Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType); 446 447 auto newDivOp = 448 rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType, 449 ArrayRef<Value>{lhsAdapted, rhsAdapted}); 450 Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy); 451 rewriter.replaceOp(uiBinOp, resultAdapted); 452 return success(); 453 } 454 }; 455 456 template <typename ArithOp, typename EmitCOp> 457 class IntegerOpConversion final : public OpConversionPattern<ArithOp> { 458 public: 459 using OpConversionPattern<ArithOp>::OpConversionPattern; 460 461 LogicalResult 462 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, 463 ConversionPatternRewriter &rewriter) const override { 464 465 Type type = this->getTypeConverter()->convertType(op.getType()); 466 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) { 467 return rewriter.notifyMatchFailure( 468 op, "expected integer or size_t/ssize_t/ptrdiff_t type"); 469 } 470 471 if (type.isInteger(1)) { 472 // arith expects wrap-around arithmethic, which doesn't happen on `bool`. 473 return rewriter.notifyMatchFailure(op, "i1 type is not implemented"); 474 } 475 476 Type arithmeticType = type; 477 if ((type.isSignlessInteger() || type.isSignedInteger()) && 478 !bitEnumContainsAll(op.getOverflowFlags(), 479 arith::IntegerOverflowFlags::nsw)) { 480 // If the C type is signed and the op doesn't guarantee "No Signed Wrap", 481 // we compute in unsigned integers to avoid UB. 482 arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(), 483 /*isSigned=*/false); 484 } 485 486 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); 487 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); 488 489 Value arithmeticResult = rewriter.template create<EmitCOp>( 490 op.getLoc(), arithmeticType, lhs, rhs); 491 492 Value result = adaptValueType(arithmeticResult, rewriter, type); 493 494 rewriter.replaceOp(op, result); 495 return success(); 496 } 497 }; 498 499 template <typename ArithOp, typename EmitCOp> 500 class BitwiseOpConversion : public OpConversionPattern<ArithOp> { 501 public: 502 using OpConversionPattern<ArithOp>::OpConversionPattern; 503 504 LogicalResult 505 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, 506 ConversionPatternRewriter &rewriter) const override { 507 508 Type type = this->getTypeConverter()->convertType(op.getType()); 509 if (!isa_and_nonnull<IntegerType>(type)) { 510 return rewriter.notifyMatchFailure( 511 op, 512 "expected integer type, vector/tensor support not yet implemented"); 513 } 514 515 // Bitwise ops can be performed directly on booleans 516 if (type.isInteger(1)) { 517 rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(), 518 adaptor.getRhs()); 519 return success(); 520 } 521 522 // Bitwise ops are defined by the C standard on unsigned operands. 523 Type arithmeticType = 524 adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true); 525 526 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); 527 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); 528 529 Value arithmeticResult = rewriter.template create<EmitCOp>( 530 op.getLoc(), arithmeticType, lhs, rhs); 531 532 Value result = adaptValueType(arithmeticResult, rewriter, type); 533 534 rewriter.replaceOp(op, result); 535 return success(); 536 } 537 }; 538 539 template <typename ArithOp, typename EmitCOp, bool isUnsignedOp> 540 class ShiftOpConversion : public OpConversionPattern<ArithOp> { 541 public: 542 using OpConversionPattern<ArithOp>::OpConversionPattern; 543 544 LogicalResult 545 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, 546 ConversionPatternRewriter &rewriter) const override { 547 548 Type type = this->getTypeConverter()->convertType(op.getType()); 549 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) { 550 return rewriter.notifyMatchFailure( 551 op, "expected integer or size_t/ssize_t/ptrdiff_t type"); 552 } 553 554 if (type.isInteger(1)) { 555 return rewriter.notifyMatchFailure(op, "i1 type is not implemented"); 556 } 557 558 Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp); 559 560 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); 561 // Shift amount interpreted as unsigned per Arith dialect spec. 562 Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(), 563 /*needsUnsigned=*/true); 564 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType); 565 566 // Add a runtime check for overflow 567 Value width; 568 if (emitc::isPointerWideType(type)) { 569 Value eight = rewriter.create<emitc::ConstantOp>( 570 op.getLoc(), rhsType, rewriter.getIndexAttr(8)); 571 emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>( 572 op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight}); 573 width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight, 574 sizeOfCall.getResult(0)); 575 } else { 576 width = rewriter.create<emitc::ConstantOp>( 577 op.getLoc(), rhsType, 578 rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth())); 579 } 580 581 Value excessCheck = rewriter.create<emitc::CmpOp>( 582 op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width); 583 584 // Any concrete value is a valid refinement of poison. 585 Value poison = rewriter.create<emitc::ConstantOp>( 586 op.getLoc(), arithmeticType, 587 (isa<IntegerType>(arithmeticType) 588 ? rewriter.getIntegerAttr(arithmeticType, 0) 589 : rewriter.getIndexAttr(0))); 590 591 emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>( 592 op.getLoc(), arithmeticType, /*do_not_inline=*/false); 593 Block &bodyBlock = ternary.getBodyRegion().emplaceBlock(); 594 auto currentPoint = rewriter.getInsertionPoint(); 595 rewriter.setInsertionPointToStart(&bodyBlock); 596 Value arithmeticResult = 597 rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs); 598 Value resultOrPoison = rewriter.create<emitc::ConditionalOp>( 599 op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison); 600 rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison); 601 rewriter.setInsertionPoint(op->getBlock(), currentPoint); 602 603 Value result = adaptValueType(ternary, rewriter, type); 604 605 rewriter.replaceOp(op, result); 606 return success(); 607 } 608 }; 609 610 template <typename ArithOp, typename EmitCOp> 611 class SignedShiftOpConversion final 612 : public ShiftOpConversion<ArithOp, EmitCOp, false> { 613 using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion; 614 }; 615 616 template <typename ArithOp, typename EmitCOp> 617 class UnsignedShiftOpConversion final 618 : public ShiftOpConversion<ArithOp, EmitCOp, true> { 619 using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion; 620 }; 621 622 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> { 623 public: 624 using OpConversionPattern<arith::SelectOp>::OpConversionPattern; 625 626 LogicalResult 627 matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor, 628 ConversionPatternRewriter &rewriter) const override { 629 630 Type dstType = getTypeConverter()->convertType(selectOp.getType()); 631 if (!dstType) 632 return rewriter.notifyMatchFailure(selectOp, "type conversion failed"); 633 634 if (!adaptor.getCondition().getType().isInteger(1)) 635 return rewriter.notifyMatchFailure( 636 selectOp, 637 "can only be converted if condition is a scalar of type i1"); 638 639 rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType, 640 adaptor.getOperands()); 641 642 return success(); 643 } 644 }; 645 646 // Floating-point to integer conversions. 647 template <typename CastOp> 648 class FtoICastOpConversion : public OpConversionPattern<CastOp> { 649 public: 650 FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context) 651 : OpConversionPattern<CastOp>(typeConverter, context) {} 652 653 LogicalResult 654 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor, 655 ConversionPatternRewriter &rewriter) const override { 656 657 Type operandType = adaptor.getIn().getType(); 658 if (!emitc::isSupportedFloatType(operandType)) 659 return rewriter.notifyMatchFailure(castOp, 660 "unsupported cast source type"); 661 662 Type dstType = this->getTypeConverter()->convertType(castOp.getType()); 663 if (!dstType) 664 return rewriter.notifyMatchFailure(castOp, "type conversion failed"); 665 666 // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be 667 // truncated to 0, whereas a boolean conversion would return true. 668 if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1)) 669 return rewriter.notifyMatchFailure(castOp, 670 "unsupported cast destination type"); 671 672 // Convert to unsigned if it's the "ui" variant 673 // Signless is interpreted as signed, so no need to cast for "si" 674 Type actualResultType = dstType; 675 if (isa<arith::FPToUIOp>(castOp)) { 676 actualResultType = 677 rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(), 678 /*isSigned=*/false); 679 } 680 681 Value result = rewriter.create<emitc::CastOp>( 682 castOp.getLoc(), actualResultType, adaptor.getOperands()); 683 684 if (isa<arith::FPToUIOp>(castOp)) { 685 result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result); 686 } 687 rewriter.replaceOp(castOp, result); 688 689 return success(); 690 } 691 }; 692 693 // Integer to floating-point conversions. 694 template <typename CastOp> 695 class ItoFCastOpConversion : public OpConversionPattern<CastOp> { 696 public: 697 ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context) 698 : OpConversionPattern<CastOp>(typeConverter, context) {} 699 700 LogicalResult 701 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor, 702 ConversionPatternRewriter &rewriter) const override { 703 // Vectors in particular are not supported 704 Type operandType = adaptor.getIn().getType(); 705 if (!emitc::isSupportedIntegerType(operandType)) 706 return rewriter.notifyMatchFailure(castOp, 707 "unsupported cast source type"); 708 709 Type dstType = this->getTypeConverter()->convertType(castOp.getType()); 710 if (!dstType) 711 return rewriter.notifyMatchFailure(castOp, "type conversion failed"); 712 713 if (!emitc::isSupportedFloatType(dstType)) 714 return rewriter.notifyMatchFailure(castOp, 715 "unsupported cast destination type"); 716 717 // Convert to unsigned if it's the "ui" variant 718 // Signless is interpreted as signed, so no need to cast for "si" 719 Type actualOperandType = operandType; 720 if (isa<arith::UIToFPOp>(castOp)) { 721 actualOperandType = 722 rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(), 723 /*isSigned=*/false); 724 } 725 Value fpCastOperand = adaptor.getIn(); 726 if (actualOperandType != operandType) { 727 fpCastOperand = rewriter.template create<emitc::CastOp>( 728 castOp.getLoc(), actualOperandType, fpCastOperand); 729 } 730 rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand); 731 732 return success(); 733 } 734 }; 735 736 // Floating-point to floating-point conversions. 737 template <typename CastOp> 738 class FpCastOpConversion : public OpConversionPattern<CastOp> { 739 public: 740 FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context) 741 : OpConversionPattern<CastOp>(typeConverter, context) {} 742 743 LogicalResult 744 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor, 745 ConversionPatternRewriter &rewriter) const override { 746 // Vectors in particular are not supported. 747 Type operandType = adaptor.getIn().getType(); 748 if (!emitc::isSupportedFloatType(operandType)) 749 return rewriter.notifyMatchFailure(castOp, 750 "unsupported cast source type"); 751 if (auto roundingModeOp = 752 dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) { 753 // Only supporting default rounding mode as of now. 754 if (roundingModeOp.getRoundingModeAttr()) 755 return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode"); 756 } 757 758 Type dstType = this->getTypeConverter()->convertType(castOp.getType()); 759 if (!dstType) 760 return rewriter.notifyMatchFailure(castOp, "type conversion failed"); 761 762 if (!emitc::isSupportedFloatType(dstType)) 763 return rewriter.notifyMatchFailure(castOp, 764 "unsupported cast destination type"); 765 766 Value fpCastOperand = adaptor.getIn(); 767 rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand); 768 769 return success(); 770 } 771 }; 772 773 } // namespace 774 775 //===----------------------------------------------------------------------===// 776 // Pattern population 777 //===----------------------------------------------------------------------===// 778 779 void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, 780 RewritePatternSet &patterns) { 781 MLIRContext *ctx = patterns.getContext(); 782 783 mlir::populateEmitCSizeTTypeConversions(typeConverter); 784 785 // clang-format off 786 patterns.add< 787 ArithConstantOpConversionPattern, 788 ArithOpConversion<arith::AddFOp, emitc::AddOp>, 789 ArithOpConversion<arith::DivFOp, emitc::DivOp>, 790 ArithOpConversion<arith::DivSIOp, emitc::DivOp>, 791 ArithOpConversion<arith::MulFOp, emitc::MulOp>, 792 ArithOpConversion<arith::RemSIOp, emitc::RemOp>, 793 ArithOpConversion<arith::SubFOp, emitc::SubOp>, 794 BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>, 795 BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>, 796 IntegerOpConversion<arith::AddIOp, emitc::AddOp>, 797 IntegerOpConversion<arith::MulIOp, emitc::MulOp>, 798 IntegerOpConversion<arith::SubIOp, emitc::SubOp>, 799 BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>, 800 BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>, 801 BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>, 802 UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>, 803 SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>, 804 UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>, 805 CmpFOpConversion, 806 CmpIOpConversion, 807 NegFOpConversion, 808 SelectOpConversion, 809 // Truncation is guaranteed for unsigned types. 810 UnsignedCastConversion<arith::TruncIOp>, 811 SignedCastConversion<arith::ExtSIOp>, 812 UnsignedCastConversion<arith::ExtUIOp>, 813 SignedCastConversion<arith::IndexCastOp>, 814 UnsignedCastConversion<arith::IndexCastUIOp>, 815 ItoFCastOpConversion<arith::SIToFPOp>, 816 ItoFCastOpConversion<arith::UIToFPOp>, 817 FtoICastOpConversion<arith::FPToSIOp>, 818 FtoICastOpConversion<arith::FPToUIOp>, 819 FpCastOpConversion<arith::ExtFOp>, 820 FpCastOpConversion<arith::TruncFOp> 821 >(typeConverter, ctx); 822 // clang-format on 823 } 824