xref: /llvm-project/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
10 // operations on f32 input datatype
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Interfaces/SideEffectInterfaces.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/MathExtras.h"
24 
25 using namespace mlir;
26 using namespace mlir::nvgpu;
27 
28 namespace {
29 
30 struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
31 
32   using OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern;
33 
MmaSyncF32ToTF32Pattern__anondff23b310111::MmaSyncF32ToTF32Pattern34   MmaSyncF32ToTF32Pattern(MLIRContext *context,
35                           nvgpu::MmaSyncF32Lowering precision)
36       : OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
37         precision(precision) {}
38 
matchAndRewrite__anondff23b310111::MmaSyncF32ToTF32Pattern39   LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
40                                 PatternRewriter &rewriter) const override {
41     Location location = op->getLoc();
42 
43     if (op->hasAttr(op.getTf32EnabledAttrName()) ||
44         !cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
45       return failure();
46 
47     if (precision == MmaSyncF32Lowering::Unkown)
48       return emitError(location, "MmaSync F32-to-TF32 cannot be lowered with "
49                                  "unknown precision level");
50 
51     if (precision == MmaSyncF32Lowering::TF32x3)
52       return emitError(location, "TF32x3 is not supported at the moment "
53                                  "for nvgpu.mma.sync on f32 datatype");
54 
55     if (precision == MmaSyncF32Lowering::TF32) {
56       rewriter.modifyOpInPlace(
57           op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
58     }
59 
60     return success();
61   }
62 
63 private:
64   /// Precision for F32 Tensor Cores (TF32 or TF32x3)
65   nvgpu::MmaSyncF32Lowering precision;
66 };
67 
68 } // namespace
69 
populateMmaSyncF32ToTF32Patterns(RewritePatternSet & patterns,nvgpu::MmaSyncF32Lowering precision)70 void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns(
71     RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
72 
73   patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision);
74 }
75