1// Check that the wide integer emulation produces the same result as wide
2// calculations. Emulate i16 ops with i8 ops.
3
4// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \
5// RUN:             --convert-vector-to-scf --convert-scf-to-cf --convert-cf-to-llvm \
6// RUN:             --convert-vector-to-llvm --convert-func-to-llvm --convert-arith-to-llvm \
7// RUN:             --reconcile-unrealized-casts | \
8// RUN:   mlir-runner -e entry -entry-point-result=void \
9// RUN:      --shared-libs="%mlir_c_runner_utils,%mlir_runner_utils" | \
10// RUN:   FileCheck %s
11
12// CHECK-NOT: Mismatch
13
14//===----------------------------------------------------------------------===//
15// Common Utility Functions
16//===----------------------------------------------------------------------===//
17
18// Prints both binary op operands and the first result. If the second result
19// does not match, prints the second result and a 'Mismatch' message.
20func.func @check_results(%lhs : i16, %rhs : i16, %res0 : i16, %res1 : i16) -> () {
21  %vec_zero = arith.constant dense<0> : vector<2xi16>
22  %ins0 = vector.insert %lhs, %vec_zero[0] : i16 into vector<2xi16>
23  %operands = vector.insert %rhs, %ins0[1] : i16 into vector<2xi16>
24  vector.print %operands : vector<2xi16>
25  vector.print %res0 : i16
26  %mismatch = arith.cmpi ne, %res0, %res1 : i16
27  scf.if %mismatch -> () {
28    vector.print %res1 : i16
29    vector.print str "Mismatch\n"
30  }
31  return
32}
33
34func.func @xorshift(%i : i16) -> (i16) {
35  %cst8 = arith.constant 8 : i16
36  %shifted = arith.shrui %i, %cst8 : i16
37  %res = arith.xori %i, %shifted : i16
38  return %res : i16
39}
40
41// Returns a hash of the input number. This is used we want to sample a bunch
42// of i16 inputs with close to uniform distribution but without fixed offsets
43// between each sample.
44func.func @xhash(%i : i16) -> (i16) {
45  %pattern = arith.constant 21845 : i16 // Alternating ones and zeros.
46  %prime = arith.constant 25867 : i16   // Large i16 prime.
47  %xi = func.call @xorshift(%i) : (i16) -> (i16)
48  %inner = arith.muli %xi, %pattern : i16
49  %xinner = func.call @xorshift(%inner) : (i16) -> (i16)
50  %res = arith.muli %xinner, %prime : i16
51  return %res : i16
52}
53
54//===----------------------------------------------------------------------===//
55// Test arith.addi
56//===----------------------------------------------------------------------===//
57
58// Ops in this function will be emulated using i8 ops.
59func.func @emulate_addi(%lhs : i16, %rhs : i16) -> (i16) {
60  %res = arith.addi %lhs, %rhs : i16
61  return %res : i16
62}
63
64// Performs both wide and emulated `arith.muli`, and checks that the results
65// match.
66func.func @check_addi(%lhs : i16, %rhs : i16) -> () {
67  %wide = arith.addi %lhs, %rhs : i16
68  %emulated = func.call @emulate_addi(%lhs, %rhs) : (i16, i16) -> (i16)
69  func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
70  return
71}
72
73// Checks that `arith.addi` is emulated properly by sampling the input space.
74// In total, this test function checks 500 * 500 = 250k input pairs.
75func.func @test_addi() -> () {
76  %idx0 = arith.constant 0 : index
77  %idx1 = arith.constant 1 : index
78  %idx500 = arith.constant 500 : index
79
80  %cst0 = arith.constant 0 : i16
81  %cst1 = arith.constant 1 : i16
82
83  scf.for %lhs_idx = %idx0 to %idx500 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
84    %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
85
86    scf.for %rhs_idx = %idx0 to %idx500 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
87        %arg_rhs = func.call @xhash(%rhs) : (i16) -> (i16)
88        func.call @check_addi(%arg_lhs, %arg_rhs) : (i16, i16) -> ()
89
90        %rhs_next = arith.addi %rhs, %cst1 : i16
91        scf.yield %rhs_next : i16
92    }
93
94    %lhs_next = arith.addi %lhs, %cst1 : i16
95    scf.yield %lhs_next : i16
96  }
97
98  return
99}
100
101//===----------------------------------------------------------------------===//
102// Test arith.muli
103//===----------------------------------------------------------------------===//
104
105// Ops in this function will be emulated using i8 ops.
106func.func @emulate_muli(%lhs : i16, %rhs : i16) -> (i16) {
107  %res = arith.muli %lhs, %rhs : i16
108  return %res : i16
109}
110
111// Performs both wide and emulated `arith.muli`, and checks that the results
112// match.
113func.func @check_muli(%lhs : i16, %rhs : i16) -> () {
114  %wide = arith.muli %lhs, %rhs : i16
115  %emulated = func.call @emulate_muli(%lhs, %rhs) : (i16, i16) -> (i16)
116  func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
117  return
118}
119
120// Checks that `arith.muli` is emulated properly by sampling the input space.
121// In total, this test function checks 500 * 500 = 250k input pairs.
122func.func @test_muli() -> () {
123  %idx0 = arith.constant 0 : index
124  %idx1 = arith.constant 1 : index
125  %idx500 = arith.constant 500 : index
126
127  %cst0 = arith.constant 0 : i16
128  %cst1 = arith.constant 1 : i16
129
130  scf.for %lhs_idx = %idx0 to %idx500 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
131    %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
132
133    scf.for %rhs_idx = %idx0 to %idx500 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
134        %arg_rhs = func.call @xhash(%rhs) : (i16) -> (i16)
135        func.call @check_muli(%arg_lhs, %arg_rhs) : (i16, i16) -> ()
136
137        %rhs_next = arith.addi %rhs, %cst1 : i16
138        scf.yield %rhs_next : i16
139    }
140
141    %lhs_next = arith.addi %lhs, %cst1 : i16
142    scf.yield %lhs_next : i16
143  }
144
145  return
146}
147
148//===----------------------------------------------------------------------===//
149// Test arith.shli
150//===----------------------------------------------------------------------===//
151
152// Ops in this function will be emulated using i8 ops.
153func.func @emulate_shli(%lhs : i16, %rhs : i16) -> (i16) {
154  %res = arith.shli %lhs, %rhs : i16
155  return %res : i16
156}
157
158// Performs both wide and emulated `arith.shli`, and checks that the results
159// match.
160func.func @check_shli(%lhs : i16, %rhs : i16) -> () {
161  %wide = arith.shli %lhs, %rhs : i16
162  %emulated = func.call @emulate_shli(%lhs, %rhs) : (i16, i16) -> (i16)
163  func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
164  return
165}
166
167// Checks that `arith.shli` is emulated properly by sampling the input space.
168// Checks all valid shift amounts for i16: 0 to 15.
169// In total, this test function checks 100 * 16 = 1.6k input pairs.
170func.func @test_shli() -> () {
171  %idx0 = arith.constant 0 : index
172  %idx1 = arith.constant 1 : index
173  %idx16 = arith.constant 16 : index
174  %idx100 = arith.constant 100 : index
175
176  %cst0 = arith.constant 0 : i16
177  %cst1 = arith.constant 1 : i16
178
179  scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
180    %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
181
182    scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
183        func.call @check_shli(%arg_lhs, %rhs) : (i16, i16) -> ()
184        %rhs_next = arith.addi %rhs, %cst1 : i16
185        scf.yield %rhs_next : i16
186    }
187
188    %lhs_next = arith.addi %lhs, %cst1 : i16
189    scf.yield %lhs_next : i16
190  }
191
192  return
193}
194
195//===----------------------------------------------------------------------===//
196// Test arith.shrsi
197//===----------------------------------------------------------------------===//
198
199// Ops in this function will be emulated using i8 ops.
200func.func @emulate_shrsi(%lhs : i16, %rhs : i16) -> (i16) {
201  %res = arith.shrsi %lhs, %rhs : i16
202  return %res : i16
203}
204
205// Performs both wide and emulated `arith.shrsi`, and checks that the results
206// match.
207func.func @check_shrsi(%lhs : i16, %rhs : i16) -> () {
208  %wide = arith.shrsi %lhs, %rhs : i16
209  %emulated = func.call @emulate_shrsi(%lhs, %rhs) : (i16, i16) -> (i16)
210  func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
211  return
212}
213
214// Checks that `arith.shrus` is emulated properly by sampling the input space.
215// Checks all valid shift amounts for i16: 0 to 15.
216// In total, this test function checks 100 * 16 = 1.6k input pairs.
217func.func @test_shrsi() -> () {
218  %idx0 = arith.constant 0 : index
219  %idx1 = arith.constant 1 : index
220  %idx16 = arith.constant 16 : index
221  %idx100 = arith.constant 100 : index
222
223  %cst0 = arith.constant 0 : i16
224  %cst1 = arith.constant 1 : i16
225
226  scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
227    %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
228
229    scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
230        func.call @check_shrsi(%arg_lhs, %rhs) : (i16, i16) -> ()
231        %rhs_next = arith.addi %rhs, %cst1 : i16
232        scf.yield %rhs_next : i16
233    }
234
235    %lhs_next = arith.addi %lhs, %cst1 : i16
236    scf.yield %lhs_next : i16
237  }
238
239  return
240}
241
242//===----------------------------------------------------------------------===//
243// Test arith.shrui
244//===----------------------------------------------------------------------===//
245
246// Ops in this function will be emulated using i8 ops.
247func.func @emulate_shrui(%lhs : i16, %rhs : i16) -> (i16) {
248  %res = arith.shrui %lhs, %rhs : i16
249  return %res : i16
250}
251
252// Performs both wide and emulated `arith.shrui`, and checks that the results
253// match.
254func.func @check_shrui(%lhs : i16, %rhs : i16) -> () {
255  %wide = arith.shrui %lhs, %rhs : i16
256  %emulated = func.call @emulate_shrui(%lhs, %rhs) : (i16, i16) -> (i16)
257  func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
258  return
259}
260
261// Checks that `arith.shrui` is emulated properly by sampling the input space.
262// Checks all valid shift amounts for i16: 0 to 15.
263// In total, this test function checks 100 * 16 = 1.6k input pairs.
264func.func @test_shrui() -> () {
265  %idx0 = arith.constant 0 : index
266  %idx1 = arith.constant 1 : index
267  %idx16 = arith.constant 16 : index
268  %idx100 = arith.constant 100 : index
269
270  %cst0 = arith.constant 0 : i16
271  %cst1 = arith.constant 1 : i16
272
273  scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
274    %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
275
276    scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
277        func.call @check_shrui(%arg_lhs, %rhs) : (i16, i16) -> ()
278        %rhs_next = arith.addi %rhs, %cst1 : i16
279        scf.yield %rhs_next : i16
280    }
281
282    %lhs_next = arith.addi %lhs, %cst1 : i16
283    scf.yield %lhs_next : i16
284  }
285
286  return
287}
288
289//===----------------------------------------------------------------------===//
290// Entry Point
291//===----------------------------------------------------------------------===//
292
293func.func @entry() {
294  func.call @test_addi() : () -> ()
295  func.call @test_muli() : () -> ()
296  func.call @test_shli() : () -> ()
297  func.call @test_shrsi() : () -> ()
298  func.call @test_shrui() : () -> ()
299  return
300}
301