1995c3984SLei Zhang //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// 2995c3984SLei Zhang // 3995c3984SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4995c3984SLei Zhang // See https://llvm.org/LICENSE.txt for license information. 5995c3984SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6995c3984SLei Zhang // 7995c3984SLei Zhang //===----------------------------------------------------------------------===// 8995c3984SLei Zhang // 9995c3984SLei Zhang // This file implements patterns to convert Math dialect to SPIR-V dialect. 10995c3984SLei Zhang // 11995c3984SLei Zhang //===----------------------------------------------------------------------===// 12995c3984SLei Zhang 13a54f4eaeSMogball #include "../SPIRVCommon/Pattern.h" 14995c3984SLei Zhang #include "mlir/Dialect/Math/IR/Math.h" 15995c3984SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16995c3984SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17995c3984SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18533ec929SLei Zhang #include "mlir/IR/BuiltinTypes.h" 19b9e642afSRobert Suderman #include "mlir/IR/TypeUtilities.h" 20cc020a22SLei Zhang #include "mlir/Transforms/DialectConversion.h" 217f7e33c2SJakub Kuderski #include "llvm/ADT/STLExtras.h" 22995c3984SLei Zhang #include "llvm/Support/Debug.h" 237f7e33c2SJakub Kuderski #include "llvm/Support/FormatVariadic.h" 24995c3984SLei Zhang 25995c3984SLei Zhang #define DEBUG_TYPE "math-to-spirv-pattern" 26995c3984SLei Zhang 27995c3984SLei Zhang using namespace mlir; 28995c3984SLei Zhang 29995c3984SLei Zhang //===----------------------------------------------------------------------===// 30cc020a22SLei Zhang // Utility functions 31cc020a22SLei Zhang //===----------------------------------------------------------------------===// 32cc020a22SLei Zhang 33cc020a22SLei Zhang /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the 34cc020a22SLei Zhang /// given type is not a 32-bit scalar/vector type. 35cc020a22SLei Zhang static Value getScalarOrVectorI32Constant(Type type, int value, 36cc020a22SLei Zhang OpBuilder &builder, Location loc) { 375550c821STres Popp if (auto vectorType = dyn_cast<VectorType>(type)) { 38cc020a22SLei Zhang if (!vectorType.getElementType().isInteger(32)) 39cc020a22SLei Zhang return nullptr; 40cc020a22SLei Zhang SmallVector<int> values(vectorType.getNumElements(), value); 41cc020a22SLei Zhang return builder.create<spirv::ConstantOp>(loc, type, 42cc020a22SLei Zhang builder.getI32VectorAttr(values)); 43cc020a22SLei Zhang } 44cc020a22SLei Zhang if (type.isInteger(32)) 45cc020a22SLei Zhang return builder.create<spirv::ConstantOp>(loc, type, 46cc020a22SLei Zhang builder.getI32IntegerAttr(value)); 47cc020a22SLei Zhang 48cc020a22SLei Zhang return nullptr; 49cc020a22SLei Zhang } 50cc020a22SLei Zhang 517f7e33c2SJakub Kuderski /// Check if the type is supported by math-to-spirv conversion. We expect to 527f7e33c2SJakub Kuderski /// only see scalars and vectors at this point, with higher-level types already 537f7e33c2SJakub Kuderski /// lowered. 547f7e33c2SJakub Kuderski static bool isSupportedSourceType(Type originalType) { 557f7e33c2SJakub Kuderski if (originalType.isIntOrIndexOrFloat()) 567f7e33c2SJakub Kuderski return true; 577f7e33c2SJakub Kuderski 585550c821STres Popp if (auto vecTy = dyn_cast<VectorType>(originalType)) { 597f7e33c2SJakub Kuderski if (!vecTy.getElementType().isIntOrIndexOrFloat()) 607f7e33c2SJakub Kuderski return false; 617f7e33c2SJakub Kuderski if (vecTy.isScalable()) 627f7e33c2SJakub Kuderski return false; 637f7e33c2SJakub Kuderski if (vecTy.getRank() > 1) 647f7e33c2SJakub Kuderski return false; 657f7e33c2SJakub Kuderski 667f7e33c2SJakub Kuderski return true; 677f7e33c2SJakub Kuderski } 687f7e33c2SJakub Kuderski 697f7e33c2SJakub Kuderski return false; 707f7e33c2SJakub Kuderski } 717f7e33c2SJakub Kuderski 727f7e33c2SJakub Kuderski /// Check if all `sourceOp` types are supported by math-to-spirv conversion. 737f7e33c2SJakub Kuderski /// Notify of a match failure othwerise and return a `failure` result. 747f7e33c2SJakub Kuderski /// This is intended to simplify type checks in `OpConversionPattern`s. 757f7e33c2SJakub Kuderski static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, 767f7e33c2SJakub Kuderski Operation *sourceOp) { 777f7e33c2SJakub Kuderski auto allTypes = llvm::to_vector(sourceOp->getOperandTypes()); 787f7e33c2SJakub Kuderski llvm::append_range(allTypes, sourceOp->getResultTypes()); 797f7e33c2SJakub Kuderski 807f7e33c2SJakub Kuderski for (Type ty : allTypes) { 817f7e33c2SJakub Kuderski if (!isSupportedSourceType(ty)) { 827f7e33c2SJakub Kuderski return rewriter.notifyMatchFailure( 837f7e33c2SJakub Kuderski sourceOp, 847f7e33c2SJakub Kuderski llvm::formatv( 857f7e33c2SJakub Kuderski "unsupported source type for Math to SPIR-V conversion: {0}", 867f7e33c2SJakub Kuderski ty)); 877f7e33c2SJakub Kuderski } 887f7e33c2SJakub Kuderski } 897f7e33c2SJakub Kuderski 907f7e33c2SJakub Kuderski return success(); 917f7e33c2SJakub Kuderski } 927f7e33c2SJakub Kuderski 93cc020a22SLei Zhang //===----------------------------------------------------------------------===// 94995c3984SLei Zhang // Operation conversion 95995c3984SLei Zhang //===----------------------------------------------------------------------===// 96995c3984SLei Zhang 97995c3984SLei Zhang // Note that DRR cannot be used for the patterns in this file: we may need to 98995c3984SLei Zhang // convert type along the way, which requires ConversionPattern. DRR generates 99995c3984SLei Zhang // normal RewritePattern. 100995c3984SLei Zhang 101995c3984SLei Zhang namespace { 1027f7e33c2SJakub Kuderski /// Converts elementwise unary, binary, and ternary standard operations to 1037f7e33c2SJakub Kuderski /// SPIR-V operations. Checks that source `Op` types are supported. 1047f7e33c2SJakub Kuderski template <typename Op, typename SPIRVOp> 1057f7e33c2SJakub Kuderski struct CheckedElementwiseOpPattern final 1067f7e33c2SJakub Kuderski : public spirv::ElementwiseOpPattern<Op, SPIRVOp> { 1077f7e33c2SJakub Kuderski using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>; 1087f7e33c2SJakub Kuderski using BasePattern::BasePattern; 1097f7e33c2SJakub Kuderski 1107f7e33c2SJakub Kuderski LogicalResult 1117f7e33c2SJakub Kuderski matchAndRewrite(Op op, typename Op::Adaptor adaptor, 1127f7e33c2SJakub Kuderski ConversionPatternRewriter &rewriter) const override { 1137f7e33c2SJakub Kuderski if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res)) 1147f7e33c2SJakub Kuderski return res; 1157f7e33c2SJakub Kuderski 1167f7e33c2SJakub Kuderski return BasePattern::matchAndRewrite(op, adaptor, rewriter); 1177f7e33c2SJakub Kuderski } 1187f7e33c2SJakub Kuderski }; 1197f7e33c2SJakub Kuderski 120533ec929SLei Zhang /// Converts math.copysign to SPIR-V ops. 1217f7e33c2SJakub Kuderski struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> { 122533ec929SLei Zhang using OpConversionPattern::OpConversionPattern; 123533ec929SLei Zhang 124533ec929SLei Zhang LogicalResult 125533ec929SLei Zhang matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor, 126533ec929SLei Zhang ConversionPatternRewriter &rewriter) const override { 1277f7e33c2SJakub Kuderski if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp); 1287f7e33c2SJakub Kuderski failed(res)) 1297f7e33c2SJakub Kuderski return res; 1307f7e33c2SJakub Kuderski 1317f7e33c2SJakub Kuderski Type type = getTypeConverter()->convertType(copySignOp.getType()); 132533ec929SLei Zhang if (!type) 133533ec929SLei Zhang return failure(); 134533ec929SLei Zhang 135533ec929SLei Zhang FloatType floatType; 1365550c821STres Popp if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) { 137533ec929SLei Zhang floatType = scalarType; 1385550c821STres Popp } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) { 1395550c821STres Popp floatType = cast<FloatType>(vectorType.getElementType()); 140533ec929SLei Zhang } else { 141533ec929SLei Zhang return failure(); 142533ec929SLei Zhang } 143533ec929SLei Zhang 144533ec929SLei Zhang Location loc = copySignOp.getLoc(); 145533ec929SLei Zhang int bitwidth = floatType.getWidth(); 146533ec929SLei Zhang Type intType = rewriter.getIntegerType(bitwidth); 1475f14aee3SStella Stamenova uint64_t intValue = uint64_t(1) << (bitwidth - 1); 148533ec929SLei Zhang 149533ec929SLei Zhang Value signMask = rewriter.create<spirv::ConstantOp>( 1505f14aee3SStella Stamenova loc, intType, rewriter.getIntegerAttr(intType, intValue)); 151533ec929SLei Zhang Value valueMask = rewriter.create<spirv::ConstantOp>( 1525f14aee3SStella Stamenova loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); 153533ec929SLei Zhang 1545550c821STres Popp if (auto vectorType = dyn_cast<VectorType>(type)) { 155533ec929SLei Zhang assert(vectorType.getRank() == 1); 156533ec929SLei Zhang int count = vectorType.getNumElements(); 157533ec929SLei Zhang intType = VectorType::get(count, intType); 158533ec929SLei Zhang 159533ec929SLei Zhang SmallVector<Value> signSplat(count, signMask); 160533ec929SLei Zhang signMask = 161533ec929SLei Zhang rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat); 162533ec929SLei Zhang 163533ec929SLei Zhang SmallVector<Value> valueSplat(count, valueMask); 164533ec929SLei Zhang valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType, 165533ec929SLei Zhang valueSplat); 166533ec929SLei Zhang } 167533ec929SLei Zhang 168533ec929SLei Zhang Value lhsCast = 169533ec929SLei Zhang rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs()); 170533ec929SLei Zhang Value rhsCast = 171533ec929SLei Zhang rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs()); 172533ec929SLei Zhang 173533ec929SLei Zhang Value value = rewriter.create<spirv::BitwiseAndOp>( 174533ec929SLei Zhang loc, intType, ValueRange{lhsCast, valueMask}); 175533ec929SLei Zhang Value sign = rewriter.create<spirv::BitwiseAndOp>( 176533ec929SLei Zhang loc, intType, ValueRange{rhsCast, signMask}); 177533ec929SLei Zhang 178533ec929SLei Zhang Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType, 179533ec929SLei Zhang ValueRange{value, sign}); 180533ec929SLei Zhang rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result); 181533ec929SLei Zhang return success(); 182533ec929SLei Zhang } 183533ec929SLei Zhang }; 184533ec929SLei Zhang 185cc020a22SLei Zhang /// Converts math.ctlz to SPIR-V ops. 186cc020a22SLei Zhang /// 187cc020a22SLei Zhang /// SPIR-V does not have a direct operations for counting leading zeros. If 18852b630daSJakub Kuderski /// Shader capability is supported, we can leverage GL FindUMsb to calculate 189cc020a22SLei Zhang /// it. 1907f7e33c2SJakub Kuderski struct CountLeadingZerosPattern final 191cc020a22SLei Zhang : public OpConversionPattern<math::CountLeadingZerosOp> { 192cc020a22SLei Zhang using OpConversionPattern::OpConversionPattern; 193cc020a22SLei Zhang 194cc020a22SLei Zhang LogicalResult 195cc020a22SLei Zhang matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor, 196cc020a22SLei Zhang ConversionPatternRewriter &rewriter) const override { 1977f7e33c2SJakub Kuderski if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res)) 1987f7e33c2SJakub Kuderski return res; 1997f7e33c2SJakub Kuderski 2007f7e33c2SJakub Kuderski Type type = getTypeConverter()->convertType(countOp.getType()); 201cc020a22SLei Zhang if (!type) 202cc020a22SLei Zhang return failure(); 203cc020a22SLei Zhang 204cc020a22SLei Zhang // We can only support 32-bit integer types for now. 205cc020a22SLei Zhang unsigned bitwidth = 0; 2065550c821STres Popp if (isa<IntegerType>(type)) 207cc020a22SLei Zhang bitwidth = type.getIntOrFloatBitWidth(); 2085550c821STres Popp if (auto vectorType = dyn_cast<VectorType>(type)) 209cc020a22SLei Zhang bitwidth = vectorType.getElementTypeBitWidth(); 210cc020a22SLei Zhang if (bitwidth != 32) 211cc020a22SLei Zhang return failure(); 212cc020a22SLei Zhang 213cc020a22SLei Zhang Location loc = countOp.getLoc(); 2142320a4aeSLei Zhang Value input = adaptor.getOperand(); 2152320a4aeSLei Zhang Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc); 216cc020a22SLei Zhang Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); 2172320a4aeSLei Zhang Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); 2182320a4aeSLei Zhang 21952b630daSJakub Kuderski Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input); 2202320a4aeSLei Zhang // We need to subtract from 31 given that the index returned by GLSL 2212320a4aeSLei Zhang // FindUMsb is counted from the least significant bit. Theoretically this 2222320a4aeSLei Zhang // also gives the correct result even if the integer has all zero bits, in 22352b630daSJakub Kuderski // which case GL FindUMsb would return -1. 2242320a4aeSLei Zhang Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb); 2252320a4aeSLei Zhang // However, certain Vulkan implementations have driver bugs for the corner 2262320a4aeSLei Zhang // case where the input is zero. And.. it can be smart to optimize a select 2272320a4aeSLei Zhang // only involving the corner case. So separately compute the result when the 2282320a4aeSLei Zhang // input is either zero or one. 2292320a4aeSLei Zhang Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input); 2302320a4aeSLei Zhang Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1); 2312320a4aeSLei Zhang rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput, 2322320a4aeSLei Zhang subMsb); 233cc020a22SLei Zhang return success(); 234cc020a22SLei Zhang } 235cc020a22SLei Zhang }; 236cc020a22SLei Zhang 2373e746c6dSRob Suderman /// Converts math.expm1 to SPIR-V ops. 2383e746c6dSRob Suderman /// 2393e746c6dSRob Suderman /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to 2403e746c6dSRob Suderman /// these operations. 2413e746c6dSRob Suderman template <typename ExpOp> 242533ec929SLei Zhang struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> { 243533ec929SLei Zhang using OpConversionPattern::OpConversionPattern; 2443e746c6dSRob Suderman 2453e746c6dSRob Suderman LogicalResult 2463e746c6dSRob Suderman matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, 2473e746c6dSRob Suderman ConversionPatternRewriter &rewriter) const override { 2483e746c6dSRob Suderman assert(adaptor.getOperands().size() == 1); 2497f7e33c2SJakub Kuderski if (LogicalResult res = checkSourceOpTypes(rewriter, operation); 2507f7e33c2SJakub Kuderski failed(res)) 2517f7e33c2SJakub Kuderski return res; 2527f7e33c2SJakub Kuderski 2533e746c6dSRob Suderman Location loc = operation.getLoc(); 2547f7e33c2SJakub Kuderski Type type = this->getTypeConverter()->convertType(operation.getType()); 2557f7e33c2SJakub Kuderski if (!type) 2567f7e33c2SJakub Kuderski return failure(); 2577f7e33c2SJakub Kuderski 2587f7e33c2SJakub Kuderski Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand()); 2593e746c6dSRob Suderman auto one = spirv::ConstantOp::getOne(type, loc, rewriter); 2603e746c6dSRob Suderman rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one); 2613e746c6dSRob Suderman return success(); 2623e746c6dSRob Suderman } 2633e746c6dSRob Suderman }; 2643e746c6dSRob Suderman 265995c3984SLei Zhang /// Converts math.log1p to SPIR-V ops. 266995c3984SLei Zhang /// 267995c3984SLei Zhang /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to 268995c3984SLei Zhang /// these operations. 26975a1bee0SButygin template <typename LogOp> 270533ec929SLei Zhang struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { 271533ec929SLei Zhang using OpConversionPattern::OpConversionPattern; 272995c3984SLei Zhang 273995c3984SLei Zhang LogicalResult 274b54c724bSRiver Riddle matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, 275995c3984SLei Zhang ConversionPatternRewriter &rewriter) const override { 276b54c724bSRiver Riddle assert(adaptor.getOperands().size() == 1); 2777f7e33c2SJakub Kuderski if (LogicalResult res = checkSourceOpTypes(rewriter, operation); 2787f7e33c2SJakub Kuderski failed(res)) 2797f7e33c2SJakub Kuderski return res; 2807f7e33c2SJakub Kuderski 281995c3984SLei Zhang Location loc = operation.getLoc(); 2827f7e33c2SJakub Kuderski Type type = this->getTypeConverter()->convertType(operation.getType()); 2837f7e33c2SJakub Kuderski if (!type) 2847f7e33c2SJakub Kuderski return failure(); 2857f7e33c2SJakub Kuderski 286995c3984SLei Zhang auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); 2877f7e33c2SJakub Kuderski Value onePlus = 2883e746c6dSRob Suderman rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand()); 28975a1bee0SButygin rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus); 290995c3984SLei Zhang return success(); 291995c3984SLei Zhang } 292995c3984SLei Zhang }; 29306c6758aSLei Zhang 2948d165136Smeehatpa /// Converts math.log2 and math.log10 to SPIR-V ops. 2958d165136Smeehatpa /// 2968d165136Smeehatpa /// SPIR-V does not have direct operations for log2 and log10. Explicitly 2978d165136Smeehatpa /// lower to these operations using: 2988d165136Smeehatpa /// log2(x) = log(x) * 1/log(2) 2998d165136Smeehatpa /// log10(x) = log(x) * 1/log(10) 3008d165136Smeehatpa 3018d165136Smeehatpa template <typename MathLogOp, typename SpirvLogOp> 3028d165136Smeehatpa struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> { 3038d165136Smeehatpa using OpConversionPattern<MathLogOp>::OpConversionPattern; 3048d165136Smeehatpa using typename OpConversionPattern<MathLogOp>::OpAdaptor; 3058d165136Smeehatpa 3068d165136Smeehatpa static constexpr double log2Reciprocal = 3078d165136Smeehatpa 1.442695040888963407359924681001892137426645954152985934135449407; 3088d165136Smeehatpa static constexpr double log10Reciprocal = 3098d165136Smeehatpa 0.4342944819032518276511289189166050822943970058036665661144537832; 3108d165136Smeehatpa 3118d165136Smeehatpa LogicalResult 3128d165136Smeehatpa matchAndRewrite(MathLogOp operation, OpAdaptor adaptor, 3138d165136Smeehatpa ConversionPatternRewriter &rewriter) const override { 3148d165136Smeehatpa assert(adaptor.getOperands().size() == 1); 3158d165136Smeehatpa if (LogicalResult res = checkSourceOpTypes(rewriter, operation); 3168d165136Smeehatpa failed(res)) 3178d165136Smeehatpa return res; 3188d165136Smeehatpa 3198d165136Smeehatpa Location loc = operation.getLoc(); 3208d165136Smeehatpa Type type = this->getTypeConverter()->convertType(operation.getType()); 3218d165136Smeehatpa if (!type) 3228d165136Smeehatpa return rewriter.notifyMatchFailure(operation, "type conversion failed"); 3238d165136Smeehatpa 3248d165136Smeehatpa auto getConstantValue = [&](double value) { 3258d165136Smeehatpa if (auto floatType = dyn_cast<FloatType>(type)) { 3268d165136Smeehatpa return rewriter.create<spirv::ConstantOp>( 3278d165136Smeehatpa loc, type, rewriter.getFloatAttr(floatType, value)); 3288d165136Smeehatpa } 3298d165136Smeehatpa if (auto vectorType = dyn_cast<VectorType>(type)) { 3308d165136Smeehatpa Type elemType = vectorType.getElementType(); 3318d165136Smeehatpa 3328d165136Smeehatpa if (isa<FloatType>(elemType)) { 3338d165136Smeehatpa return rewriter.create<spirv::ConstantOp>( 3348d165136Smeehatpa loc, type, 3358d165136Smeehatpa DenseFPElementsAttr::get( 3368d165136Smeehatpa vectorType, FloatAttr::get(elemType, value).getValue())); 3378d165136Smeehatpa } 3388d165136Smeehatpa } 3398d165136Smeehatpa 3408d165136Smeehatpa llvm_unreachable("unimplemented types for log2/log10"); 3418d165136Smeehatpa }; 3428d165136Smeehatpa 3438d165136Smeehatpa Value constantValue = getConstantValue( 3448d165136Smeehatpa std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal 3458d165136Smeehatpa : log10Reciprocal); 3468d165136Smeehatpa Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand()); 3478d165136Smeehatpa rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log, 3488d165136Smeehatpa constantValue); 3498d165136Smeehatpa return success(); 3508d165136Smeehatpa } 3518d165136Smeehatpa }; 3528d165136Smeehatpa 35306c6758aSLei Zhang /// Converts math.powf to SPIRV-Ops. 35406c6758aSLei Zhang struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { 35506c6758aSLei Zhang using OpConversionPattern::OpConversionPattern; 35606c6758aSLei Zhang 35706c6758aSLei Zhang LogicalResult 35806c6758aSLei Zhang matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor, 35906c6758aSLei Zhang ConversionPatternRewriter &rewriter) const override { 3607f7e33c2SJakub Kuderski if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res)) 3617f7e33c2SJakub Kuderski return res; 3627f7e33c2SJakub Kuderski 3637f7e33c2SJakub Kuderski Type dstType = getTypeConverter()->convertType(powfOp.getType()); 36406c6758aSLei Zhang if (!dstType) 36506c6758aSLei Zhang return failure(); 36606c6758aSLei Zhang 36758839f2eSDaniel Garvey // Get the scalar float type. 36858839f2eSDaniel Garvey FloatType scalarFloatType; 3695550c821STres Popp if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) { 37058839f2eSDaniel Garvey scalarFloatType = scalarType; 3715550c821STres Popp } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) { 3725550c821STres Popp scalarFloatType = cast<FloatType>(vectorType.getElementType()); 37358839f2eSDaniel Garvey } else { 37458839f2eSDaniel Garvey return failure(); 37558839f2eSDaniel Garvey } 37658839f2eSDaniel Garvey 37758839f2eSDaniel Garvey // Get int type of the same shape as the float type. 37858839f2eSDaniel Garvey Type scalarIntType = rewriter.getIntegerType(32); 37958839f2eSDaniel Garvey Type intType = scalarIntType; 380*b9314a82SDmitriy Smirnov auto operandType = adaptor.getRhs().getType(); 381*b9314a82SDmitriy Smirnov if (auto vectorType = dyn_cast<VectorType>(operandType)) { 38258839f2eSDaniel Garvey auto shape = vectorType.getShape(); 38358839f2eSDaniel Garvey intType = VectorType::get(shape, scalarIntType); 38458839f2eSDaniel Garvey } 38558839f2eSDaniel Garvey 38652b630daSJakub Kuderski // Per GL Pow extended instruction spec: 38706c6758aSLei Zhang // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0." 38806c6758aSLei Zhang Location loc = powfOp.getLoc(); 389*b9314a82SDmitriy Smirnov Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter); 39006c6758aSLei Zhang Value lessThan = 39106c6758aSLei Zhang rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero); 392*b9314a82SDmitriy Smirnov 393*b9314a82SDmitriy Smirnov // Per C/C++ spec: 394*b9314a82SDmitriy Smirnov // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is 395*b9314a82SDmitriy Smirnov // > finite and negative and exponent is finite and non-integer. 396*b9314a82SDmitriy Smirnov // Calculate the reminder from the exponent and check whether it is zero. 397*b9314a82SDmitriy Smirnov Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter); 398*b9314a82SDmitriy Smirnov Value expRem = 399*b9314a82SDmitriy Smirnov rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne); 400*b9314a82SDmitriy Smirnov Value expRemNonZero = 401*b9314a82SDmitriy Smirnov rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero); 402*b9314a82SDmitriy Smirnov Value cmpNegativeWithFractionalExp = 403*b9314a82SDmitriy Smirnov rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan); 404*b9314a82SDmitriy Smirnov // Create NaN result and replace base value if conditions are met. 405*b9314a82SDmitriy Smirnov const auto &floatSemantics = scalarFloatType.getFloatSemantics(); 406*b9314a82SDmitriy Smirnov const auto nan = APFloat::getNaN(floatSemantics); 407*b9314a82SDmitriy Smirnov Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan); 408*b9314a82SDmitriy Smirnov if (auto vectorType = dyn_cast<VectorType>(operandType)) 409*b9314a82SDmitriy Smirnov nanAttr = DenseElementsAttr::get(vectorType, nan); 410*b9314a82SDmitriy Smirnov 411*b9314a82SDmitriy Smirnov Value NanValue = 412*b9314a82SDmitriy Smirnov rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr); 413*b9314a82SDmitriy Smirnov Value lhs = rewriter.create<spirv::SelectOp>( 414*b9314a82SDmitriy Smirnov loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs()); 415*b9314a82SDmitriy Smirnov Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs); 41658839f2eSDaniel Garvey 41758839f2eSDaniel Garvey // TODO: The following just forcefully casts y into an integer value in 41858839f2eSDaniel Garvey // order to properly propagate the sign, assuming integer y cases. It 41958839f2eSDaniel Garvey // doesn't cover other cases and should be fixed. 42058839f2eSDaniel Garvey 42158839f2eSDaniel Garvey // Cast exponent to integer and calculate exponent % 2 != 0. 42258839f2eSDaniel Garvey Value intRhs = 42358839f2eSDaniel Garvey rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs()); 42458839f2eSDaniel Garvey Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter); 42558839f2eSDaniel Garvey Value bitwiseAndOne = 42658839f2eSDaniel Garvey rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne); 42758839f2eSDaniel Garvey Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne); 42858839f2eSDaniel Garvey 42958839f2eSDaniel Garvey // calculate pow based on abs(lhs)^rhs. 43052b630daSJakub Kuderski Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs()); 43106c6758aSLei Zhang Value negate = rewriter.create<spirv::FNegateOp>(loc, pow); 43258839f2eSDaniel Garvey // if the exponent is odd and lhs < 0, negate the result. 43358839f2eSDaniel Garvey Value shouldNegate = 43458839f2eSDaniel Garvey rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd); 43558839f2eSDaniel Garvey rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate, 43658839f2eSDaniel Garvey pow); 43706c6758aSLei Zhang return success(); 43806c6758aSLei Zhang } 43906c6758aSLei Zhang }; 44006c6758aSLei Zhang 441b9e642afSRobert Suderman /// Converts math.round to GLSL SPIRV extended ops. 442b9e642afSRobert Suderman struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> { 443b9e642afSRobert Suderman using OpConversionPattern::OpConversionPattern; 444b9e642afSRobert Suderman 445b9e642afSRobert Suderman LogicalResult 446b9e642afSRobert Suderman matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor, 447b9e642afSRobert Suderman ConversionPatternRewriter &rewriter) const override { 4487f7e33c2SJakub Kuderski if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res)) 4497f7e33c2SJakub Kuderski return res; 4507f7e33c2SJakub Kuderski 451b9e642afSRobert Suderman Location loc = roundOp.getLoc(); 4527f7e33c2SJakub Kuderski Value operand = roundOp.getOperand(); 4537f7e33c2SJakub Kuderski Type ty = operand.getType(); 4547f7e33c2SJakub Kuderski Type ety = getElementTypeOrSelf(ty); 455b9e642afSRobert Suderman 456b9e642afSRobert Suderman auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter); 457b9e642afSRobert Suderman auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); 458b9e642afSRobert Suderman Value half; 4595550c821STres Popp if (VectorType vty = dyn_cast<VectorType>(ty)) { 460b9e642afSRobert Suderman half = rewriter.create<spirv::ConstantOp>( 461b9e642afSRobert Suderman loc, vty, 462b9e642afSRobert Suderman DenseElementsAttr::get(vty, 463b9e642afSRobert Suderman rewriter.getFloatAttr(ety, 0.5).getValue())); 464b9e642afSRobert Suderman } else { 465b9e642afSRobert Suderman half = rewriter.create<spirv::ConstantOp>( 466b9e642afSRobert Suderman loc, ty, rewriter.getFloatAttr(ety, 0.5)); 467b9e642afSRobert Suderman } 468b9e642afSRobert Suderman 46952b630daSJakub Kuderski auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand); 47052b630daSJakub Kuderski auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs); 471b9e642afSRobert Suderman auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor); 472b9e642afSRobert Suderman auto greater = 473b9e642afSRobert Suderman rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half); 474b9e642afSRobert Suderman auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero); 475b9e642afSRobert Suderman auto add = rewriter.create<spirv::FAddOp>(loc, floor, select); 476b9e642afSRobert Suderman rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand); 477b9e642afSRobert Suderman return success(); 478b9e642afSRobert Suderman } 479b9e642afSRobert Suderman }; 480b9e642afSRobert Suderman 481995c3984SLei Zhang } // namespace 482995c3984SLei Zhang 483995c3984SLei Zhang //===----------------------------------------------------------------------===// 484995c3984SLei Zhang // Pattern population 485995c3984SLei Zhang //===----------------------------------------------------------------------===// 486995c3984SLei Zhang 487995c3984SLei Zhang namespace mlir { 488206fad0eSMatthias Springer void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, 489995c3984SLei Zhang RewritePatternSet &patterns) { 490533ec929SLei Zhang // Core patterns 491533ec929SLei Zhang patterns.add<CopySignPattern>(typeConverter, patterns.getContext()); 49275a1bee0SButygin 49375a1bee0SButygin // GLSL patterns 494d9edc1a5SThomas Raoux patterns 49552b630daSJakub Kuderski .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>, 4968d165136Smeehatpa Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>, 4978d165136Smeehatpa Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>, 49852b630daSJakub Kuderski ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern, 4997f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>, 5007f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>, 50149777d7fSmeehatpa CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>, 5027f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>, 5037f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>, 5047f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>, 5057f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>, 5067f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>, 5077f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>, 5087f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>, 5097f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>, 5107f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>, 5117f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>, 5127f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>( 513995c3984SLei Zhang typeConverter, patterns.getContext()); 51475a1bee0SButygin 51575a1bee0SButygin // OpenCL patterns 5163930cc68SJakub Kuderski patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>, 5178d165136Smeehatpa Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>, 5188d165136Smeehatpa Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>, 5197f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>, 520266b5bc1SNishant Patel CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>, 52149777d7fSmeehatpa CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>, 52249777d7fSmeehatpa CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>, 5237f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>, 5247f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>, 5257f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>, 5267f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>, 5277f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>, 5287f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>, 5297f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>, 5307f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>, 5317f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>, 5327f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>, 5337f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>, 5347f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>, 5357f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>, 5367f7e33c2SJakub Kuderski CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>( 53775a1bee0SButygin typeConverter, patterns.getContext()); 538995c3984SLei Zhang } 539995c3984SLei Zhang 540995c3984SLei Zhang } // namespace mlir 541