xref: /llvm-project/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (revision b9314a82196a656e2bcc48459123a98ccc02a54d)
1995c3984SLei Zhang //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2995c3984SLei Zhang //
3995c3984SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4995c3984SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
5995c3984SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6995c3984SLei Zhang //
7995c3984SLei Zhang //===----------------------------------------------------------------------===//
8995c3984SLei Zhang //
9995c3984SLei Zhang // This file implements patterns to convert Math dialect to SPIR-V dialect.
10995c3984SLei Zhang //
11995c3984SLei Zhang //===----------------------------------------------------------------------===//
12995c3984SLei Zhang 
13a54f4eaeSMogball #include "../SPIRVCommon/Pattern.h"
14995c3984SLei Zhang #include "mlir/Dialect/Math/IR/Math.h"
15995c3984SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16995c3984SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17995c3984SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18533ec929SLei Zhang #include "mlir/IR/BuiltinTypes.h"
19b9e642afSRobert Suderman #include "mlir/IR/TypeUtilities.h"
20cc020a22SLei Zhang #include "mlir/Transforms/DialectConversion.h"
217f7e33c2SJakub Kuderski #include "llvm/ADT/STLExtras.h"
22995c3984SLei Zhang #include "llvm/Support/Debug.h"
237f7e33c2SJakub Kuderski #include "llvm/Support/FormatVariadic.h"
24995c3984SLei Zhang 
25995c3984SLei Zhang #define DEBUG_TYPE "math-to-spirv-pattern"
26995c3984SLei Zhang 
27995c3984SLei Zhang using namespace mlir;
28995c3984SLei Zhang 
29995c3984SLei Zhang //===----------------------------------------------------------------------===//
30cc020a22SLei Zhang // Utility functions
31cc020a22SLei Zhang //===----------------------------------------------------------------------===//
32cc020a22SLei Zhang 
33cc020a22SLei Zhang /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
34cc020a22SLei Zhang /// given type is not a 32-bit scalar/vector type.
35cc020a22SLei Zhang static Value getScalarOrVectorI32Constant(Type type, int value,
36cc020a22SLei Zhang                                           OpBuilder &builder, Location loc) {
375550c821STres Popp   if (auto vectorType = dyn_cast<VectorType>(type)) {
38cc020a22SLei Zhang     if (!vectorType.getElementType().isInteger(32))
39cc020a22SLei Zhang       return nullptr;
40cc020a22SLei Zhang     SmallVector<int> values(vectorType.getNumElements(), value);
41cc020a22SLei Zhang     return builder.create<spirv::ConstantOp>(loc, type,
42cc020a22SLei Zhang                                              builder.getI32VectorAttr(values));
43cc020a22SLei Zhang   }
44cc020a22SLei Zhang   if (type.isInteger(32))
45cc020a22SLei Zhang     return builder.create<spirv::ConstantOp>(loc, type,
46cc020a22SLei Zhang                                              builder.getI32IntegerAttr(value));
47cc020a22SLei Zhang 
48cc020a22SLei Zhang   return nullptr;
49cc020a22SLei Zhang }
50cc020a22SLei Zhang 
517f7e33c2SJakub Kuderski /// Check if the type is supported by math-to-spirv conversion. We expect to
527f7e33c2SJakub Kuderski /// only see scalars and vectors at this point, with higher-level types already
537f7e33c2SJakub Kuderski /// lowered.
547f7e33c2SJakub Kuderski static bool isSupportedSourceType(Type originalType) {
557f7e33c2SJakub Kuderski   if (originalType.isIntOrIndexOrFloat())
567f7e33c2SJakub Kuderski     return true;
577f7e33c2SJakub Kuderski 
585550c821STres Popp   if (auto vecTy = dyn_cast<VectorType>(originalType)) {
597f7e33c2SJakub Kuderski     if (!vecTy.getElementType().isIntOrIndexOrFloat())
607f7e33c2SJakub Kuderski       return false;
617f7e33c2SJakub Kuderski     if (vecTy.isScalable())
627f7e33c2SJakub Kuderski       return false;
637f7e33c2SJakub Kuderski     if (vecTy.getRank() > 1)
647f7e33c2SJakub Kuderski       return false;
657f7e33c2SJakub Kuderski 
667f7e33c2SJakub Kuderski     return true;
677f7e33c2SJakub Kuderski   }
687f7e33c2SJakub Kuderski 
697f7e33c2SJakub Kuderski   return false;
707f7e33c2SJakub Kuderski }
717f7e33c2SJakub Kuderski 
727f7e33c2SJakub Kuderski /// Check if all `sourceOp` types are supported by math-to-spirv conversion.
737f7e33c2SJakub Kuderski /// Notify of a match failure othwerise and return a `failure` result.
747f7e33c2SJakub Kuderski /// This is intended to simplify type checks in `OpConversionPattern`s.
757f7e33c2SJakub Kuderski static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
767f7e33c2SJakub Kuderski                                         Operation *sourceOp) {
777f7e33c2SJakub Kuderski   auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
787f7e33c2SJakub Kuderski   llvm::append_range(allTypes, sourceOp->getResultTypes());
797f7e33c2SJakub Kuderski 
807f7e33c2SJakub Kuderski   for (Type ty : allTypes) {
817f7e33c2SJakub Kuderski     if (!isSupportedSourceType(ty)) {
827f7e33c2SJakub Kuderski       return rewriter.notifyMatchFailure(
837f7e33c2SJakub Kuderski           sourceOp,
847f7e33c2SJakub Kuderski           llvm::formatv(
857f7e33c2SJakub Kuderski               "unsupported source type for Math to SPIR-V conversion: {0}",
867f7e33c2SJakub Kuderski               ty));
877f7e33c2SJakub Kuderski     }
887f7e33c2SJakub Kuderski   }
897f7e33c2SJakub Kuderski 
907f7e33c2SJakub Kuderski   return success();
917f7e33c2SJakub Kuderski }
927f7e33c2SJakub Kuderski 
93cc020a22SLei Zhang //===----------------------------------------------------------------------===//
94995c3984SLei Zhang // Operation conversion
95995c3984SLei Zhang //===----------------------------------------------------------------------===//
96995c3984SLei Zhang 
97995c3984SLei Zhang // Note that DRR cannot be used for the patterns in this file: we may need to
98995c3984SLei Zhang // convert type along the way, which requires ConversionPattern. DRR generates
99995c3984SLei Zhang // normal RewritePattern.
100995c3984SLei Zhang 
101995c3984SLei Zhang namespace {
1027f7e33c2SJakub Kuderski /// Converts elementwise unary, binary, and ternary standard operations to
1037f7e33c2SJakub Kuderski /// SPIR-V operations. Checks that source `Op` types are supported.
1047f7e33c2SJakub Kuderski template <typename Op, typename SPIRVOp>
1057f7e33c2SJakub Kuderski struct CheckedElementwiseOpPattern final
1067f7e33c2SJakub Kuderski     : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
1077f7e33c2SJakub Kuderski   using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
1087f7e33c2SJakub Kuderski   using BasePattern::BasePattern;
1097f7e33c2SJakub Kuderski 
1107f7e33c2SJakub Kuderski   LogicalResult
1117f7e33c2SJakub Kuderski   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1127f7e33c2SJakub Kuderski                   ConversionPatternRewriter &rewriter) const override {
1137f7e33c2SJakub Kuderski     if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
1147f7e33c2SJakub Kuderski       return res;
1157f7e33c2SJakub Kuderski 
1167f7e33c2SJakub Kuderski     return BasePattern::matchAndRewrite(op, adaptor, rewriter);
1177f7e33c2SJakub Kuderski   }
1187f7e33c2SJakub Kuderski };
1197f7e33c2SJakub Kuderski 
120533ec929SLei Zhang /// Converts math.copysign to SPIR-V ops.
1217f7e33c2SJakub Kuderski struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
122533ec929SLei Zhang   using OpConversionPattern::OpConversionPattern;
123533ec929SLei Zhang 
124533ec929SLei Zhang   LogicalResult
125533ec929SLei Zhang   matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
126533ec929SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1277f7e33c2SJakub Kuderski     if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
1287f7e33c2SJakub Kuderski         failed(res))
1297f7e33c2SJakub Kuderski       return res;
1307f7e33c2SJakub Kuderski 
1317f7e33c2SJakub Kuderski     Type type = getTypeConverter()->convertType(copySignOp.getType());
132533ec929SLei Zhang     if (!type)
133533ec929SLei Zhang       return failure();
134533ec929SLei Zhang 
135533ec929SLei Zhang     FloatType floatType;
1365550c821STres Popp     if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
137533ec929SLei Zhang       floatType = scalarType;
1385550c821STres Popp     } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
1395550c821STres Popp       floatType = cast<FloatType>(vectorType.getElementType());
140533ec929SLei Zhang     } else {
141533ec929SLei Zhang       return failure();
142533ec929SLei Zhang     }
143533ec929SLei Zhang 
144533ec929SLei Zhang     Location loc = copySignOp.getLoc();
145533ec929SLei Zhang     int bitwidth = floatType.getWidth();
146533ec929SLei Zhang     Type intType = rewriter.getIntegerType(bitwidth);
1475f14aee3SStella Stamenova     uint64_t intValue = uint64_t(1) << (bitwidth - 1);
148533ec929SLei Zhang 
149533ec929SLei Zhang     Value signMask = rewriter.create<spirv::ConstantOp>(
1505f14aee3SStella Stamenova         loc, intType, rewriter.getIntegerAttr(intType, intValue));
151533ec929SLei Zhang     Value valueMask = rewriter.create<spirv::ConstantOp>(
1525f14aee3SStella Stamenova         loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
153533ec929SLei Zhang 
1545550c821STres Popp     if (auto vectorType = dyn_cast<VectorType>(type)) {
155533ec929SLei Zhang       assert(vectorType.getRank() == 1);
156533ec929SLei Zhang       int count = vectorType.getNumElements();
157533ec929SLei Zhang       intType = VectorType::get(count, intType);
158533ec929SLei Zhang 
159533ec929SLei Zhang       SmallVector<Value> signSplat(count, signMask);
160533ec929SLei Zhang       signMask =
161533ec929SLei Zhang           rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
162533ec929SLei Zhang 
163533ec929SLei Zhang       SmallVector<Value> valueSplat(count, valueMask);
164533ec929SLei Zhang       valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
165533ec929SLei Zhang                                                                valueSplat);
166533ec929SLei Zhang     }
167533ec929SLei Zhang 
168533ec929SLei Zhang     Value lhsCast =
169533ec929SLei Zhang         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
170533ec929SLei Zhang     Value rhsCast =
171533ec929SLei Zhang         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
172533ec929SLei Zhang 
173533ec929SLei Zhang     Value value = rewriter.create<spirv::BitwiseAndOp>(
174533ec929SLei Zhang         loc, intType, ValueRange{lhsCast, valueMask});
175533ec929SLei Zhang     Value sign = rewriter.create<spirv::BitwiseAndOp>(
176533ec929SLei Zhang         loc, intType, ValueRange{rhsCast, signMask});
177533ec929SLei Zhang 
178533ec929SLei Zhang     Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
179533ec929SLei Zhang                                                        ValueRange{value, sign});
180533ec929SLei Zhang     rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
181533ec929SLei Zhang     return success();
182533ec929SLei Zhang   }
183533ec929SLei Zhang };
184533ec929SLei Zhang 
185cc020a22SLei Zhang /// Converts math.ctlz to SPIR-V ops.
186cc020a22SLei Zhang ///
187cc020a22SLei Zhang /// SPIR-V does not have a direct operations for counting leading zeros. If
18852b630daSJakub Kuderski /// Shader capability is supported, we can leverage GL FindUMsb to calculate
189cc020a22SLei Zhang /// it.
1907f7e33c2SJakub Kuderski struct CountLeadingZerosPattern final
191cc020a22SLei Zhang     : public OpConversionPattern<math::CountLeadingZerosOp> {
192cc020a22SLei Zhang   using OpConversionPattern::OpConversionPattern;
193cc020a22SLei Zhang 
194cc020a22SLei Zhang   LogicalResult
195cc020a22SLei Zhang   matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
196cc020a22SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1977f7e33c2SJakub Kuderski     if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
1987f7e33c2SJakub Kuderski       return res;
1997f7e33c2SJakub Kuderski 
2007f7e33c2SJakub Kuderski     Type type = getTypeConverter()->convertType(countOp.getType());
201cc020a22SLei Zhang     if (!type)
202cc020a22SLei Zhang       return failure();
203cc020a22SLei Zhang 
204cc020a22SLei Zhang     // We can only support 32-bit integer types for now.
205cc020a22SLei Zhang     unsigned bitwidth = 0;
2065550c821STres Popp     if (isa<IntegerType>(type))
207cc020a22SLei Zhang       bitwidth = type.getIntOrFloatBitWidth();
2085550c821STres Popp     if (auto vectorType = dyn_cast<VectorType>(type))
209cc020a22SLei Zhang       bitwidth = vectorType.getElementTypeBitWidth();
210cc020a22SLei Zhang     if (bitwidth != 32)
211cc020a22SLei Zhang       return failure();
212cc020a22SLei Zhang 
213cc020a22SLei Zhang     Location loc = countOp.getLoc();
2142320a4aeSLei Zhang     Value input = adaptor.getOperand();
2152320a4aeSLei Zhang     Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
216cc020a22SLei Zhang     Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
2172320a4aeSLei Zhang     Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
2182320a4aeSLei Zhang 
21952b630daSJakub Kuderski     Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
2202320a4aeSLei Zhang     // We need to subtract from 31 given that the index returned by GLSL
2212320a4aeSLei Zhang     // FindUMsb is counted from the least significant bit. Theoretically this
2222320a4aeSLei Zhang     // also gives the correct result even if the integer has all zero bits, in
22352b630daSJakub Kuderski     // which case GL FindUMsb would return -1.
2242320a4aeSLei Zhang     Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
2252320a4aeSLei Zhang     // However, certain Vulkan implementations have driver bugs for the corner
2262320a4aeSLei Zhang     // case where the input is zero. And.. it can be smart to optimize a select
2272320a4aeSLei Zhang     // only involving the corner case. So separately compute the result when the
2282320a4aeSLei Zhang     // input is either zero or one.
2292320a4aeSLei Zhang     Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
2302320a4aeSLei Zhang     Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
2312320a4aeSLei Zhang     rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
2322320a4aeSLei Zhang                                                  subMsb);
233cc020a22SLei Zhang     return success();
234cc020a22SLei Zhang   }
235cc020a22SLei Zhang };
236cc020a22SLei Zhang 
2373e746c6dSRob Suderman /// Converts math.expm1 to SPIR-V ops.
2383e746c6dSRob Suderman ///
2393e746c6dSRob Suderman /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
2403e746c6dSRob Suderman /// these operations.
2413e746c6dSRob Suderman template <typename ExpOp>
242533ec929SLei Zhang struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
243533ec929SLei Zhang   using OpConversionPattern::OpConversionPattern;
2443e746c6dSRob Suderman 
2453e746c6dSRob Suderman   LogicalResult
2463e746c6dSRob Suderman   matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
2473e746c6dSRob Suderman                   ConversionPatternRewriter &rewriter) const override {
2483e746c6dSRob Suderman     assert(adaptor.getOperands().size() == 1);
2497f7e33c2SJakub Kuderski     if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
2507f7e33c2SJakub Kuderski         failed(res))
2517f7e33c2SJakub Kuderski       return res;
2527f7e33c2SJakub Kuderski 
2533e746c6dSRob Suderman     Location loc = operation.getLoc();
2547f7e33c2SJakub Kuderski     Type type = this->getTypeConverter()->convertType(operation.getType());
2557f7e33c2SJakub Kuderski     if (!type)
2567f7e33c2SJakub Kuderski       return failure();
2577f7e33c2SJakub Kuderski 
2587f7e33c2SJakub Kuderski     Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
2593e746c6dSRob Suderman     auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
2603e746c6dSRob Suderman     rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
2613e746c6dSRob Suderman     return success();
2623e746c6dSRob Suderman   }
2633e746c6dSRob Suderman };
2643e746c6dSRob Suderman 
265995c3984SLei Zhang /// Converts math.log1p to SPIR-V ops.
266995c3984SLei Zhang ///
267995c3984SLei Zhang /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
268995c3984SLei Zhang /// these operations.
26975a1bee0SButygin template <typename LogOp>
270533ec929SLei Zhang struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
271533ec929SLei Zhang   using OpConversionPattern::OpConversionPattern;
272995c3984SLei Zhang 
273995c3984SLei Zhang   LogicalResult
274b54c724bSRiver Riddle   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
275995c3984SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
276b54c724bSRiver Riddle     assert(adaptor.getOperands().size() == 1);
2777f7e33c2SJakub Kuderski     if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
2787f7e33c2SJakub Kuderski         failed(res))
2797f7e33c2SJakub Kuderski       return res;
2807f7e33c2SJakub Kuderski 
281995c3984SLei Zhang     Location loc = operation.getLoc();
2827f7e33c2SJakub Kuderski     Type type = this->getTypeConverter()->convertType(operation.getType());
2837f7e33c2SJakub Kuderski     if (!type)
2847f7e33c2SJakub Kuderski       return failure();
2857f7e33c2SJakub Kuderski 
286995c3984SLei Zhang     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
2877f7e33c2SJakub Kuderski     Value onePlus =
2883e746c6dSRob Suderman         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
28975a1bee0SButygin     rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
290995c3984SLei Zhang     return success();
291995c3984SLei Zhang   }
292995c3984SLei Zhang };
29306c6758aSLei Zhang 
2948d165136Smeehatpa /// Converts math.log2 and math.log10 to SPIR-V ops.
2958d165136Smeehatpa ///
2968d165136Smeehatpa /// SPIR-V does not have direct operations for log2 and log10. Explicitly
2978d165136Smeehatpa /// lower to these operations using:
2988d165136Smeehatpa ///   log2(x) = log(x) * 1/log(2)
2998d165136Smeehatpa ///   log10(x) = log(x) * 1/log(10)
3008d165136Smeehatpa 
3018d165136Smeehatpa template <typename MathLogOp, typename SpirvLogOp>
3028d165136Smeehatpa struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
3038d165136Smeehatpa   using OpConversionPattern<MathLogOp>::OpConversionPattern;
3048d165136Smeehatpa   using typename OpConversionPattern<MathLogOp>::OpAdaptor;
3058d165136Smeehatpa 
3068d165136Smeehatpa   static constexpr double log2Reciprocal =
3078d165136Smeehatpa       1.442695040888963407359924681001892137426645954152985934135449407;
3088d165136Smeehatpa   static constexpr double log10Reciprocal =
3098d165136Smeehatpa       0.4342944819032518276511289189166050822943970058036665661144537832;
3108d165136Smeehatpa 
3118d165136Smeehatpa   LogicalResult
3128d165136Smeehatpa   matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
3138d165136Smeehatpa                   ConversionPatternRewriter &rewriter) const override {
3148d165136Smeehatpa     assert(adaptor.getOperands().size() == 1);
3158d165136Smeehatpa     if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
3168d165136Smeehatpa         failed(res))
3178d165136Smeehatpa       return res;
3188d165136Smeehatpa 
3198d165136Smeehatpa     Location loc = operation.getLoc();
3208d165136Smeehatpa     Type type = this->getTypeConverter()->convertType(operation.getType());
3218d165136Smeehatpa     if (!type)
3228d165136Smeehatpa       return rewriter.notifyMatchFailure(operation, "type conversion failed");
3238d165136Smeehatpa 
3248d165136Smeehatpa     auto getConstantValue = [&](double value) {
3258d165136Smeehatpa       if (auto floatType = dyn_cast<FloatType>(type)) {
3268d165136Smeehatpa         return rewriter.create<spirv::ConstantOp>(
3278d165136Smeehatpa             loc, type, rewriter.getFloatAttr(floatType, value));
3288d165136Smeehatpa       }
3298d165136Smeehatpa       if (auto vectorType = dyn_cast<VectorType>(type)) {
3308d165136Smeehatpa         Type elemType = vectorType.getElementType();
3318d165136Smeehatpa 
3328d165136Smeehatpa         if (isa<FloatType>(elemType)) {
3338d165136Smeehatpa           return rewriter.create<spirv::ConstantOp>(
3348d165136Smeehatpa               loc, type,
3358d165136Smeehatpa               DenseFPElementsAttr::get(
3368d165136Smeehatpa                   vectorType, FloatAttr::get(elemType, value).getValue()));
3378d165136Smeehatpa         }
3388d165136Smeehatpa       }
3398d165136Smeehatpa 
3408d165136Smeehatpa       llvm_unreachable("unimplemented types for log2/log10");
3418d165136Smeehatpa     };
3428d165136Smeehatpa 
3438d165136Smeehatpa     Value constantValue = getConstantValue(
3448d165136Smeehatpa         std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
3458d165136Smeehatpa                                                 : log10Reciprocal);
3468d165136Smeehatpa     Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
3478d165136Smeehatpa     rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
3488d165136Smeehatpa                                                constantValue);
3498d165136Smeehatpa     return success();
3508d165136Smeehatpa   }
3518d165136Smeehatpa };
3528d165136Smeehatpa 
35306c6758aSLei Zhang /// Converts math.powf to SPIRV-Ops.
35406c6758aSLei Zhang struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
35506c6758aSLei Zhang   using OpConversionPattern::OpConversionPattern;
35606c6758aSLei Zhang 
35706c6758aSLei Zhang   LogicalResult
35806c6758aSLei Zhang   matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
35906c6758aSLei Zhang                   ConversionPatternRewriter &rewriter) const override {
3607f7e33c2SJakub Kuderski     if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res))
3617f7e33c2SJakub Kuderski       return res;
3627f7e33c2SJakub Kuderski 
3637f7e33c2SJakub Kuderski     Type dstType = getTypeConverter()->convertType(powfOp.getType());
36406c6758aSLei Zhang     if (!dstType)
36506c6758aSLei Zhang       return failure();
36606c6758aSLei Zhang 
36758839f2eSDaniel Garvey     // Get the scalar float type.
36858839f2eSDaniel Garvey     FloatType scalarFloatType;
3695550c821STres Popp     if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
37058839f2eSDaniel Garvey       scalarFloatType = scalarType;
3715550c821STres Popp     } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
3725550c821STres Popp       scalarFloatType = cast<FloatType>(vectorType.getElementType());
37358839f2eSDaniel Garvey     } else {
37458839f2eSDaniel Garvey       return failure();
37558839f2eSDaniel Garvey     }
37658839f2eSDaniel Garvey 
37758839f2eSDaniel Garvey     // Get int type of the same shape as the float type.
37858839f2eSDaniel Garvey     Type scalarIntType = rewriter.getIntegerType(32);
37958839f2eSDaniel Garvey     Type intType = scalarIntType;
380*b9314a82SDmitriy Smirnov     auto operandType = adaptor.getRhs().getType();
381*b9314a82SDmitriy Smirnov     if (auto vectorType = dyn_cast<VectorType>(operandType)) {
38258839f2eSDaniel Garvey       auto shape = vectorType.getShape();
38358839f2eSDaniel Garvey       intType = VectorType::get(shape, scalarIntType);
38458839f2eSDaniel Garvey     }
38558839f2eSDaniel Garvey 
38652b630daSJakub Kuderski     // Per GL Pow extended instruction spec:
38706c6758aSLei Zhang     // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
38806c6758aSLei Zhang     Location loc = powfOp.getLoc();
389*b9314a82SDmitriy Smirnov     Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
39006c6758aSLei Zhang     Value lessThan =
39106c6758aSLei Zhang         rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
392*b9314a82SDmitriy Smirnov 
393*b9314a82SDmitriy Smirnov     // Per C/C++ spec:
394*b9314a82SDmitriy Smirnov     // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
395*b9314a82SDmitriy Smirnov     // > finite and negative and exponent is finite and non-integer.
396*b9314a82SDmitriy Smirnov     // Calculate the reminder from the exponent and check whether it is zero.
397*b9314a82SDmitriy Smirnov     Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
398*b9314a82SDmitriy Smirnov     Value expRem =
399*b9314a82SDmitriy Smirnov         rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
400*b9314a82SDmitriy Smirnov     Value expRemNonZero =
401*b9314a82SDmitriy Smirnov         rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
402*b9314a82SDmitriy Smirnov     Value cmpNegativeWithFractionalExp =
403*b9314a82SDmitriy Smirnov         rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
404*b9314a82SDmitriy Smirnov     // Create NaN result and replace base value if conditions are met.
405*b9314a82SDmitriy Smirnov     const auto &floatSemantics = scalarFloatType.getFloatSemantics();
406*b9314a82SDmitriy Smirnov     const auto nan = APFloat::getNaN(floatSemantics);
407*b9314a82SDmitriy Smirnov     Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
408*b9314a82SDmitriy Smirnov     if (auto vectorType = dyn_cast<VectorType>(operandType))
409*b9314a82SDmitriy Smirnov       nanAttr = DenseElementsAttr::get(vectorType, nan);
410*b9314a82SDmitriy Smirnov 
411*b9314a82SDmitriy Smirnov     Value NanValue =
412*b9314a82SDmitriy Smirnov         rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr);
413*b9314a82SDmitriy Smirnov     Value lhs = rewriter.create<spirv::SelectOp>(
414*b9314a82SDmitriy Smirnov         loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
415*b9314a82SDmitriy Smirnov     Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);
41658839f2eSDaniel Garvey 
41758839f2eSDaniel Garvey     // TODO: The following just forcefully casts y into an integer value in
41858839f2eSDaniel Garvey     // order to properly propagate the sign, assuming integer y cases. It
41958839f2eSDaniel Garvey     // doesn't cover other cases and should be fixed.
42058839f2eSDaniel Garvey 
42158839f2eSDaniel Garvey     // Cast exponent to integer and calculate exponent % 2 != 0.
42258839f2eSDaniel Garvey     Value intRhs =
42358839f2eSDaniel Garvey         rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
42458839f2eSDaniel Garvey     Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
42558839f2eSDaniel Garvey     Value bitwiseAndOne =
42658839f2eSDaniel Garvey         rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
42758839f2eSDaniel Garvey     Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
42858839f2eSDaniel Garvey 
42958839f2eSDaniel Garvey     // calculate pow based on abs(lhs)^rhs.
43052b630daSJakub Kuderski     Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
43106c6758aSLei Zhang     Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
43258839f2eSDaniel Garvey     // if the exponent is odd and lhs < 0, negate the result.
43358839f2eSDaniel Garvey     Value shouldNegate =
43458839f2eSDaniel Garvey         rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
43558839f2eSDaniel Garvey     rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
43658839f2eSDaniel Garvey                                                  pow);
43706c6758aSLei Zhang     return success();
43806c6758aSLei Zhang   }
43906c6758aSLei Zhang };
44006c6758aSLei Zhang 
441b9e642afSRobert Suderman /// Converts math.round to GLSL SPIRV extended ops.
442b9e642afSRobert Suderman struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
443b9e642afSRobert Suderman   using OpConversionPattern::OpConversionPattern;
444b9e642afSRobert Suderman 
445b9e642afSRobert Suderman   LogicalResult
446b9e642afSRobert Suderman   matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
447b9e642afSRobert Suderman                   ConversionPatternRewriter &rewriter) const override {
4487f7e33c2SJakub Kuderski     if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res))
4497f7e33c2SJakub Kuderski       return res;
4507f7e33c2SJakub Kuderski 
451b9e642afSRobert Suderman     Location loc = roundOp.getLoc();
4527f7e33c2SJakub Kuderski     Value operand = roundOp.getOperand();
4537f7e33c2SJakub Kuderski     Type ty = operand.getType();
4547f7e33c2SJakub Kuderski     Type ety = getElementTypeOrSelf(ty);
455b9e642afSRobert Suderman 
456b9e642afSRobert Suderman     auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
457b9e642afSRobert Suderman     auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
458b9e642afSRobert Suderman     Value half;
4595550c821STres Popp     if (VectorType vty = dyn_cast<VectorType>(ty)) {
460b9e642afSRobert Suderman       half = rewriter.create<spirv::ConstantOp>(
461b9e642afSRobert Suderman           loc, vty,
462b9e642afSRobert Suderman           DenseElementsAttr::get(vty,
463b9e642afSRobert Suderman                                  rewriter.getFloatAttr(ety, 0.5).getValue()));
464b9e642afSRobert Suderman     } else {
465b9e642afSRobert Suderman       half = rewriter.create<spirv::ConstantOp>(
466b9e642afSRobert Suderman           loc, ty, rewriter.getFloatAttr(ety, 0.5));
467b9e642afSRobert Suderman     }
468b9e642afSRobert Suderman 
46952b630daSJakub Kuderski     auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
47052b630daSJakub Kuderski     auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
471b9e642afSRobert Suderman     auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
472b9e642afSRobert Suderman     auto greater =
473b9e642afSRobert Suderman         rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
474b9e642afSRobert Suderman     auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
475b9e642afSRobert Suderman     auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
476b9e642afSRobert Suderman     rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
477b9e642afSRobert Suderman     return success();
478b9e642afSRobert Suderman   }
479b9e642afSRobert Suderman };
480b9e642afSRobert Suderman 
481995c3984SLei Zhang } // namespace
482995c3984SLei Zhang 
483995c3984SLei Zhang //===----------------------------------------------------------------------===//
484995c3984SLei Zhang // Pattern population
485995c3984SLei Zhang //===----------------------------------------------------------------------===//
486995c3984SLei Zhang 
487995c3984SLei Zhang namespace mlir {
488206fad0eSMatthias Springer void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
489995c3984SLei Zhang                                  RewritePatternSet &patterns) {
490533ec929SLei Zhang   // Core patterns
491533ec929SLei Zhang   patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
49275a1bee0SButygin 
49375a1bee0SButygin   // GLSL patterns
494d9edc1a5SThomas Raoux   patterns
49552b630daSJakub Kuderski       .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
4968d165136Smeehatpa            Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
4978d165136Smeehatpa            Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
49852b630daSJakub Kuderski            ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
4997f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
5007f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
50149777d7fSmeehatpa            CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
5027f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
5037f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
5047f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
5057f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
5067f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
5077f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
5087f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
5097f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
5107f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
5117f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
5127f7e33c2SJakub Kuderski            CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
513995c3984SLei Zhang           typeConverter, patterns.getContext());
51475a1bee0SButygin 
51575a1bee0SButygin   // OpenCL patterns
5163930cc68SJakub Kuderski   patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
5178d165136Smeehatpa                Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
5188d165136Smeehatpa                Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
5197f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
520266b5bc1SNishant Patel                CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
52149777d7fSmeehatpa                CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
52249777d7fSmeehatpa                CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
5237f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
5247f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
5257f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
5267f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
5277f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
5287f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
5297f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
5307f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
5317f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
5327f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
5337f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
5347f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
5357f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
5367f7e33c2SJakub Kuderski                CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
53775a1bee0SButygin       typeConverter, patterns.getContext());
538995c3984SLei Zhang }
539995c3984SLei Zhang 
540995c3984SLei Zhang } // namespace mlir
541