xref: /llvm-project/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir (revision 435114f9fe2139bec770e5a95799f4eab20639e7)
1// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm="enable-x86vector" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts  | \
2// RUN: mlir-translate  --mlir-to-llvmir | \
3// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \
4// RUN: FileCheck %s
5
6// This test shows how to implement a sparse vector-vector dot product with
7// AVX512. It uses vp2intersect, mask.compress and vector.contract to compute
8// the dot product of two sparse HW vectors of 8 float64 elements ("segment").
9// Each sparse vector is represented by an index memref (A or C) and by a data
10// memref (B or D), containing M or N elements.
11//
12// There are four different implementations:
13// * `memref_dot_simple`: Simple O(N*M) implementation with two for loops.
14// * `memref_dot_optimized`: An optimized O(N*M) version of the previous
15//   implementation, where the second for loop skips over some elements.
16// * `memref_dot_while`: An optimized O(N+M) implementation that utilizes a
17//   single while loop, coiterating over both vectors.
18// * `memref_dot_while_branchless`: An optimized O(N+M) implementation that
19//   consists of a single while loop and has no branches within the loop.
20//
21// Output of llvm-mca:
22// https://gist.github.com/matthias-springer/72e7ee1b3c467e7aefb6e1fd862e4841
23
24#contraction_accesses = [
25 affine_map<(i) -> (i)>,
26 affine_map<(i) -> (i)>,
27 affine_map<(i) -> ()>
28]
29#contraction_trait = {
30  indexing_maps = #contraction_accesses,
31  iterator_types = ["reduction"]
32}
33
34// Sparse vector dot product of two vectors.
35func.func @vector_dot(%v_A : vector<8xi64>, %v_B : vector<8xf64>,
36                 %v_C : vector<8xi64>, %v_D : vector<8xf64>) -> f64 {
37  // Compute intersection of indices.
38  %k0, %k1 = x86vector.avx512.vp2intersect %v_A, %v_C : vector<8xi64>
39
40  // Filter out values without match and compress vector.
41  %p0 = x86vector.avx512.mask.compress %k0, %v_B : vector<8xf64>
42  %p1 = x86vector.avx512.mask.compress %k1, %v_D : vector<8xf64>
43
44  // Dense vector dot product.
45  %acc = arith.constant 0.0 : f64
46  %r = vector.contract #contraction_trait %p0, %p1, %acc
47      : vector<8xf64>, vector<8xf64> into f64
48
49  return %r : f64
50}
51
52// Fill input memrefs will all zeros, so that they can be used with arbitrary
53// input sizes up to 128 elements per sparse vector.
54func.func @init_input(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
55                 %m_C : memref<?xi64>, %m_D : memref<?xf64>) {
56  %c0 = arith.constant 0 : index
57  %v_data = arith.constant dense<0.0> : vector<128xf64>
58  %v_index = arith.constant dense<9223372036854775807> : vector<128xi64>
59
60  vector.transfer_write %v_index, %m_A[%c0] : vector<128xi64>, memref<?xi64>
61  vector.transfer_write %v_data, %m_B[%c0] : vector<128xf64>, memref<?xf64>
62  vector.transfer_write %v_index, %m_C[%c0] : vector<128xi64>, memref<?xi64>
63  vector.transfer_write %v_data, %m_D[%c0] : vector<128xf64>, memref<?xf64>
64
65  return
66}
67
68func.func @fill_input_1(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
69                   %m_C : memref<?xi64>, %m_D : memref<?xf64>)
70    -> (index, index){
71  func.call @init_input(%m_A, %m_B, %m_C, %m_D)
72      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) -> ()
73
74  %c0 = arith.constant 0 : index
75
76  %v_A = arith.constant dense<[0,  1,  10, 12, 13, 17, 18, 21,
77                            51, 52, 57, 61, 62, 82, 98, 99]> : vector<16xi64>
78  %v_B = arith.constant dense<[1., 5., 8., 3., 2., 1., 0., 9.,
79                            6., 7., 7., 3., 5., 2., 9., 1.]> : vector<16xf64>
80  %v_C = arith.constant dense<[1,  2,  5,  10, 11, 12, 47, 48,
81                            67, 68, 69, 70, 71, 72, 77, 78,
82                            79, 82, 83, 84, 85, 90, 91, 98]> : vector<24xi64>
83  %v_D = arith.constant dense<[1., 5., 8., 3., 2., 1., 2., 9.,
84                            6., 7., 7., 3., 5., 2., 9., 1.,
85                            2., 9., 8., 7., 2., 0., 0., 4.]> : vector<24xf64>
86
87  vector.transfer_write %v_A, %m_A[%c0] : vector<16xi64>, memref<?xi64>
88  vector.transfer_write %v_B, %m_B[%c0] : vector<16xf64>, memref<?xf64>
89  vector.transfer_write %v_C, %m_C[%c0] : vector<24xi64>, memref<?xi64>
90  vector.transfer_write %v_D, %m_D[%c0] : vector<24xf64>, memref<?xf64>
91
92  %M = arith.constant 16 : index
93  %N = arith.constant 24 : index
94
95  return %M, %N : index, index
96}
97
98func.func @fill_input_2(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
99                   %m_C : memref<?xi64>, %m_D : memref<?xf64>)
100    -> (index, index){
101  func.call @init_input(%m_A, %m_B, %m_C, %m_D)
102      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) -> ()
103
104  %c0 = arith.constant 0 : index
105
106  %v_A = arith.constant dense<[0,  1,  3,  5,  6,  7,  8,  9,
107                            51, 52, 57, 61, 62, 63, 65, 66]> : vector<16xi64>
108  %v_B = arith.constant dense<[1., 5., 8., 3., 2., 1., 2., 9.,
109                            6., 7., 7., 3., 5., 2., 9., 1.]> : vector<16xf64>
110  %v_C = arith.constant dense<[6,  7,  11, 12, 15, 17, 19, 21,
111                            30, 31, 33, 34, 37, 39, 40, 41,
112                            42, 44, 45, 46, 47, 48, 49, 50,
113                            62, 63, 64, 65, 66, 67, 68, 69,
114                            70, 77, 78, 79, 81, 82, 89, 99]> : vector<40xi64>
115  %v_D = arith.constant dense<[1., 5., 8., 3., 2., 1., 2., 9.,
116                            6., 7., 7., 3., 5., 2., 9., 1.,
117                            2., 9., 8., 7., 2., 1., 2., 4.,
118                            4., 5., 8., 8., 2., 3., 5., 1.,
119                            8., 6., 6., 4., 3., 8., 9., 2.]> : vector<40xf64>
120
121  vector.transfer_write %v_A, %m_A[%c0] : vector<16xi64>, memref<?xi64>
122  vector.transfer_write %v_B, %m_B[%c0] : vector<16xf64>, memref<?xf64>
123  vector.transfer_write %v_C, %m_C[%c0] : vector<40xi64>, memref<?xi64>
124  vector.transfer_write %v_D, %m_D[%c0] : vector<40xf64>, memref<?xf64>
125
126  %M = arith.constant 16 : index
127  %N = arith.constant 40 : index
128
129  return %M, %N : index, index
130}
131
132// Simple vector dot product implementation: Intersect every segment of size 8
133// in (%m_A, %m_B) with every segment of size 8 in (%m_C, %m_D).
134func.func @memref_dot_simple(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
135                        %m_C : memref<?xi64>, %m_D : memref<?xf64>,
136                        %M : index, %N : index)
137    -> f64 {
138  // Helper constants for loops.
139  %c0 = arith.constant 0 : index
140  %c8 = arith.constant 8 : index
141
142  %data_zero = arith.constant 0.0 : f64
143  %index_padding = arith.constant 9223372036854775807 : i64
144
145  // Notation: %sum is the current (partial) aggregated dot product sum.
146
147  %r0 = scf.for %a = %c0 to %M step %c8
148      iter_args(%sum0 = %data_zero) -> (f64) {
149    %v_A = vector.transfer_read %m_A[%a], %index_padding
150        : memref<?xi64>, vector<8xi64>
151    %v_B = vector.transfer_read %m_B[%a], %data_zero
152        : memref<?xf64>, vector<8xf64>
153
154    %r1 = scf.for %b = %c0 to %N step %c8
155        iter_args(%sum1 = %sum0) -> (f64) {
156      %v_C = vector.transfer_read %m_C[%b], %index_padding
157          : memref<?xi64>, vector<8xi64>
158      %v_D = vector.transfer_read %m_D[%b], %data_zero
159          : memref<?xf64>, vector<8xf64>
160
161      %subresult = func.call @vector_dot(%v_A, %v_B, %v_C, %v_D)
162          : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) -> f64
163      %r2 = arith.addf %sum1, %subresult : f64
164      scf.yield %r2 : f64
165    }
166
167    scf.yield %r1 : f64
168  }
169
170  return %r0 : f64
171}
172
173// Optimized vector dot product implementation: Taking advantage of the fact
174// that indices in %m_A and %m_C are sorted ascendingly, skip over segments
175// in (%m_C, %m_D) that are know to have no intersection with the current
176// segment from (%m_A, %m_B).
177func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
178                           %m_C : memref<?xi64>, %m_D : memref<?xf64>,
179                           %M : index, %N : index)
180    -> f64 {
181  // Helper constants for loops.
182  %c0 = arith.constant 0 : index
183  %i0 = arith.constant 0 : i32
184  %i7 = arith.constant 7 : i32
185  %c8 = arith.constant 8 : index
186
187  %data_zero = arith.constant 0.0 : f64
188  %index_padding = arith.constant 9223372036854775807 : i64
189
190  // Notation: %sum is the current (partial) aggregated dot product sum.
191  // %j_start is the value from which the inner for loop starts iterating. This
192  // value keeps increasing if earlier segments of (%m_C, %m_D) are known to
193  // be no longer needed.
194
195  %r0, %t0 = scf.for %a = %c0 to %M step %c8
196      iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) {
197    %v_A = vector.transfer_read %m_A[%a], %index_padding
198        : memref<?xi64>, vector<8xi64>
199    %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64>
200
201    %r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8
202        iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) {
203      %v_C = vector.transfer_read %m_C[%b], %index_padding
204          : memref<?xi64>, vector<8xi64>
205      %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
206      %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64
207
208      %r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) {
209        // %v_C segment is done, no need to examine this one again (ever).
210        %next_b_start2 = arith.addi %b_start1, %c8 : index
211        scf.yield %sum1, %next_b_start2 : f64, index
212      } else {
213        %v_B = vector.transfer_read %m_B[%a], %data_zero
214            : memref<?xf64>, vector<8xf64>
215        %v_D = vector.transfer_read %m_D[%b], %data_zero
216            : memref<?xf64>, vector<8xf64>
217
218        %subresult = func.call @vector_dot(%v_A, %v_B, %v_C, %v_D)
219            : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
220                -> f64
221        %r3 = arith.addf %sum1, %subresult : f64
222        scf.yield %r3, %b_start1 : f64, index
223      }
224
225      scf.yield %r2, %next_b_start1 : f64, index
226    }
227
228    scf.yield %r1, %next_b_start0 : f64, index
229  }
230
231  return %r0 : f64
232}
233
234// Vector dot product with a while loop. Implemented as follows:
235//
236// r = 0.0, a = 0, b = 0
237// while (a < M && b < N) {
238//   segA = A[a:a+8], segB = B[b:b+8]
239//   if   (segB[7] < segA[0]) b += 8
240//   elif (segA[7] < segB[0]) a += 8
241//   else {
242//     r += vector_dot(...)
243//     if   (segA[7] < segB[7]) a += 8
244//     elif (segB[7] < segA[7]) b += 8
245//     else                     a += 8, b += 8
246//   }
247// }
248func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
249                       %m_C : memref<?xi64>, %m_D : memref<?xf64>,
250                       %M : index, %N : index)
251    -> f64 {
252  // Helper constants for loops.
253  %c0 = arith.constant 0 : index
254  %i0 = arith.constant 0 : i32
255  %i7 = arith.constant 7 : i32
256  %c8 = arith.constant 8 : index
257
258  %data_zero = arith.constant 0.0 : f64
259  %index_padding = arith.constant 9223372036854775807 : i64
260
261  %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0)
262      : (f64, index, index) -> (f64, index, index) {
263    %cond_i = arith.cmpi "slt", %a1, %M : index
264    %cond_j = arith.cmpi "slt", %b1, %N : index
265    %cond = arith.andi %cond_i, %cond_j : i1
266    scf.condition(%cond) %r1, %a1, %b1 : f64, index, index
267  } do {
268  ^bb0(%r1 : f64, %a1 : index, %b1 : index):
269    // v_A, v_B, seg*_* could be part of the loop state to avoid a few
270    // redundant reads.
271    %v_A = vector.transfer_read %m_A[%a1], %index_padding
272        : memref<?xi64>, vector<8xi64>
273    %v_C = vector.transfer_read %m_C[%b1], %index_padding
274        : memref<?xi64>, vector<8xi64>
275
276    %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64>
277    %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
278    %segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64>
279    %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
280
281    %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64
282    %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) {
283      %b3 = arith.addi %b1, %c8 : index
284      scf.yield %r1, %a1, %b3 : f64, index, index
285    } else {
286      %seg0_done = arith.cmpi "slt", %segA_max, %segB_min : i64
287      %r4, %a4, %b4 = scf.if %seg0_done -> (f64, index, index) {
288        %a5 = arith.addi %a1, %c8 : index
289        scf.yield %r1, %a5, %b1 : f64, index, index
290      } else {
291        %v_B = vector.transfer_read %m_B[%a1], %data_zero
292            : memref<?xf64>, vector<8xf64>
293        %v_D = vector.transfer_read %m_D[%b1], %data_zero
294            : memref<?xf64>, vector<8xf64>
295
296        %subresult = func.call @vector_dot(%v_A, %v_B, %v_C, %v_D)
297            : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
298                -> f64
299        %r6 = arith.addf %r1, %subresult : f64
300
301        %incr_a = arith.cmpi "slt", %segA_max, %segB_max : i64
302        %a6, %b6 = scf.if %incr_a -> (index, index) {
303          %a7 = arith.addi %a1, %c8 : index
304          scf.yield %a7, %b1 : index, index
305        } else {
306          %incr_b = arith.cmpi "slt", %segB_max, %segA_max : i64
307          %a8, %b8 = scf.if %incr_b -> (index, index) {
308            %b9 = arith.addi %b1, %c8 : index
309            scf.yield %a1, %b9 : index, index
310          } else {
311            %a10 = arith.addi %a1, %c8 : index
312            %b10 = arith.addi %b1, %c8 : index
313            scf.yield %a10, %b10 : index, index
314          }
315          scf.yield %a8, %b8 : index, index
316        }
317        scf.yield %r6, %a6, %b6 : f64, index, index
318      }
319      scf.yield %r4, %a4, %b4 : f64, index, index
320    }
321    scf.yield %r2, %a2, %b2 : f64, index, index
322  }
323
324  return %r0 : f64
325}
326
327// Vector dot product with a while loop that has no branches (apart from the
328// while loop itself). Implemented as follows:
329//
330// r = 0.0, a = 0, b = 0
331// while (a < M && b < N) {
332//   segA = A[a:a+8], segB = B[b:b+8]
333//   r += vector_dot(...)
334//   a += (segA[7] <= segB[7]) * 8
335//   b += (segB[7] <= segA[7]) * 8
336// }
337func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
338                                  %m_C : memref<?xi64>, %m_D : memref<?xf64>,
339                                  %M : index, %N : index)
340    -> f64 {
341  // Helper constants for loops.
342  %c0 = arith.constant 0 : index
343  %i7 = arith.constant 7 : i32
344  %c8 = arith.constant 8 : index
345
346  %data_zero = arith.constant 0.0 : f64
347  %index_padding = arith.constant 9223372036854775807 : i64
348
349  %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0)
350      : (f64, index, index) -> (f64, index, index) {
351    %cond_i = arith.cmpi "slt", %a1, %M : index
352    %cond_j = arith.cmpi "slt", %b1, %N : index
353    %cond = arith.andi %cond_i, %cond_j : i1
354    scf.condition(%cond) %r1, %a1, %b1 : f64, index, index
355  } do {
356  ^bb0(%r1 : f64, %a1 : index, %b1 : index):
357    // v_A, v_B, seg*_* could be part of the loop state to avoid a few
358    // redundant reads.
359    %v_A = vector.transfer_read %m_A[%a1], %index_padding
360        : memref<?xi64>, vector<8xi64>
361    %v_B = vector.transfer_read %m_B[%a1], %data_zero
362        : memref<?xf64>, vector<8xf64>
363    %v_C = vector.transfer_read %m_C[%b1], %index_padding
364        : memref<?xi64>, vector<8xi64>
365    %v_D = vector.transfer_read %m_D[%b1], %data_zero
366        : memref<?xf64>, vector<8xf64>
367
368    %subresult = func.call @vector_dot(%v_A, %v_B, %v_C, %v_D)
369        : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
370            -> f64
371    %r2 = arith.addf %r1, %subresult : f64
372
373    %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
374    %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
375
376    %cond_a = arith.cmpi "sle", %segA_max, %segB_max : i64
377    %cond_a_i64 = arith.extui %cond_a : i1 to i64
378    %cond_a_idx = arith.index_cast %cond_a_i64 : i64 to index
379    %incr_a = arith.muli %cond_a_idx, %c8 : index
380    %a2 = arith.addi %a1, %incr_a : index
381
382    %cond_b = arith.cmpi "sle", %segB_max, %segA_max : i64
383    %cond_b_i64 = arith.extui %cond_b : i1 to i64
384    %cond_b_idx = arith.index_cast %cond_b_i64 : i64 to index
385    %incr_b = arith.muli %cond_b_idx, %c8 : index
386    %b2 = arith.addi %b1, %incr_b : index
387
388    scf.yield %r2, %a2, %b2 : f64, index, index
389  }
390
391  return %r0 : f64
392}
393
394func.func @entry() -> i32 {
395  // Initialize large buffers that can be used for multiple test cases of
396  // different sizes.
397  %b_A = memref.alloc() : memref<128xi64>
398  %b_B = memref.alloc() : memref<128xf64>
399  %b_C = memref.alloc() : memref<128xi64>
400  %b_D = memref.alloc() : memref<128xf64>
401
402  %m_A = memref.cast %b_A : memref<128xi64> to memref<?xi64>
403  %m_B = memref.cast %b_B : memref<128xf64> to memref<?xf64>
404  %m_C = memref.cast %b_C : memref<128xi64> to memref<?xi64>
405  %m_D = memref.cast %b_D : memref<128xf64> to memref<?xf64>
406
407  // --- Test case 1 ---.
408  // M and N must be a multiple of 8 if smaller than 128.
409  // (Because padding kicks in only for out-of-bounds accesses.)
410  %M1, %N1 = func.call @fill_input_1(%m_A, %m_B, %m_C, %m_D)
411      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>)
412          -> (index, index)
413
414  %r0 = func.call @memref_dot_simple(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
415      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
416         index, index) -> f64
417  vector.print %r0 : f64
418  // CHECK: 86
419
420  %r1 = func.call @memref_dot_optimized(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
421      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
422         index, index) -> f64
423  vector.print %r1 : f64
424  // CHECK: 86
425
426  %r2 = func.call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
427      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
428         index, index) -> f64
429  vector.print %r2 : f64
430  // CHECK: 86
431
432  %r6 = func.call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
433      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
434         index, index) -> f64
435  vector.print %r6 : f64
436  // CHECK: 86
437
438  // --- Test case 2 ---.
439  // M and N must be a multiple of 8 if smaller than 128.
440  // (Because padding kicks in only for out-of-bounds accesses.)
441  %M2, %N2 = func.call @fill_input_2(%m_A, %m_B, %m_C, %m_D)
442      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>)
443          -> (index, index)
444
445  %r3 = func.call @memref_dot_simple(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
446      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
447         index, index) -> f64
448  vector.print %r3 : f64
449  // CHECK: 111
450
451  %r4 = func.call @memref_dot_optimized(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
452      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
453         index, index) -> f64
454  vector.print %r4 : f64
455  // CHECK: 111
456
457  %r5 = func.call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
458      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
459         index, index) -> f64
460  vector.print %r5 : f64
461  // CHECK: 111
462
463  %r7 = func.call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
464      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
465         index, index) -> f64
466  vector.print %r7 : f64
467  // CHECK: 111
468
469  // Release all resources.
470  memref.dealloc %b_A : memref<128xi64>
471  memref.dealloc %b_B : memref<128xf64>
472  memref.dealloc %b_C : memref<128xi64>
473  memref.dealloc %b_D : memref<128xf64>
474
475  %r = arith.constant 0 : i32
476  return %r : i32
477}
478