xref: /llvm-project/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
114d79afeSManish Gupta //===- TestNVGPUTransforms.cpp - Test NVGPU transforms and lowerings ----===//
214d79afeSManish Gupta //
314d79afeSManish Gupta // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
414d79afeSManish Gupta // See https://llvm.org/LICENSE.txt for license information.
514d79afeSManish Gupta // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
614d79afeSManish Gupta //
714d79afeSManish Gupta //===----------------------------------------------------------------------===//
814d79afeSManish Gupta 
914d79afeSManish Gupta #include <type_traits>
1014d79afeSManish Gupta 
1114d79afeSManish Gupta #include "mlir/Analysis/SliceAnalysis.h"
1214d79afeSManish Gupta #include "mlir/Dialect/Affine/IR/AffineOps.h"
1314d79afeSManish Gupta #include "mlir/Dialect/Func/IR/FuncOps.h"
1414d79afeSManish Gupta #include "mlir/Dialect/GPU/IR/GPUDialect.h"
1514d79afeSManish Gupta #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1614d79afeSManish Gupta #include "mlir/Dialect/Linalg/IR/Linalg.h"
1714d79afeSManish Gupta #include "mlir/Dialect/Linalg/Passes.h"
1814d79afeSManish Gupta #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1914d79afeSManish Gupta #include "mlir/Dialect/MemRef/IR/MemRef.h"
2014d79afeSManish Gupta #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
2114d79afeSManish Gupta #include "mlir/Dialect/SCF/IR/SCF.h"
2214d79afeSManish Gupta #include "mlir/Pass/Pass.h"
2314d79afeSManish Gupta #include "mlir/Pass/PassManager.h"
2414d79afeSManish Gupta #include "mlir/Support/LLVM.h"
2514d79afeSManish Gupta #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2614d79afeSManish Gupta 
2714d79afeSManish Gupta using namespace mlir;
2814d79afeSManish Gupta using namespace mlir::nvgpu;
2914d79afeSManish Gupta 
3014d79afeSManish Gupta namespace {
3114d79afeSManish Gupta 
3214d79afeSManish Gupta struct TestMmaSyncF32ToTF32Patterns
3314d79afeSManish Gupta     : public PassWrapper<TestMmaSyncF32ToTF32Patterns,
3414d79afeSManish Gupta                          OperationPass<func::FuncOp>> {
3514d79afeSManish Gupta   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMmaSyncF32ToTF32Patterns)
3614d79afeSManish Gupta 
3714d79afeSManish Gupta   StringRef getArgument() const final {
3814d79afeSManish Gupta     return "test-nvgpu-mmasync-f32-to-tf32-patterns";
3914d79afeSManish Gupta   }
4014d79afeSManish Gupta   StringRef getDescription() const final {
4114d79afeSManish Gupta     return "Test patterns to convert mma.sync on f32 with tf32 precision";
4214d79afeSManish Gupta   }
4314d79afeSManish Gupta   TestMmaSyncF32ToTF32Patterns() = default;
4414d79afeSManish Gupta   TestMmaSyncF32ToTF32Patterns(const TestMmaSyncF32ToTF32Patterns &pass)
4514d79afeSManish Gupta       : PassWrapper(pass) {}
4614d79afeSManish Gupta 
4714d79afeSManish Gupta   Option<std::string> precision{
4814d79afeSManish Gupta       *this, "precision",
4914d79afeSManish Gupta       llvm::cl::desc(
5014d79afeSManish Gupta           "Target nvgpu.mma.sync on f32 input with tf32 or tf32x3 precision"),
5114d79afeSManish Gupta       llvm::cl::init("tf32")};
5214d79afeSManish Gupta 
5314d79afeSManish Gupta   MmaSyncF32Lowering tf32Precision =
5414d79afeSManish Gupta       llvm::StringSwitch<MmaSyncF32Lowering>(precision)
5514d79afeSManish Gupta           .Case("tf32", MmaSyncF32Lowering::TF32)
5614d79afeSManish Gupta           .Case("tf32x3", MmaSyncF32Lowering::TF32x3)
5714d79afeSManish Gupta           .Default(MmaSyncF32Lowering::Unkown);
5814d79afeSManish Gupta 
5914d79afeSManish Gupta   void runOnOperation() override {
6014d79afeSManish Gupta     RewritePatternSet patterns(&getContext());
6114d79afeSManish Gupta 
6214d79afeSManish Gupta     populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision);
63*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
6414d79afeSManish Gupta   }
6514d79afeSManish Gupta };
6614d79afeSManish Gupta 
6714d79afeSManish Gupta } // namespace
6814d79afeSManish Gupta 
6914d79afeSManish Gupta namespace mlir {
7014d79afeSManish Gupta namespace test {
71baa5beecStyb0807 void registerTestNVGPULowerings() {
7214d79afeSManish Gupta   PassRegistration<TestMmaSyncF32ToTF32Patterns>();
7314d79afeSManish Gupta }
7414d79afeSManish Gupta 
7514d79afeSManish Gupta } // namespace test
7614d79afeSManish Gupta } // namespace mlir
77