1 //===- TestNVGPUTransforms.cpp - Test NVGPU transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <type_traits> 10 11 #include "mlir/Analysis/SliceAnalysis.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" 21 #include "mlir/Dialect/SCF/IR/SCF.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Pass/PassManager.h" 24 #include "mlir/Support/LLVM.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 using namespace mlir; 28 using namespace mlir::nvgpu; 29 30 namespace { 31 32 struct TestMmaSyncF32ToTF32Patterns 33 : public PassWrapper<TestMmaSyncF32ToTF32Patterns, 34 OperationPass<func::FuncOp>> { 35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMmaSyncF32ToTF32Patterns) 36 37 StringRef getArgument() const final { 38 return "test-nvgpu-mmasync-f32-to-tf32-patterns"; 39 } 40 StringRef getDescription() const final { 41 return "Test patterns to convert mma.sync on f32 with tf32 precision"; 42 } 43 TestMmaSyncF32ToTF32Patterns() = default; 44 TestMmaSyncF32ToTF32Patterns(const TestMmaSyncF32ToTF32Patterns &pass) 45 : PassWrapper(pass) {} 46 47 Option<std::string> precision{ 48 *this, "precision", 49 llvm::cl::desc( 50 "Target nvgpu.mma.sync on f32 input with tf32 or tf32x3 precision"), 51 llvm::cl::init("tf32")}; 52 53 MmaSyncF32Lowering tf32Precision = 54 llvm::StringSwitch<MmaSyncF32Lowering>(precision) 55 .Case("tf32", MmaSyncF32Lowering::TF32) 56 .Case("tf32x3", MmaSyncF32Lowering::TF32x3) 57 .Default(MmaSyncF32Lowering::Unkown); 58 59 void runOnOperation() override { 60 RewritePatternSet patterns(&getContext()); 61 62 populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision); 63 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 64 } 65 }; 66 67 } // namespace 68 69 namespace mlir { 70 namespace test { 71 void registerTestNVGPULowerings() { 72 PassRegistration<TestMmaSyncF32ToTF32Patterns>(); 73 } 74 75 } // namespace test 76 } // namespace mlir 77