xref: /llvm-project/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1c957fe0fSJakub Kuderski //===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===//
2c957fe0fSJakub Kuderski //
3c957fe0fSJakub Kuderski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c957fe0fSJakub Kuderski // See https://llvm.org/LICENSE.txt for license information.
5c957fe0fSJakub Kuderski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c957fe0fSJakub Kuderski //
7c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===//
8c957fe0fSJakub Kuderski //
9c957fe0fSJakub Kuderski // This file implements SPIR-V transforms used when targetting WebGPU.
10c957fe0fSJakub Kuderski //
11c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===//
12c957fe0fSJakub Kuderski 
13c957fe0fSJakub Kuderski #include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h"
14c957fe0fSJakub Kuderski #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15c957fe0fSJakub Kuderski #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
16c957fe0fSJakub Kuderski #include "mlir/IR/BuiltinAttributes.h"
17c957fe0fSJakub Kuderski #include "mlir/IR/Location.h"
181b822453SJakub Kuderski #include "mlir/IR/PatternMatch.h"
19c957fe0fSJakub Kuderski #include "mlir/IR/TypeUtilities.h"
20c957fe0fSJakub Kuderski #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2147232beaSJakub Kuderski #include "llvm/ADT/ArrayRef.h"
2247232beaSJakub Kuderski #include "llvm/ADT/STLExtras.h"
23c957fe0fSJakub Kuderski #include "llvm/Support/FormatVariadic.h"
24c957fe0fSJakub Kuderski 
2547232beaSJakub Kuderski #include <array>
2647232beaSJakub Kuderski #include <cstdint>
2747232beaSJakub Kuderski 
28c957fe0fSJakub Kuderski namespace mlir {
29c957fe0fSJakub Kuderski namespace spirv {
30c957fe0fSJakub Kuderski #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31c957fe0fSJakub Kuderski #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
32c957fe0fSJakub Kuderski } // namespace spirv
33c957fe0fSJakub Kuderski } // namespace mlir
34c957fe0fSJakub Kuderski 
35c957fe0fSJakub Kuderski namespace mlir {
36c957fe0fSJakub Kuderski namespace spirv {
37c957fe0fSJakub Kuderski namespace {
38c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===//
39c957fe0fSJakub Kuderski // Helpers
40c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===//
41d61ec513SJakub Kuderski static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
42c957fe0fSJakub Kuderski   APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
435550c821STres Popp   if (auto intTy = dyn_cast<IntegerType>(type))
44c957fe0fSJakub Kuderski     return IntegerAttr::get(intTy, sizedValue);
45c957fe0fSJakub Kuderski 
466089d612SRahul Kayaith   return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
47c957fe0fSJakub Kuderski }
48c957fe0fSJakub Kuderski 
49d61ec513SJakub Kuderski static Value lowerExtendedMultiplication(Operation *mulOp,
50d61ec513SJakub Kuderski                                          PatternRewriter &rewriter, Value lhs,
51d61ec513SJakub Kuderski                                          Value rhs, bool signExtendArguments) {
521b822453SJakub Kuderski   Location loc = mulOp->getLoc();
53c957fe0fSJakub Kuderski   Type argTy = lhs.getType();
5447232beaSJakub Kuderski   // Emulate 64-bit multiplication by splitting each input element of type i32
5547232beaSJakub Kuderski   // into 2 16-bit digits of type i32. This is so that the intermediate
5647232beaSJakub Kuderski   // multiplications and additions do not overflow. We extract these 16-bit
5747232beaSJakub Kuderski   // digits from i32 vector elements by masking (low digit) and shifting right
5847232beaSJakub Kuderski   // (high digit).
59c957fe0fSJakub Kuderski   //
6047232beaSJakub Kuderski   // The multiplication algorithm used is the standard (long) multiplication.
6147232beaSJakub Kuderski   // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
621b822453SJakub Kuderski   // digits.
631b822453SJakub Kuderski   //   - With zero-extended arguments, we end up emitting only 4 multiplications
641b822453SJakub Kuderski   //     and 4 additions after constant folding.
651b822453SJakub Kuderski   //   - With sign-extended arguments, we end up emitting 8 multiplications and
661b822453SJakub Kuderski   //     and 12 additions after CSE.
67c957fe0fSJakub Kuderski   Value cstLowMask = rewriter.create<ConstantOp>(
68c957fe0fSJakub Kuderski       loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
6947232beaSJakub Kuderski   auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
70c957fe0fSJakub Kuderski     return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
71c957fe0fSJakub Kuderski   };
72c957fe0fSJakub Kuderski 
73c957fe0fSJakub Kuderski   Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
74c957fe0fSJakub Kuderski                                             getScalarOrSplatAttr(argTy, 16));
7547232beaSJakub Kuderski   auto getHighDigit = [&rewriter, loc, cst16](Value val) {
76c957fe0fSJakub Kuderski     return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
77c957fe0fSJakub Kuderski   };
78c957fe0fSJakub Kuderski 
791b822453SJakub Kuderski   auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
801b822453SJakub Kuderski     // We only need to shift arithmetically by 15, but the extra
811b822453SJakub Kuderski     // sign-extension bit will be truncated by the logical shift, so this is
821b822453SJakub Kuderski     // fine. We do not have to introduce an extra constant since any
831b822453SJakub Kuderski     // value in [15, 32) would do.
841b822453SJakub Kuderski     return getHighDigit(
851b822453SJakub Kuderski         rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
861b822453SJakub Kuderski   };
871b822453SJakub Kuderski 
8847232beaSJakub Kuderski   Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
8947232beaSJakub Kuderski                                            getScalarOrSplatAttr(argTy, 0));
90c957fe0fSJakub Kuderski 
9147232beaSJakub Kuderski   Value lhsLow = getLowDigit(lhs);
9247232beaSJakub Kuderski   Value lhsHigh = getHighDigit(lhs);
931b822453SJakub Kuderski   Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
9447232beaSJakub Kuderski   Value rhsLow = getLowDigit(rhs);
9547232beaSJakub Kuderski   Value rhsHigh = getHighDigit(rhs);
961b822453SJakub Kuderski   Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
9747232beaSJakub Kuderski 
981b822453SJakub Kuderski   std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
991b822453SJakub Kuderski   std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
10047232beaSJakub Kuderski   std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
10147232beaSJakub Kuderski 
10247232beaSJakub Kuderski   for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
10347232beaSJakub Kuderski     for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
1041b822453SJakub Kuderski       if (i + j >= resultDigits.size())
1051b822453SJakub Kuderski         continue;
1061b822453SJakub Kuderski 
1071b822453SJakub Kuderski       if (lhsDigit == cst0 || rhsDigit == cst0)
1081b822453SJakub Kuderski         continue;
1091b822453SJakub Kuderski 
11047232beaSJakub Kuderski       Value &thisResDigit = resultDigits[i + j];
11147232beaSJakub Kuderski       Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
11247232beaSJakub Kuderski       Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
11347232beaSJakub Kuderski       thisResDigit = getLowDigit(current);
11447232beaSJakub Kuderski 
11547232beaSJakub Kuderski       if (i + j + 1 != resultDigits.size()) {
11647232beaSJakub Kuderski         Value &nextResDigit = resultDigits[i + j + 1];
11747232beaSJakub Kuderski         Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
11847232beaSJakub Kuderski                                                     getHighDigit(current));
11947232beaSJakub Kuderski         nextResDigit = carry;
12047232beaSJakub Kuderski       }
12147232beaSJakub Kuderski     }
12247232beaSJakub Kuderski   }
12347232beaSJakub Kuderski 
12447232beaSJakub Kuderski   auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
12547232beaSJakub Kuderski     Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
12647232beaSJakub Kuderski     return rewriter.create<BitwiseOrOp>(loc, low, highBits);
12747232beaSJakub Kuderski   };
12847232beaSJakub Kuderski   Value low = combineDigits(resultDigits[0], resultDigits[1]);
12947232beaSJakub Kuderski   Value high = combineDigits(resultDigits[2], resultDigits[3]);
130c957fe0fSJakub Kuderski 
1311b822453SJakub Kuderski   return rewriter.create<CompositeConstructOp>(
13236117cc4Sserge-sans-paille       loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
1331b822453SJakub Kuderski }
1341b822453SJakub Kuderski 
1351b822453SJakub Kuderski //===----------------------------------------------------------------------===//
1361b822453SJakub Kuderski // Rewrite Patterns
1371b822453SJakub Kuderski //===----------------------------------------------------------------------===//
1381b822453SJakub Kuderski 
1391b822453SJakub Kuderski template <typename MulExtendedOp, bool SignExtendArguments>
1401b822453SJakub Kuderski struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
1411b822453SJakub Kuderski   using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
1421b822453SJakub Kuderski 
1431b822453SJakub Kuderski   LogicalResult matchAndRewrite(MulExtendedOp op,
1441b822453SJakub Kuderski                                 PatternRewriter &rewriter) const override {
1451b822453SJakub Kuderski     Location loc = op->getLoc();
1461b822453SJakub Kuderski     Value lhs = op.getOperand1();
1471b822453SJakub Kuderski     Value rhs = op.getOperand2();
1481b822453SJakub Kuderski 
1491b822453SJakub Kuderski     // Currently, WGSL only supports 32-bit integer types. Any other integer
1501b822453SJakub Kuderski     // types should already have been promoted/demoted to i32.
1515550c821STres Popp     auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
1521b822453SJakub Kuderski     if (elemTy.getIntOrFloatBitWidth() != 32)
1531b822453SJakub Kuderski       return rewriter.notifyMatchFailure(
1541b822453SJakub Kuderski           loc,
1551b822453SJakub Kuderski           llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
1561b822453SJakub Kuderski 
1571b822453SJakub Kuderski     Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
1581b822453SJakub Kuderski                                             SignExtendArguments);
1591b822453SJakub Kuderski     rewriter.replaceOp(op, mul);
160c957fe0fSJakub Kuderski     return success();
161c957fe0fSJakub Kuderski   }
162c957fe0fSJakub Kuderski };
163c957fe0fSJakub Kuderski 
1641b822453SJakub Kuderski using ExpandSMulExtendedPattern =
1651b822453SJakub Kuderski     ExpandMulExtendedPattern<SMulExtendedOp, true>;
1661b822453SJakub Kuderski using ExpandUMulExtendedPattern =
1671b822453SJakub Kuderski     ExpandMulExtendedPattern<UMulExtendedOp, false>;
1681b822453SJakub Kuderski 
1696ddc03d9SFinn Plummer struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
1706ddc03d9SFinn Plummer   using OpRewritePattern<IAddCarryOp>::OpRewritePattern;
1716ddc03d9SFinn Plummer 
1726ddc03d9SFinn Plummer   LogicalResult matchAndRewrite(IAddCarryOp op,
1736ddc03d9SFinn Plummer                                 PatternRewriter &rewriter) const override {
1746ddc03d9SFinn Plummer     Location loc = op->getLoc();
1756ddc03d9SFinn Plummer     Value lhs = op.getOperand1();
1766ddc03d9SFinn Plummer     Value rhs = op.getOperand2();
1776ddc03d9SFinn Plummer 
1786ddc03d9SFinn Plummer     // Currently, WGSL only supports 32-bit integer types. Any other integer
1796ddc03d9SFinn Plummer     // types should already have been promoted/demoted to i32.
1806ddc03d9SFinn Plummer     Type argTy = lhs.getType();
1816ddc03d9SFinn Plummer     auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
1826ddc03d9SFinn Plummer     if (elemTy.getIntOrFloatBitWidth() != 32)
1836ddc03d9SFinn Plummer       return rewriter.notifyMatchFailure(
1846ddc03d9SFinn Plummer           loc,
1856ddc03d9SFinn Plummer           llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
1866ddc03d9SFinn Plummer 
1876ddc03d9SFinn Plummer     Value one =
1886ddc03d9SFinn Plummer         rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
1896ddc03d9SFinn Plummer     Value zero =
1906ddc03d9SFinn Plummer         rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
1916ddc03d9SFinn Plummer 
1926ddc03d9SFinn Plummer     // Calculate the carry by checking if the addition resulted in an overflow.
1936ddc03d9SFinn Plummer     Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
1946ddc03d9SFinn Plummer     Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
1956ddc03d9SFinn Plummer     Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
1966ddc03d9SFinn Plummer 
1976ddc03d9SFinn Plummer     Value add = rewriter.create<CompositeConstructOp>(
1986ddc03d9SFinn Plummer         loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
1996ddc03d9SFinn Plummer 
2006ddc03d9SFinn Plummer     rewriter.replaceOp(op, add);
2016ddc03d9SFinn Plummer     return success();
2026ddc03d9SFinn Plummer   }
2036ddc03d9SFinn Plummer };
2046ddc03d9SFinn Plummer 
205d61ec513SJakub Kuderski struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
206d61ec513SJakub Kuderski   using OpRewritePattern::OpRewritePattern;
207d61ec513SJakub Kuderski 
208d61ec513SJakub Kuderski   LogicalResult matchAndRewrite(IsInfOp op,
209d61ec513SJakub Kuderski                                 PatternRewriter &rewriter) const override {
210d61ec513SJakub Kuderski     // We assume values to be finite and turn `IsInf` info `false`.
211d61ec513SJakub Kuderski     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
212d61ec513SJakub Kuderski         op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
213d61ec513SJakub Kuderski     return success();
214d61ec513SJakub Kuderski   }
215d61ec513SJakub Kuderski };
216d61ec513SJakub Kuderski 
217d61ec513SJakub Kuderski struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
218d61ec513SJakub Kuderski   using OpRewritePattern::OpRewritePattern;
219d61ec513SJakub Kuderski 
220d61ec513SJakub Kuderski   LogicalResult matchAndRewrite(IsNanOp op,
221d61ec513SJakub Kuderski                                 PatternRewriter &rewriter) const override {
222d61ec513SJakub Kuderski     // We assume values to be finite and turn `IsNan` info `false`.
223d61ec513SJakub Kuderski     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
224d61ec513SJakub Kuderski         op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
225d61ec513SJakub Kuderski     return success();
226d61ec513SJakub Kuderski   }
227d61ec513SJakub Kuderski };
228d61ec513SJakub Kuderski 
229c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===//
230c957fe0fSJakub Kuderski // Passes
231c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===//
232d61ec513SJakub Kuderski struct WebGPUPreparePass final
233d61ec513SJakub Kuderski     : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
234c957fe0fSJakub Kuderski   void runOnOperation() override {
235c957fe0fSJakub Kuderski     RewritePatternSet patterns(&getContext());
236c957fe0fSJakub Kuderski     populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
237d61ec513SJakub Kuderski     populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
238c957fe0fSJakub Kuderski 
239*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
240c957fe0fSJakub Kuderski       signalPassFailure();
241c957fe0fSJakub Kuderski   }
242c957fe0fSJakub Kuderski };
243c957fe0fSJakub Kuderski } // namespace
244c957fe0fSJakub Kuderski 
245c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===//
246c957fe0fSJakub Kuderski // Public Interface
247c957fe0fSJakub Kuderski //===----------------------------------------------------------------------===//
248c957fe0fSJakub Kuderski void populateSPIRVExpandExtendedMultiplicationPatterns(
249c957fe0fSJakub Kuderski     RewritePatternSet &patterns) {
250c957fe0fSJakub Kuderski   // WGSL currently does not support extended multiplication ops, see:
251c957fe0fSJakub Kuderski   // https://github.com/gpuweb/gpuweb/issues/1565.
252d61ec513SJakub Kuderski   patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
253d61ec513SJakub Kuderski                ExpandAddCarryPattern>(patterns.getContext());
254c957fe0fSJakub Kuderski }
255d61ec513SJakub Kuderski 
256d61ec513SJakub Kuderski void populateSPIRVExpandNonFiniteArithmeticPatterns(
257d61ec513SJakub Kuderski     RewritePatternSet &patterns) {
258d61ec513SJakub Kuderski   // WGSL currently does not support `isInf` and `isNan`, see:
259d61ec513SJakub Kuderski   // https://github.com/gpuweb/gpuweb/pull/2311.
260d61ec513SJakub Kuderski   patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
261d61ec513SJakub Kuderski }
262d61ec513SJakub Kuderski 
263c957fe0fSJakub Kuderski } // namespace spirv
264c957fe0fSJakub Kuderski } // namespace mlir
265