xref: /llvm-project/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir (revision 0a1569a400491e264060b8a6ff7b7f64e1865496)
1// RUN: mlir-opt %s -test-nvgpu-mmasync-f32-to-tf32-patterns="precision=tf32" -split-input-file | FileCheck %s
2
3// CHECK-LABEL: m16n8k4_tf32
4func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
5  // CHECK: nvgpu.mma.sync
6  // CHECK-SAME: tf32Enabled
7  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
8  return %d : vector<2x2xf32>
9}
10
11// -----
12
13// CHECK-LABEL: m16n8k8_tf32
14func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
15  // CHECK: nvgpu.mma.sync
16  // CHECK-SAME: tf32Enabled
17  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
18  return %d : vector<2x2xf32>
19}
20// -----
21
22// Negative test for non f32 case.
23// CHECK-LABEL: mma_sync_f16
24//   CHECK-NOT: tf32Enabled
25//       CHECK: return
26func.func @mma_sync_f16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
27  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
28  return %d : vector<2x2xf16>
29}
30