1 //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
10 // operations on f32 input datatype
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
15
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Interfaces/SideEffectInterfaces.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/MathExtras.h"
24
25 using namespace mlir;
26 using namespace mlir::nvgpu;
27
28 namespace {
29
30 struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
31
32 using OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern;
33
MmaSyncF32ToTF32Pattern__anondff23b310111::MmaSyncF32ToTF32Pattern34 MmaSyncF32ToTF32Pattern(MLIRContext *context,
35 nvgpu::MmaSyncF32Lowering precision)
36 : OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
37 precision(precision) {}
38
matchAndRewrite__anondff23b310111::MmaSyncF32ToTF32Pattern39 LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
40 PatternRewriter &rewriter) const override {
41 Location location = op->getLoc();
42
43 if (op->hasAttr(op.getTf32EnabledAttrName()) ||
44 !cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
45 return failure();
46
47 if (precision == MmaSyncF32Lowering::Unkown)
48 return emitError(location, "MmaSync F32-to-TF32 cannot be lowered with "
49 "unknown precision level");
50
51 if (precision == MmaSyncF32Lowering::TF32x3)
52 return emitError(location, "TF32x3 is not supported at the moment "
53 "for nvgpu.mma.sync on f32 datatype");
54
55 if (precision == MmaSyncF32Lowering::TF32) {
56 rewriter.modifyOpInPlace(
57 op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
58 }
59
60 return success();
61 }
62
63 private:
64 /// Precision for F32 Tensor Cores (TF32 or TF32x3)
65 nvgpu::MmaSyncF32Lowering precision;
66 };
67
68 } // namespace
69
populateMmaSyncF32ToTF32Patterns(RewritePatternSet & patterns,nvgpu::MmaSyncF32Lowering precision)70 void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns(
71 RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
72
73 patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision);
74 }
75