xref: /llvm-project/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
10 
11 #include "../SPIRVCommon/Pattern.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/DialectResourceBlobManager.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/MathExtras.h"
26 #include <cassert>
27 #include <memory>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRV
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "arith-to-spirv-pattern"
35 
36 using namespace mlir;
37 
38 //===----------------------------------------------------------------------===//
39 // Conversion Helpers
40 //===----------------------------------------------------------------------===//
41 
42 /// Converts the given `srcAttr` into a boolean attribute if it holds an
43 /// integral value. Returns null attribute if conversion fails.
44 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45   if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46     return boolAttr;
47   if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49   return {};
50 }
51 
52 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53 /// Returns null attribute if conversion fails.
54 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55                                       Builder builder) {
56   // If the source number uses less active bits than the target bitwidth, then
57   // it should be safe to convert.
58   if (srcAttr.getValue().isIntN(dstType.getWidth()))
59     return builder.getIntegerAttr(dstType, srcAttr.getInt());
60 
61   // XXX: Try again by interpreting the source number as a signed value.
62   // Although integers in the standard dialect are signless, they can represent
63   // a signed number. It's the operation decides how to interpret. This is
64   // dangerous, but it seems there is no good way of handling this if we still
65   // want to change the bitwidth. Emit a message at least.
66   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67     auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68     LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69                             << dstAttr << "' for type '" << dstType << "'\n");
70     return dstAttr;
71   }
72 
73   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74                           << "' illegal: cannot fit into target type '"
75                           << dstType << "'\n");
76   return {};
77 }
78 
79 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82                                   Builder builder) {
83   // Only support converting to float for now.
84   if (!dstType.isF32())
85     return FloatAttr();
86 
87   // Try to convert the source floating-point number to single precision.
88   APFloat dstVal = srcAttr.getValue();
89   bool losesInfo = false;
90   APFloat::opStatus status =
91       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92   if (status != APFloat::opOK || losesInfo) {
93     LLVM_DEBUG(llvm::dbgs()
94                << srcAttr << " illegal: cannot fit into converted type '"
95                << dstType << "'\n");
96     return FloatAttr();
97   }
98 
99   return builder.getF32FloatAttr(dstVal.convertToFloat());
100 }
101 
102 /// Returns true if the given `type` is a boolean scalar or vector type.
103 static bool isBoolScalarOrVector(Type type) {
104   assert(type && "Not a valid type");
105   if (type.isInteger(1))
106     return true;
107 
108   if (auto vecType = dyn_cast<VectorType>(type))
109     return vecType.getElementType().isInteger(1);
110 
111   return false;
112 }
113 
114 /// Creates a scalar/vector integer constant.
115 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
116                                        OpBuilder &builder, Location loc) {
117   if (auto vectorType = dyn_cast<VectorType>(type)) {
118     Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
119     auto attr = SplatElementsAttr::get(vectorType, element);
120     return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
121   }
122 
123   if (auto intType = dyn_cast<IntegerType>(type))
124     return builder.create<spirv::ConstantOp>(
125         loc, type, builder.getIntegerAttr(type, value));
126 
127   return nullptr;
128 }
129 
130 /// Returns true if scalar/vector type `a` and `b` have the same number of
131 /// bitwidth.
132 static bool hasSameBitwidth(Type a, Type b) {
133   auto getNumBitwidth = [](Type type) {
134     unsigned bw = 0;
135     if (type.isIntOrFloat())
136       bw = type.getIntOrFloatBitWidth();
137     else if (auto vecType = dyn_cast<VectorType>(type))
138       bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
139     return bw;
140   };
141   unsigned aBW = getNumBitwidth(a);
142   unsigned bBW = getNumBitwidth(b);
143   return aBW != 0 && bBW != 0 && aBW == bBW;
144 }
145 
146 /// Returns a source type conversion failure for `srcType` and operation `op`.
147 static LogicalResult
148 getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
149                          Type srcType) {
150   return rewriter.notifyMatchFailure(
151       op->getLoc(),
152       llvm::formatv("failed to convert source type '{0}'", srcType));
153 }
154 
155 /// Returns a source type conversion failure for the result type of `op`.
156 static LogicalResult
157 getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
158   assert(op->getNumResults() == 1);
159   return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
160 }
161 
162 // TODO: Move to some common place?
163 static std::string getDecorationString(spirv::Decoration decor) {
164   return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
165 }
166 
167 namespace {
168 
169 /// Converts elementwise unary, binary and ternary arith operations to SPIR-V
170 /// operations. Op can potentially support overflow flags.
171 template <typename Op, typename SPIRVOp>
172 struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
173   using OpConversionPattern<Op>::OpConversionPattern;
174 
175   LogicalResult
176   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
177                   ConversionPatternRewriter &rewriter) const override {
178     assert(adaptor.getOperands().size() <= 3);
179     auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
180     Type dstType = converter->convertType(op.getType());
181     if (!dstType) {
182       return rewriter.notifyMatchFailure(
183           op->getLoc(),
184           llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
185     }
186 
187     if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
188         !getElementTypeOrSelf(op.getType()).isIndex() &&
189         dstType != op.getType()) {
190       return op.emitError("bitwidth emulation is not implemented yet on "
191                           "unsigned op pattern version");
192     }
193 
194     auto overflowFlags = arith::IntegerOverflowFlags::none;
195     if (auto overflowIface =
196             dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
197       if (converter->getTargetEnv().allows(
198               spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
199         overflowFlags = overflowIface.getOverflowAttr().getValue();
200     }
201 
202     auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
203         op, dstType, adaptor.getOperands());
204 
205     if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
206       newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
207                      rewriter.getUnitAttr());
208 
209     if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
210       newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
211                      rewriter.getUnitAttr());
212 
213     return success();
214   }
215 };
216 
217 //===----------------------------------------------------------------------===//
218 // ConstantOp
219 //===----------------------------------------------------------------------===//
220 
221 /// Converts composite arith.constant operation to spirv.Constant.
222 struct ConstantCompositeOpPattern final
223     : public OpConversionPattern<arith::ConstantOp> {
224   using OpConversionPattern::OpConversionPattern;
225 
226   LogicalResult
227   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
228                   ConversionPatternRewriter &rewriter) const override {
229     auto srcType = dyn_cast<ShapedType>(constOp.getType());
230     if (!srcType || srcType.getNumElements() == 1)
231       return failure();
232 
233     // arith.constant should only have vector or tensor types. This is a MLIR
234     // wide problem at the moment.
235     if (!isa<VectorType, RankedTensorType>(srcType))
236       return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
237 
238     Type dstType = getTypeConverter()->convertType(srcType);
239     if (!dstType)
240       return failure();
241 
242     // Import the resource into the IR to make use of the special handling of
243     // element types later on.
244     mlir::DenseElementsAttr dstElementsAttr;
245     if (auto denseElementsAttr =
246             dyn_cast<DenseElementsAttr>(constOp.getValue())) {
247       dstElementsAttr = denseElementsAttr;
248     } else if (auto resourceAttr =
249                    dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
250 
251       AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
252       if (!blob)
253         return constOp->emitError("could not find resource blob");
254 
255       ArrayRef<char> ptr = blob->getData();
256 
257       // Check that the buffer meets the requirements to get converted to a
258       // DenseElementsAttr
259       bool detectedSplat = false;
260       if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
261         return constOp->emitError("resource is not a valid buffer");
262 
263       dstElementsAttr =
264           DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
265     } else {
266       return constOp->emitError("unsupported elements attribute");
267     }
268 
269     ShapedType dstAttrType = dstElementsAttr.getType();
270 
271     // If the composite type has more than one dimensions, perform
272     // linearization.
273     if (srcType.getRank() > 1) {
274       if (isa<RankedTensorType>(srcType)) {
275         dstAttrType = RankedTensorType::get(srcType.getNumElements(),
276                                             srcType.getElementType());
277         dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
278       } else {
279         // TODO: add support for large vectors.
280         return failure();
281       }
282     }
283 
284     Type srcElemType = srcType.getElementType();
285     Type dstElemType;
286     // Tensor types are converted to SPIR-V array types; vector types are
287     // converted to SPIR-V vector/array types.
288     if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
289       dstElemType = arrayType.getElementType();
290     else
291       dstElemType = cast<VectorType>(dstType).getElementType();
292 
293     // If the source and destination element types are different, perform
294     // attribute conversion.
295     if (srcElemType != dstElemType) {
296       SmallVector<Attribute, 8> elements;
297       if (isa<FloatType>(srcElemType)) {
298         for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
299           FloatAttr dstAttr =
300               convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
301           if (!dstAttr)
302             return failure();
303           elements.push_back(dstAttr);
304         }
305       } else if (srcElemType.isInteger(1)) {
306         return failure();
307       } else {
308         for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
309           IntegerAttr dstAttr = convertIntegerAttr(
310               srcAttr, cast<IntegerType>(dstElemType), rewriter);
311           if (!dstAttr)
312             return failure();
313           elements.push_back(dstAttr);
314         }
315       }
316 
317       // Unfortunately, we cannot use dialect-specific types for element
318       // attributes; element attributes only works with builtin types. So we
319       // need to prepare another converted builtin types for the destination
320       // elements attribute.
321       if (isa<RankedTensorType>(dstAttrType))
322         dstAttrType =
323             RankedTensorType::get(dstAttrType.getShape(), dstElemType);
324       else
325         dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
326 
327       dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
328     }
329 
330     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
331                                                    dstElementsAttr);
332     return success();
333   }
334 };
335 
336 /// Converts scalar arith.constant operation to spirv.Constant.
337 struct ConstantScalarOpPattern final
338     : public OpConversionPattern<arith::ConstantOp> {
339   using OpConversionPattern::OpConversionPattern;
340 
341   LogicalResult
342   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
343                   ConversionPatternRewriter &rewriter) const override {
344     Type srcType = constOp.getType();
345     if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
346       if (shapedType.getNumElements() != 1)
347         return failure();
348       srcType = shapedType.getElementType();
349     }
350     if (!srcType.isIntOrIndexOrFloat())
351       return failure();
352 
353     Attribute cstAttr = constOp.getValue();
354     if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
355       cstAttr = elementsAttr.getSplatValue<Attribute>();
356 
357     Type dstType = getTypeConverter()->convertType(srcType);
358     if (!dstType)
359       return failure();
360 
361     // Floating-point types.
362     if (isa<FloatType>(srcType)) {
363       auto srcAttr = cast<FloatAttr>(cstAttr);
364       auto dstAttr = srcAttr;
365 
366       // Floating-point types not supported in the target environment are all
367       // converted to float type.
368       if (srcType != dstType) {
369         dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
370         if (!dstAttr)
371           return failure();
372       }
373 
374       rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
375       return success();
376     }
377 
378     // Bool type.
379     if (srcType.isInteger(1)) {
380       // arith.constant can use 0/1 instead of true/false for i1 values. We need
381       // to handle that here.
382       auto dstAttr = convertBoolAttr(cstAttr, rewriter);
383       if (!dstAttr)
384         return failure();
385       rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
386       return success();
387     }
388 
389     // IndexType or IntegerType. Index values are converted to 32-bit integer
390     // values when converting to SPIR-V.
391     auto srcAttr = cast<IntegerAttr>(cstAttr);
392     IntegerAttr dstAttr =
393         convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
394     if (!dstAttr)
395       return failure();
396     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
397     return success();
398   }
399 };
400 
401 //===----------------------------------------------------------------------===//
402 // RemSIOp
403 //===----------------------------------------------------------------------===//
404 
405 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
406 /// the sign of `signOperand`.
407 ///
408 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
409 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
410 /// the result is undefined."  So we cannot directly use spirv.SRem/spirv.SMod
411 /// if either operand can be negative. Emulate it via spirv.UMod.
412 template <typename SignedAbsOp>
413 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
414                                     Value signOperand, OpBuilder &builder) {
415   assert(lhs.getType() == rhs.getType());
416   assert(lhs == signOperand || rhs == signOperand);
417 
418   Type type = lhs.getType();
419 
420   // Calculate the remainder with spirv.UMod.
421   Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
422   Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
423   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
424 
425   // Fix the sign.
426   Value isPositive;
427   if (lhs == signOperand)
428     isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
429   else
430     isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
431   Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
432   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
433 }
434 
435 /// Converts arith.remsi to GLSL SPIR-V ops.
436 ///
437 /// This cannot be merged into the template unary/binary pattern due to Vulkan
438 /// restrictions over spirv.SRem and spirv.SMod.
439 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
440   using OpConversionPattern::OpConversionPattern;
441 
442   LogicalResult
443   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
444                   ConversionPatternRewriter &rewriter) const override {
445     Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
446         op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
447         adaptor.getOperands()[0], rewriter);
448     rewriter.replaceOp(op, result);
449 
450     return success();
451   }
452 };
453 
454 /// Converts arith.remsi to OpenCL SPIR-V ops.
455 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
456   using OpConversionPattern::OpConversionPattern;
457 
458   LogicalResult
459   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
460                   ConversionPatternRewriter &rewriter) const override {
461     Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
462         op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
463         adaptor.getOperands()[0], rewriter);
464     rewriter.replaceOp(op, result);
465 
466     return success();
467   }
468 };
469 
470 //===----------------------------------------------------------------------===//
471 // BitwiseOp
472 //===----------------------------------------------------------------------===//
473 
474 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
475 /// other than the BinaryOpPatternPattern because if the operands are boolean
476 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
477 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
478 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
479 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
480   using OpConversionPattern<Op>::OpConversionPattern;
481 
482   LogicalResult
483   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
484                   ConversionPatternRewriter &rewriter) const override {
485     assert(adaptor.getOperands().size() == 2);
486     Type dstType = this->getTypeConverter()->convertType(op.getType());
487     if (!dstType)
488       return getTypeConversionFailure(rewriter, op);
489 
490     if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
491       rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
492           op, dstType, adaptor.getOperands());
493     } else {
494       rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
495           op, dstType, adaptor.getOperands());
496     }
497     return success();
498   }
499 };
500 
501 //===----------------------------------------------------------------------===//
502 // XOrIOp
503 //===----------------------------------------------------------------------===//
504 
505 /// Converts arith.xori to SPIR-V operations.
506 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
507   using OpConversionPattern::OpConversionPattern;
508 
509   LogicalResult
510   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
511                   ConversionPatternRewriter &rewriter) const override {
512     assert(adaptor.getOperands().size() == 2);
513 
514     if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
515       return failure();
516 
517     Type dstType = getTypeConverter()->convertType(op.getType());
518     if (!dstType)
519       return getTypeConversionFailure(rewriter, op);
520 
521     rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
522                                                      adaptor.getOperands());
523 
524     return success();
525   }
526 };
527 
528 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
529 /// vector of i1.
530 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
531   using OpConversionPattern::OpConversionPattern;
532 
533   LogicalResult
534   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
535                   ConversionPatternRewriter &rewriter) const override {
536     assert(adaptor.getOperands().size() == 2);
537 
538     if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
539       return failure();
540 
541     Type dstType = getTypeConverter()->convertType(op.getType());
542     if (!dstType)
543       return getTypeConversionFailure(rewriter, op);
544 
545     rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
546         op, dstType, adaptor.getOperands());
547     return success();
548   }
549 };
550 
551 //===----------------------------------------------------------------------===//
552 // UIToFPOp
553 //===----------------------------------------------------------------------===//
554 
555 /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
556 /// of i1.
557 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
558   using OpConversionPattern::OpConversionPattern;
559 
560   LogicalResult
561   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
562                   ConversionPatternRewriter &rewriter) const override {
563     Type srcType = adaptor.getOperands().front().getType();
564     if (!isBoolScalarOrVector(srcType))
565       return failure();
566 
567     Type dstType = getTypeConverter()->convertType(op.getType());
568     if (!dstType)
569       return getTypeConversionFailure(rewriter, op);
570 
571     Location loc = op.getLoc();
572     Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
573     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
574     rewriter.replaceOpWithNewOp<spirv::SelectOp>(
575         op, dstType, adaptor.getOperands().front(), one, zero);
576     return success();
577   }
578 };
579 
580 //===----------------------------------------------------------------------===//
581 // ExtSIOp
582 //===----------------------------------------------------------------------===//
583 
584 /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
585 /// of i1.
586 struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
587   using OpConversionPattern::OpConversionPattern;
588 
589   LogicalResult
590   matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
591                   ConversionPatternRewriter &rewriter) const override {
592     Value operand = adaptor.getIn();
593     if (!isBoolScalarOrVector(operand.getType()))
594       return failure();
595 
596     Location loc = op.getLoc();
597     Type dstType = getTypeConverter()->convertType(op.getType());
598     if (!dstType)
599       return getTypeConversionFailure(rewriter, op);
600 
601     Value allOnes;
602     if (auto intTy = dyn_cast<IntegerType>(dstType)) {
603       unsigned componentBitwidth = intTy.getWidth();
604       allOnes = rewriter.create<spirv::ConstantOp>(
605           loc, intTy,
606           rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
607     } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
608       unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
609       allOnes = rewriter.create<spirv::ConstantOp>(
610           loc, vectorTy,
611           SplatElementsAttr::get(vectorTy,
612                                  APInt::getAllOnes(componentBitwidth)));
613     } else {
614       return rewriter.notifyMatchFailure(
615           loc, llvm::formatv("unhandled type: {0}", dstType));
616     }
617 
618     Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
619     rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
620                                                  zero);
621     return success();
622   }
623 };
624 
625 /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
626 /// vector of i1.
627 struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
628   using OpConversionPattern::OpConversionPattern;
629 
630   LogicalResult
631   matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
632                   ConversionPatternRewriter &rewriter) const override {
633     Type srcType = adaptor.getIn().getType();
634     if (isBoolScalarOrVector(srcType))
635       return failure();
636 
637     Type dstType = getTypeConverter()->convertType(op.getType());
638     if (!dstType)
639       return getTypeConversionFailure(rewriter, op);
640 
641     if (dstType == srcType) {
642       // We can have the same source and destination type due to type emulation.
643       // Perform bit shifting to make sure we have the proper leading set bits.
644 
645       unsigned srcBW =
646           getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
647       unsigned dstBW =
648           getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
649       assert(srcBW < dstBW);
650       Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
651                                                   rewriter, op.getLoc());
652 
653       // First shift left to sequeeze out all leading bits beyond the original
654       // bitwidth. Here we need to use the original source and result type's
655       // bitwidth.
656       auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
657           op.getLoc(), dstType, adaptor.getIn(), shiftSize);
658 
659       // Then we perform arithmetic right shift to make sure we have the right
660       // sign bits for negative values.
661       rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
662           op, dstType, shiftLOp, shiftSize);
663     } else {
664       rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
665                                                      adaptor.getOperands());
666     }
667 
668     return success();
669   }
670 };
671 
672 //===----------------------------------------------------------------------===//
673 // ExtUIOp
674 //===----------------------------------------------------------------------===//
675 
676 /// Converts arith.extui to spirv.Select if the type of source is i1 or vector
677 /// of i1.
678 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
679   using OpConversionPattern::OpConversionPattern;
680 
681   LogicalResult
682   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
683                   ConversionPatternRewriter &rewriter) const override {
684     Type srcType = adaptor.getOperands().front().getType();
685     if (!isBoolScalarOrVector(srcType))
686       return failure();
687 
688     Type dstType = getTypeConverter()->convertType(op.getType());
689     if (!dstType)
690       return getTypeConversionFailure(rewriter, op);
691 
692     Location loc = op.getLoc();
693     Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
694     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
695     rewriter.replaceOpWithNewOp<spirv::SelectOp>(
696         op, dstType, adaptor.getOperands().front(), one, zero);
697     return success();
698   }
699 };
700 
701 /// Converts arith.extui for cases where the type of source is neither i1 nor
702 /// vector of i1.
703 struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
704   using OpConversionPattern::OpConversionPattern;
705 
706   LogicalResult
707   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
708                   ConversionPatternRewriter &rewriter) const override {
709     Type srcType = adaptor.getIn().getType();
710     if (isBoolScalarOrVector(srcType))
711       return failure();
712 
713     Type dstType = getTypeConverter()->convertType(op.getType());
714     if (!dstType)
715       return getTypeConversionFailure(rewriter, op);
716 
717     if (dstType == srcType) {
718       // We can have the same source and destination type due to type emulation.
719       // Perform bit masking to make sure we don't pollute downstream consumers
720       // with unwanted bits. Here we need to use the original source type's
721       // bitwidth.
722       unsigned bitwidth =
723           getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
724       Value mask = getScalarOrVectorConstInt(
725           dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
726           op.getLoc());
727       rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
728                                                        adaptor.getIn(), mask);
729     } else {
730       rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
731                                                      adaptor.getOperands());
732     }
733     return success();
734   }
735 };
736 
737 //===----------------------------------------------------------------------===//
738 // TruncIOp
739 //===----------------------------------------------------------------------===//
740 
741 /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
742 /// of i1.
743 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
744   using OpConversionPattern::OpConversionPattern;
745 
746   LogicalResult
747   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
748                   ConversionPatternRewriter &rewriter) const override {
749     Type dstType = getTypeConverter()->convertType(op.getType());
750     if (!dstType)
751       return getTypeConversionFailure(rewriter, op);
752 
753     if (!isBoolScalarOrVector(dstType))
754       return failure();
755 
756     Location loc = op.getLoc();
757     auto srcType = adaptor.getOperands().front().getType();
758     // Check if (x & 1) == 1.
759     Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
760     Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
761         loc, srcType, adaptor.getOperands()[0], mask);
762     Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
763 
764     Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
765     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
766     rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
767     return success();
768   }
769 };
770 
771 /// Converts arith.trunci for cases where the type of result is neither i1
772 /// nor vector of i1.
773 struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
774   using OpConversionPattern::OpConversionPattern;
775 
776   LogicalResult
777   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
778                   ConversionPatternRewriter &rewriter) const override {
779     Type srcType = adaptor.getIn().getType();
780     Type dstType = getTypeConverter()->convertType(op.getType());
781     if (!dstType)
782       return getTypeConversionFailure(rewriter, op);
783 
784     if (isBoolScalarOrVector(dstType))
785       return failure();
786 
787     if (dstType == srcType) {
788       // We can have the same source and destination type due to type emulation.
789       // Perform bit masking to make sure we don't pollute downstream consumers
790       // with unwanted bits. Here we need to use the original result type's
791       // bitwidth.
792       unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
793       Value mask = getScalarOrVectorConstInt(
794           dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
795       rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
796                                                        adaptor.getIn(), mask);
797     } else {
798       // Given this is truncation, either SConvertOp or UConvertOp works.
799       rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
800                                                      adaptor.getOperands());
801     }
802     return success();
803   }
804 };
805 
806 //===----------------------------------------------------------------------===//
807 // TypeCastingOp
808 //===----------------------------------------------------------------------===//
809 
810 static std::optional<spirv::FPRoundingMode>
811 convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
812   switch (roundingMode) {
813   case arith::RoundingMode::downward:
814     return spirv::FPRoundingMode::RTN;
815   case arith::RoundingMode::to_nearest_even:
816     return spirv::FPRoundingMode::RTE;
817   case arith::RoundingMode::toward_zero:
818     return spirv::FPRoundingMode::RTZ;
819   case arith::RoundingMode::upward:
820     return spirv::FPRoundingMode::RTP;
821   case arith::RoundingMode::to_nearest_away:
822     // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
823     // (as of SPIR-V 1.6)
824     return std::nullopt;
825   }
826   llvm_unreachable("Unhandled rounding mode");
827 }
828 
829 /// Converts type-casting standard operations to SPIR-V operations.
830 template <typename Op, typename SPIRVOp>
831 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
832   using OpConversionPattern<Op>::OpConversionPattern;
833 
834   LogicalResult
835   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
836                   ConversionPatternRewriter &rewriter) const override {
837     assert(adaptor.getOperands().size() == 1);
838     Type srcType = adaptor.getOperands().front().getType();
839     Type dstType = this->getTypeConverter()->convertType(op.getType());
840     if (!dstType)
841       return getTypeConversionFailure(rewriter, op);
842 
843     if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
844       return failure();
845 
846     if (dstType == srcType) {
847       // Due to type conversion, we are seeing the same source and target type.
848       // Then we can just erase this operation by forwarding its operand.
849       rewriter.replaceOp(op, adaptor.getOperands().front());
850     } else {
851       auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
852           op, dstType, adaptor.getOperands());
853       if (auto roundingModeOp =
854               dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
855         if (arith::RoundingModeAttr roundingMode =
856                 roundingModeOp.getRoundingModeAttr()) {
857           if (auto rm =
858                   convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
859             newOp->setAttr(
860                 getDecorationString(spirv::Decoration::FPRoundingMode),
861                 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
862           } else {
863             return rewriter.notifyMatchFailure(
864                 op->getLoc(),
865                 llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
866           }
867         }
868       }
869     }
870     return success();
871   }
872 };
873 
874 //===----------------------------------------------------------------------===//
875 // CmpIOp
876 //===----------------------------------------------------------------------===//
877 
878 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
879 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
880 public:
881   using OpConversionPattern::OpConversionPattern;
882 
883   LogicalResult
884   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
885                   ConversionPatternRewriter &rewriter) const override {
886     Type srcType = op.getLhs().getType();
887     if (!isBoolScalarOrVector(srcType))
888       return failure();
889     Type dstType = getTypeConverter()->convertType(srcType);
890     if (!dstType)
891       return getTypeConversionFailure(rewriter, op, srcType);
892 
893     switch (op.getPredicate()) {
894     case arith::CmpIPredicate::eq: {
895       rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
896                                                          adaptor.getRhs());
897       return success();
898     }
899     case arith::CmpIPredicate::ne: {
900       rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
901           op, adaptor.getLhs(), adaptor.getRhs());
902       return success();
903     }
904     case arith::CmpIPredicate::uge:
905     case arith::CmpIPredicate::ugt:
906     case arith::CmpIPredicate::ule:
907     case arith::CmpIPredicate::ult: {
908       // There are no direct corresponding instructions in SPIR-V for such
909       // cases. Extend them to 32-bit and do comparision then.
910       Type type = rewriter.getI32Type();
911       if (auto vectorType = dyn_cast<VectorType>(dstType))
912         type = VectorType::get(vectorType.getShape(), type);
913       Value extLhs =
914           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
915       Value extRhs =
916           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
917 
918       rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
919                                                  extRhs);
920       return success();
921     }
922     default:
923       break;
924     }
925     return failure();
926   }
927 };
928 
929 /// Converts integer compare operation to SPIR-V ops.
930 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
931 public:
932   using OpConversionPattern::OpConversionPattern;
933 
934   LogicalResult
935   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
936                   ConversionPatternRewriter &rewriter) const override {
937     Type srcType = op.getLhs().getType();
938     if (isBoolScalarOrVector(srcType))
939       return failure();
940     Type dstType = getTypeConverter()->convertType(srcType);
941     if (!dstType)
942       return getTypeConversionFailure(rewriter, op, srcType);
943 
944     switch (op.getPredicate()) {
945 #define DISPATCH(cmpPredicate, spirvOp)                                        \
946   case cmpPredicate:                                                           \
947     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
948         !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType &&      \
949         !hasSameBitwidth(srcType, dstType)) {                                  \
950       return op.emitError(                                                     \
951           "bitwidth emulation is not implemented yet on unsigned op");         \
952     }                                                                          \
953     rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
954                                          adaptor.getRhs());                    \
955     return success();
956 
957       DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
958       DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
959       DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
960       DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
961       DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
962       DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
963       DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
964       DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
965       DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
966       DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
967 
968 #undef DISPATCH
969     }
970     return failure();
971   }
972 };
973 
974 //===----------------------------------------------------------------------===//
975 // CmpFOpPattern
976 //===----------------------------------------------------------------------===//
977 
978 /// Converts floating-point comparison operations to SPIR-V ops.
979 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
980 public:
981   using OpConversionPattern::OpConversionPattern;
982 
983   LogicalResult
984   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
985                   ConversionPatternRewriter &rewriter) const override {
986     switch (op.getPredicate()) {
987 #define DISPATCH(cmpPredicate, spirvOp)                                        \
988   case cmpPredicate:                                                           \
989     rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
990                                          adaptor.getRhs());                    \
991     return success();
992 
993       // Ordered.
994       DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
995       DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
996       DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
997       DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
998       DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
999       DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1000       // Unordered.
1001       DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1002       DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1003       DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1004       DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1005       DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1006       DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1007 
1008 #undef DISPATCH
1009 
1010     default:
1011       break;
1012     }
1013     return failure();
1014   }
1015 };
1016 
1017 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
1018 /// Kernel capability.
1019 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1020 public:
1021   using OpConversionPattern::OpConversionPattern;
1022 
1023   LogicalResult
1024   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1025                   ConversionPatternRewriter &rewriter) const override {
1026     if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1027       rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1028                                                     adaptor.getRhs());
1029       return success();
1030     }
1031 
1032     if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1033       rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1034                                                       adaptor.getRhs());
1035       return success();
1036     }
1037 
1038     return failure();
1039   }
1040 };
1041 
1042 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
1043 /// require additional capability.
1044 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1045 public:
1046   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
1047 
1048   LogicalResult
1049   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1050                   ConversionPatternRewriter &rewriter) const override {
1051     if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1052         op.getPredicate() != arith::CmpFPredicate::UNO)
1053       return failure();
1054 
1055     Location loc = op.getLoc();
1056 
1057     Value replace;
1058     if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1059       if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1060         // Ordered comparsion checks if neither operand is NaN.
1061         replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1062       } else {
1063         // Unordered comparsion checks if either operand is NaN.
1064         replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1065       }
1066     } else {
1067       Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1068       Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1069 
1070       replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1071       if (op.getPredicate() == arith::CmpFPredicate::ORD)
1072         replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
1073     }
1074 
1075     rewriter.replaceOp(op, replace);
1076     return success();
1077   }
1078 };
1079 
1080 //===----------------------------------------------------------------------===//
1081 // AddUIExtendedOp
1082 //===----------------------------------------------------------------------===//
1083 
1084 /// Converts arith.addui_extended to spirv.IAddCarry.
1085 class AddUIExtendedOpPattern final
1086     : public OpConversionPattern<arith::AddUIExtendedOp> {
1087 public:
1088   using OpConversionPattern::OpConversionPattern;
1089   LogicalResult
1090   matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1091                   ConversionPatternRewriter &rewriter) const override {
1092     Type dstElemTy = adaptor.getLhs().getType();
1093     Location loc = op->getLoc();
1094     Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1095                                                        adaptor.getRhs());
1096 
1097     Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
1098         loc, result, llvm::ArrayRef(0));
1099     Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
1100         loc, result, llvm::ArrayRef(1));
1101 
1102     // Convert the carry value to boolean.
1103     Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1104     Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
1105 
1106     rewriter.replaceOp(op, {sumResult, carryResult});
1107     return success();
1108   }
1109 };
1110 
1111 //===----------------------------------------------------------------------===//
1112 // MulIExtendedOp
1113 //===----------------------------------------------------------------------===//
1114 
1115 /// Converts arith.mul*i_extended to spirv.*MulExtended.
1116 template <typename ArithMulOp, typename SPIRVMulOp>
1117 class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1118 public:
1119   using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1120   LogicalResult
1121   matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1122                   ConversionPatternRewriter &rewriter) const override {
1123     Location loc = op->getLoc();
1124     Value result =
1125         rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1126 
1127     Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1128                                                            llvm::ArrayRef(0));
1129     Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1130                                                             llvm::ArrayRef(1));
1131 
1132     rewriter.replaceOp(op, {low, high});
1133     return success();
1134   }
1135 };
1136 
1137 //===----------------------------------------------------------------------===//
1138 // SelectOp
1139 //===----------------------------------------------------------------------===//
1140 
1141 /// Converts arith.select to spirv.Select.
1142 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1143 public:
1144   using OpConversionPattern::OpConversionPattern;
1145   LogicalResult
1146   matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1147                   ConversionPatternRewriter &rewriter) const override {
1148     rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1149                                                  adaptor.getTrueValue(),
1150                                                  adaptor.getFalseValue());
1151     return success();
1152   }
1153 };
1154 
1155 //===----------------------------------------------------------------------===//
1156 // MinimumFOp, MaximumFOp
1157 //===----------------------------------------------------------------------===//
1158 
1159 /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1160 /// spirv.CL.fmax/fmin.
1161 template <typename Op, typename SPIRVOp>
1162 class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1163 public:
1164   using OpConversionPattern<Op>::OpConversionPattern;
1165   LogicalResult
1166   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1167                   ConversionPatternRewriter &rewriter) const override {
1168     auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1169     Type dstType = converter->convertType(op.getType());
1170     if (!dstType)
1171       return getTypeConversionFailure(rewriter, op);
1172 
1173     // arith.maximumf/minimumf:
1174     //   "if one of the arguments is NaN, then the result is also NaN."
1175     // spirv.GL.FMax/FMin
1176     //   "which operand is the result is undefined if one of the operands
1177     //   is a NaN."
1178     // spirv.CL.fmax/fmin:
1179     //   "If one argument is a NaN, Fmin returns the other argument."
1180 
1181     Location loc = op.getLoc();
1182     Value spirvOp =
1183         rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1184 
1185     if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1186       rewriter.replaceOp(op, spirvOp);
1187       return success();
1188     }
1189 
1190     Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1191     Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1192 
1193     Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1194                                                      adaptor.getLhs(), spirvOp);
1195     Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1196                                                      adaptor.getRhs(), select1);
1197 
1198     rewriter.replaceOp(op, select2);
1199     return success();
1200   }
1201 };
1202 
1203 //===----------------------------------------------------------------------===//
1204 // MinNumFOp, MaxNumFOp
1205 //===----------------------------------------------------------------------===//
1206 
1207 /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1208 /// spirv.CL.fmax/fmin.
1209 template <typename Op, typename SPIRVOp>
1210 class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1211   template <typename TargetOp>
1212   constexpr bool shouldInsertNanGuards() const {
1213     return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1214   }
1215 
1216 public:
1217   using OpConversionPattern<Op>::OpConversionPattern;
1218   LogicalResult
1219   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1220                   ConversionPatternRewriter &rewriter) const override {
1221     auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1222     Type dstType = converter->convertType(op.getType());
1223     if (!dstType)
1224       return getTypeConversionFailure(rewriter, op);
1225 
1226     // arith.maxnumf/minnumf:
1227     //   "If one of the arguments is NaN, then the result is the other
1228     //   argument."
1229     // spirv.GL.FMax/FMin
1230     //   "which operand is the result is undefined if one of the operands
1231     //   is a NaN."
1232     // spirv.CL.fmax/fmin:
1233     //   "If one argument is a NaN, Fmin returns the other argument."
1234 
1235     Location loc = op.getLoc();
1236     Value spirvOp =
1237         rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1238 
1239     if (!shouldInsertNanGuards<SPIRVOp>() ||
1240         bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1241       rewriter.replaceOp(op, spirvOp);
1242       return success();
1243     }
1244 
1245     Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1246     Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1247 
1248     Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1249                                                      adaptor.getRhs(), spirvOp);
1250     Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1251                                                      adaptor.getLhs(), select1);
1252 
1253     rewriter.replaceOp(op, select2);
1254     return success();
1255   }
1256 };
1257 
1258 } // namespace
1259 
1260 //===----------------------------------------------------------------------===//
1261 // Pattern Population
1262 //===----------------------------------------------------------------------===//
1263 
1264 void mlir::arith::populateArithToSPIRVPatterns(
1265     const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1266   // clang-format off
1267   patterns.add<
1268     ConstantCompositeOpPattern,
1269     ConstantScalarOpPattern,
1270     ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1271     ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1272     ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1273     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
1274     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
1275     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
1276     RemSIOpGLPattern, RemSIOpCLPattern,
1277     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1278     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1279     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1280     ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1281     spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
1282     spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
1283     spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
1284     spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
1285     spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
1286     spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
1287     spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
1288     spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
1289     ExtUIPattern, ExtUII1Pattern,
1290     ExtSIPattern, ExtSII1Pattern,
1291     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1292     TruncIPattern, TruncII1Pattern,
1293     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1294     TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1295     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1296     TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1297     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1298     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1299     TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1300     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1301     CmpIOpBooleanPattern, CmpIOpPattern,
1302     CmpFOpNanNonePattern, CmpFOpPattern,
1303     AddUIExtendedOpPattern,
1304     MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1305     MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1306     SelectOpPattern,
1307 
1308     MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1309     MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1310     MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1311     MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1312     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
1313     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
1314     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
1315     spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>,
1316 
1317     MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1318     MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1319     MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1320     MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1321     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
1322     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
1323     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
1324     spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::CLUMinOp>
1325   >(typeConverter, patterns.getContext());
1326   // clang-format on
1327 
1328   // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1329   // capability is available.
1330   patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1331                                        /*benefit=*/2);
1332 }
1333 
1334 //===----------------------------------------------------------------------===//
1335 // Pass Definition
1336 //===----------------------------------------------------------------------===//
1337 
1338 namespace {
1339 struct ConvertArithToSPIRVPass
1340     : public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
1341   void runOnOperation() override {
1342     Operation *op = getOperation();
1343     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
1344     std::unique_ptr<SPIRVConversionTarget> target =
1345         SPIRVConversionTarget::get(targetAttr);
1346 
1347     SPIRVConversionOptions options;
1348     options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1349     SPIRVTypeConverter typeConverter(targetAttr, options);
1350 
1351     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1352     // in patterns for other dialects.
1353     target->addLegalOp<UnrealizedConversionCastOp>();
1354 
1355     // Fail hard when there are any remaining 'arith' ops.
1356     target->addIllegalDialect<arith::ArithDialect>();
1357 
1358     RewritePatternSet patterns(&getContext());
1359     arith::populateArithToSPIRVPatterns(typeConverter, patterns);
1360 
1361     if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1362       signalPassFailure();
1363   }
1364 };
1365 } // namespace
1366 
1367 std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {
1368   return std::make_unique<ConvertArithToSPIRVPass>();
1369 }
1370