114d79afeSManish Gupta //===- TestNVGPUTransforms.cpp - Test NVGPU transforms and lowerings ----===// 214d79afeSManish Gupta // 314d79afeSManish Gupta // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 414d79afeSManish Gupta // See https://llvm.org/LICENSE.txt for license information. 514d79afeSManish Gupta // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 614d79afeSManish Gupta // 714d79afeSManish Gupta //===----------------------------------------------------------------------===// 814d79afeSManish Gupta 914d79afeSManish Gupta #include <type_traits> 1014d79afeSManish Gupta 1114d79afeSManish Gupta #include "mlir/Analysis/SliceAnalysis.h" 1214d79afeSManish Gupta #include "mlir/Dialect/Affine/IR/AffineOps.h" 1314d79afeSManish Gupta #include "mlir/Dialect/Func/IR/FuncOps.h" 1414d79afeSManish Gupta #include "mlir/Dialect/GPU/IR/GPUDialect.h" 1514d79afeSManish Gupta #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1614d79afeSManish Gupta #include "mlir/Dialect/Linalg/IR/Linalg.h" 1714d79afeSManish Gupta #include "mlir/Dialect/Linalg/Passes.h" 1814d79afeSManish Gupta #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 1914d79afeSManish Gupta #include "mlir/Dialect/MemRef/IR/MemRef.h" 2014d79afeSManish Gupta #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" 2114d79afeSManish Gupta #include "mlir/Dialect/SCF/IR/SCF.h" 2214d79afeSManish Gupta #include "mlir/Pass/Pass.h" 2314d79afeSManish Gupta #include "mlir/Pass/PassManager.h" 2414d79afeSManish Gupta #include "mlir/Support/LLVM.h" 2514d79afeSManish Gupta #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 2614d79afeSManish Gupta 2714d79afeSManish Gupta using namespace mlir; 2814d79afeSManish Gupta using namespace mlir::nvgpu; 2914d79afeSManish Gupta 3014d79afeSManish Gupta namespace { 3114d79afeSManish Gupta 3214d79afeSManish Gupta struct TestMmaSyncF32ToTF32Patterns 3314d79afeSManish Gupta : public PassWrapper<TestMmaSyncF32ToTF32Patterns, 3414d79afeSManish Gupta OperationPass<func::FuncOp>> { 3514d79afeSManish Gupta MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMmaSyncF32ToTF32Patterns) 3614d79afeSManish Gupta 3714d79afeSManish Gupta StringRef getArgument() const final { 3814d79afeSManish Gupta return "test-nvgpu-mmasync-f32-to-tf32-patterns"; 3914d79afeSManish Gupta } 4014d79afeSManish Gupta StringRef getDescription() const final { 4114d79afeSManish Gupta return "Test patterns to convert mma.sync on f32 with tf32 precision"; 4214d79afeSManish Gupta } 4314d79afeSManish Gupta TestMmaSyncF32ToTF32Patterns() = default; 4414d79afeSManish Gupta TestMmaSyncF32ToTF32Patterns(const TestMmaSyncF32ToTF32Patterns &pass) 4514d79afeSManish Gupta : PassWrapper(pass) {} 4614d79afeSManish Gupta 4714d79afeSManish Gupta Option<std::string> precision{ 4814d79afeSManish Gupta *this, "precision", 4914d79afeSManish Gupta llvm::cl::desc( 5014d79afeSManish Gupta "Target nvgpu.mma.sync on f32 input with tf32 or tf32x3 precision"), 5114d79afeSManish Gupta llvm::cl::init("tf32")}; 5214d79afeSManish Gupta 5314d79afeSManish Gupta MmaSyncF32Lowering tf32Precision = 5414d79afeSManish Gupta llvm::StringSwitch<MmaSyncF32Lowering>(precision) 5514d79afeSManish Gupta .Case("tf32", MmaSyncF32Lowering::TF32) 5614d79afeSManish Gupta .Case("tf32x3", MmaSyncF32Lowering::TF32x3) 5714d79afeSManish Gupta .Default(MmaSyncF32Lowering::Unkown); 5814d79afeSManish Gupta 5914d79afeSManish Gupta void runOnOperation() override { 6014d79afeSManish Gupta RewritePatternSet patterns(&getContext()); 6114d79afeSManish Gupta 6214d79afeSManish Gupta populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision); 63*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 6414d79afeSManish Gupta } 6514d79afeSManish Gupta }; 6614d79afeSManish Gupta 6714d79afeSManish Gupta } // namespace 6814d79afeSManish Gupta 6914d79afeSManish Gupta namespace mlir { 7014d79afeSManish Gupta namespace test { 71baa5beecStyb0807 void registerTestNVGPULowerings() { 7214d79afeSManish Gupta PassRegistration<TestMmaSyncF32ToTF32Patterns>(); 7314d79afeSManish Gupta } 7414d79afeSManish Gupta 7514d79afeSManish Gupta } // namespace test 7614d79afeSManish Gupta } // namespace mlir 77