1 //===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements SPIR-V transforms used when targetting WebGPU. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 15 #include "mlir/Dialect/SPIRV/Transforms/Passes.h" 16 #include "mlir/IR/BuiltinAttributes.h" 17 #include "mlir/IR/Location.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/STLExtras.h" 23 #include "llvm/Support/FormatVariadic.h" 24 25 #include <array> 26 #include <cstdint> 27 28 namespace mlir { 29 namespace spirv { 30 #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS 31 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" 32 } // namespace spirv 33 } // namespace mlir 34 35 namespace mlir { 36 namespace spirv { 37 namespace { 38 //===----------------------------------------------------------------------===// 39 // Helpers 40 //===----------------------------------------------------------------------===// 41 static Attribute getScalarOrSplatAttr(Type type, int64_t value) { 42 APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value); 43 if (auto intTy = dyn_cast<IntegerType>(type)) 44 return IntegerAttr::get(intTy, sizedValue); 45 46 return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue); 47 } 48 49 static Value lowerExtendedMultiplication(Operation *mulOp, 50 PatternRewriter &rewriter, Value lhs, 51 Value rhs, bool signExtendArguments) { 52 Location loc = mulOp->getLoc(); 53 Type argTy = lhs.getType(); 54 // Emulate 64-bit multiplication by splitting each input element of type i32 55 // into 2 16-bit digits of type i32. This is so that the intermediate 56 // multiplications and additions do not overflow. We extract these 16-bit 57 // digits from i32 vector elements by masking (low digit) and shifting right 58 // (high digit). 59 // 60 // The multiplication algorithm used is the standard (long) multiplication. 61 // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit 62 // digits. 63 // - With zero-extended arguments, we end up emitting only 4 multiplications 64 // and 4 additions after constant folding. 65 // - With sign-extended arguments, we end up emitting 8 multiplications and 66 // and 12 additions after CSE. 67 Value cstLowMask = rewriter.create<ConstantOp>( 68 loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); 69 auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) { 70 return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask); 71 }; 72 73 Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(), 74 getScalarOrSplatAttr(argTy, 16)); 75 auto getHighDigit = [&rewriter, loc, cst16](Value val) { 76 return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16); 77 }; 78 79 auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) { 80 // We only need to shift arithmetically by 15, but the extra 81 // sign-extension bit will be truncated by the logical shift, so this is 82 // fine. We do not have to introduce an extra constant since any 83 // value in [15, 32) would do. 84 return getHighDigit( 85 rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16)); 86 }; 87 88 Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(), 89 getScalarOrSplatAttr(argTy, 0)); 90 91 Value lhsLow = getLowDigit(lhs); 92 Value lhsHigh = getHighDigit(lhs); 93 Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0; 94 Value rhsLow = getLowDigit(rhs); 95 Value rhsHigh = getHighDigit(rhs); 96 Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0; 97 98 std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt}; 99 std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt}; 100 std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0}; 101 102 for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) { 103 for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) { 104 if (i + j >= resultDigits.size()) 105 continue; 106 107 if (lhsDigit == cst0 || rhsDigit == cst0) 108 continue; 109 110 Value &thisResDigit = resultDigits[i + j]; 111 Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit); 112 Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul); 113 thisResDigit = getLowDigit(current); 114 115 if (i + j + 1 != resultDigits.size()) { 116 Value &nextResDigit = resultDigits[i + j + 1]; 117 Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit, 118 getHighDigit(current)); 119 nextResDigit = carry; 120 } 121 } 122 } 123 124 auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) { 125 Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16); 126 return rewriter.create<BitwiseOrOp>(loc, low, highBits); 127 }; 128 Value low = combineDigits(resultDigits[0], resultDigits[1]); 129 Value high = combineDigits(resultDigits[2], resultDigits[3]); 130 131 return rewriter.create<CompositeConstructOp>( 132 loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high})); 133 } 134 135 //===----------------------------------------------------------------------===// 136 // Rewrite Patterns 137 //===----------------------------------------------------------------------===// 138 139 template <typename MulExtendedOp, bool SignExtendArguments> 140 struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> { 141 using OpRewritePattern<MulExtendedOp>::OpRewritePattern; 142 143 LogicalResult matchAndRewrite(MulExtendedOp op, 144 PatternRewriter &rewriter) const override { 145 Location loc = op->getLoc(); 146 Value lhs = op.getOperand1(); 147 Value rhs = op.getOperand2(); 148 149 // Currently, WGSL only supports 32-bit integer types. Any other integer 150 // types should already have been promoted/demoted to i32. 151 auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType())); 152 if (elemTy.getIntOrFloatBitWidth() != 32) 153 return rewriter.notifyMatchFailure( 154 loc, 155 llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy)); 156 157 Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs, 158 SignExtendArguments); 159 rewriter.replaceOp(op, mul); 160 return success(); 161 } 162 }; 163 164 using ExpandSMulExtendedPattern = 165 ExpandMulExtendedPattern<SMulExtendedOp, true>; 166 using ExpandUMulExtendedPattern = 167 ExpandMulExtendedPattern<UMulExtendedOp, false>; 168 169 struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> { 170 using OpRewritePattern<IAddCarryOp>::OpRewritePattern; 171 172 LogicalResult matchAndRewrite(IAddCarryOp op, 173 PatternRewriter &rewriter) const override { 174 Location loc = op->getLoc(); 175 Value lhs = op.getOperand1(); 176 Value rhs = op.getOperand2(); 177 178 // Currently, WGSL only supports 32-bit integer types. Any other integer 179 // types should already have been promoted/demoted to i32. 180 Type argTy = lhs.getType(); 181 auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy)); 182 if (elemTy.getIntOrFloatBitWidth() != 32) 183 return rewriter.notifyMatchFailure( 184 loc, 185 llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy)); 186 187 Value one = 188 rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1)); 189 Value zero = 190 rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0)); 191 192 // Calculate the carry by checking if the addition resulted in an overflow. 193 Value out = rewriter.create<IAddOp>(loc, lhs, rhs); 194 Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs); 195 Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero); 196 197 Value add = rewriter.create<CompositeConstructOp>( 198 loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry})); 199 200 rewriter.replaceOp(op, add); 201 return success(); 202 } 203 }; 204 205 struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> { 206 using OpRewritePattern::OpRewritePattern; 207 208 LogicalResult matchAndRewrite(IsInfOp op, 209 PatternRewriter &rewriter) const override { 210 // We assume values to be finite and turn `IsInf` info `false`. 211 rewriter.replaceOpWithNewOp<spirv::ConstantOp>( 212 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); 213 return success(); 214 } 215 }; 216 217 struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> { 218 using OpRewritePattern::OpRewritePattern; 219 220 LogicalResult matchAndRewrite(IsNanOp op, 221 PatternRewriter &rewriter) const override { 222 // We assume values to be finite and turn `IsNan` info `false`. 223 rewriter.replaceOpWithNewOp<spirv::ConstantOp>( 224 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); 225 return success(); 226 } 227 }; 228 229 //===----------------------------------------------------------------------===// 230 // Passes 231 //===----------------------------------------------------------------------===// 232 struct WebGPUPreparePass final 233 : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> { 234 void runOnOperation() override { 235 RewritePatternSet patterns(&getContext()); 236 populateSPIRVExpandExtendedMultiplicationPatterns(patterns); 237 populateSPIRVExpandNonFiniteArithmeticPatterns(patterns); 238 239 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 240 signalPassFailure(); 241 } 242 }; 243 } // namespace 244 245 //===----------------------------------------------------------------------===// 246 // Public Interface 247 //===----------------------------------------------------------------------===// 248 void populateSPIRVExpandExtendedMultiplicationPatterns( 249 RewritePatternSet &patterns) { 250 // WGSL currently does not support extended multiplication ops, see: 251 // https://github.com/gpuweb/gpuweb/issues/1565. 252 patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern, 253 ExpandAddCarryPattern>(patterns.getContext()); 254 } 255 256 void populateSPIRVExpandNonFiniteArithmeticPatterns( 257 RewritePatternSet &patterns) { 258 // WGSL currently does not support `isInf` and `isNan`, see: 259 // https://github.com/gpuweb/gpuweb/pull/2311. 260 patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext()); 261 } 262 263 } // namespace spirv 264 } // namespace mlir 265