xref: /llvm-project/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10 // NVVM Dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
15 #include "mlir/Conversion/LLVMCommon/Pattern.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19 #include "mlir/IR/TypeUtilities.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 
25 /// Checks if all the operands of the op being lowered are of LLVM Types. The
26 /// types are expected to be converted by the `LLVMTypeConverter` before the op
27 /// is actually lowered. If the type of an operands is not already converted it
28 /// hints a missing typeConversion and failure is returned in that case.
29 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
30                                      ConversionPatternRewriter &rewriter) {
31   if (!llvm::all_of(operands, [](Value value) {
32         return LLVM::isCompatibleType(value.getType());
33       })) {
34     return rewriter.notifyMatchFailure(
35         op, "cannot convert if operands aren't of LLVM type.");
36   }
37 
38   return success();
39 }
40 
41 /// Error string to emit when an unimplemented WMMA variant is encountered.
42 static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
43 
44 static NVVM::MMAFrag convertOperand(StringRef operandName) {
45   if (operandName == "AOp")
46     return NVVM::MMAFrag::a;
47   if (operandName == "BOp")
48     return NVVM::MMAFrag::b;
49   if (operandName == "COp")
50     return NVVM::MMAFrag::c;
51   llvm_unreachable("Unknown operand name");
52 }
53 
54 static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
55   if (type.getElementType().isF16())
56     return NVVM::MMATypes::f16;
57   if (type.getElementType().isF32())
58     return type.getOperand() == "COp" ? NVVM::MMATypes::f32
59                                       : NVVM::MMATypes::tf32;
60 
61   if (type.getElementType().isSignedInteger(8))
62     return NVVM::MMATypes::s8;
63   if (type.getElementType().isUnsignedInteger(8))
64     return NVVM::MMATypes::u8;
65   // Accumulator type is signless and implies signed.
66   if (type.getElementType().isInteger(32))
67     return NVVM::MMATypes::s32;
68   llvm_unreachable("Unsupported type");
69 }
70 
71 /// This class implements the conversion of GPU MMA loadOp to wmma.load op
72 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
73 /// emits code that is necessary to store the data in the destination memref
74 /// after it has been loaded.
75 struct WmmaLoadOpToNVVMLowering
76     : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
77   using ConvertOpToLLVMPattern<
78       gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
79 
80   LogicalResult
81   matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
82                   OpAdaptor adaptor,
83                   ConversionPatternRewriter &rewriter) const override {
84     Operation *op = subgroupMmaLoadMatrixOp.getOperation();
85     if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
86       return failure();
87 
88     // Get the shape of the MMAMatrix type being returned. The shape will
89     // choose which intrinsic this op will be lowered to.
90     NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
91                                  ? NVVM::MMALayout::col
92                                  : NVVM::MMALayout::row;
93     gpu::MMAMatrixType retType =
94         cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
95     ArrayRef<int64_t> retTypeShape = retType.getShape();
96     int64_t m = 0;
97     int64_t n = 0;
98     int64_t k = 0;
99     NVVM::MMATypes eltype = getElementType(retType);
100     // NVVM intrinsics require to give mxnxk dimensions, infer the missing
101     // dimension based on the valid intrinsics available.
102     if (retType.getOperand() == "AOp") {
103       m = retTypeShape[0];
104       k = retTypeShape[1];
105       n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
106     } else if (retType.getOperand() == "BOp") {
107       k = retTypeShape[0];
108       n = retTypeShape[1];
109       m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
110     } else if (retType.getOperand() == "COp") {
111       m = retTypeShape[0];
112       n = retTypeShape[1];
113       k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
114     }
115     NVVM::MMAFrag frag = convertOperand(retType.getOperand());
116     // Check that there is an exisiting instruction for the combination we need.
117     if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
118       return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
119 
120     Type resType = convertMMAToLLVMType(retType);
121     Location loc = op->getLoc();
122 
123     // Create nvvm.mma_load op according to the operand types.
124     Value dataPtr = getStridedElementPtr(
125         loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
126         adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
127 
128     Value leadingDim = rewriter.create<LLVM::ConstantOp>(
129         loc, rewriter.getI32Type(),
130         subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
131     rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
132         op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
133     return success();
134   }
135 };
136 
137 /// This class implements the conversion of GPU MMA storeOp to wmma.store op
138 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
139 /// emits code that is necessary to unpack the data in the source and
140 /// convert the data in the format that is needed by the NVVM op.
141 struct WmmaStoreOpToNVVMLowering
142     : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
143   using ConvertOpToLLVMPattern<
144       gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
145 
146   LogicalResult
147   matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
148                   OpAdaptor adaptor,
149                   ConversionPatternRewriter &rewriter) const override {
150     Operation *op = subgroupMmaStoreMatrixOp.getOperation();
151     if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
152       return failure();
153 
154     Location loc = op->getLoc();
155 
156     SmallVector<Value, 4> storeOpOperands;
157     // Get the shape of the MMAMatrix type being stored. The shape will
158     // choose which intrinsic this op will be lowered to.
159     gpu::MMAMatrixType srcType =
160         cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
161     ArrayRef<int64_t> srcTypeShape = srcType.getShape();
162     NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
163                                  ? NVVM::MMALayout::col
164                                  : NVVM::MMALayout::row;
165     NVVM::MMATypes eltype = getElementType(srcType);
166     int64_t m = srcTypeShape[0];
167     int64_t n = srcTypeShape[1];
168     int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
169     if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
170       return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
171 
172     auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
173     for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
174       Value toUse =
175           rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
176       storeOpOperands.push_back(toUse);
177     }
178 
179     Value dataPtr = getStridedElementPtr(
180         loc,
181         cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
182         adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
183     Value leadingDim = rewriter.create<LLVM::ConstantOp>(
184         loc, rewriter.getI32Type(),
185         subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
186     rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
187         op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
188     return success();
189   }
190 };
191 
192 /// This class implements the conversion of GPU MMA computeOp to wmma.mma op
193 /// in the NVVM dialect.
194 struct WmmaMmaOpToNVVMLowering
195     : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> {
196   using ConvertOpToLLVMPattern<
197       gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
198 
199   LogicalResult
200   matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
201                   OpAdaptor adaptor,
202                   ConversionPatternRewriter &rewriter) const override {
203     Operation *op = subgroupMmaComputeOp.getOperation();
204     if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
205       return failure();
206 
207     Location loc = op->getLoc();
208 
209     // The wmma.mma intrinsic in llvm requires the operands as individual
210     // values. So individual elements from the memrefs need to be extracted and
211     // then passed on to the intrinsic call. Emit llvm ops to extract individual
212     // values form lowered memrefs.
213     SmallVector<Value> unpackedOps;
214 
215     auto unpackOp = [&](Value operand) {
216       auto structType = cast<LLVM::LLVMStructType>(operand.getType());
217       for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
218         Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
219         unpackedOps.push_back(toUse);
220       }
221     };
222 
223     // Get the shapes of the MMAMatrix type being used. The shapes will
224     // choose which intrinsic this op will be lowered to.
225     gpu::MMAMatrixType aType =
226         cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
227     ArrayRef<int64_t> aTypeShape = aType.getShape();
228     gpu::MMAMatrixType cType =
229         cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
230     ArrayRef<int64_t> cTypeShape = cType.getShape();
231     int64_t m = cTypeShape[0];
232     int64_t n = cTypeShape[1];
233     int64_t k = aTypeShape[1];
234     NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
235                                   ? NVVM::MMALayout::col
236                                   : NVVM::MMALayout::row;
237     NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
238                                   ? NVVM::MMALayout::col
239                                   : NVVM::MMALayout::row;
240     NVVM::MMATypes sourceType = getElementType(aType);
241     NVVM::MMATypes destType = getElementType(cType);
242     if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
243                                         destType) == 0)
244       return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
245 
246     NVVM::MMATypes bElementType = getElementType(
247         cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
248     if (bElementType != sourceType)
249       return rewriter.notifyMatchFailure(
250           op, "WMMA compute op input matrix element types must match.");
251 
252     unpackOp(adaptor.getOpA());
253     unpackOp(adaptor.getOpB());
254     unpackOp(adaptor.getOpC());
255 
256     rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
257         op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType,
258         destType, unpackedOps);
259     return success();
260   }
261 };
262 
263 /// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
264 struct WmmaConstantOpToNVVMLowering
265     : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> {
266   using ConvertOpToLLVMPattern<
267       gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
268 
269   LogicalResult
270   matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
271                   OpAdaptor adaptor,
272                   ConversionPatternRewriter &rewriter) const override {
273     if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(),
274                                adaptor.getOperands(), rewriter)))
275       return failure();
276     Location loc = subgroupMmaConstantOp.getLoc();
277     Value cst = adaptor.getOperands()[0];
278     LLVM::LLVMStructType type = convertMMAToLLVMType(
279         cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
280     // If the element type is a vector create a vector from the operand.
281     if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
282       Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType);
283       for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
284         Value idx = rewriter.create<LLVM::ConstantOp>(
285             loc, rewriter.getI32Type(), vecEl);
286         vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
287                                                         cst, idx);
288       }
289       cst = vecCst;
290     }
291     Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type);
292     for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
293       matrixStruct =
294           rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
295     }
296     rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
297     return success();
298   }
299 };
300 
301 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
302                            Value rhs, bool isMin) {
303   auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
304   Type i1Type = builder.getI1Type();
305   if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
306     i1Type = VectorType::get(vecType.getShape(), i1Type);
307   Value cmp = builder.create<LLVM::FCmpOp>(
308       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
309       lhs, rhs);
310   Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
311   Value isNan = builder.create<LLVM::FCmpOp>(
312       loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
313   Value nan = builder.create<LLVM::ConstantOp>(
314       loc, lhs.getType(),
315       builder.getFloatAttr(floatType,
316                            APFloat::getQNaN(floatType.getFloatSemantics())));
317   return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
318 }
319 
320 static Value createScalarOp(OpBuilder &builder, Location loc,
321                             gpu::MMAElementwiseOp op,
322                             ArrayRef<Value> operands) {
323   switch (op) {
324   case gpu::MMAElementwiseOp::ADDF:
325     return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
326   case gpu::MMAElementwiseOp::MULF:
327     return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
328   case gpu::MMAElementwiseOp::DIVF:
329     return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
330   case gpu::MMAElementwiseOp::MAXF:
331     return createMinMaxF(builder, loc, operands[0], operands[1],
332                          /*isMin=*/false);
333   case gpu::MMAElementwiseOp::MINF:
334     return createMinMaxF(builder, loc, operands[0], operands[1],
335                          /*isMin=*/true);
336   default:
337     llvm_unreachable("unknown op");
338   }
339 }
340 
341 /// Convert GPU MMA elementwise ops to extract + op + insert.
342 struct WmmaElementwiseOpToNVVMLowering
343     : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
344   using ConvertOpToLLVMPattern<
345       gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
346 
347   LogicalResult
348   matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
349                   OpAdaptor adaptor,
350                   ConversionPatternRewriter &rewriter) const override {
351     if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(),
352                                adaptor.getOperands(), rewriter)))
353       return failure();
354     Location loc = subgroupMmaElementwiseOp.getLoc();
355     size_t numOperands = adaptor.getOperands().size();
356     LLVM::LLVMStructType destType = convertMMAToLLVMType(
357         cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
358     Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType);
359     for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
360       SmallVector<Value> extractedOperands;
361       for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
362         extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
363             loc, adaptor.getOperands()[opIdx], i));
364       }
365       Value element =
366           createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
367                          extractedOperands);
368       matrixStruct =
369           rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
370     }
371     rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
372     return success();
373   }
374 };
375 
376 } // namespace
377 
378 /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
379 LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
380   NVVM::MMAFrag frag = convertOperand(type.getOperand());
381   NVVM::MMATypes eltType = getElementType(type);
382   auto nRow = type.getShape()[0];
383   auto nCol = type.getShape()[1];
384   std::pair<Type, unsigned> typeInfo =
385       NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
386   return LLVM::LLVMStructType::getLiteral(
387       type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
388 }
389 
390 void mlir::populateGpuWMMAToNVVMConversionPatterns(
391     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
392   patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
393                WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
394                WmmaElementwiseOpToNVVMLowering>(converter);
395 }
396