xref: /llvm-project/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10 // SPIRV Cooperative Matrix ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
15 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
22 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
23 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
24 #include "mlir/IR/BuiltinAttributes.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/ValueRange.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringSwitch.h"
30 
31 #include <cassert>
32 
33 namespace mlir {
34 //===----------------------------------------------------------------------===//
35 // Patterns and helpers.
36 //===----------------------------------------------------------------------===//
37 
38 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
39 /// when the elementwise op directly supports with cooperative matrix type.
40 /// Returns false if cannot.
41 ///
42 /// See SPV_KHR_cooperative_matrix for supported elementwise ops.
43 static bool createElementwiseOp(ConversionPatternRewriter &builder,
44                                 gpu::SubgroupMmaElementwiseOp op, Type coopType,
45                                 ValueRange operands) {
46   assert((isa<spirv::CooperativeMatrixType>(coopType)));
47 
48   switch (op.getOpType()) {
49   case gpu::MMAElementwiseOp::ADDF:
50     builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
51     return true;
52   case gpu::MMAElementwiseOp::ADDI:
53     builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
54     return true;
55   case gpu::MMAElementwiseOp::SUBF:
56     builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
57     return true;
58   case gpu::MMAElementwiseOp::SUBI:
59     builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
60     return true;
61   case gpu::MMAElementwiseOp::DIVF:
62     builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
63     return true;
64   case gpu::MMAElementwiseOp::DIVS:
65     builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
66     return true;
67   case gpu::MMAElementwiseOp::DIVU:
68     builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
69     return true;
70   case gpu::MMAElementwiseOp::NEGATEF:
71     builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
72     return true;
73   case gpu::MMAElementwiseOp::NEGATES:
74     builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
75     return true;
76   case gpu::MMAElementwiseOp::EXTF:
77     builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
78     return true;
79   default:
80     break;
81   }
82   return false;
83 }
84 
85 bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
86   assert(!operands.empty());
87   if (!llvm::all_equal(
88           llvm::map_range(operands, [](Value v) { return v.getType(); })))
89     return false;
90 
91   return isa<spirv::CooperativeMatrixType>(operands.front().getType());
92 }
93 
94 namespace {
95 /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
96 /// matrix ops.
97 struct WmmaConstantOpToSPIRVLowering final
98     : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
99   using OpConversionPattern::OpConversionPattern;
100 
101   LogicalResult
102   matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103                   ConversionPatternRewriter &rewriter) const override {
104     assert(adaptor.getOperands().size() == 1);
105     Value cst = adaptor.getOperands().front();
106     auto coopType = getTypeConverter()->convertType(op.getType());
107     if (!coopType)
108       return rewriter.notifyMatchFailure(op, "type conversion failed");
109 
110     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
111     return success();
112   }
113 };
114 
115 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
116 /// the default case.
117 struct WmmaElementwiseOpToSPIRVDefaultLowering final
118     : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
119   using OpConversionPattern::OpConversionPattern;
120 
121   LogicalResult
122   matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
123                   ConversionPatternRewriter &rewriter) const override {
124     // All operands should be of cooperative matrix types.
125     if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
126       return rewriter.notifyMatchFailure(op,
127                                          "not all operands are coop matrices");
128     }
129 
130     auto coopType = getTypeConverter()->convertType(op.getType());
131     if (!coopType)
132       return rewriter.notifyMatchFailure(op, "type conversion failed");
133 
134     return success(
135         createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
136   }
137 };
138 
139 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
140 /// matrix times scalar case.
141 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
142     : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
143   using OpConversionPattern::OpConversionPattern;
144 
145   LogicalResult
146   matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
147                   ConversionPatternRewriter &rewriter) const override {
148     if (adaptor.getOperands().size() != 2)
149       return failure();
150 
151     // All operands should be of cooperative matrix types.
152     if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
153       return rewriter.notifyMatchFailure(op,
154                                          "not all operands are coop matrices");
155     }
156 
157     if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
158       return failure();
159 
160     // Use the original operands to check whether one of the operands is a splat
161     // scalar value.
162     Value lhs = op.getOperands().front();
163     Value rhs = op.getOperands().back();
164     Value splat = nullptr;
165     Value matrix = nullptr;
166     if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
167       splat = adaptor.getOperands().front();
168       matrix = adaptor.getOperands().back();
169     } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
170       matrix = adaptor.getOperands().front();
171       splat = adaptor.getOperands().back();
172     }
173     if (!splat || !matrix)
174       return rewriter.notifyMatchFailure(op, "no splat operand");
175 
176     // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
177     Value scalar;
178     auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
179     if (!cc) {
180       return rewriter.notifyMatchFailure(op,
181                                          "splat is not a composite construct");
182     }
183 
184     assert(cc.getConstituents().size() == 1);
185     scalar = cc.getConstituents().front();
186 
187     auto coopType = getTypeConverter()->convertType(op.getType());
188     if (!coopType)
189       return rewriter.notifyMatchFailure(op, "type conversion failed");
190     rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
191         op, coopType, ValueRange{matrix, scalar});
192     return success();
193   }
194 };
195 } // namespace
196 
197 //===----------------------------------------------------------------------===//
198 // SPV_KHR_cooperative_matrix
199 //===----------------------------------------------------------------------===//
200 
201 namespace khr {
202 namespace {
203 
204 /// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
205 /// dialect.
206 struct WmmaLoadOpToSPIRVLowering final
207     : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
208   using OpConversionPattern::OpConversionPattern;
209 
210   LogicalResult
211   matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
212                   ConversionPatternRewriter &rewriter) const override {
213     const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
214     Location loc = op->getLoc();
215 
216     auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
217     MemRefType memrefType = op.getSrcMemref().getType();
218     Value bufferPtr =
219         spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
220                              adaptor.getIndices(), loc, rewriter);
221 
222     auto coopType =
223         typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
224     if (!coopType)
225       return rewriter.notifyMatchFailure(op, "type conversion failed");
226 
227     int64_t stride = op.getLeadDimension().getSExtValue();
228     IntegerType i32Type = rewriter.getI32Type();
229     auto strideValue = rewriter.create<spirv::ConstantOp>(
230         loc, i32Type, IntegerAttr::get(i32Type, stride));
231 
232     bool isColMajor = op.getTranspose().value_or(false);
233     auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
234                              : spirv::CooperativeMatrixLayoutKHR::RowMajor;
235 
236     rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
237         op, coopType, bufferPtr, strideValue, layout);
238     return success();
239   }
240 };
241 
242 /// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
243 /// dialect.
244 struct WmmaStoreOpToSPIRVLowering final
245     : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
246   using OpConversionPattern::OpConversionPattern;
247 
248   LogicalResult
249   matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
250                   ConversionPatternRewriter &rewriter) const override {
251     const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
252     Location loc = op->getLoc();
253 
254     auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
255     Value bufferPtr =
256         spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
257                              adaptor.getIndices(), loc, rewriter);
258 
259     int64_t stride = op.getLeadDimension().getSExtValue();
260     IntegerType i32Type = rewriter.getI32Type();
261     auto strideValue = rewriter.create<spirv::ConstantOp>(
262         loc, i32Type, IntegerAttr::get(i32Type, stride));
263 
264     bool isColMajor = op.getTranspose().value_or(false);
265     auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
266                              : spirv::CooperativeMatrixLayoutKHR::RowMajor;
267 
268     rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
269         op, bufferPtr, adaptor.getSrc(), strideValue, layout);
270     return success();
271   }
272 };
273 
274 /// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
275 /// dialect.
276 struct WmmaMmaOpToSPIRVLowering final
277     : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
278   using OpConversionPattern::OpConversionPattern;
279 
280   LogicalResult
281   matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
282                   OpAdaptor adaptor,
283                   ConversionPatternRewriter &rewriter) const override {
284     rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
285         subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
286         adaptor.getOpC());
287     return success();
288   }
289 };
290 
291 } // namespace
292 } // namespace khr
293 } // namespace mlir
294 
295 void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
296     const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
297   using namespace mlir;
298   MLIRContext *context = patterns.getContext();
299   patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
300                khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
301                WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
302   // Give the following patterns higher benefit to prevail over the default one.
303   patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
304                                                           /*benefit=*/2);
305 }
306 
307 void mlir::populateMMAToSPIRVCoopMatrixTypeConversion(
308     mlir::SPIRVTypeConverter &typeConverter) {
309   typeConverter.addConversion([](gpu::MMAMatrixType type) {
310     ArrayRef<int64_t> retTypeShape = type.getShape();
311     Type elementType = type.getElementType();
312     auto use =
313         llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
314             .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
315             .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
316             .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
317 
318     return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
319                                              retTypeShape[1],
320                                              spirv::Scope::Subgroup, use);
321   });
322 }
323