1 //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/Transforms/Passes.h" 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h" 13 #include "mlir/Dialect/Arith/Utils/Utils.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "llvm/ADT/APInt.h" 21 #include "llvm/Support/FormatVariadic.h" 22 #include "llvm/Support/MathExtras.h" 23 #include <cassert> 24 25 namespace mlir::arith { 26 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT 27 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" 28 } // namespace mlir::arith 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // Common Helper Functions 34 //===----------------------------------------------------------------------===// 35 36 /// Returns N bottom and N top bits from `value`, where N = `newBitWidth`. 37 /// Treats `value` as a 2*N bits-wide integer. 38 /// The bottom bits are returned in the first pair element, while the top bits 39 /// in the second one. 40 static std::pair<APInt, APInt> getHalves(const APInt &value, 41 unsigned newBitWidth) { 42 APInt low = value.extractBits(newBitWidth, 0); 43 APInt high = value.extractBits(newBitWidth, newBitWidth); 44 return {std::move(low), std::move(high)}; 45 } 46 47 /// Returns the type with the last (innermost) dimension reduced to x1. 48 /// Scalarizes 1D vector inputs to match how we extract/insert vector values, 49 /// e.g.: 50 /// - vector<3x2xi16> --> vector<3x1xi16> 51 /// - vector<2xi16> --> i16 52 static Type reduceInnermostDim(VectorType type) { 53 if (type.getShape().size() == 1) 54 return type.getElementType(); 55 56 auto newShape = to_vector(type.getShape()); 57 newShape.back() = 1; 58 return VectorType::get(newShape, type.getElementType()); 59 } 60 61 /// Extracts the `input` vector slice with elements at the last dimension offset 62 /// by `lastOffset`. Returns a value of vector type with the last dimension 63 /// reduced to x1 or fully scalarized, e.g.: 64 /// - vector<3x2xi16> --> vector<3x1xi16> 65 /// - vector<2xi16> --> i16 66 static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, 67 Location loc, Value input, 68 int64_t lastOffset) { 69 ArrayRef<int64_t> shape = cast<VectorType>(input.getType()).getShape(); 70 assert(lastOffset < shape.back() && "Offset out of bounds"); 71 72 // Scalarize the result in case of 1D vectors. 73 if (shape.size() == 1) 74 return rewriter.create<vector::ExtractOp>(loc, input, lastOffset); 75 76 SmallVector<int64_t> offsets(shape.size(), 0); 77 offsets.back() = lastOffset; 78 auto sizes = llvm::to_vector(shape); 79 sizes.back() = 1; 80 SmallVector<int64_t> strides(shape.size(), 1); 81 82 return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets, 83 sizes, strides); 84 } 85 86 /// Extracts two vector slices from the `input` whose type is `vector<...x2T>`, 87 /// with the first element at offset 0 and the second element at offset 1. 88 static std::pair<Value, Value> 89 extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, 90 Value input) { 91 return {extractLastDimSlice(rewriter, loc, input, 0), 92 extractLastDimSlice(rewriter, loc, input, 1)}; 93 } 94 95 // Performs a vector shape cast to drop the trailing x1 dimension. If the 96 // `input` is a scalar, this is a noop. 97 static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, 98 Location loc, Value input) { 99 auto vecTy = dyn_cast<VectorType>(input.getType()); 100 if (!vecTy) 101 return input; 102 103 // Shape cast to drop the last x1 dimension. 104 ArrayRef<int64_t> shape = vecTy.getShape(); 105 assert(shape.size() >= 2 && "Expected vector with at list two dims"); 106 assert(shape.back() == 1 && "Expected the last vector dim to be x1"); 107 108 auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType()); 109 return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input); 110 } 111 112 /// Performs a vector shape cast to append an x1 dimension. If the 113 /// `input` is a scalar, this is a noop. 114 static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, 115 Value input) { 116 auto vecTy = dyn_cast<VectorType>(input.getType()); 117 if (!vecTy) 118 return input; 119 120 // Add a trailing x1 dim. 121 auto newShape = llvm::to_vector(vecTy.getShape()); 122 newShape.push_back(1); 123 auto newTy = VectorType::get(newShape, vecTy.getElementType()); 124 return rewriter.create<vector::ShapeCastOp>(loc, newTy, input); 125 } 126 127 /// Inserts the `source` vector slice into the `dest` vector at offset 128 /// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is 129 /// a 1D vector. 130 static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, 131 Location loc, Value source, Value dest, 132 int64_t lastOffset) { 133 ArrayRef<int64_t> shape = cast<VectorType>(dest.getType()).getShape(); 134 assert(lastOffset < shape.back() && "Offset out of bounds"); 135 136 // Handle scalar source. 137 if (isa<IntegerType>(source.getType())) 138 return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset); 139 140 SmallVector<int64_t> offsets(shape.size(), 0); 141 offsets.back() = lastOffset; 142 SmallVector<int64_t> strides(shape.size(), 1); 143 return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest, 144 offsets, strides); 145 } 146 147 /// Constructs a new vector of type `resultType` by creating a series of 148 /// insertions of `resultComponents`, each at the next offset of the last vector 149 /// dimension. 150 /// When all `resultComponents` are scalars, the result type is `vector<NxT>`; 151 /// when `resultComponents` are `vector<...x1xT>`s, the result type is 152 /// `vector<...xNxT>`, where `N` is the number of `resultComponents`. 153 static Value constructResultVector(ConversionPatternRewriter &rewriter, 154 Location loc, VectorType resultType, 155 ValueRange resultComponents) { 156 llvm::ArrayRef<int64_t> resultShape = resultType.getShape(); 157 (void)resultShape; 158 assert(!resultShape.empty() && "Result expected to have dimensions"); 159 assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) && 160 "Wrong number of result components"); 161 162 Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0); 163 for (auto [i, component] : llvm::enumerate(resultComponents)) 164 resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i); 165 166 return resultVec; 167 } 168 169 namespace { 170 //===----------------------------------------------------------------------===// 171 // ConvertConstant 172 //===----------------------------------------------------------------------===// 173 174 struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> { 175 using OpConversionPattern::OpConversionPattern; 176 177 LogicalResult 178 matchAndRewrite(arith::ConstantOp op, OpAdaptor, 179 ConversionPatternRewriter &rewriter) const override { 180 Type oldType = op.getType(); 181 auto newType = getTypeConverter()->convertType<VectorType>(oldType); 182 if (!newType) 183 return rewriter.notifyMatchFailure( 184 op, llvm::formatv("unsupported type: {0}", op.getType())); 185 186 unsigned newBitWidth = newType.getElementTypeBitWidth(); 187 Attribute oldValue = op.getValueAttr(); 188 189 if (auto intAttr = dyn_cast<IntegerAttr>(oldValue)) { 190 auto [low, high] = getHalves(intAttr.getValue(), newBitWidth); 191 auto newAttr = DenseElementsAttr::get(newType, {low, high}); 192 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr); 193 return success(); 194 } 195 196 if (auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) { 197 auto [low, high] = 198 getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth); 199 int64_t numSplatElems = splatAttr.getNumElements(); 200 SmallVector<APInt> values; 201 values.reserve(numSplatElems * 2); 202 for (int64_t i = 0; i < numSplatElems; ++i) { 203 values.push_back(low); 204 values.push_back(high); 205 } 206 207 auto attr = DenseElementsAttr::get(newType, values); 208 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr); 209 return success(); 210 } 211 212 if (auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) { 213 int64_t numElems = elemsAttr.getNumElements(); 214 SmallVector<APInt> values; 215 values.reserve(numElems * 2); 216 for (const APInt &origVal : elemsAttr.getValues<APInt>()) { 217 auto [low, high] = getHalves(origVal, newBitWidth); 218 values.push_back(std::move(low)); 219 values.push_back(std::move(high)); 220 } 221 222 auto attr = DenseElementsAttr::get(newType, values); 223 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr); 224 return success(); 225 } 226 227 return rewriter.notifyMatchFailure(op.getLoc(), 228 "unhandled constant attribute"); 229 } 230 }; 231 232 //===----------------------------------------------------------------------===// 233 // ConvertAddI 234 //===----------------------------------------------------------------------===// 235 236 struct ConvertAddI final : OpConversionPattern<arith::AddIOp> { 237 using OpConversionPattern::OpConversionPattern; 238 239 LogicalResult 240 matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, 241 ConversionPatternRewriter &rewriter) const override { 242 Location loc = op->getLoc(); 243 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType()); 244 if (!newTy) 245 return rewriter.notifyMatchFailure( 246 loc, llvm::formatv("unsupported type: {0}", op.getType())); 247 248 Type newElemTy = reduceInnermostDim(newTy); 249 250 auto [lhsElem0, lhsElem1] = 251 extractLastDimHalves(rewriter, loc, adaptor.getLhs()); 252 auto [rhsElem0, rhsElem1] = 253 extractLastDimHalves(rewriter, loc, adaptor.getRhs()); 254 255 auto lowSum = 256 rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0); 257 Value overflowVal = 258 rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow()); 259 260 Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1); 261 Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1); 262 263 Value resultVec = 264 constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high}); 265 rewriter.replaceOp(op, resultVec); 266 return success(); 267 } 268 }; 269 270 //===----------------------------------------------------------------------===// 271 // ConvertBitwiseBinary 272 //===----------------------------------------------------------------------===// 273 274 /// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`. 275 template <typename BinaryOp> 276 struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> { 277 using OpConversionPattern<BinaryOp>::OpConversionPattern; 278 using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor; 279 280 LogicalResult 281 matchAndRewrite(BinaryOp op, OpAdaptor adaptor, 282 ConversionPatternRewriter &rewriter) const override { 283 Location loc = op->getLoc(); 284 auto newTy = this->getTypeConverter()->template convertType<VectorType>( 285 op.getType()); 286 if (!newTy) 287 return rewriter.notifyMatchFailure( 288 loc, llvm::formatv("unsupported type: {0}", op.getType())); 289 290 auto [lhsElem0, lhsElem1] = 291 extractLastDimHalves(rewriter, loc, adaptor.getLhs()); 292 auto [rhsElem0, rhsElem1] = 293 extractLastDimHalves(rewriter, loc, adaptor.getRhs()); 294 295 Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0); 296 Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1); 297 Value resultVec = 298 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); 299 rewriter.replaceOp(op, resultVec); 300 return success(); 301 } 302 }; 303 304 //===----------------------------------------------------------------------===// 305 // ConvertCmpI 306 //===----------------------------------------------------------------------===// 307 308 /// Returns the matching unsigned version of the given predicate `pred`, or the 309 /// same predicate if `pred` is not a signed. 310 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) { 311 using P = arith::CmpIPredicate; 312 switch (pred) { 313 case P::sge: 314 return P::uge; 315 case P::sgt: 316 return P::ugt; 317 case P::sle: 318 return P::ule; 319 case P::slt: 320 return P::ult; 321 default: 322 return pred; 323 } 324 } 325 326 struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> { 327 using OpConversionPattern::OpConversionPattern; 328 329 LogicalResult 330 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 331 ConversionPatternRewriter &rewriter) const override { 332 Location loc = op->getLoc(); 333 auto inputTy = 334 getTypeConverter()->convertType<VectorType>(op.getLhs().getType()); 335 if (!inputTy) 336 return rewriter.notifyMatchFailure( 337 loc, llvm::formatv("unsupported type: {0}", op.getType())); 338 339 arith::CmpIPredicate highPred = adaptor.getPredicate(); 340 arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred); 341 342 auto [lhsElem0, lhsElem1] = 343 extractLastDimHalves(rewriter, loc, adaptor.getLhs()); 344 auto [rhsElem0, rhsElem1] = 345 extractLastDimHalves(rewriter, loc, adaptor.getRhs()); 346 347 Value lowCmp = 348 rewriter.create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0); 349 Value highCmp = 350 rewriter.create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1); 351 352 Value cmpResult{}; 353 switch (highPred) { 354 case arith::CmpIPredicate::eq: { 355 cmpResult = rewriter.create<arith::AndIOp>(loc, lowCmp, highCmp); 356 break; 357 } 358 case arith::CmpIPredicate::ne: { 359 cmpResult = rewriter.create<arith::OrIOp>(loc, lowCmp, highCmp); 360 break; 361 } 362 default: { 363 // Handle inequality checks. 364 Value highEq = rewriter.create<arith::CmpIOp>( 365 loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1); 366 cmpResult = 367 rewriter.create<arith::SelectOp>(loc, highEq, lowCmp, highCmp); 368 break; 369 } 370 } 371 372 assert(cmpResult && "Unhandled case"); 373 rewriter.replaceOp(op, dropTrailingX1Dim(rewriter, loc, cmpResult)); 374 return success(); 375 } 376 }; 377 378 //===----------------------------------------------------------------------===// 379 // ConvertMulI 380 //===----------------------------------------------------------------------===// 381 382 struct ConvertMulI final : OpConversionPattern<arith::MulIOp> { 383 using OpConversionPattern::OpConversionPattern; 384 385 LogicalResult 386 matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, 387 ConversionPatternRewriter &rewriter) const override { 388 Location loc = op->getLoc(); 389 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType()); 390 if (!newTy) 391 return rewriter.notifyMatchFailure( 392 loc, llvm::formatv("unsupported type: {0}", op.getType())); 393 394 auto [lhsElem0, lhsElem1] = 395 extractLastDimHalves(rewriter, loc, adaptor.getLhs()); 396 auto [rhsElem0, rhsElem1] = 397 extractLastDimHalves(rewriter, loc, adaptor.getRhs()); 398 399 // The multiplication algorithm used is the standard (long) multiplication. 400 // Multiplying two i2N integers produces (at most) an i4N result, but 401 // because the calculation of top i2N is not necessary, we omit it. 402 auto mulLowLow = 403 rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0); 404 Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1); 405 Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0); 406 407 Value resLow = mulLowLow.getLow(); 408 Value resHi = 409 rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi); 410 resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow); 411 412 Value resultVec = 413 constructResultVector(rewriter, loc, newTy, {resLow, resHi}); 414 rewriter.replaceOp(op, resultVec); 415 return success(); 416 } 417 }; 418 419 //===----------------------------------------------------------------------===// 420 // ConvertExtSI 421 //===----------------------------------------------------------------------===// 422 423 struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> { 424 using OpConversionPattern::OpConversionPattern; 425 426 LogicalResult 427 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, 428 ConversionPatternRewriter &rewriter) const override { 429 Location loc = op->getLoc(); 430 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType()); 431 if (!newTy) 432 return rewriter.notifyMatchFailure( 433 loc, llvm::formatv("unsupported type: {0}", op.getType())); 434 435 Type newResultComponentTy = reduceInnermostDim(newTy); 436 437 // Sign-extend the input value to determine the low half of the result. 438 // Then, check if the low half is negative, and sign-extend the comparison 439 // result to get the high half. 440 Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn()); 441 Value extended = rewriter.createOrFold<arith::ExtSIOp>( 442 loc, newResultComponentTy, newOperand); 443 Value operandZeroCst = 444 createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0); 445 Value signBit = rewriter.create<arith::CmpIOp>( 446 loc, arith::CmpIPredicate::slt, extended, operandZeroCst); 447 Value signValue = 448 rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit); 449 450 Value resultVec = 451 constructResultVector(rewriter, loc, newTy, {extended, signValue}); 452 rewriter.replaceOp(op, resultVec); 453 return success(); 454 } 455 }; 456 457 //===----------------------------------------------------------------------===// 458 // ConvertExtUI 459 //===----------------------------------------------------------------------===// 460 461 struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> { 462 using OpConversionPattern::OpConversionPattern; 463 464 LogicalResult 465 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 466 ConversionPatternRewriter &rewriter) const override { 467 Location loc = op->getLoc(); 468 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType()); 469 if (!newTy) 470 return rewriter.notifyMatchFailure( 471 loc, llvm::formatv("unsupported type: {0}", op.getType())); 472 473 Type newResultComponentTy = reduceInnermostDim(newTy); 474 475 // Zero-extend the input value to determine the low half of the result. 476 // The high half is always zero. 477 Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn()); 478 Value extended = rewriter.createOrFold<arith::ExtUIOp>( 479 loc, newResultComponentTy, newOperand); 480 Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0); 481 Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0); 482 rewriter.replaceOp(op, newRes); 483 return success(); 484 } 485 }; 486 487 //===----------------------------------------------------------------------===// 488 // ConvertMaxMin 489 //===----------------------------------------------------------------------===// 490 491 template <typename SourceOp, arith::CmpIPredicate CmpPred> 492 struct ConvertMaxMin final : OpConversionPattern<SourceOp> { 493 using OpConversionPattern<SourceOp>::OpConversionPattern; 494 495 LogicalResult 496 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 497 ConversionPatternRewriter &rewriter) const override { 498 Location loc = op->getLoc(); 499 500 Type oldTy = op.getType(); 501 auto newTy = dyn_cast_or_null<VectorType>( 502 this->getTypeConverter()->convertType(oldTy)); 503 if (!newTy) 504 return rewriter.notifyMatchFailure( 505 loc, llvm::formatv("unsupported type: {0}", op.getType())); 506 507 // Rewrite Max*I/Min*I as compare and select over original operands. Let 508 // the CmpI and Select emulation patterns handle the final legalization. 509 Value cmp = 510 rewriter.create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs()); 511 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(), 512 op.getRhs()); 513 return success(); 514 } 515 }; 516 517 // Convert IndexCast ops 518 //===----------------------------------------------------------------------===// 519 520 /// Returns true iff the type is `index` or `vector<...index>`. 521 static bool isIndexOrIndexVector(Type type) { 522 if (isa<IndexType>(type)) 523 return true; 524 525 if (auto vectorTy = dyn_cast<VectorType>(type)) 526 if (isa<IndexType>(vectorTy.getElementType())) 527 return true; 528 529 return false; 530 } 531 532 template <typename CastOp> 533 struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> { 534 using OpConversionPattern<CastOp>::OpConversionPattern; 535 536 LogicalResult 537 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, 538 ConversionPatternRewriter &rewriter) const override { 539 Type resultType = op.getType(); 540 if (!isIndexOrIndexVector(resultType)) 541 return failure(); 542 543 Location loc = op.getLoc(); 544 Type inType = op.getIn().getType(); 545 auto newInTy = 546 this->getTypeConverter()->template convertType<VectorType>(inType); 547 if (!newInTy) 548 return rewriter.notifyMatchFailure( 549 loc, llvm::formatv("unsupported type: {0}", inType)); 550 551 // Discard the high half of the input truncating the original value. 552 Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0); 553 extracted = dropTrailingX1Dim(rewriter, loc, extracted); 554 rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted); 555 return success(); 556 } 557 }; 558 559 template <typename CastOp, typename ExtensionOp> 560 struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> { 561 using OpConversionPattern<CastOp>::OpConversionPattern; 562 563 LogicalResult 564 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, 565 ConversionPatternRewriter &rewriter) const override { 566 Type inType = op.getIn().getType(); 567 if (!isIndexOrIndexVector(inType)) 568 return failure(); 569 570 Location loc = op.getLoc(); 571 auto *typeConverter = 572 this->template getTypeConverter<arith::WideIntEmulationConverter>(); 573 574 Type resultType = op.getType(); 575 auto newTy = typeConverter->template convertType<VectorType>(resultType); 576 if (!newTy) 577 return rewriter.notifyMatchFailure( 578 loc, llvm::formatv("unsupported type: {0}", resultType)); 579 580 // Emit an index cast over the matching narrow type. 581 Type narrowTy = 582 rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth()); 583 if (auto vecTy = dyn_cast<VectorType>(resultType)) 584 narrowTy = VectorType::get(vecTy.getShape(), narrowTy); 585 586 // Sign or zero-extend the result. Let the matching conversion pattern 587 // legalize the extension op. 588 Value underlyingVal = 589 rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn()); 590 rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal); 591 return success(); 592 } 593 }; 594 595 //===----------------------------------------------------------------------===// 596 // ConvertSelect 597 //===----------------------------------------------------------------------===// 598 599 struct ConvertSelect final : OpConversionPattern<arith::SelectOp> { 600 using OpConversionPattern::OpConversionPattern; 601 602 LogicalResult 603 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 604 ConversionPatternRewriter &rewriter) const override { 605 Location loc = op->getLoc(); 606 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType()); 607 if (!newTy) 608 return rewriter.notifyMatchFailure( 609 loc, llvm::formatv("unsupported type: {0}", op.getType())); 610 611 auto [trueElem0, trueElem1] = 612 extractLastDimHalves(rewriter, loc, adaptor.getTrueValue()); 613 auto [falseElem0, falseElem1] = 614 extractLastDimHalves(rewriter, loc, adaptor.getFalseValue()); 615 Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition()); 616 617 Value resElem0 = 618 rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0); 619 Value resElem1 = 620 rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1); 621 Value resultVec = 622 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); 623 rewriter.replaceOp(op, resultVec); 624 return success(); 625 } 626 }; 627 628 //===----------------------------------------------------------------------===// 629 // ConvertShLI 630 //===----------------------------------------------------------------------===// 631 632 struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> { 633 using OpConversionPattern::OpConversionPattern; 634 635 LogicalResult 636 matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor, 637 ConversionPatternRewriter &rewriter) const override { 638 Location loc = op->getLoc(); 639 640 Type oldTy = op.getType(); 641 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy); 642 if (!newTy) 643 return rewriter.notifyMatchFailure( 644 loc, llvm::formatv("unsupported type: {0}", op.getType())); 645 646 Type newOperandTy = reduceInnermostDim(newTy); 647 // `oldBitWidth` == `2 * newBitWidth` 648 unsigned newBitWidth = newTy.getElementTypeBitWidth(); 649 650 auto [lhsElem0, lhsElem1] = 651 extractLastDimHalves(rewriter, loc, adaptor.getLhs()); 652 Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0); 653 654 // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and 655 // high halves of the results separately: 656 // 1. low := LHS.low shli RHS 657 // 658 // 2. high := a or b or c, where: 659 // a) Bits from LHS.high, shifted by the RHS. 660 // b) Bits from LHS.low, shifted right. These come into play when 661 // RHS < newBitWidth, e.g.: 662 // [0000][llll] shli 3 --> [0lll][l000] 663 // ^ 664 // | 665 // [llll] shrui (4 - 3) 666 // c) Bits from LHS.low, shifted left. These matter when 667 // RHS > newBitWidth, e.g.: 668 // [0000][llll] shli 7 --> [l000][0000] 669 // ^ 670 // | 671 // [llll] shli (7 - 4) 672 // 673 // Because shifts by values >= newBitWidth are undefined, we ignore the high 674 // half of RHS, and introduce 'bounds checks' to account for 675 // RHS.low > newBitWidth. 676 // 677 // TODO: Explore possible optimizations. 678 Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0); 679 Value elemBitWidth = 680 createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth); 681 682 Value illegalElemShift = rewriter.create<arith::CmpIOp>( 683 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); 684 685 Value shiftedElem0 = 686 rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0); 687 Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift, 688 zeroCst, shiftedElem0); 689 690 Value cappedShiftAmount = rewriter.create<arith::SelectOp>( 691 loc, illegalElemShift, elemBitWidth, rhsElem0); 692 Value rightShiftAmount = 693 rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount); 694 Value shiftedRight = 695 rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount); 696 Value overshotShiftAmount = 697 rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth); 698 Value shiftedLeft = 699 rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount); 700 701 Value shiftedElem1 = 702 rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0); 703 Value resElem1High = rewriter.create<arith::SelectOp>( 704 loc, illegalElemShift, zeroCst, shiftedElem1); 705 Value resElem1Low = rewriter.create<arith::SelectOp>( 706 loc, illegalElemShift, shiftedLeft, shiftedRight); 707 Value resElem1 = 708 rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High); 709 710 Value resultVec = 711 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); 712 rewriter.replaceOp(op, resultVec); 713 return success(); 714 } 715 }; 716 717 //===----------------------------------------------------------------------===// 718 // ConvertShRUI 719 //===----------------------------------------------------------------------===// 720 721 struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> { 722 using OpConversionPattern::OpConversionPattern; 723 724 LogicalResult 725 matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor, 726 ConversionPatternRewriter &rewriter) const override { 727 Location loc = op->getLoc(); 728 729 Type oldTy = op.getType(); 730 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy); 731 if (!newTy) 732 return rewriter.notifyMatchFailure( 733 loc, llvm::formatv("unsupported type: {0}", op.getType())); 734 735 Type newOperandTy = reduceInnermostDim(newTy); 736 // `oldBitWidth` == `2 * newBitWidth` 737 unsigned newBitWidth = newTy.getElementTypeBitWidth(); 738 739 auto [lhsElem0, lhsElem1] = 740 extractLastDimHalves(rewriter, loc, adaptor.getLhs()); 741 Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0); 742 743 // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and 744 // high halves of the results separately: 745 // 1. low := a or b or c, where: 746 // a) Bits from LHS.low, shifted by the RHS. 747 // b) Bits from LHS.high, shifted left. These matter when 748 // RHS < newBitWidth, e.g.: 749 // [hhhh][0000] shrui 3 --> [000h][hhh0] 750 // ^ 751 // | 752 // [hhhh] shli (4 - 1) 753 // c) Bits from LHS.high, shifted right. These come into play when 754 // RHS > newBitWidth, e.g.: 755 // [hhhh][0000] shrui 7 --> [0000][000h] 756 // ^ 757 // | 758 // [hhhh] shrui (7 - 4) 759 // 760 // 2. high := LHS.high shrui RHS 761 // 762 // Because shifts by values >= newBitWidth are undefined, we ignore the high 763 // half of RHS, and introduce 'bounds checks' to account for 764 // RHS.low > newBitWidth. 765 // 766 // TODO: Explore possible optimizations. 767 Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0); 768 Value elemBitWidth = 769 createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth); 770 771 Value illegalElemShift = rewriter.create<arith::CmpIOp>( 772 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); 773 774 Value shiftedElem0 = 775 rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0); 776 Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift, 777 zeroCst, shiftedElem0); 778 Value shiftedElem1 = 779 rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0); 780 Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift, 781 zeroCst, shiftedElem1); 782 783 Value cappedShiftAmount = rewriter.create<arith::SelectOp>( 784 loc, illegalElemShift, elemBitWidth, rhsElem0); 785 Value leftShiftAmount = 786 rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount); 787 Value shiftedLeft = 788 rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount); 789 Value overshotShiftAmount = 790 rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth); 791 Value shiftedRight = 792 rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount); 793 794 Value resElem0High = rewriter.create<arith::SelectOp>( 795 loc, illegalElemShift, shiftedRight, shiftedLeft); 796 Value resElem0 = 797 rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High); 798 799 Value resultVec = 800 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); 801 rewriter.replaceOp(op, resultVec); 802 return success(); 803 } 804 }; 805 806 //===----------------------------------------------------------------------===// 807 // ConvertShRSI 808 //===----------------------------------------------------------------------===// 809 810 struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> { 811 using OpConversionPattern::OpConversionPattern; 812 813 LogicalResult 814 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor, 815 ConversionPatternRewriter &rewriter) const override { 816 Location loc = op->getLoc(); 817 818 Type oldTy = op.getType(); 819 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy); 820 if (!newTy) 821 return rewriter.notifyMatchFailure( 822 loc, llvm::formatv("unsupported type: {0}", op.getType())); 823 824 Value lhsElem1 = extractLastDimSlice(rewriter, loc, adaptor.getLhs(), 1); 825 Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0); 826 827 Type narrowTy = rhsElem0.getType(); 828 int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2; 829 830 // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits. 831 // Perform as many ops over the narrow integer type as possible and let the 832 // other emulation patterns convert the rest. 833 Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0); 834 Value signBit = rewriter.create<arith::CmpIOp>( 835 loc, arith::CmpIPredicate::slt, lhsElem1, elemZero); 836 signBit = dropTrailingX1Dim(rewriter, loc, signBit); 837 838 // Create a bit pattern of either all ones or all zeros. Then shift it left 839 // to calculate the sign extension bits created by shifting the original 840 // sign bit right. 841 Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit); 842 Value maxShift = 843 createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth); 844 Value numNonSignExtBits = 845 rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0); 846 numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits); 847 numNonSignExtBits = 848 rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits); 849 Value signBits = 850 rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits); 851 852 // Use original arguments to create the right shift. 853 Value shrui = 854 rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs()); 855 Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits); 856 857 // Handle shifting by zero. This is necessary when the `signBits` shift is 858 // invalid. 859 Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 860 rhsElem0, elemZero); 861 isNoop = dropTrailingX1Dim(rewriter, loc, isNoop); 862 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(), 863 shrsi); 864 865 return success(); 866 } 867 }; 868 869 //===----------------------------------------------------------------------===// 870 // ConvertSIToFP 871 //===----------------------------------------------------------------------===// 872 873 struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> { 874 using OpConversionPattern::OpConversionPattern; 875 876 LogicalResult 877 matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor, 878 ConversionPatternRewriter &rewriter) const override { 879 Location loc = op.getLoc(); 880 881 Value in = op.getIn(); 882 Type oldTy = in.getType(); 883 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy); 884 if (!newTy) 885 return rewriter.notifyMatchFailure( 886 loc, llvm::formatv("unsupported type: {0}", oldTy)); 887 888 unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth(); 889 Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0); 890 Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1); 891 Value allOnesCst = createScalarOrSplatConstant( 892 rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth)); 893 894 // To avoid operating on very large unsigned numbers, perform the 895 // conversion on the absolute value. Then, decide whether to negate the 896 // result or not based on that sign bit. We assume two's complement and 897 // implement negation by flipping all bits and adding 1. 898 // Note that this relies on the the other conversion patterns to legalize 899 // created ops and narrow the bit widths. 900 Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 901 in, zeroCst); 902 Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst); 903 Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst); 904 Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in); 905 906 Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs); 907 Value negResult = rewriter.create<arith::NegFOp>(loc, absResult); 908 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult, 909 absResult); 910 return success(); 911 } 912 }; 913 914 //===----------------------------------------------------------------------===// 915 // ConvertUIToFP 916 //===----------------------------------------------------------------------===// 917 918 struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> { 919 using OpConversionPattern::OpConversionPattern; 920 921 LogicalResult 922 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 923 ConversionPatternRewriter &rewriter) const override { 924 Location loc = op.getLoc(); 925 926 Type oldTy = op.getIn().getType(); 927 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy); 928 if (!newTy) 929 return rewriter.notifyMatchFailure( 930 loc, llvm::formatv("unsupported type: {0}", oldTy)); 931 unsigned newBitWidth = newTy.getElementTypeBitWidth(); 932 933 auto [low, hi] = extractLastDimHalves(rewriter, loc, adaptor.getIn()); 934 Value lowInt = dropTrailingX1Dim(rewriter, loc, low); 935 Value hiInt = dropTrailingX1Dim(rewriter, loc, hi); 936 Value zeroCst = 937 createScalarOrSplatConstant(rewriter, loc, hiInt.getType(), 0); 938 939 // The final result has the following form: 940 // if (hi == 0) return uitofp(low) 941 // else return uitofp(low) + uitofp(hi) * 2^BW 942 // 943 // where `BW` is the bitwidth of the narrowed integer type. We emit a 944 // select to make it easier to fold-away the `hi` part calculation when it 945 // is known to be zero. 946 // 947 // Note 1: The emulation is precise only for input values that have exact 948 // integer representation in the result floating point type, and may lead 949 // loss of precision otherwise. 950 // 951 // Note 2: We do not strictly need the `hi == 0`, case, but it makes 952 // constant folding easier. 953 Value hiEqZero = rewriter.create<arith::CmpIOp>( 954 loc, arith::CmpIPredicate::eq, hiInt, zeroCst); 955 956 Type resultTy = op.getType(); 957 Type resultElemTy = getElementTypeOrSelf(resultTy); 958 Value lowFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, lowInt); 959 Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt); 960 961 int64_t pow2Int = int64_t(1) << newBitWidth; 962 TypedAttr pow2Attr = 963 rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int)); 964 if (auto vecTy = dyn_cast<VectorType>(resultTy)) 965 pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr); 966 967 Value pow2Val = rewriter.create<arith::ConstantOp>(loc, resultTy, pow2Attr); 968 969 Value hiVal = rewriter.create<arith::MulFOp>(loc, hiFp, pow2Val); 970 Value result = rewriter.create<arith::AddFOp>(loc, lowFp, hiVal); 971 972 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result); 973 return success(); 974 } 975 }; 976 977 //===----------------------------------------------------------------------===// 978 // ConvertTruncI 979 //===----------------------------------------------------------------------===// 980 981 struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> { 982 using OpConversionPattern::OpConversionPattern; 983 984 LogicalResult 985 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 986 ConversionPatternRewriter &rewriter) const override { 987 Location loc = op.getLoc(); 988 // Check if the result type is legal for this target. Currently, we do not 989 // support truncation to types wider than supported by the target. 990 if (!getTypeConverter()->isLegal(op.getType())) 991 return rewriter.notifyMatchFailure( 992 loc, llvm::formatv("unsupported truncation result type: {0}", 993 op.getType())); 994 995 // Discard the high half of the input. Truncate the low half, if 996 // necessary. 997 Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0); 998 extracted = dropTrailingX1Dim(rewriter, loc, extracted); 999 Value truncated = 1000 rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted); 1001 rewriter.replaceOp(op, truncated); 1002 return success(); 1003 } 1004 }; 1005 1006 //===----------------------------------------------------------------------===// 1007 // ConvertVectorPrint 1008 //===----------------------------------------------------------------------===// 1009 1010 struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> { 1011 using OpConversionPattern::OpConversionPattern; 1012 1013 LogicalResult 1014 matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor, 1015 ConversionPatternRewriter &rewriter) const override { 1016 rewriter.replaceOpWithNewOp<vector::PrintOp>(op, adaptor.getSource()); 1017 return success(); 1018 } 1019 }; 1020 1021 //===----------------------------------------------------------------------===// 1022 // Pass Definition 1023 //===----------------------------------------------------------------------===// 1024 1025 struct EmulateWideIntPass final 1026 : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> { 1027 using ArithEmulateWideIntBase::ArithEmulateWideIntBase; 1028 1029 void runOnOperation() override { 1030 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) { 1031 signalPassFailure(); 1032 return; 1033 } 1034 1035 Operation *op = getOperation(); 1036 MLIRContext *ctx = op->getContext(); 1037 1038 arith::WideIntEmulationConverter typeConverter(widestIntSupported); 1039 ConversionTarget target(*ctx); 1040 target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) { 1041 return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType()); 1042 }); 1043 auto opLegalCallback = [&typeConverter](Operation *op) { 1044 return typeConverter.isLegal(op); 1045 }; 1046 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback); 1047 target 1048 .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>( 1049 opLegalCallback); 1050 1051 RewritePatternSet patterns(ctx); 1052 arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); 1053 1054 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1055 signalPassFailure(); 1056 } 1057 }; 1058 } // end anonymous namespace 1059 1060 //===----------------------------------------------------------------------===// 1061 // Public Interface Definition 1062 //===----------------------------------------------------------------------===// 1063 1064 arith::WideIntEmulationConverter::WideIntEmulationConverter( 1065 unsigned widestIntSupportedByTarget) 1066 : maxIntWidth(widestIntSupportedByTarget) { 1067 assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) && 1068 "Only power-of-two integers with are supported"); 1069 assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow"); 1070 1071 // Allow unknown types. 1072 addConversion([](Type ty) -> std::optional<Type> { return ty; }); 1073 1074 // Scalar case. 1075 addConversion([this](IntegerType ty) -> std::optional<Type> { 1076 unsigned width = ty.getWidth(); 1077 if (width <= maxIntWidth) 1078 return ty; 1079 1080 // i2N --> vector<2xiN> 1081 if (width == 2 * maxIntWidth) 1082 return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth)); 1083 1084 return nullptr; 1085 }); 1086 1087 // Vector case. 1088 addConversion([this](VectorType ty) -> std::optional<Type> { 1089 auto intTy = dyn_cast<IntegerType>(ty.getElementType()); 1090 if (!intTy) 1091 return ty; 1092 1093 unsigned width = intTy.getWidth(); 1094 if (width <= maxIntWidth) 1095 return ty; 1096 1097 // vector<...xi2N> --> vector<...x2xiN> 1098 if (width == 2 * maxIntWidth) { 1099 auto newShape = to_vector(ty.getShape()); 1100 newShape.push_back(2); 1101 return VectorType::get(newShape, 1102 IntegerType::get(ty.getContext(), maxIntWidth)); 1103 } 1104 1105 return nullptr; 1106 }); 1107 1108 // Function case. 1109 addConversion([this](FunctionType ty) -> std::optional<Type> { 1110 // Convert inputs and results, e.g.: 1111 // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN> 1112 SmallVector<Type> inputs; 1113 if (failed(convertTypes(ty.getInputs(), inputs))) 1114 return nullptr; 1115 1116 SmallVector<Type> results; 1117 if (failed(convertTypes(ty.getResults(), results))) 1118 return nullptr; 1119 1120 return FunctionType::get(ty.getContext(), inputs, results); 1121 }); 1122 } 1123 1124 void arith::populateArithWideIntEmulationPatterns( 1125 const WideIntEmulationConverter &typeConverter, 1126 RewritePatternSet &patterns) { 1127 // Populate `func.*` conversion patterns. 1128 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 1129 typeConverter); 1130 populateCallOpTypeConversionPattern(patterns, typeConverter); 1131 populateReturnOpTypeConversionPattern(patterns, typeConverter); 1132 1133 // Populate `arith.*` conversion patterns. 1134 patterns.add< 1135 // Misc ops. 1136 ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint, 1137 // Binary ops. 1138 ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI, 1139 ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>, 1140 ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>, 1141 ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>, 1142 ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, 1143 // Bitwise binary ops. 1144 ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>, 1145 ConvertBitwiseBinary<arith::XOrIOp>, 1146 // Extension and truncation ops. 1147 ConvertExtSI, ConvertExtUI, ConvertTruncI, 1148 // Cast ops. 1149 ConvertIndexCastIntToIndex<arith::IndexCastOp>, 1150 ConvertIndexCastIntToIndex<arith::IndexCastUIOp>, 1151 ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>, 1152 ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>, 1153 ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext()); 1154 } 1155