xref: /llvm-project/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestNVGPUTransforms.cpp - Test NVGPU transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 using namespace mlir;
28 using namespace mlir::nvgpu;
29 
30 namespace {
31 
32 struct TestMmaSyncF32ToTF32Patterns
33     : public PassWrapper<TestMmaSyncF32ToTF32Patterns,
34                          OperationPass<func::FuncOp>> {
35   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMmaSyncF32ToTF32Patterns)
36 
37   StringRef getArgument() const final {
38     return "test-nvgpu-mmasync-f32-to-tf32-patterns";
39   }
40   StringRef getDescription() const final {
41     return "Test patterns to convert mma.sync on f32 with tf32 precision";
42   }
43   TestMmaSyncF32ToTF32Patterns() = default;
44   TestMmaSyncF32ToTF32Patterns(const TestMmaSyncF32ToTF32Patterns &pass)
45       : PassWrapper(pass) {}
46 
47   Option<std::string> precision{
48       *this, "precision",
49       llvm::cl::desc(
50           "Target nvgpu.mma.sync on f32 input with tf32 or tf32x3 precision"),
51       llvm::cl::init("tf32")};
52 
53   MmaSyncF32Lowering tf32Precision =
54       llvm::StringSwitch<MmaSyncF32Lowering>(precision)
55           .Case("tf32", MmaSyncF32Lowering::TF32)
56           .Case("tf32x3", MmaSyncF32Lowering::TF32x3)
57           .Default(MmaSyncF32Lowering::Unkown);
58 
59   void runOnOperation() override {
60     RewritePatternSet patterns(&getContext());
61 
62     populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision);
63     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
64   }
65 };
66 
67 } // namespace
68 
69 namespace mlir {
70 namespace test {
71 void registerTestNVGPULowerings() {
72   PassRegistration<TestMmaSyncF32ToTF32Patterns>();
73 }
74 
75 } // namespace test
76 } // namespace mlir
77