xref: /llvm-project/mlir/test/Dialect/LLVMIR/nvvm.mlir (revision d4159e2a1d1d640077b2e5cde66b0a284049955f)
1// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
2
3// CHECK-LABEL: @nvvm_special_regs
4func.func @nvvm_special_regs() -> i32 {
5  // CHECK: nvvm.read.ptx.sreg.tid.x : i32
6  %0 = nvvm.read.ptx.sreg.tid.x : i32
7  // CHECK: nvvm.read.ptx.sreg.tid.y : i32
8  %1 = nvvm.read.ptx.sreg.tid.y : i32
9  // CHECK: nvvm.read.ptx.sreg.tid.z : i32
10  %2 = nvvm.read.ptx.sreg.tid.z : i32
11  // CHECK: nvvm.read.ptx.sreg.ntid.x : i32
12  %3 = nvvm.read.ptx.sreg.ntid.x : i32
13  // CHECK: nvvm.read.ptx.sreg.ntid.y : i32
14  %4 = nvvm.read.ptx.sreg.ntid.y : i32
15  // CHECK: nvvm.read.ptx.sreg.ntid.z : i32
16  %5 = nvvm.read.ptx.sreg.ntid.z : i32
17  // CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
18  %6 = nvvm.read.ptx.sreg.ctaid.x : i32
19  // CHECK: nvvm.read.ptx.sreg.ctaid.y : i32
20  %7 = nvvm.read.ptx.sreg.ctaid.y : i32
21  // CHECK: nvvm.read.ptx.sreg.ctaid.z : i32
22  %8 = nvvm.read.ptx.sreg.ctaid.z : i32
23  // CHECK: nvvm.read.ptx.sreg.nctaid.x : i32
24  %9 = nvvm.read.ptx.sreg.nctaid.x : i32
25  // CHECK: nvvm.read.ptx.sreg.nctaid.y : i32
26  %10 = nvvm.read.ptx.sreg.nctaid.y : i32
27  // CHECK: nvvm.read.ptx.sreg.nctaid.z : i32
28  %11 = nvvm.read.ptx.sreg.nctaid.z : i32
29  llvm.return %0 : i32
30}
31
32// CHECK-LABEL: @nvvm_rcp
33func.func @nvvm_rcp(%arg0: f32) -> f32 {
34  // CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32
35  %0 = nvvm.rcp.approx.ftz.f %arg0 : f32
36  llvm.return %0 : f32
37}
38
39// CHECK-LABEL: @llvm_nvvm_barrier0
40func.func @llvm_nvvm_barrier0() {
41  // CHECK: nvvm.barrier0
42  nvvm.barrier0
43  llvm.return
44}
45
46// CHECK-LABEL: @llvm_nvvm_barrier
47// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
48llvm.func @llvm_nvvm_barrier(%barId : i32, %numberOfThreads : i32) {
49  // CHECK: nvvm.barrier
50  nvvm.barrier
51  // CHECK: nvvm.barrier id = %[[barId]]
52  nvvm.barrier id = %barId
53  // CHECK: nvvm.barrier id = %[[barId]] number_of_threads = %[[numberOfThreads]]
54  nvvm.barrier id = %barId number_of_threads = %numberOfThreads
55  llvm.return
56}
57
58// CHECK-LABEL: @llvm_nvvm_barrier_arrive
59// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
60llvm.func @llvm_nvvm_barrier_arrive(%barId : i32, %numberOfThreads : i32) {
61  // CHECK: nvvm.barrier.arrive number_of_threads = %[[numberOfThreads]]
62  nvvm.barrier.arrive number_of_threads = %numberOfThreads
63  // CHECK: nvvm.barrier.arrive id = %[[barId]] number_of_threads = %[[numberOfThreads]]
64  nvvm.barrier.arrive id = %barId number_of_threads = %numberOfThreads
65  llvm.return
66}
67
68// CHECK-LABEL: @llvm_nvvm_cluster_arrive
69func.func @llvm_nvvm_cluster_arrive() {
70  // CHECK: nvvm.cluster.arrive
71  nvvm.cluster.arrive
72  // CHECK: nvvm.cluster.arrive {aligned}
73  nvvm.cluster.arrive {aligned}
74  llvm.return
75}
76
77// CHECK-LABEL: @llvm_nvvm_cluster_arrive_relaxed
78func.func @llvm_nvvm_cluster_arrive_relaxed() {
79  // CHECK: nvvm.cluster.arrive.relaxed
80  nvvm.cluster.arrive.relaxed
81  // CHECK: nvvm.cluster.arrive.relaxed {aligned}
82  nvvm.cluster.arrive.relaxed {aligned}
83  llvm.return
84}
85
86// CHECK-LABEL: @llvm_nvvm_cluster_wait
87func.func @llvm_nvvm_cluster_wait() {
88  // CHECK: nvvm.cluster.wait
89  nvvm.cluster.wait
90  // CHECK: nvvm.cluster.wait {aligned}
91  nvvm.cluster.wait {aligned}
92  llvm.return
93}
94
95// CHECK-LABEL: @llvm_nvvm_fence_sc_cluster
96func.func @llvm_nvvm_fence_sc_cluster() {
97  // CHECK: nvvm.fence.sc.cluster
98  nvvm.fence.sc.cluster
99  llvm.return
100}
101
102// CHECK-LABEL: @nvvm_shfl
103func.func @nvvm_shfl(
104    %arg0 : i32, %arg1 : i32, %arg2 : i32,
105    %arg3 : i32, %arg4 : f32) -> i32 {
106  // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32 -> i32
107  %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 : i32 -> i32
108  // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
109  %1 = nvvm.shfl.sync bfly %arg0, %arg4, %arg1, %arg2 : f32 -> f32
110  // CHECK: nvvm.shfl.sync up %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
111  %2 = nvvm.shfl.sync up %arg0, %arg4, %arg1, %arg2 : f32 -> f32
112  // CHECK: nvvm.shfl.sync down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
113  %3 = nvvm.shfl.sync down %arg0, %arg4, %arg1, %arg2 : f32 -> f32
114  // CHECK: nvvm.shfl.sync idx %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
115  %4 = nvvm.shfl.sync idx %arg0, %arg4, %arg1, %arg2 : f32 -> f32
116  llvm.return %0 : i32
117}
118
119// CHECK-LABEL: @nvvm_shfl_pred
120func.func @nvvm_shfl_pred(
121    %arg0 : i32, %arg1 : i32, %arg2 : i32,
122    %arg3 : i32, %arg4 : f32) -> !llvm.struct<(i32, i1)> {
123  // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
124  %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
125  // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
126  %1 = nvvm.shfl.sync bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
127  llvm.return %0 : !llvm.struct<(i32, i1)>
128}
129
130// CHECK-LABEL: @nvvm_vote(
131func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
132  // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
133  %0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
134  llvm.return %0 : i32
135}
136
137// CHECK-LABEL: @llvm_nvvm_bar_warp_sync
138func.func @llvm_nvvm_bar_warp_sync(%mask : i32) {
139  // CHECK: nvvm.bar.warp.sync %{{.*}}
140  nvvm.bar.warp.sync %mask : i32
141  llvm.return
142}
143
144// CHECK-LABEL: @nvvm_mma_m8n8k4_row_col_f32_f32
145func.func @nvvm_mma_m8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
146               %b0 : vector<2xf16>, %b1 : vector<2xf16>,
147               %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
148  // CHECK: nvvm.mma.sync
149  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
150    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
151     shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
152  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
153}
154
155// CHECK-LABEL: @nvvm_mma_m8n8k4_f16_f16
156func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
157                              %b0 : vector<2xf16>, %b1 : vector<2xf16>,
158                              %c0 : vector<2xf16>, %c1 : vector<2xf16>, %c2 : vector<2xf16>, %c3 : vector<2xf16>) {
159  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}]
160  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
161    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
162     shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
163  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
164}
165
166// CHECK-LABEL: @nvvm_mma_m16n8k8_bf16_bf16
167func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32,
168                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
169  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
170  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
171    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
172     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
173     shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
174  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
175}
176
177// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
178func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
179                              %b0 : i32, %b1 : i32,
180                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
181  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
182  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
183    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
184     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
185     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
186  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
187}
188
189// CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
190func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
191                             %c0 : i32, %c1 : i32) {
192  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 8, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
193  %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
194    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
195     multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
196     intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
197     shape = #nvvm.shape<m = 8, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
198  llvm.return %0 : !llvm.struct<(i32, i32)>
199}
200
201// CHECK-LABEL: @nvvm_mma_m16n8k8_f16_f16
202func.func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
203                               %b0 : vector<2xf16>,
204                               %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
205  // CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
206  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1]
207    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
208     shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
209  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
210}
211
212// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f16
213func.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
214                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
215                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
216                                %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
217  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
218  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
219    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
220     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
221  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
222}
223
224// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
225func.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
226                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
227                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
228                                %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
229  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
230  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
231    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
232     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
233  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
234}
235
236// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f32
237func.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
238                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
239                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
240                                %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
241  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
242  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
243    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
244     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
245  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
246}
247
248// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f32
249func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
250                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
251                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
252                                %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
253  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
254  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
255    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
256     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
257  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
258}
259
260// CHECK-LABEL: @nvvm_mma_m16n8k4_tf32_f32
261func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
262                                     %b0 : i32,
263                                     %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
264  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 4>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
265  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
266    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
267     multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
268     shape = #nvvm.shape<m = 16, n = 8, k = 4>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
269  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
270}
271
272// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_s8
273func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
274                              %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
275  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
276  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
277    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
278     multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
279     intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
280     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
281  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
282}
283
284// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_u8
285func.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
286                                %b0 : i32,
287                                %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
288  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
289  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
290    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
291     multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>,
292     intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
293     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
294  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
295}
296
297// CHECK-LABEL: @nvvm_mma_m16n8k256_b1_b1
298func.func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
299                               %b0 : i32, %b1 : i32,
300                               %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
301  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = #nvvm.shape<m = 16, n = 8, k = 256>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
302  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
303    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
304     multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
305     b1Op = #nvvm.mma_b1op<xor_popc>, shape = #nvvm.shape<m = 16, n = 8, k = 256>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
306  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
307}
308
309// CHECK-LABEL: @nvvm_mma_m16n8k128_b1_b1
310func.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
311                               %b0 : i32,
312                               %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
313  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
314  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
315    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
316     multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
317     b1Op = #nvvm.mma_b1op<xor_popc>,
318     shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
319  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
320}
321
322// CHECK-LABEL: @nvvm_mma_m8n8k128_b1_b1
323func.func @nvvm_mma_m8n8k128_b1_b1(%a0 : i32,
324                              %b0 : i32,
325                              %c0 : i32, %c1 : i32) {
326  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = #nvvm.shape<m = 8, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
327  %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
328    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
329     multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
330     b1Op = #nvvm.mma_b1op<xor_popc>, shape = #nvvm.shape<m = 8, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32,i32)>
331  llvm.return %0 : !llvm.struct<(i32,i32)>
332}
333
334// CHECK-LABEL: @nvvm_mma_m16n8k32_s4_s4
335func.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
336                               %b0 : i32,
337                               %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
338  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
339  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
340    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
341     multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
342     intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
343     shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
344  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
345}
346
347// CHECK-LABEL: @nvvm_wmma_load_tf32
348func.func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr, %arg1 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
349  // CHECK: nvvm.wmma.load {{.*}} {eltype = #nvvm.mma_type<tf32>, frag = #nvvm.mma_frag<a>, k = 8 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32}
350  %0 = nvvm.wmma.load %arg0, %arg1
351    {eltype = #nvvm.mma_type<tf32>, frag = #nvvm.mma_frag<a>, k = 8 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32}
352    : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32)>
353  llvm.return %0 : !llvm.struct<(i32, i32, i32, i32)>
354}
355
356// CHECK-LABEL: @nvvm_wmma_mma
357func.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32,
358                    %6 : i32, %7 : i32, %8 : f32, %9 : f32, %10 : f32,
359                    %11 : f32, %12 : f32, %13 : f32, %14 : f32, %15 : f32)
360                   -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> {
361  // CHECK: nvvm.wmma.mma {{.*}} {eltypeA = #nvvm.mma_type<tf32>, eltypeB = #nvvm.mma_type<f32>, k = 8 : i32, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32}
362  %r = nvvm.wmma.mma %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15
363    {eltypeA = #nvvm.mma_type<tf32>, eltypeB = #nvvm.mma_type<f32>, k = 8 : i32, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32}
364    : (i32, i32, i32, i32, i32, i32, i32, i32, f32, f32, f32, f32, f32, f32, f32, f32)
365    -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
366  llvm.return %r : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
367}
368
369// CHECK-LABEL: @cp_async
370llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
371// CHECK:  nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = ca
372  nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
373// CHECK:  nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg
374  nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1>
375// CHECK: nvvm.cp.async.commit.group
376  nvvm.cp.async.commit.group
377// CHECK: nvvm.cp.async.wait.group 0
378  nvvm.cp.async.wait.group 0
379  llvm.return
380}
381
382// CHECK-LABEL: llvm.func @ld_matrix
383llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
384  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<3>) -> i32
385  %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
386  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
387  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
388  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
389  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
390  llvm.return
391}
392
393// CHECK-LABEL: llvm.func @redux_sync
394llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
395  // CHECK: nvvm.redux.sync  add %{{.*}}
396  %r1 = nvvm.redux.sync add %value, %offset : i32 -> i32
397  // CHECK: nvvm.redux.sync  max %{{.*}}
398  %r2 = nvvm.redux.sync max %value, %offset : i32 -> i32
399  // CHECK: nvvm.redux.sync  min %{{.*}}
400  %r3 = nvvm.redux.sync min %value, %offset : i32 -> i32
401  // CHECK: nvvm.redux.sync  umax %{{.*}}
402  %r5 = nvvm.redux.sync umax %value, %offset : i32 -> i32
403  // CHECK: nvvm.redux.sync  umin %{{.*}}
404  %r6 = nvvm.redux.sync umin %value, %offset : i32 -> i32
405  // CHECK: nvvm.redux.sync  and %{{.*}}
406  %r7 = nvvm.redux.sync and %value, %offset : i32 -> i32
407  // CHECK: nvvm.redux.sync  or %{{.*}}
408  %r8 = nvvm.redux.sync or %value, %offset : i32 -> i32
409  // CHECK: nvvm.redux.sync  xor %{{.*}}
410  %r9 = nvvm.redux.sync xor %value, %offset : i32 -> i32
411  llvm.return %r1 : i32
412}
413
414
415// -----
416
417// expected-error@below {{attribute attached to unexpected op}}
418func.func private @expected_llvm_func() attributes { nvvm.kernel }
419
420// -----
421
422llvm.func private @mbarrier_init_generic(%barrier: !llvm.ptr) {
423  %count = nvvm.read.ptx.sreg.ntid.x : i32
424  // CHECK:   nvvm.mbarrier.init %{{.*}}, %{{.*}} : !llvm.ptr, i32
425  nvvm.mbarrier.init %barrier, %count : !llvm.ptr, i32
426  llvm.return
427}
428
429
430llvm.func private @mbarrier_init_shared(%barrier: !llvm.ptr<3>) {
431  %count = nvvm.read.ptx.sreg.ntid.x : i32
432  // CHECK:   nvvm.mbarrier.init.shared %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32
433  nvvm.mbarrier.init.shared %barrier, %count : !llvm.ptr<3>, i32
434  llvm.return
435}
436
437
438llvm.func private @mbarrier_inval_generic(%barrier: !llvm.ptr) {
439  // CHECK:   nvvm.mbarrier.inval %{{.*}} : !llvm.ptr
440  nvvm.mbarrier.inval %barrier : !llvm.ptr
441  llvm.return
442}
443
444
445llvm.func private @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) {
446  // CHECK:   nvvm.mbarrier.inval.shared %{{.*}} : !llvm.ptr<3>
447  nvvm.mbarrier.inval.shared %barrier : !llvm.ptr<3>
448  llvm.return
449}
450
451llvm.func private @mbarrier_arrive(%barrier: !llvm.ptr) {
452  // CHECK:   nvvm.mbarrier.arrive %{{.*}} : !llvm.ptr
453  %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr  -> i64
454  llvm.return
455}
456
457llvm.func private @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) {
458  // CHECK:   nvvm.mbarrier.arrive.shared %{{.*}} : !llvm.ptr<3>
459  %0 = nvvm.mbarrier.arrive.shared %barrier : !llvm.ptr<3> -> i64
460  llvm.return
461}
462
463llvm.func private @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) {
464  %count = nvvm.read.ptx.sreg.ntid.x : i32
465  // CHECK:   nvvm.mbarrier.arrive.nocomplete %{{.*}} : !llvm.ptr
466  %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr, i32 -> i64
467  llvm.return
468}
469
470llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
471  %count = nvvm.read.ptx.sreg.ntid.x : i32
472  // CHECK:   nvvm.mbarrier.arrive.nocomplete.shared %{{.*}} : !llvm.ptr<3>
473  %0 = nvvm.mbarrier.arrive.nocomplete.shared %barrier, %count : !llvm.ptr<3>, i32  -> i64
474  llvm.return
475}
476
477llvm.func private @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 {
478  // CHECK:   nvvm.mbarrier.test.wait %{{.*}}
479  %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1
480  llvm.return %isComplete : i1
481}
482
483llvm.func private @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) {
484  %count = nvvm.read.ptx.sreg.ntid.x : i32
485  // CHECK:   nvvm.mbarrier.test.wait.shared %{{.*}}
486  %isComplete = nvvm.mbarrier.test.wait.shared %barrier, %token : !llvm.ptr<3>, i64 -> i1
487  llvm.return
488}
489
490// CHECK-LABEL: @wgmma_fence_aligned
491func.func @wgmma_fence_aligned() {
492  // CHECK: nvvm.wgmma.fence.aligned
493  nvvm.wgmma.fence.aligned
494  return
495}
496
497// CHECK-LABEL: @wgmma_commit_group_sync_aligned
498func.func @wgmma_commit_group_sync_aligned() {
499  // CHECK: nvvm.wgmma.commit.group.sync.aligned
500  nvvm.wgmma.commit.group.sync.aligned
501  return
502}
503
504
505// CHECK-LABEL: @wgmma_wait_group_sync_aligned
506func.func @wgmma_wait_group_sync_aligned() {
507  // CHECK: nvvm.wgmma.wait.group.sync.aligned
508  nvvm.wgmma.wait.group.sync.aligned 0
509  return
510}
511
512func.func @griddepcontrol_wait() {
513  // CHECK: nvvm.griddepcontrol.wait
514  nvvm.griddepcontrol.wait
515  return
516}
517
518func.func @griddepcontrol_launch_dependents()
519{
520  // CHECK: nvvm.griddepcontrol.launch.dependents
521  nvvm.griddepcontrol.launch.dependents
522  return
523}
524
525// -----
526
527// Just check these don't emit errors.
528gpu.module @module_1 [#nvvm.target<chip = "sm_90", features = "+ptx70", link = ["my_device_lib.bc"], flags = {fast, ftz}>] {
529}
530
531gpu.module @module_2 [#nvvm.target<chip = "sm_90">, #nvvm.target<chip = "sm_80">, #nvvm.target<chip = "sm_70">] {
532}
533
534// CHECK-LABEL: nvvm.grid_constant
535llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
536  llvm.return
537}
538
539// -----
540
541// expected-error @below {{'"nvvm.grid_constant"' attribute must be present only on kernel arguments}}
542llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) {
543  llvm.return
544}
545
546// -----
547
548// expected-error @below {{'"nvvm.grid_constant"' attribute requires the argument to also have attribute 'llvm.byval'}}
549llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.kernel} {
550  llvm.return
551}
552
553// -----
554
555// expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
556llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
557  llvm.return
558}
559