//===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements SPIR-V transforms used when targetting WebGPU. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include #include namespace mlir { namespace spirv { #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" } // namespace spirv } // namespace mlir namespace mlir { namespace spirv { namespace { //===----------------------------------------------------------------------===// // Helpers //===----------------------------------------------------------------------===// static Attribute getScalarOrSplatAttr(Type type, int64_t value) { APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value); if (auto intTy = dyn_cast(type)) return IntegerAttr::get(intTy, sizedValue); return SplatElementsAttr::get(cast(type), sizedValue); } static Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter, Value lhs, Value rhs, bool signExtendArguments) { Location loc = mulOp->getLoc(); Type argTy = lhs.getType(); // Emulate 64-bit multiplication by splitting each input element of type i32 // into 2 16-bit digits of type i32. This is so that the intermediate // multiplications and additions do not overflow. We extract these 16-bit // digits from i32 vector elements by masking (low digit) and shifting right // (high digit). // // The multiplication algorithm used is the standard (long) multiplication. // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit // digits. // - With zero-extended arguments, we end up emitting only 4 multiplications // and 4 additions after constant folding. // - With sign-extended arguments, we end up emitting 8 multiplications and // and 12 additions after CSE. Value cstLowMask = rewriter.create( loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) { return rewriter.create(loc, val, cstLowMask); }; Value cst16 = rewriter.create(loc, lhs.getType(), getScalarOrSplatAttr(argTy, 16)); auto getHighDigit = [&rewriter, loc, cst16](Value val) { return rewriter.create(loc, val, cst16); }; auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) { // We only need to shift arithmetically by 15, but the extra // sign-extension bit will be truncated by the logical shift, so this is // fine. We do not have to introduce an extra constant since any // value in [15, 32) would do. return getHighDigit( rewriter.create(loc, val, cst16)); }; Value cst0 = rewriter.create(loc, lhs.getType(), getScalarOrSplatAttr(argTy, 0)); Value lhsLow = getLowDigit(lhs); Value lhsHigh = getHighDigit(lhs); Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0; Value rhsLow = getLowDigit(rhs); Value rhsHigh = getHighDigit(rhs); Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0; std::array lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt}; std::array rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt}; std::array resultDigits = {cst0, cst0, cst0, cst0}; for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) { for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) { if (i + j >= resultDigits.size()) continue; if (lhsDigit == cst0 || rhsDigit == cst0) continue; Value &thisResDigit = resultDigits[i + j]; Value mul = rewriter.create(loc, lhsDigit, rhsDigit); Value current = rewriter.createOrFold(loc, thisResDigit, mul); thisResDigit = getLowDigit(current); if (i + j + 1 != resultDigits.size()) { Value &nextResDigit = resultDigits[i + j + 1]; Value carry = rewriter.createOrFold(loc, nextResDigit, getHighDigit(current)); nextResDigit = carry; } } } auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) { Value highBits = rewriter.create(loc, high, cst16); return rewriter.create(loc, low, highBits); }; Value low = combineDigits(resultDigits[0], resultDigits[1]); Value high = combineDigits(resultDigits[2], resultDigits[3]); return rewriter.create( loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high})); } //===----------------------------------------------------------------------===// // Rewrite Patterns //===----------------------------------------------------------------------===// template struct ExpandMulExtendedPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(MulExtendedOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value lhs = op.getOperand1(); Value rhs = op.getOperand2(); // Currently, WGSL only supports 32-bit integer types. Any other integer // types should already have been promoted/demoted to i32. auto elemTy = cast(getElementTypeOrSelf(lhs.getType())); if (elemTy.getIntOrFloatBitWidth() != 32) return rewriter.notifyMatchFailure( loc, llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy)); Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs, SignExtendArguments); rewriter.replaceOp(op, mul); return success(); } }; using ExpandSMulExtendedPattern = ExpandMulExtendedPattern; using ExpandUMulExtendedPattern = ExpandMulExtendedPattern; struct ExpandAddCarryPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(IAddCarryOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value lhs = op.getOperand1(); Value rhs = op.getOperand2(); // Currently, WGSL only supports 32-bit integer types. Any other integer // types should already have been promoted/demoted to i32. Type argTy = lhs.getType(); auto elemTy = cast(getElementTypeOrSelf(argTy)); if (elemTy.getIntOrFloatBitWidth() != 32) return rewriter.notifyMatchFailure( loc, llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy)); Value one = rewriter.create(loc, argTy, getScalarOrSplatAttr(argTy, 1)); Value zero = rewriter.create(loc, argTy, getScalarOrSplatAttr(argTy, 0)); // Calculate the carry by checking if the addition resulted in an overflow. Value out = rewriter.create(loc, lhs, rhs); Value cmp = rewriter.create(loc, out, lhs); Value carry = rewriter.create(loc, cmp, one, zero); Value add = rewriter.create( loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry})); rewriter.replaceOp(op, add); return success(); } }; struct ExpandIsInfPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(IsInfOp op, PatternRewriter &rewriter) const override { // We assume values to be finite and turn `IsInf` info `false`. rewriter.replaceOpWithNewOp( op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); return success(); } }; struct ExpandIsNanPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(IsNanOp op, PatternRewriter &rewriter) const override { // We assume values to be finite and turn `IsNan` info `false`. rewriter.replaceOpWithNewOp( op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); return success(); } }; //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// struct WebGPUPreparePass final : impl::SPIRVWebGPUPreparePassBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateSPIRVExpandExtendedMultiplicationPatterns(patterns); populateSPIRVExpandNonFiniteArithmeticPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } }; } // namespace //===----------------------------------------------------------------------===// // Public Interface //===----------------------------------------------------------------------===// void populateSPIRVExpandExtendedMultiplicationPatterns( RewritePatternSet &patterns) { // WGSL currently does not support extended multiplication ops, see: // https://github.com/gpuweb/gpuweb/issues/1565. patterns.add(patterns.getContext()); } void populateSPIRVExpandNonFiniteArithmeticPatterns( RewritePatternSet &patterns) { // WGSL currently does not support `isInf` and `isNan`, see: // https://github.com/gpuweb/gpuweb/pull/2311. patterns.add(patterns.getContext()); } } // namespace spirv } // namespace mlir