1c957fe0fSJakub Kuderski //===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===// 2c957fe0fSJakub Kuderski // 3c957fe0fSJakub Kuderski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c957fe0fSJakub Kuderski // See https://llvm.org/LICENSE.txt for license information. 5c957fe0fSJakub Kuderski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c957fe0fSJakub Kuderski // 7c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===// 8c957fe0fSJakub Kuderski // 9c957fe0fSJakub Kuderski // This file implements SPIR-V transforms used when targetting WebGPU. 10c957fe0fSJakub Kuderski // 11c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===// 12c957fe0fSJakub Kuderski 13c957fe0fSJakub Kuderski #include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h" 14c957fe0fSJakub Kuderski #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 15c957fe0fSJakub Kuderski #include "mlir/Dialect/SPIRV/Transforms/Passes.h" 16c957fe0fSJakub Kuderski #include "mlir/IR/BuiltinAttributes.h" 17c957fe0fSJakub Kuderski #include "mlir/IR/Location.h" 181b822453SJakub Kuderski #include "mlir/IR/PatternMatch.h" 19c957fe0fSJakub Kuderski #include "mlir/IR/TypeUtilities.h" 20c957fe0fSJakub Kuderski #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 2147232beaSJakub Kuderski #include "llvm/ADT/ArrayRef.h" 2247232beaSJakub Kuderski #include "llvm/ADT/STLExtras.h" 23c957fe0fSJakub Kuderski #include "llvm/Support/FormatVariadic.h" 24c957fe0fSJakub Kuderski 2547232beaSJakub Kuderski #include <array> 2647232beaSJakub Kuderski #include <cstdint> 2747232beaSJakub Kuderski 28c957fe0fSJakub Kuderski namespace mlir { 29c957fe0fSJakub Kuderski namespace spirv { 30c957fe0fSJakub Kuderski #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS 31c957fe0fSJakub Kuderski #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" 32c957fe0fSJakub Kuderski } // namespace spirv 33c957fe0fSJakub Kuderski } // namespace mlir 34c957fe0fSJakub Kuderski 35c957fe0fSJakub Kuderski namespace mlir { 36c957fe0fSJakub Kuderski namespace spirv { 37c957fe0fSJakub Kuderski namespace { 38c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===// 39c957fe0fSJakub Kuderski // Helpers 40c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===// 41d61ec513SJakub Kuderski static Attribute getScalarOrSplatAttr(Type type, int64_t value) { 42c957fe0fSJakub Kuderski APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value); 435550c821STres Popp if (auto intTy = dyn_cast<IntegerType>(type)) 44c957fe0fSJakub Kuderski return IntegerAttr::get(intTy, sizedValue); 45c957fe0fSJakub Kuderski 466089d612SRahul Kayaith return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue); 47c957fe0fSJakub Kuderski } 48c957fe0fSJakub Kuderski 49d61ec513SJakub Kuderski static Value lowerExtendedMultiplication(Operation *mulOp, 50d61ec513SJakub Kuderski PatternRewriter &rewriter, Value lhs, 51d61ec513SJakub Kuderski Value rhs, bool signExtendArguments) { 521b822453SJakub Kuderski Location loc = mulOp->getLoc(); 53c957fe0fSJakub Kuderski Type argTy = lhs.getType(); 5447232beaSJakub Kuderski // Emulate 64-bit multiplication by splitting each input element of type i32 5547232beaSJakub Kuderski // into 2 16-bit digits of type i32. This is so that the intermediate 5647232beaSJakub Kuderski // multiplications and additions do not overflow. We extract these 16-bit 5747232beaSJakub Kuderski // digits from i32 vector elements by masking (low digit) and shifting right 5847232beaSJakub Kuderski // (high digit). 59c957fe0fSJakub Kuderski // 6047232beaSJakub Kuderski // The multiplication algorithm used is the standard (long) multiplication. 6147232beaSJakub Kuderski // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit 621b822453SJakub Kuderski // digits. 631b822453SJakub Kuderski // - With zero-extended arguments, we end up emitting only 4 multiplications 641b822453SJakub Kuderski // and 4 additions after constant folding. 651b822453SJakub Kuderski // - With sign-extended arguments, we end up emitting 8 multiplications and 661b822453SJakub Kuderski // and 12 additions after CSE. 67c957fe0fSJakub Kuderski Value cstLowMask = rewriter.create<ConstantOp>( 68c957fe0fSJakub Kuderski loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); 6947232beaSJakub Kuderski auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) { 70c957fe0fSJakub Kuderski return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask); 71c957fe0fSJakub Kuderski }; 72c957fe0fSJakub Kuderski 73c957fe0fSJakub Kuderski Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(), 74c957fe0fSJakub Kuderski getScalarOrSplatAttr(argTy, 16)); 7547232beaSJakub Kuderski auto getHighDigit = [&rewriter, loc, cst16](Value val) { 76c957fe0fSJakub Kuderski return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16); 77c957fe0fSJakub Kuderski }; 78c957fe0fSJakub Kuderski 791b822453SJakub Kuderski auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) { 801b822453SJakub Kuderski // We only need to shift arithmetically by 15, but the extra 811b822453SJakub Kuderski // sign-extension bit will be truncated by the logical shift, so this is 821b822453SJakub Kuderski // fine. We do not have to introduce an extra constant since any 831b822453SJakub Kuderski // value in [15, 32) would do. 841b822453SJakub Kuderski return getHighDigit( 851b822453SJakub Kuderski rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16)); 861b822453SJakub Kuderski }; 871b822453SJakub Kuderski 8847232beaSJakub Kuderski Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(), 8947232beaSJakub Kuderski getScalarOrSplatAttr(argTy, 0)); 90c957fe0fSJakub Kuderski 9147232beaSJakub Kuderski Value lhsLow = getLowDigit(lhs); 9247232beaSJakub Kuderski Value lhsHigh = getHighDigit(lhs); 931b822453SJakub Kuderski Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0; 9447232beaSJakub Kuderski Value rhsLow = getLowDigit(rhs); 9547232beaSJakub Kuderski Value rhsHigh = getHighDigit(rhs); 961b822453SJakub Kuderski Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0; 9747232beaSJakub Kuderski 981b822453SJakub Kuderski std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt}; 991b822453SJakub Kuderski std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt}; 10047232beaSJakub Kuderski std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0}; 10147232beaSJakub Kuderski 10247232beaSJakub Kuderski for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) { 10347232beaSJakub Kuderski for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) { 1041b822453SJakub Kuderski if (i + j >= resultDigits.size()) 1051b822453SJakub Kuderski continue; 1061b822453SJakub Kuderski 1071b822453SJakub Kuderski if (lhsDigit == cst0 || rhsDigit == cst0) 1081b822453SJakub Kuderski continue; 1091b822453SJakub Kuderski 11047232beaSJakub Kuderski Value &thisResDigit = resultDigits[i + j]; 11147232beaSJakub Kuderski Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit); 11247232beaSJakub Kuderski Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul); 11347232beaSJakub Kuderski thisResDigit = getLowDigit(current); 11447232beaSJakub Kuderski 11547232beaSJakub Kuderski if (i + j + 1 != resultDigits.size()) { 11647232beaSJakub Kuderski Value &nextResDigit = resultDigits[i + j + 1]; 11747232beaSJakub Kuderski Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit, 11847232beaSJakub Kuderski getHighDigit(current)); 11947232beaSJakub Kuderski nextResDigit = carry; 12047232beaSJakub Kuderski } 12147232beaSJakub Kuderski } 12247232beaSJakub Kuderski } 12347232beaSJakub Kuderski 12447232beaSJakub Kuderski auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) { 12547232beaSJakub Kuderski Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16); 12647232beaSJakub Kuderski return rewriter.create<BitwiseOrOp>(loc, low, highBits); 12747232beaSJakub Kuderski }; 12847232beaSJakub Kuderski Value low = combineDigits(resultDigits[0], resultDigits[1]); 12947232beaSJakub Kuderski Value high = combineDigits(resultDigits[2], resultDigits[3]); 130c957fe0fSJakub Kuderski 1311b822453SJakub Kuderski return rewriter.create<CompositeConstructOp>( 13236117cc4Sserge-sans-paille loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high})); 1331b822453SJakub Kuderski } 1341b822453SJakub Kuderski 1351b822453SJakub Kuderski //===----------------------------------------------------------------------===// 1361b822453SJakub Kuderski // Rewrite Patterns 1371b822453SJakub Kuderski //===----------------------------------------------------------------------===// 1381b822453SJakub Kuderski 1391b822453SJakub Kuderski template <typename MulExtendedOp, bool SignExtendArguments> 1401b822453SJakub Kuderski struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> { 1411b822453SJakub Kuderski using OpRewritePattern<MulExtendedOp>::OpRewritePattern; 1421b822453SJakub Kuderski 1431b822453SJakub Kuderski LogicalResult matchAndRewrite(MulExtendedOp op, 1441b822453SJakub Kuderski PatternRewriter &rewriter) const override { 1451b822453SJakub Kuderski Location loc = op->getLoc(); 1461b822453SJakub Kuderski Value lhs = op.getOperand1(); 1471b822453SJakub Kuderski Value rhs = op.getOperand2(); 1481b822453SJakub Kuderski 1491b822453SJakub Kuderski // Currently, WGSL only supports 32-bit integer types. Any other integer 1501b822453SJakub Kuderski // types should already have been promoted/demoted to i32. 1515550c821STres Popp auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType())); 1521b822453SJakub Kuderski if (elemTy.getIntOrFloatBitWidth() != 32) 1531b822453SJakub Kuderski return rewriter.notifyMatchFailure( 1541b822453SJakub Kuderski loc, 1551b822453SJakub Kuderski llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy)); 1561b822453SJakub Kuderski 1571b822453SJakub Kuderski Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs, 1581b822453SJakub Kuderski SignExtendArguments); 1591b822453SJakub Kuderski rewriter.replaceOp(op, mul); 160c957fe0fSJakub Kuderski return success(); 161c957fe0fSJakub Kuderski } 162c957fe0fSJakub Kuderski }; 163c957fe0fSJakub Kuderski 1641b822453SJakub Kuderski using ExpandSMulExtendedPattern = 1651b822453SJakub Kuderski ExpandMulExtendedPattern<SMulExtendedOp, true>; 1661b822453SJakub Kuderski using ExpandUMulExtendedPattern = 1671b822453SJakub Kuderski ExpandMulExtendedPattern<UMulExtendedOp, false>; 1681b822453SJakub Kuderski 1696ddc03d9SFinn Plummer struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> { 1706ddc03d9SFinn Plummer using OpRewritePattern<IAddCarryOp>::OpRewritePattern; 1716ddc03d9SFinn Plummer 1726ddc03d9SFinn Plummer LogicalResult matchAndRewrite(IAddCarryOp op, 1736ddc03d9SFinn Plummer PatternRewriter &rewriter) const override { 1746ddc03d9SFinn Plummer Location loc = op->getLoc(); 1756ddc03d9SFinn Plummer Value lhs = op.getOperand1(); 1766ddc03d9SFinn Plummer Value rhs = op.getOperand2(); 1776ddc03d9SFinn Plummer 1786ddc03d9SFinn Plummer // Currently, WGSL only supports 32-bit integer types. Any other integer 1796ddc03d9SFinn Plummer // types should already have been promoted/demoted to i32. 1806ddc03d9SFinn Plummer Type argTy = lhs.getType(); 1816ddc03d9SFinn Plummer auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy)); 1826ddc03d9SFinn Plummer if (elemTy.getIntOrFloatBitWidth() != 32) 1836ddc03d9SFinn Plummer return rewriter.notifyMatchFailure( 1846ddc03d9SFinn Plummer loc, 1856ddc03d9SFinn Plummer llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy)); 1866ddc03d9SFinn Plummer 1876ddc03d9SFinn Plummer Value one = 1886ddc03d9SFinn Plummer rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1)); 1896ddc03d9SFinn Plummer Value zero = 1906ddc03d9SFinn Plummer rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0)); 1916ddc03d9SFinn Plummer 1926ddc03d9SFinn Plummer // Calculate the carry by checking if the addition resulted in an overflow. 1936ddc03d9SFinn Plummer Value out = rewriter.create<IAddOp>(loc, lhs, rhs); 1946ddc03d9SFinn Plummer Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs); 1956ddc03d9SFinn Plummer Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero); 1966ddc03d9SFinn Plummer 1976ddc03d9SFinn Plummer Value add = rewriter.create<CompositeConstructOp>( 1986ddc03d9SFinn Plummer loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry})); 1996ddc03d9SFinn Plummer 2006ddc03d9SFinn Plummer rewriter.replaceOp(op, add); 2016ddc03d9SFinn Plummer return success(); 2026ddc03d9SFinn Plummer } 2036ddc03d9SFinn Plummer }; 2046ddc03d9SFinn Plummer 205d61ec513SJakub Kuderski struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> { 206d61ec513SJakub Kuderski using OpRewritePattern::OpRewritePattern; 207d61ec513SJakub Kuderski 208d61ec513SJakub Kuderski LogicalResult matchAndRewrite(IsInfOp op, 209d61ec513SJakub Kuderski PatternRewriter &rewriter) const override { 210d61ec513SJakub Kuderski // We assume values to be finite and turn `IsInf` info `false`. 211d61ec513SJakub Kuderski rewriter.replaceOpWithNewOp<spirv::ConstantOp>( 212d61ec513SJakub Kuderski op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); 213d61ec513SJakub Kuderski return success(); 214d61ec513SJakub Kuderski } 215d61ec513SJakub Kuderski }; 216d61ec513SJakub Kuderski 217d61ec513SJakub Kuderski struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> { 218d61ec513SJakub Kuderski using OpRewritePattern::OpRewritePattern; 219d61ec513SJakub Kuderski 220d61ec513SJakub Kuderski LogicalResult matchAndRewrite(IsNanOp op, 221d61ec513SJakub Kuderski PatternRewriter &rewriter) const override { 222d61ec513SJakub Kuderski // We assume values to be finite and turn `IsNan` info `false`. 223d61ec513SJakub Kuderski rewriter.replaceOpWithNewOp<spirv::ConstantOp>( 224d61ec513SJakub Kuderski op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); 225d61ec513SJakub Kuderski return success(); 226d61ec513SJakub Kuderski } 227d61ec513SJakub Kuderski }; 228d61ec513SJakub Kuderski 229c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===// 230c957fe0fSJakub Kuderski // Passes 231c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===// 232d61ec513SJakub Kuderski struct WebGPUPreparePass final 233d61ec513SJakub Kuderski : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> { 234c957fe0fSJakub Kuderski void runOnOperation() override { 235c957fe0fSJakub Kuderski RewritePatternSet patterns(&getContext()); 236c957fe0fSJakub Kuderski populateSPIRVExpandExtendedMultiplicationPatterns(patterns); 237d61ec513SJakub Kuderski populateSPIRVExpandNonFiniteArithmeticPatterns(patterns); 238c957fe0fSJakub Kuderski 239*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 240c957fe0fSJakub Kuderski signalPassFailure(); 241c957fe0fSJakub Kuderski } 242c957fe0fSJakub Kuderski }; 243c957fe0fSJakub Kuderski } // namespace 244c957fe0fSJakub Kuderski 245c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===// 246c957fe0fSJakub Kuderski // Public Interface 247c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===// 248c957fe0fSJakub Kuderski void populateSPIRVExpandExtendedMultiplicationPatterns( 249c957fe0fSJakub Kuderski RewritePatternSet &patterns) { 250c957fe0fSJakub Kuderski // WGSL currently does not support extended multiplication ops, see: 251c957fe0fSJakub Kuderski // https://github.com/gpuweb/gpuweb/issues/1565. 252d61ec513SJakub Kuderski patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern, 253d61ec513SJakub Kuderski ExpandAddCarryPattern>(patterns.getContext()); 254c957fe0fSJakub Kuderski } 255d61ec513SJakub Kuderski 256d61ec513SJakub Kuderski void populateSPIRVExpandNonFiniteArithmeticPatterns( 257d61ec513SJakub Kuderski RewritePatternSet &patterns) { 258d61ec513SJakub Kuderski // WGSL currently does not support `isInf` and `isNan`, see: 259d61ec513SJakub Kuderski // https://github.com/gpuweb/gpuweb/pull/2311. 260d61ec513SJakub Kuderski patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext()); 261d61ec513SJakub Kuderski } 262d61ec513SJakub Kuderski 263c957fe0fSJakub Kuderski } // namespace spirv 264c957fe0fSJakub Kuderski } // namespace mlir 265