1// RUN: mlir-opt %s -test-nvgpu-mmasync-f32-to-tf32-patterns="precision=tf32" -split-input-file | FileCheck %s 2 3// CHECK-LABEL: m16n8k4_tf32 4func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { 5 // CHECK: nvgpu.mma.sync 6 // CHECK-SAME: tf32Enabled 7 %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> 8 return %d : vector<2x2xf32> 9} 10 11// ----- 12 13// CHECK-LABEL: m16n8k8_tf32 14func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { 15 // CHECK: nvgpu.mma.sync 16 // CHECK-SAME: tf32Enabled 17 %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> 18 return %d : vector<2x2xf32> 19} 20// ----- 21 22// Negative test for non f32 case. 23// CHECK-LABEL: mma_sync_f16 24// CHECK-NOT: tf32Enabled 25// CHECK: return 26func.func @mma_sync_f16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { 27 %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> 28 return %d : vector<2x2xf16> 29} 30