xref: /llvm-project/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (revision b9314a82196a656e2bcc48459123a98ccc02a54d)
1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 #define DEBUG_TYPE "math-to-spirv-pattern"
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // Utility functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
34 /// given type is not a 32-bit scalar/vector type.
35 static Value getScalarOrVectorI32Constant(Type type, int value,
36                                           OpBuilder &builder, Location loc) {
37   if (auto vectorType = dyn_cast<VectorType>(type)) {
38     if (!vectorType.getElementType().isInteger(32))
39       return nullptr;
40     SmallVector<int> values(vectorType.getNumElements(), value);
41     return builder.create<spirv::ConstantOp>(loc, type,
42                                              builder.getI32VectorAttr(values));
43   }
44   if (type.isInteger(32))
45     return builder.create<spirv::ConstantOp>(loc, type,
46                                              builder.getI32IntegerAttr(value));
47 
48   return nullptr;
49 }
50 
51 /// Check if the type is supported by math-to-spirv conversion. We expect to
52 /// only see scalars and vectors at this point, with higher-level types already
53 /// lowered.
54 static bool isSupportedSourceType(Type originalType) {
55   if (originalType.isIntOrIndexOrFloat())
56     return true;
57 
58   if (auto vecTy = dyn_cast<VectorType>(originalType)) {
59     if (!vecTy.getElementType().isIntOrIndexOrFloat())
60       return false;
61     if (vecTy.isScalable())
62       return false;
63     if (vecTy.getRank() > 1)
64       return false;
65 
66     return true;
67   }
68 
69   return false;
70 }
71 
72 /// Check if all `sourceOp` types are supported by math-to-spirv conversion.
73 /// Notify of a match failure othwerise and return a `failure` result.
74 /// This is intended to simplify type checks in `OpConversionPattern`s.
75 static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
76                                         Operation *sourceOp) {
77   auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
78   llvm::append_range(allTypes, sourceOp->getResultTypes());
79 
80   for (Type ty : allTypes) {
81     if (!isSupportedSourceType(ty)) {
82       return rewriter.notifyMatchFailure(
83           sourceOp,
84           llvm::formatv(
85               "unsupported source type for Math to SPIR-V conversion: {0}",
86               ty));
87     }
88   }
89 
90   return success();
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Operation conversion
95 //===----------------------------------------------------------------------===//
96 
97 // Note that DRR cannot be used for the patterns in this file: we may need to
98 // convert type along the way, which requires ConversionPattern. DRR generates
99 // normal RewritePattern.
100 
101 namespace {
102 /// Converts elementwise unary, binary, and ternary standard operations to
103 /// SPIR-V operations. Checks that source `Op` types are supported.
104 template <typename Op, typename SPIRVOp>
105 struct CheckedElementwiseOpPattern final
106     : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
107   using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
108   using BasePattern::BasePattern;
109 
110   LogicalResult
111   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
112                   ConversionPatternRewriter &rewriter) const override {
113     if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
114       return res;
115 
116     return BasePattern::matchAndRewrite(op, adaptor, rewriter);
117   }
118 };
119 
120 /// Converts math.copysign to SPIR-V ops.
121 struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
122   using OpConversionPattern::OpConversionPattern;
123 
124   LogicalResult
125   matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
126                   ConversionPatternRewriter &rewriter) const override {
127     if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
128         failed(res))
129       return res;
130 
131     Type type = getTypeConverter()->convertType(copySignOp.getType());
132     if (!type)
133       return failure();
134 
135     FloatType floatType;
136     if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
137       floatType = scalarType;
138     } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
139       floatType = cast<FloatType>(vectorType.getElementType());
140     } else {
141       return failure();
142     }
143 
144     Location loc = copySignOp.getLoc();
145     int bitwidth = floatType.getWidth();
146     Type intType = rewriter.getIntegerType(bitwidth);
147     uint64_t intValue = uint64_t(1) << (bitwidth - 1);
148 
149     Value signMask = rewriter.create<spirv::ConstantOp>(
150         loc, intType, rewriter.getIntegerAttr(intType, intValue));
151     Value valueMask = rewriter.create<spirv::ConstantOp>(
152         loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
153 
154     if (auto vectorType = dyn_cast<VectorType>(type)) {
155       assert(vectorType.getRank() == 1);
156       int count = vectorType.getNumElements();
157       intType = VectorType::get(count, intType);
158 
159       SmallVector<Value> signSplat(count, signMask);
160       signMask =
161           rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
162 
163       SmallVector<Value> valueSplat(count, valueMask);
164       valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
165                                                                valueSplat);
166     }
167 
168     Value lhsCast =
169         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
170     Value rhsCast =
171         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
172 
173     Value value = rewriter.create<spirv::BitwiseAndOp>(
174         loc, intType, ValueRange{lhsCast, valueMask});
175     Value sign = rewriter.create<spirv::BitwiseAndOp>(
176         loc, intType, ValueRange{rhsCast, signMask});
177 
178     Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
179                                                        ValueRange{value, sign});
180     rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
181     return success();
182   }
183 };
184 
185 /// Converts math.ctlz to SPIR-V ops.
186 ///
187 /// SPIR-V does not have a direct operations for counting leading zeros. If
188 /// Shader capability is supported, we can leverage GL FindUMsb to calculate
189 /// it.
190 struct CountLeadingZerosPattern final
191     : public OpConversionPattern<math::CountLeadingZerosOp> {
192   using OpConversionPattern::OpConversionPattern;
193 
194   LogicalResult
195   matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
196                   ConversionPatternRewriter &rewriter) const override {
197     if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
198       return res;
199 
200     Type type = getTypeConverter()->convertType(countOp.getType());
201     if (!type)
202       return failure();
203 
204     // We can only support 32-bit integer types for now.
205     unsigned bitwidth = 0;
206     if (isa<IntegerType>(type))
207       bitwidth = type.getIntOrFloatBitWidth();
208     if (auto vectorType = dyn_cast<VectorType>(type))
209       bitwidth = vectorType.getElementTypeBitWidth();
210     if (bitwidth != 32)
211       return failure();
212 
213     Location loc = countOp.getLoc();
214     Value input = adaptor.getOperand();
215     Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
216     Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
217     Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
218 
219     Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
220     // We need to subtract from 31 given that the index returned by GLSL
221     // FindUMsb is counted from the least significant bit. Theoretically this
222     // also gives the correct result even if the integer has all zero bits, in
223     // which case GL FindUMsb would return -1.
224     Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
225     // However, certain Vulkan implementations have driver bugs for the corner
226     // case where the input is zero. And.. it can be smart to optimize a select
227     // only involving the corner case. So separately compute the result when the
228     // input is either zero or one.
229     Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
230     Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
231     rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
232                                                  subMsb);
233     return success();
234   }
235 };
236 
237 /// Converts math.expm1 to SPIR-V ops.
238 ///
239 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
240 /// these operations.
241 template <typename ExpOp>
242 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
243   using OpConversionPattern::OpConversionPattern;
244 
245   LogicalResult
246   matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
247                   ConversionPatternRewriter &rewriter) const override {
248     assert(adaptor.getOperands().size() == 1);
249     if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
250         failed(res))
251       return res;
252 
253     Location loc = operation.getLoc();
254     Type type = this->getTypeConverter()->convertType(operation.getType());
255     if (!type)
256       return failure();
257 
258     Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
259     auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
260     rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
261     return success();
262   }
263 };
264 
265 /// Converts math.log1p to SPIR-V ops.
266 ///
267 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
268 /// these operations.
269 template <typename LogOp>
270 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
271   using OpConversionPattern::OpConversionPattern;
272 
273   LogicalResult
274   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
275                   ConversionPatternRewriter &rewriter) const override {
276     assert(adaptor.getOperands().size() == 1);
277     if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
278         failed(res))
279       return res;
280 
281     Location loc = operation.getLoc();
282     Type type = this->getTypeConverter()->convertType(operation.getType());
283     if (!type)
284       return failure();
285 
286     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
287     Value onePlus =
288         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
289     rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
290     return success();
291   }
292 };
293 
294 /// Converts math.log2 and math.log10 to SPIR-V ops.
295 ///
296 /// SPIR-V does not have direct operations for log2 and log10. Explicitly
297 /// lower to these operations using:
298 ///   log2(x) = log(x) * 1/log(2)
299 ///   log10(x) = log(x) * 1/log(10)
300 
301 template <typename MathLogOp, typename SpirvLogOp>
302 struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
303   using OpConversionPattern<MathLogOp>::OpConversionPattern;
304   using typename OpConversionPattern<MathLogOp>::OpAdaptor;
305 
306   static constexpr double log2Reciprocal =
307       1.442695040888963407359924681001892137426645954152985934135449407;
308   static constexpr double log10Reciprocal =
309       0.4342944819032518276511289189166050822943970058036665661144537832;
310 
311   LogicalResult
312   matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
313                   ConversionPatternRewriter &rewriter) const override {
314     assert(adaptor.getOperands().size() == 1);
315     if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
316         failed(res))
317       return res;
318 
319     Location loc = operation.getLoc();
320     Type type = this->getTypeConverter()->convertType(operation.getType());
321     if (!type)
322       return rewriter.notifyMatchFailure(operation, "type conversion failed");
323 
324     auto getConstantValue = [&](double value) {
325       if (auto floatType = dyn_cast<FloatType>(type)) {
326         return rewriter.create<spirv::ConstantOp>(
327             loc, type, rewriter.getFloatAttr(floatType, value));
328       }
329       if (auto vectorType = dyn_cast<VectorType>(type)) {
330         Type elemType = vectorType.getElementType();
331 
332         if (isa<FloatType>(elemType)) {
333           return rewriter.create<spirv::ConstantOp>(
334               loc, type,
335               DenseFPElementsAttr::get(
336                   vectorType, FloatAttr::get(elemType, value).getValue()));
337         }
338       }
339 
340       llvm_unreachable("unimplemented types for log2/log10");
341     };
342 
343     Value constantValue = getConstantValue(
344         std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
345                                                 : log10Reciprocal);
346     Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
347     rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
348                                                constantValue);
349     return success();
350   }
351 };
352 
353 /// Converts math.powf to SPIRV-Ops.
354 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
355   using OpConversionPattern::OpConversionPattern;
356 
357   LogicalResult
358   matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
359                   ConversionPatternRewriter &rewriter) const override {
360     if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res))
361       return res;
362 
363     Type dstType = getTypeConverter()->convertType(powfOp.getType());
364     if (!dstType)
365       return failure();
366 
367     // Get the scalar float type.
368     FloatType scalarFloatType;
369     if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
370       scalarFloatType = scalarType;
371     } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
372       scalarFloatType = cast<FloatType>(vectorType.getElementType());
373     } else {
374       return failure();
375     }
376 
377     // Get int type of the same shape as the float type.
378     Type scalarIntType = rewriter.getIntegerType(32);
379     Type intType = scalarIntType;
380     auto operandType = adaptor.getRhs().getType();
381     if (auto vectorType = dyn_cast<VectorType>(operandType)) {
382       auto shape = vectorType.getShape();
383       intType = VectorType::get(shape, scalarIntType);
384     }
385 
386     // Per GL Pow extended instruction spec:
387     // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
388     Location loc = powfOp.getLoc();
389     Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
390     Value lessThan =
391         rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
392 
393     // Per C/C++ spec:
394     // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
395     // > finite and negative and exponent is finite and non-integer.
396     // Calculate the reminder from the exponent and check whether it is zero.
397     Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
398     Value expRem =
399         rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
400     Value expRemNonZero =
401         rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
402     Value cmpNegativeWithFractionalExp =
403         rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
404     // Create NaN result and replace base value if conditions are met.
405     const auto &floatSemantics = scalarFloatType.getFloatSemantics();
406     const auto nan = APFloat::getNaN(floatSemantics);
407     Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
408     if (auto vectorType = dyn_cast<VectorType>(operandType))
409       nanAttr = DenseElementsAttr::get(vectorType, nan);
410 
411     Value NanValue =
412         rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr);
413     Value lhs = rewriter.create<spirv::SelectOp>(
414         loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
415     Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);
416 
417     // TODO: The following just forcefully casts y into an integer value in
418     // order to properly propagate the sign, assuming integer y cases. It
419     // doesn't cover other cases and should be fixed.
420 
421     // Cast exponent to integer and calculate exponent % 2 != 0.
422     Value intRhs =
423         rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
424     Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
425     Value bitwiseAndOne =
426         rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
427     Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
428 
429     // calculate pow based on abs(lhs)^rhs.
430     Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
431     Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
432     // if the exponent is odd and lhs < 0, negate the result.
433     Value shouldNegate =
434         rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
435     rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
436                                                  pow);
437     return success();
438   }
439 };
440 
441 /// Converts math.round to GLSL SPIRV extended ops.
442 struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
443   using OpConversionPattern::OpConversionPattern;
444 
445   LogicalResult
446   matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
447                   ConversionPatternRewriter &rewriter) const override {
448     if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res))
449       return res;
450 
451     Location loc = roundOp.getLoc();
452     Value operand = roundOp.getOperand();
453     Type ty = operand.getType();
454     Type ety = getElementTypeOrSelf(ty);
455 
456     auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
457     auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
458     Value half;
459     if (VectorType vty = dyn_cast<VectorType>(ty)) {
460       half = rewriter.create<spirv::ConstantOp>(
461           loc, vty,
462           DenseElementsAttr::get(vty,
463                                  rewriter.getFloatAttr(ety, 0.5).getValue()));
464     } else {
465       half = rewriter.create<spirv::ConstantOp>(
466           loc, ty, rewriter.getFloatAttr(ety, 0.5));
467     }
468 
469     auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
470     auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
471     auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
472     auto greater =
473         rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
474     auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
475     auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
476     rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
477     return success();
478   }
479 };
480 
481 } // namespace
482 
483 //===----------------------------------------------------------------------===//
484 // Pattern population
485 //===----------------------------------------------------------------------===//
486 
487 namespace mlir {
488 void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
489                                  RewritePatternSet &patterns) {
490   // Core patterns
491   patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
492 
493   // GLSL patterns
494   patterns
495       .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
496            Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
497            Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
498            ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
499            CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
500            CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
501            CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
502            CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
503            CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
504            CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
505            CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
506            CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
507            CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
508            CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
509            CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
510            CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
511            CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
512            CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
513           typeConverter, patterns.getContext());
514 
515   // OpenCL patterns
516   patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
517                Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
518                Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
519                CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
520                CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
521                CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
522                CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
523                CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
524                CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
525                CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
526                CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
527                CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
528                CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
529                CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
530                CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
531                CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
532                CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
533                CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
534                CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
535                CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
536                CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
537       typeConverter, patterns.getContext());
538 }
539 
540 } // namespace mlir
541