1 //===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" 10 11 #include "../SPIRVCommon/Pattern.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/IR/BuiltinAttributes.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/DialectResourceBlobManager.h" 21 #include "llvm/ADT/APInt.h" 22 #include "llvm/ADT/ArrayRef.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/MathExtras.h" 26 #include <cassert> 27 #include <memory> 28 29 namespace mlir { 30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRV 31 #include "mlir/Conversion/Passes.h.inc" 32 } // namespace mlir 33 34 #define DEBUG_TYPE "arith-to-spirv-pattern" 35 36 using namespace mlir; 37 38 //===----------------------------------------------------------------------===// 39 // Conversion Helpers 40 //===----------------------------------------------------------------------===// 41 42 /// Converts the given `srcAttr` into a boolean attribute if it holds an 43 /// integral value. Returns null attribute if conversion fails. 44 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { 45 if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr)) 46 return boolAttr; 47 if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr)) 48 return builder.getBoolAttr(intAttr.getValue().getBoolValue()); 49 return {}; 50 } 51 52 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 53 /// Returns null attribute if conversion fails. 54 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, 55 Builder builder) { 56 // If the source number uses less active bits than the target bitwidth, then 57 // it should be safe to convert. 58 if (srcAttr.getValue().isIntN(dstType.getWidth())) 59 return builder.getIntegerAttr(dstType, srcAttr.getInt()); 60 61 // XXX: Try again by interpreting the source number as a signed value. 62 // Although integers in the standard dialect are signless, they can represent 63 // a signed number. It's the operation decides how to interpret. This is 64 // dangerous, but it seems there is no good way of handling this if we still 65 // want to change the bitwidth. Emit a message at least. 66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { 67 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); 68 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" 69 << dstAttr << "' for type '" << dstType << "'\n"); 70 return dstAttr; 71 } 72 73 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr 74 << "' illegal: cannot fit into target type '" 75 << dstType << "'\n"); 76 return {}; 77 } 78 79 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 80 /// Returns null attribute if `dstType` is not 32-bit or conversion fails. 81 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, 82 Builder builder) { 83 // Only support converting to float for now. 84 if (!dstType.isF32()) 85 return FloatAttr(); 86 87 // Try to convert the source floating-point number to single precision. 88 APFloat dstVal = srcAttr.getValue(); 89 bool losesInfo = false; 90 APFloat::opStatus status = 91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); 92 if (status != APFloat::opOK || losesInfo) { 93 LLVM_DEBUG(llvm::dbgs() 94 << srcAttr << " illegal: cannot fit into converted type '" 95 << dstType << "'\n"); 96 return FloatAttr(); 97 } 98 99 return builder.getF32FloatAttr(dstVal.convertToFloat()); 100 } 101 102 /// Returns true if the given `type` is a boolean scalar or vector type. 103 static bool isBoolScalarOrVector(Type type) { 104 assert(type && "Not a valid type"); 105 if (type.isInteger(1)) 106 return true; 107 108 if (auto vecType = dyn_cast<VectorType>(type)) 109 return vecType.getElementType().isInteger(1); 110 111 return false; 112 } 113 114 /// Creates a scalar/vector integer constant. 115 static Value getScalarOrVectorConstInt(Type type, uint64_t value, 116 OpBuilder &builder, Location loc) { 117 if (auto vectorType = dyn_cast<VectorType>(type)) { 118 Attribute element = IntegerAttr::get(vectorType.getElementType(), value); 119 auto attr = SplatElementsAttr::get(vectorType, element); 120 return builder.create<spirv::ConstantOp>(loc, vectorType, attr); 121 } 122 123 if (auto intType = dyn_cast<IntegerType>(type)) 124 return builder.create<spirv::ConstantOp>( 125 loc, type, builder.getIntegerAttr(type, value)); 126 127 return nullptr; 128 } 129 130 /// Returns true if scalar/vector type `a` and `b` have the same number of 131 /// bitwidth. 132 static bool hasSameBitwidth(Type a, Type b) { 133 auto getNumBitwidth = [](Type type) { 134 unsigned bw = 0; 135 if (type.isIntOrFloat()) 136 bw = type.getIntOrFloatBitWidth(); 137 else if (auto vecType = dyn_cast<VectorType>(type)) 138 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements(); 139 return bw; 140 }; 141 unsigned aBW = getNumBitwidth(a); 142 unsigned bBW = getNumBitwidth(b); 143 return aBW != 0 && bBW != 0 && aBW == bBW; 144 } 145 146 /// Returns a source type conversion failure for `srcType` and operation `op`. 147 static LogicalResult 148 getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, 149 Type srcType) { 150 return rewriter.notifyMatchFailure( 151 op->getLoc(), 152 llvm::formatv("failed to convert source type '{0}'", srcType)); 153 } 154 155 /// Returns a source type conversion failure for the result type of `op`. 156 static LogicalResult 157 getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) { 158 assert(op->getNumResults() == 1); 159 return getTypeConversionFailure(rewriter, op, op->getResultTypes().front()); 160 } 161 162 // TODO: Move to some common place? 163 static std::string getDecorationString(spirv::Decoration decor) { 164 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor)); 165 } 166 167 namespace { 168 169 /// Converts elementwise unary, binary and ternary arith operations to SPIR-V 170 /// operations. Op can potentially support overflow flags. 171 template <typename Op, typename SPIRVOp> 172 struct ElementwiseArithOpPattern final : OpConversionPattern<Op> { 173 using OpConversionPattern<Op>::OpConversionPattern; 174 175 LogicalResult 176 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 177 ConversionPatternRewriter &rewriter) const override { 178 assert(adaptor.getOperands().size() <= 3); 179 auto converter = this->template getTypeConverter<SPIRVTypeConverter>(); 180 Type dstType = converter->convertType(op.getType()); 181 if (!dstType) { 182 return rewriter.notifyMatchFailure( 183 op->getLoc(), 184 llvm::formatv("failed to convert type {0} for SPIR-V", op.getType())); 185 } 186 187 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && 188 !getElementTypeOrSelf(op.getType()).isIndex() && 189 dstType != op.getType()) { 190 return op.emitError("bitwidth emulation is not implemented yet on " 191 "unsigned op pattern version"); 192 } 193 194 auto overflowFlags = arith::IntegerOverflowFlags::none; 195 if (auto overflowIface = 196 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) { 197 if (converter->getTargetEnv().allows( 198 spirv::Extension::SPV_KHR_no_integer_wrap_decoration)) 199 overflowFlags = overflowIface.getOverflowAttr().getValue(); 200 } 201 202 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>( 203 op, dstType, adaptor.getOperands()); 204 205 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw)) 206 newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap), 207 rewriter.getUnitAttr()); 208 209 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw)) 210 newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap), 211 rewriter.getUnitAttr()); 212 213 return success(); 214 } 215 }; 216 217 //===----------------------------------------------------------------------===// 218 // ConstantOp 219 //===----------------------------------------------------------------------===// 220 221 /// Converts composite arith.constant operation to spirv.Constant. 222 struct ConstantCompositeOpPattern final 223 : public OpConversionPattern<arith::ConstantOp> { 224 using OpConversionPattern::OpConversionPattern; 225 226 LogicalResult 227 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, 228 ConversionPatternRewriter &rewriter) const override { 229 auto srcType = dyn_cast<ShapedType>(constOp.getType()); 230 if (!srcType || srcType.getNumElements() == 1) 231 return failure(); 232 233 // arith.constant should only have vector or tensor types. This is a MLIR 234 // wide problem at the moment. 235 if (!isa<VectorType, RankedTensorType>(srcType)) 236 return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType"); 237 238 Type dstType = getTypeConverter()->convertType(srcType); 239 if (!dstType) 240 return failure(); 241 242 // Import the resource into the IR to make use of the special handling of 243 // element types later on. 244 mlir::DenseElementsAttr dstElementsAttr; 245 if (auto denseElementsAttr = 246 dyn_cast<DenseElementsAttr>(constOp.getValue())) { 247 dstElementsAttr = denseElementsAttr; 248 } else if (auto resourceAttr = 249 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) { 250 251 AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob(); 252 if (!blob) 253 return constOp->emitError("could not find resource blob"); 254 255 ArrayRef<char> ptr = blob->getData(); 256 257 // Check that the buffer meets the requirements to get converted to a 258 // DenseElementsAttr 259 bool detectedSplat = false; 260 if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat)) 261 return constOp->emitError("resource is not a valid buffer"); 262 263 dstElementsAttr = 264 DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr); 265 } else { 266 return constOp->emitError("unsupported elements attribute"); 267 } 268 269 ShapedType dstAttrType = dstElementsAttr.getType(); 270 271 // If the composite type has more than one dimensions, perform 272 // linearization. 273 if (srcType.getRank() > 1) { 274 if (isa<RankedTensorType>(srcType)) { 275 dstAttrType = RankedTensorType::get(srcType.getNumElements(), 276 srcType.getElementType()); 277 dstElementsAttr = dstElementsAttr.reshape(dstAttrType); 278 } else { 279 // TODO: add support for large vectors. 280 return failure(); 281 } 282 } 283 284 Type srcElemType = srcType.getElementType(); 285 Type dstElemType; 286 // Tensor types are converted to SPIR-V array types; vector types are 287 // converted to SPIR-V vector/array types. 288 if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType)) 289 dstElemType = arrayType.getElementType(); 290 else 291 dstElemType = cast<VectorType>(dstType).getElementType(); 292 293 // If the source and destination element types are different, perform 294 // attribute conversion. 295 if (srcElemType != dstElemType) { 296 SmallVector<Attribute, 8> elements; 297 if (isa<FloatType>(srcElemType)) { 298 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { 299 FloatAttr dstAttr = 300 convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter); 301 if (!dstAttr) 302 return failure(); 303 elements.push_back(dstAttr); 304 } 305 } else if (srcElemType.isInteger(1)) { 306 return failure(); 307 } else { 308 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { 309 IntegerAttr dstAttr = convertIntegerAttr( 310 srcAttr, cast<IntegerType>(dstElemType), rewriter); 311 if (!dstAttr) 312 return failure(); 313 elements.push_back(dstAttr); 314 } 315 } 316 317 // Unfortunately, we cannot use dialect-specific types for element 318 // attributes; element attributes only works with builtin types. So we 319 // need to prepare another converted builtin types for the destination 320 // elements attribute. 321 if (isa<RankedTensorType>(dstAttrType)) 322 dstAttrType = 323 RankedTensorType::get(dstAttrType.getShape(), dstElemType); 324 else 325 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); 326 327 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); 328 } 329 330 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, 331 dstElementsAttr); 332 return success(); 333 } 334 }; 335 336 /// Converts scalar arith.constant operation to spirv.Constant. 337 struct ConstantScalarOpPattern final 338 : public OpConversionPattern<arith::ConstantOp> { 339 using OpConversionPattern::OpConversionPattern; 340 341 LogicalResult 342 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, 343 ConversionPatternRewriter &rewriter) const override { 344 Type srcType = constOp.getType(); 345 if (auto shapedType = dyn_cast<ShapedType>(srcType)) { 346 if (shapedType.getNumElements() != 1) 347 return failure(); 348 srcType = shapedType.getElementType(); 349 } 350 if (!srcType.isIntOrIndexOrFloat()) 351 return failure(); 352 353 Attribute cstAttr = constOp.getValue(); 354 if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr)) 355 cstAttr = elementsAttr.getSplatValue<Attribute>(); 356 357 Type dstType = getTypeConverter()->convertType(srcType); 358 if (!dstType) 359 return failure(); 360 361 // Floating-point types. 362 if (isa<FloatType>(srcType)) { 363 auto srcAttr = cast<FloatAttr>(cstAttr); 364 auto dstAttr = srcAttr; 365 366 // Floating-point types not supported in the target environment are all 367 // converted to float type. 368 if (srcType != dstType) { 369 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter); 370 if (!dstAttr) 371 return failure(); 372 } 373 374 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 375 return success(); 376 } 377 378 // Bool type. 379 if (srcType.isInteger(1)) { 380 // arith.constant can use 0/1 instead of true/false for i1 values. We need 381 // to handle that here. 382 auto dstAttr = convertBoolAttr(cstAttr, rewriter); 383 if (!dstAttr) 384 return failure(); 385 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 386 return success(); 387 } 388 389 // IndexType or IntegerType. Index values are converted to 32-bit integer 390 // values when converting to SPIR-V. 391 auto srcAttr = cast<IntegerAttr>(cstAttr); 392 IntegerAttr dstAttr = 393 convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter); 394 if (!dstAttr) 395 return failure(); 396 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 397 return success(); 398 } 399 }; 400 401 //===----------------------------------------------------------------------===// 402 // RemSIOp 403 //===----------------------------------------------------------------------===// 404 405 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow 406 /// the sign of `signOperand`. 407 /// 408 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment 409 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative 410 /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod 411 /// if either operand can be negative. Emulate it via spirv.UMod. 412 template <typename SignedAbsOp> 413 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, 414 Value signOperand, OpBuilder &builder) { 415 assert(lhs.getType() == rhs.getType()); 416 assert(lhs == signOperand || rhs == signOperand); 417 418 Type type = lhs.getType(); 419 420 // Calculate the remainder with spirv.UMod. 421 Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); 422 Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); 423 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); 424 425 // Fix the sign. 426 Value isPositive; 427 if (lhs == signOperand) 428 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); 429 else 430 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); 431 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); 432 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); 433 } 434 435 /// Converts arith.remsi to GLSL SPIR-V ops. 436 /// 437 /// This cannot be merged into the template unary/binary pattern due to Vulkan 438 /// restrictions over spirv.SRem and spirv.SMod. 439 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> { 440 using OpConversionPattern::OpConversionPattern; 441 442 LogicalResult 443 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 444 ConversionPatternRewriter &rewriter) const override { 445 Value result = emulateSignedRemainder<spirv::CLSAbsOp>( 446 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 447 adaptor.getOperands()[0], rewriter); 448 rewriter.replaceOp(op, result); 449 450 return success(); 451 } 452 }; 453 454 /// Converts arith.remsi to OpenCL SPIR-V ops. 455 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> { 456 using OpConversionPattern::OpConversionPattern; 457 458 LogicalResult 459 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 460 ConversionPatternRewriter &rewriter) const override { 461 Value result = emulateSignedRemainder<spirv::GLSAbsOp>( 462 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 463 adaptor.getOperands()[0], rewriter); 464 rewriter.replaceOp(op, result); 465 466 return success(); 467 } 468 }; 469 470 //===----------------------------------------------------------------------===// 471 // BitwiseOp 472 //===----------------------------------------------------------------------===// 473 474 /// Converts bitwise operations to SPIR-V operations. This is a special pattern 475 /// other than the BinaryOpPatternPattern because if the operands are boolean 476 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For 477 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. 478 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 479 struct BitwiseOpPattern final : public OpConversionPattern<Op> { 480 using OpConversionPattern<Op>::OpConversionPattern; 481 482 LogicalResult 483 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 484 ConversionPatternRewriter &rewriter) const override { 485 assert(adaptor.getOperands().size() == 2); 486 Type dstType = this->getTypeConverter()->convertType(op.getType()); 487 if (!dstType) 488 return getTypeConversionFailure(rewriter, op); 489 490 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { 491 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>( 492 op, dstType, adaptor.getOperands()); 493 } else { 494 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>( 495 op, dstType, adaptor.getOperands()); 496 } 497 return success(); 498 } 499 }; 500 501 //===----------------------------------------------------------------------===// 502 // XOrIOp 503 //===----------------------------------------------------------------------===// 504 505 /// Converts arith.xori to SPIR-V operations. 506 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { 507 using OpConversionPattern::OpConversionPattern; 508 509 LogicalResult 510 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 511 ConversionPatternRewriter &rewriter) const override { 512 assert(adaptor.getOperands().size() == 2); 513 514 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) 515 return failure(); 516 517 Type dstType = getTypeConverter()->convertType(op.getType()); 518 if (!dstType) 519 return getTypeConversionFailure(rewriter, op); 520 521 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, 522 adaptor.getOperands()); 523 524 return success(); 525 } 526 }; 527 528 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or 529 /// vector of i1. 530 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { 531 using OpConversionPattern::OpConversionPattern; 532 533 LogicalResult 534 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 535 ConversionPatternRewriter &rewriter) const override { 536 assert(adaptor.getOperands().size() == 2); 537 538 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) 539 return failure(); 540 541 Type dstType = getTypeConverter()->convertType(op.getType()); 542 if (!dstType) 543 return getTypeConversionFailure(rewriter, op); 544 545 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>( 546 op, dstType, adaptor.getOperands()); 547 return success(); 548 } 549 }; 550 551 //===----------------------------------------------------------------------===// 552 // UIToFPOp 553 //===----------------------------------------------------------------------===// 554 555 /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector 556 /// of i1. 557 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { 558 using OpConversionPattern::OpConversionPattern; 559 560 LogicalResult 561 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 562 ConversionPatternRewriter &rewriter) const override { 563 Type srcType = adaptor.getOperands().front().getType(); 564 if (!isBoolScalarOrVector(srcType)) 565 return failure(); 566 567 Type dstType = getTypeConverter()->convertType(op.getType()); 568 if (!dstType) 569 return getTypeConversionFailure(rewriter, op); 570 571 Location loc = op.getLoc(); 572 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 573 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 574 rewriter.replaceOpWithNewOp<spirv::SelectOp>( 575 op, dstType, adaptor.getOperands().front(), one, zero); 576 return success(); 577 } 578 }; 579 580 //===----------------------------------------------------------------------===// 581 // ExtSIOp 582 //===----------------------------------------------------------------------===// 583 584 /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector 585 /// of i1. 586 struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> { 587 using OpConversionPattern::OpConversionPattern; 588 589 LogicalResult 590 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, 591 ConversionPatternRewriter &rewriter) const override { 592 Value operand = adaptor.getIn(); 593 if (!isBoolScalarOrVector(operand.getType())) 594 return failure(); 595 596 Location loc = op.getLoc(); 597 Type dstType = getTypeConverter()->convertType(op.getType()); 598 if (!dstType) 599 return getTypeConversionFailure(rewriter, op); 600 601 Value allOnes; 602 if (auto intTy = dyn_cast<IntegerType>(dstType)) { 603 unsigned componentBitwidth = intTy.getWidth(); 604 allOnes = rewriter.create<spirv::ConstantOp>( 605 loc, intTy, 606 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); 607 } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) { 608 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); 609 allOnes = rewriter.create<spirv::ConstantOp>( 610 loc, vectorTy, 611 SplatElementsAttr::get(vectorTy, 612 APInt::getAllOnes(componentBitwidth))); 613 } else { 614 return rewriter.notifyMatchFailure( 615 loc, llvm::formatv("unhandled type: {0}", dstType)); 616 } 617 618 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 619 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes, 620 zero); 621 return success(); 622 } 623 }; 624 625 /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor 626 /// vector of i1. 627 struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> { 628 using OpConversionPattern::OpConversionPattern; 629 630 LogicalResult 631 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, 632 ConversionPatternRewriter &rewriter) const override { 633 Type srcType = adaptor.getIn().getType(); 634 if (isBoolScalarOrVector(srcType)) 635 return failure(); 636 637 Type dstType = getTypeConverter()->convertType(op.getType()); 638 if (!dstType) 639 return getTypeConversionFailure(rewriter, op); 640 641 if (dstType == srcType) { 642 // We can have the same source and destination type due to type emulation. 643 // Perform bit shifting to make sure we have the proper leading set bits. 644 645 unsigned srcBW = 646 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth(); 647 unsigned dstBW = 648 getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); 649 assert(srcBW < dstBW); 650 Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW, 651 rewriter, op.getLoc()); 652 653 // First shift left to sequeeze out all leading bits beyond the original 654 // bitwidth. Here we need to use the original source and result type's 655 // bitwidth. 656 auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>( 657 op.getLoc(), dstType, adaptor.getIn(), shiftSize); 658 659 // Then we perform arithmetic right shift to make sure we have the right 660 // sign bits for negative values. 661 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>( 662 op, dstType, shiftLOp, shiftSize); 663 } else { 664 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType, 665 adaptor.getOperands()); 666 } 667 668 return success(); 669 } 670 }; 671 672 //===----------------------------------------------------------------------===// 673 // ExtUIOp 674 //===----------------------------------------------------------------------===// 675 676 /// Converts arith.extui to spirv.Select if the type of source is i1 or vector 677 /// of i1. 678 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { 679 using OpConversionPattern::OpConversionPattern; 680 681 LogicalResult 682 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 683 ConversionPatternRewriter &rewriter) const override { 684 Type srcType = adaptor.getOperands().front().getType(); 685 if (!isBoolScalarOrVector(srcType)) 686 return failure(); 687 688 Type dstType = getTypeConverter()->convertType(op.getType()); 689 if (!dstType) 690 return getTypeConversionFailure(rewriter, op); 691 692 Location loc = op.getLoc(); 693 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 694 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 695 rewriter.replaceOpWithNewOp<spirv::SelectOp>( 696 op, dstType, adaptor.getOperands().front(), one, zero); 697 return success(); 698 } 699 }; 700 701 /// Converts arith.extui for cases where the type of source is neither i1 nor 702 /// vector of i1. 703 struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> { 704 using OpConversionPattern::OpConversionPattern; 705 706 LogicalResult 707 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 708 ConversionPatternRewriter &rewriter) const override { 709 Type srcType = adaptor.getIn().getType(); 710 if (isBoolScalarOrVector(srcType)) 711 return failure(); 712 713 Type dstType = getTypeConverter()->convertType(op.getType()); 714 if (!dstType) 715 return getTypeConversionFailure(rewriter, op); 716 717 if (dstType == srcType) { 718 // We can have the same source and destination type due to type emulation. 719 // Perform bit masking to make sure we don't pollute downstream consumers 720 // with unwanted bits. Here we need to use the original source type's 721 // bitwidth. 722 unsigned bitwidth = 723 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth(); 724 Value mask = getScalarOrVectorConstInt( 725 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter, 726 op.getLoc()); 727 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType, 728 adaptor.getIn(), mask); 729 } else { 730 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType, 731 adaptor.getOperands()); 732 } 733 return success(); 734 } 735 }; 736 737 //===----------------------------------------------------------------------===// 738 // TruncIOp 739 //===----------------------------------------------------------------------===// 740 741 /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector 742 /// of i1. 743 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { 744 using OpConversionPattern::OpConversionPattern; 745 746 LogicalResult 747 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 748 ConversionPatternRewriter &rewriter) const override { 749 Type dstType = getTypeConverter()->convertType(op.getType()); 750 if (!dstType) 751 return getTypeConversionFailure(rewriter, op); 752 753 if (!isBoolScalarOrVector(dstType)) 754 return failure(); 755 756 Location loc = op.getLoc(); 757 auto srcType = adaptor.getOperands().front().getType(); 758 // Check if (x & 1) == 1. 759 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); 760 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( 761 loc, srcType, adaptor.getOperands()[0], mask); 762 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); 763 764 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 765 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 766 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); 767 return success(); 768 } 769 }; 770 771 /// Converts arith.trunci for cases where the type of result is neither i1 772 /// nor vector of i1. 773 struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> { 774 using OpConversionPattern::OpConversionPattern; 775 776 LogicalResult 777 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 778 ConversionPatternRewriter &rewriter) const override { 779 Type srcType = adaptor.getIn().getType(); 780 Type dstType = getTypeConverter()->convertType(op.getType()); 781 if (!dstType) 782 return getTypeConversionFailure(rewriter, op); 783 784 if (isBoolScalarOrVector(dstType)) 785 return failure(); 786 787 if (dstType == srcType) { 788 // We can have the same source and destination type due to type emulation. 789 // Perform bit masking to make sure we don't pollute downstream consumers 790 // with unwanted bits. Here we need to use the original result type's 791 // bitwidth. 792 unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); 793 Value mask = getScalarOrVectorConstInt( 794 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc()); 795 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType, 796 adaptor.getIn(), mask); 797 } else { 798 // Given this is truncation, either SConvertOp or UConvertOp works. 799 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType, 800 adaptor.getOperands()); 801 } 802 return success(); 803 } 804 }; 805 806 //===----------------------------------------------------------------------===// 807 // TypeCastingOp 808 //===----------------------------------------------------------------------===// 809 810 static std::optional<spirv::FPRoundingMode> 811 convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) { 812 switch (roundingMode) { 813 case arith::RoundingMode::downward: 814 return spirv::FPRoundingMode::RTN; 815 case arith::RoundingMode::to_nearest_even: 816 return spirv::FPRoundingMode::RTE; 817 case arith::RoundingMode::toward_zero: 818 return spirv::FPRoundingMode::RTZ; 819 case arith::RoundingMode::upward: 820 return spirv::FPRoundingMode::RTP; 821 case arith::RoundingMode::to_nearest_away: 822 // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode 823 // (as of SPIR-V 1.6) 824 return std::nullopt; 825 } 826 llvm_unreachable("Unhandled rounding mode"); 827 } 828 829 /// Converts type-casting standard operations to SPIR-V operations. 830 template <typename Op, typename SPIRVOp> 831 struct TypeCastingOpPattern final : public OpConversionPattern<Op> { 832 using OpConversionPattern<Op>::OpConversionPattern; 833 834 LogicalResult 835 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 836 ConversionPatternRewriter &rewriter) const override { 837 assert(adaptor.getOperands().size() == 1); 838 Type srcType = adaptor.getOperands().front().getType(); 839 Type dstType = this->getTypeConverter()->convertType(op.getType()); 840 if (!dstType) 841 return getTypeConversionFailure(rewriter, op); 842 843 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) 844 return failure(); 845 846 if (dstType == srcType) { 847 // Due to type conversion, we are seeing the same source and target type. 848 // Then we can just erase this operation by forwarding its operand. 849 rewriter.replaceOp(op, adaptor.getOperands().front()); 850 } else { 851 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>( 852 op, dstType, adaptor.getOperands()); 853 if (auto roundingModeOp = 854 dyn_cast<arith::ArithRoundingModeInterface>(*op)) { 855 if (arith::RoundingModeAttr roundingMode = 856 roundingModeOp.getRoundingModeAttr()) { 857 if (auto rm = 858 convertArithRoundingModeToSPIRV(roundingMode.getValue())) { 859 newOp->setAttr( 860 getDecorationString(spirv::Decoration::FPRoundingMode), 861 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm)); 862 } else { 863 return rewriter.notifyMatchFailure( 864 op->getLoc(), 865 llvm::formatv("unsupported rounding mode '{0}'", roundingMode)); 866 } 867 } 868 } 869 } 870 return success(); 871 } 872 }; 873 874 //===----------------------------------------------------------------------===// 875 // CmpIOp 876 //===----------------------------------------------------------------------===// 877 878 /// Converts integer compare operation on i1 type operands to SPIR-V ops. 879 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { 880 public: 881 using OpConversionPattern::OpConversionPattern; 882 883 LogicalResult 884 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 885 ConversionPatternRewriter &rewriter) const override { 886 Type srcType = op.getLhs().getType(); 887 if (!isBoolScalarOrVector(srcType)) 888 return failure(); 889 Type dstType = getTypeConverter()->convertType(srcType); 890 if (!dstType) 891 return getTypeConversionFailure(rewriter, op, srcType); 892 893 switch (op.getPredicate()) { 894 case arith::CmpIPredicate::eq: { 895 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(), 896 adaptor.getRhs()); 897 return success(); 898 } 899 case arith::CmpIPredicate::ne: { 900 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>( 901 op, adaptor.getLhs(), adaptor.getRhs()); 902 return success(); 903 } 904 case arith::CmpIPredicate::uge: 905 case arith::CmpIPredicate::ugt: 906 case arith::CmpIPredicate::ule: 907 case arith::CmpIPredicate::ult: { 908 // There are no direct corresponding instructions in SPIR-V for such 909 // cases. Extend them to 32-bit and do comparision then. 910 Type type = rewriter.getI32Type(); 911 if (auto vectorType = dyn_cast<VectorType>(dstType)) 912 type = VectorType::get(vectorType.getShape(), type); 913 Value extLhs = 914 rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs()); 915 Value extRhs = 916 rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs()); 917 918 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs, 919 extRhs); 920 return success(); 921 } 922 default: 923 break; 924 } 925 return failure(); 926 } 927 }; 928 929 /// Converts integer compare operation to SPIR-V ops. 930 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { 931 public: 932 using OpConversionPattern::OpConversionPattern; 933 934 LogicalResult 935 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 936 ConversionPatternRewriter &rewriter) const override { 937 Type srcType = op.getLhs().getType(); 938 if (isBoolScalarOrVector(srcType)) 939 return failure(); 940 Type dstType = getTypeConverter()->convertType(srcType); 941 if (!dstType) 942 return getTypeConversionFailure(rewriter, op, srcType); 943 944 switch (op.getPredicate()) { 945 #define DISPATCH(cmpPredicate, spirvOp) \ 946 case cmpPredicate: \ 947 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ 948 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \ 949 !hasSameBitwidth(srcType, dstType)) { \ 950 return op.emitError( \ 951 "bitwidth emulation is not implemented yet on unsigned op"); \ 952 } \ 953 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ 954 adaptor.getRhs()); \ 955 return success(); 956 957 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); 958 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); 959 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); 960 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); 961 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); 962 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); 963 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); 964 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); 965 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); 966 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); 967 968 #undef DISPATCH 969 } 970 return failure(); 971 } 972 }; 973 974 //===----------------------------------------------------------------------===// 975 // CmpFOpPattern 976 //===----------------------------------------------------------------------===// 977 978 /// Converts floating-point comparison operations to SPIR-V ops. 979 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { 980 public: 981 using OpConversionPattern::OpConversionPattern; 982 983 LogicalResult 984 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 985 ConversionPatternRewriter &rewriter) const override { 986 switch (op.getPredicate()) { 987 #define DISPATCH(cmpPredicate, spirvOp) \ 988 case cmpPredicate: \ 989 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ 990 adaptor.getRhs()); \ 991 return success(); 992 993 // Ordered. 994 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); 995 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); 996 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); 997 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); 998 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); 999 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); 1000 // Unordered. 1001 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); 1002 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); 1003 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); 1004 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); 1005 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); 1006 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); 1007 1008 #undef DISPATCH 1009 1010 default: 1011 break; 1012 } 1013 return failure(); 1014 } 1015 }; 1016 1017 /// Converts floating point NaN check to SPIR-V ops. This pattern requires 1018 /// Kernel capability. 1019 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { 1020 public: 1021 using OpConversionPattern::OpConversionPattern; 1022 1023 LogicalResult 1024 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 1025 ConversionPatternRewriter &rewriter) const override { 1026 if (op.getPredicate() == arith::CmpFPredicate::ORD) { 1027 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), 1028 adaptor.getRhs()); 1029 return success(); 1030 } 1031 1032 if (op.getPredicate() == arith::CmpFPredicate::UNO) { 1033 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), 1034 adaptor.getRhs()); 1035 return success(); 1036 } 1037 1038 return failure(); 1039 } 1040 }; 1041 1042 /// Converts floating point NaN check to SPIR-V ops. This pattern does not 1043 /// require additional capability. 1044 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { 1045 public: 1046 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 1047 1048 LogicalResult 1049 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 1050 ConversionPatternRewriter &rewriter) const override { 1051 if (op.getPredicate() != arith::CmpFPredicate::ORD && 1052 op.getPredicate() != arith::CmpFPredicate::UNO) 1053 return failure(); 1054 1055 Location loc = op.getLoc(); 1056 1057 Value replace; 1058 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { 1059 if (op.getPredicate() == arith::CmpFPredicate::ORD) { 1060 // Ordered comparsion checks if neither operand is NaN. 1061 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter); 1062 } else { 1063 // Unordered comparsion checks if either operand is NaN. 1064 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); 1065 } 1066 } else { 1067 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); 1068 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); 1069 1070 replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); 1071 if (op.getPredicate() == arith::CmpFPredicate::ORD) 1072 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); 1073 } 1074 1075 rewriter.replaceOp(op, replace); 1076 return success(); 1077 } 1078 }; 1079 1080 //===----------------------------------------------------------------------===// 1081 // AddUIExtendedOp 1082 //===----------------------------------------------------------------------===// 1083 1084 /// Converts arith.addui_extended to spirv.IAddCarry. 1085 class AddUIExtendedOpPattern final 1086 : public OpConversionPattern<arith::AddUIExtendedOp> { 1087 public: 1088 using OpConversionPattern::OpConversionPattern; 1089 LogicalResult 1090 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, 1091 ConversionPatternRewriter &rewriter) const override { 1092 Type dstElemTy = adaptor.getLhs().getType(); 1093 Location loc = op->getLoc(); 1094 Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(), 1095 adaptor.getRhs()); 1096 1097 Value sumResult = rewriter.create<spirv::CompositeExtractOp>( 1098 loc, result, llvm::ArrayRef(0)); 1099 Value carryValue = rewriter.create<spirv::CompositeExtractOp>( 1100 loc, result, llvm::ArrayRef(1)); 1101 1102 // Convert the carry value to boolean. 1103 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); 1104 Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one); 1105 1106 rewriter.replaceOp(op, {sumResult, carryResult}); 1107 return success(); 1108 } 1109 }; 1110 1111 //===----------------------------------------------------------------------===// 1112 // MulIExtendedOp 1113 //===----------------------------------------------------------------------===// 1114 1115 /// Converts arith.mul*i_extended to spirv.*MulExtended. 1116 template <typename ArithMulOp, typename SPIRVMulOp> 1117 class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> { 1118 public: 1119 using OpConversionPattern<ArithMulOp>::OpConversionPattern; 1120 LogicalResult 1121 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, 1122 ConversionPatternRewriter &rewriter) const override { 1123 Location loc = op->getLoc(); 1124 Value result = 1125 rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs()); 1126 1127 Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result, 1128 llvm::ArrayRef(0)); 1129 Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result, 1130 llvm::ArrayRef(1)); 1131 1132 rewriter.replaceOp(op, {low, high}); 1133 return success(); 1134 } 1135 }; 1136 1137 //===----------------------------------------------------------------------===// 1138 // SelectOp 1139 //===----------------------------------------------------------------------===// 1140 1141 /// Converts arith.select to spirv.Select. 1142 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> { 1143 public: 1144 using OpConversionPattern::OpConversionPattern; 1145 LogicalResult 1146 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 1147 ConversionPatternRewriter &rewriter) const override { 1148 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(), 1149 adaptor.getTrueValue(), 1150 adaptor.getFalseValue()); 1151 return success(); 1152 } 1153 }; 1154 1155 //===----------------------------------------------------------------------===// 1156 // MinimumFOp, MaximumFOp 1157 //===----------------------------------------------------------------------===// 1158 1159 /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or 1160 /// spirv.CL.fmax/fmin. 1161 template <typename Op, typename SPIRVOp> 1162 class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> { 1163 public: 1164 using OpConversionPattern<Op>::OpConversionPattern; 1165 LogicalResult 1166 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 1167 ConversionPatternRewriter &rewriter) const override { 1168 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>(); 1169 Type dstType = converter->convertType(op.getType()); 1170 if (!dstType) 1171 return getTypeConversionFailure(rewriter, op); 1172 1173 // arith.maximumf/minimumf: 1174 // "if one of the arguments is NaN, then the result is also NaN." 1175 // spirv.GL.FMax/FMin 1176 // "which operand is the result is undefined if one of the operands 1177 // is a NaN." 1178 // spirv.CL.fmax/fmin: 1179 // "If one argument is a NaN, Fmin returns the other argument." 1180 1181 Location loc = op.getLoc(); 1182 Value spirvOp = 1183 rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); 1184 1185 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { 1186 rewriter.replaceOp(op, spirvOp); 1187 return success(); 1188 } 1189 1190 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); 1191 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); 1192 1193 Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, 1194 adaptor.getLhs(), spirvOp); 1195 Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, 1196 adaptor.getRhs(), select1); 1197 1198 rewriter.replaceOp(op, select2); 1199 return success(); 1200 } 1201 }; 1202 1203 //===----------------------------------------------------------------------===// 1204 // MinNumFOp, MaxNumFOp 1205 //===----------------------------------------------------------------------===// 1206 1207 /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or 1208 /// spirv.CL.fmax/fmin. 1209 template <typename Op, typename SPIRVOp> 1210 class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> { 1211 template <typename TargetOp> 1212 constexpr bool shouldInsertNanGuards() const { 1213 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value; 1214 } 1215 1216 public: 1217 using OpConversionPattern<Op>::OpConversionPattern; 1218 LogicalResult 1219 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 1220 ConversionPatternRewriter &rewriter) const override { 1221 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>(); 1222 Type dstType = converter->convertType(op.getType()); 1223 if (!dstType) 1224 return getTypeConversionFailure(rewriter, op); 1225 1226 // arith.maxnumf/minnumf: 1227 // "If one of the arguments is NaN, then the result is the other 1228 // argument." 1229 // spirv.GL.FMax/FMin 1230 // "which operand is the result is undefined if one of the operands 1231 // is a NaN." 1232 // spirv.CL.fmax/fmin: 1233 // "If one argument is a NaN, Fmin returns the other argument." 1234 1235 Location loc = op.getLoc(); 1236 Value spirvOp = 1237 rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); 1238 1239 if (!shouldInsertNanGuards<SPIRVOp>() || 1240 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { 1241 rewriter.replaceOp(op, spirvOp); 1242 return success(); 1243 } 1244 1245 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); 1246 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); 1247 1248 Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, 1249 adaptor.getRhs(), spirvOp); 1250 Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, 1251 adaptor.getLhs(), select1); 1252 1253 rewriter.replaceOp(op, select2); 1254 return success(); 1255 } 1256 }; 1257 1258 } // namespace 1259 1260 //===----------------------------------------------------------------------===// 1261 // Pattern Population 1262 //===----------------------------------------------------------------------===// 1263 1264 void mlir::arith::populateArithToSPIRVPatterns( 1265 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 1266 // clang-format off 1267 patterns.add< 1268 ConstantCompositeOpPattern, 1269 ConstantScalarOpPattern, 1270 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>, 1271 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>, 1272 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>, 1273 spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>, 1274 spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>, 1275 spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>, 1276 RemSIOpGLPattern, RemSIOpCLPattern, 1277 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, 1278 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, 1279 XOrIOpLogicalPattern, XOrIOpBooleanPattern, 1280 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, 1281 spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, 1282 spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, 1283 spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>, 1284 spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>, 1285 spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>, 1286 spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>, 1287 spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>, 1288 spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>, 1289 ExtUIPattern, ExtUII1Pattern, 1290 ExtSIPattern, ExtSII1Pattern, 1291 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, 1292 TruncIPattern, TruncII1Pattern, 1293 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, 1294 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, 1295 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, 1296 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>, 1297 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, 1298 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, 1299 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>, 1300 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, 1301 CmpIOpBooleanPattern, CmpIOpPattern, 1302 CmpFOpNanNonePattern, CmpFOpPattern, 1303 AddUIExtendedOpPattern, 1304 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>, 1305 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>, 1306 SelectOpPattern, 1307 1308 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>, 1309 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>, 1310 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>, 1311 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>, 1312 spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>, 1313 spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>, 1314 spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>, 1315 spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>, 1316 1317 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>, 1318 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>, 1319 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>, 1320 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>, 1321 spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>, 1322 spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>, 1323 spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>, 1324 spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::CLUMinOp> 1325 >(typeConverter, patterns.getContext()); 1326 // clang-format on 1327 1328 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel 1329 // capability is available. 1330 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(), 1331 /*benefit=*/2); 1332 } 1333 1334 //===----------------------------------------------------------------------===// 1335 // Pass Definition 1336 //===----------------------------------------------------------------------===// 1337 1338 namespace { 1339 struct ConvertArithToSPIRVPass 1340 : public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> { 1341 void runOnOperation() override { 1342 Operation *op = getOperation(); 1343 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); 1344 std::unique_ptr<SPIRVConversionTarget> target = 1345 SPIRVConversionTarget::get(targetAttr); 1346 1347 SPIRVConversionOptions options; 1348 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; 1349 SPIRVTypeConverter typeConverter(targetAttr, options); 1350 1351 // Use UnrealizedConversionCast as the bridge so that we don't need to pull 1352 // in patterns for other dialects. 1353 target->addLegalOp<UnrealizedConversionCastOp>(); 1354 1355 // Fail hard when there are any remaining 'arith' ops. 1356 target->addIllegalDialect<arith::ArithDialect>(); 1357 1358 RewritePatternSet patterns(&getContext()); 1359 arith::populateArithToSPIRVPatterns(typeConverter, patterns); 1360 1361 if (failed(applyPartialConversion(op, *target, std::move(patterns)))) 1362 signalPassFailure(); 1363 } 1364 }; 1365 } // namespace 1366 1367 std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() { 1368 return std::make_unique<ConvertArithToSPIRVPass>(); 1369 } 1370