xref: /llvm-project/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements SPIR-V transforms used when targetting WebGPU.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Location.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 #include <array>
26 #include <cstdint>
27 
28 namespace mlir {
29 namespace spirv {
30 #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
32 } // namespace spirv
33 } // namespace mlir
34 
35 namespace mlir {
36 namespace spirv {
37 namespace {
38 //===----------------------------------------------------------------------===//
39 // Helpers
40 //===----------------------------------------------------------------------===//
41 static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
42   APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
43   if (auto intTy = dyn_cast<IntegerType>(type))
44     return IntegerAttr::get(intTy, sizedValue);
45 
46   return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
47 }
48 
49 static Value lowerExtendedMultiplication(Operation *mulOp,
50                                          PatternRewriter &rewriter, Value lhs,
51                                          Value rhs, bool signExtendArguments) {
52   Location loc = mulOp->getLoc();
53   Type argTy = lhs.getType();
54   // Emulate 64-bit multiplication by splitting each input element of type i32
55   // into 2 16-bit digits of type i32. This is so that the intermediate
56   // multiplications and additions do not overflow. We extract these 16-bit
57   // digits from i32 vector elements by masking (low digit) and shifting right
58   // (high digit).
59   //
60   // The multiplication algorithm used is the standard (long) multiplication.
61   // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
62   // digits.
63   //   - With zero-extended arguments, we end up emitting only 4 multiplications
64   //     and 4 additions after constant folding.
65   //   - With sign-extended arguments, we end up emitting 8 multiplications and
66   //     and 12 additions after CSE.
67   Value cstLowMask = rewriter.create<ConstantOp>(
68       loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
69   auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
70     return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
71   };
72 
73   Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
74                                             getScalarOrSplatAttr(argTy, 16));
75   auto getHighDigit = [&rewriter, loc, cst16](Value val) {
76     return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
77   };
78 
79   auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
80     // We only need to shift arithmetically by 15, but the extra
81     // sign-extension bit will be truncated by the logical shift, so this is
82     // fine. We do not have to introduce an extra constant since any
83     // value in [15, 32) would do.
84     return getHighDigit(
85         rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
86   };
87 
88   Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
89                                            getScalarOrSplatAttr(argTy, 0));
90 
91   Value lhsLow = getLowDigit(lhs);
92   Value lhsHigh = getHighDigit(lhs);
93   Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
94   Value rhsLow = getLowDigit(rhs);
95   Value rhsHigh = getHighDigit(rhs);
96   Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
97 
98   std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
99   std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
100   std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
101 
102   for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
103     for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
104       if (i + j >= resultDigits.size())
105         continue;
106 
107       if (lhsDigit == cst0 || rhsDigit == cst0)
108         continue;
109 
110       Value &thisResDigit = resultDigits[i + j];
111       Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
112       Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
113       thisResDigit = getLowDigit(current);
114 
115       if (i + j + 1 != resultDigits.size()) {
116         Value &nextResDigit = resultDigits[i + j + 1];
117         Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
118                                                     getHighDigit(current));
119         nextResDigit = carry;
120       }
121     }
122   }
123 
124   auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
125     Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
126     return rewriter.create<BitwiseOrOp>(loc, low, highBits);
127   };
128   Value low = combineDigits(resultDigits[0], resultDigits[1]);
129   Value high = combineDigits(resultDigits[2], resultDigits[3]);
130 
131   return rewriter.create<CompositeConstructOp>(
132       loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // Rewrite Patterns
137 //===----------------------------------------------------------------------===//
138 
139 template <typename MulExtendedOp, bool SignExtendArguments>
140 struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
141   using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
142 
143   LogicalResult matchAndRewrite(MulExtendedOp op,
144                                 PatternRewriter &rewriter) const override {
145     Location loc = op->getLoc();
146     Value lhs = op.getOperand1();
147     Value rhs = op.getOperand2();
148 
149     // Currently, WGSL only supports 32-bit integer types. Any other integer
150     // types should already have been promoted/demoted to i32.
151     auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
152     if (elemTy.getIntOrFloatBitWidth() != 32)
153       return rewriter.notifyMatchFailure(
154           loc,
155           llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
156 
157     Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
158                                             SignExtendArguments);
159     rewriter.replaceOp(op, mul);
160     return success();
161   }
162 };
163 
164 using ExpandSMulExtendedPattern =
165     ExpandMulExtendedPattern<SMulExtendedOp, true>;
166 using ExpandUMulExtendedPattern =
167     ExpandMulExtendedPattern<UMulExtendedOp, false>;
168 
169 struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
170   using OpRewritePattern<IAddCarryOp>::OpRewritePattern;
171 
172   LogicalResult matchAndRewrite(IAddCarryOp op,
173                                 PatternRewriter &rewriter) const override {
174     Location loc = op->getLoc();
175     Value lhs = op.getOperand1();
176     Value rhs = op.getOperand2();
177 
178     // Currently, WGSL only supports 32-bit integer types. Any other integer
179     // types should already have been promoted/demoted to i32.
180     Type argTy = lhs.getType();
181     auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
182     if (elemTy.getIntOrFloatBitWidth() != 32)
183       return rewriter.notifyMatchFailure(
184           loc,
185           llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
186 
187     Value one =
188         rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
189     Value zero =
190         rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
191 
192     // Calculate the carry by checking if the addition resulted in an overflow.
193     Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
194     Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
195     Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
196 
197     Value add = rewriter.create<CompositeConstructOp>(
198         loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
199 
200     rewriter.replaceOp(op, add);
201     return success();
202   }
203 };
204 
205 struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
206   using OpRewritePattern::OpRewritePattern;
207 
208   LogicalResult matchAndRewrite(IsInfOp op,
209                                 PatternRewriter &rewriter) const override {
210     // We assume values to be finite and turn `IsInf` info `false`.
211     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
212         op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
213     return success();
214   }
215 };
216 
217 struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
218   using OpRewritePattern::OpRewritePattern;
219 
220   LogicalResult matchAndRewrite(IsNanOp op,
221                                 PatternRewriter &rewriter) const override {
222     // We assume values to be finite and turn `IsNan` info `false`.
223     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
224         op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
225     return success();
226   }
227 };
228 
229 //===----------------------------------------------------------------------===//
230 // Passes
231 //===----------------------------------------------------------------------===//
232 struct WebGPUPreparePass final
233     : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
234   void runOnOperation() override {
235     RewritePatternSet patterns(&getContext());
236     populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
237     populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
238 
239     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
240       signalPassFailure();
241   }
242 };
243 } // namespace
244 
245 //===----------------------------------------------------------------------===//
246 // Public Interface
247 //===----------------------------------------------------------------------===//
248 void populateSPIRVExpandExtendedMultiplicationPatterns(
249     RewritePatternSet &patterns) {
250   // WGSL currently does not support extended multiplication ops, see:
251   // https://github.com/gpuweb/gpuweb/issues/1565.
252   patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
253                ExpandAddCarryPattern>(patterns.getContext());
254 }
255 
256 void populateSPIRVExpandNonFiniteArithmeticPatterns(
257     RewritePatternSet &patterns) {
258   // WGSL currently does not support `isInf` and `isNan`, see:
259   // https://github.com/gpuweb/gpuweb/pull/2311.
260   patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
261 }
262 
263 } // namespace spirv
264 } // namespace mlir
265