1// Check that the wide integer emulation produces the same result as wide 2// calculations. Emulate i16 ops with i8 ops. 3 4// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \ 5// RUN: --convert-vector-to-scf --convert-scf-to-cf --convert-cf-to-llvm \ 6// RUN: --convert-vector-to-llvm --convert-func-to-llvm --convert-arith-to-llvm \ 7// RUN: --reconcile-unrealized-casts | \ 8// RUN: mlir-runner -e entry -entry-point-result=void \ 9// RUN: --shared-libs="%mlir_c_runner_utils,%mlir_runner_utils" | \ 10// RUN: FileCheck %s 11 12// CHECK-NOT: Mismatch 13 14//===----------------------------------------------------------------------===// 15// Common Utility Functions 16//===----------------------------------------------------------------------===// 17 18// Prints both binary op operands and the first result. If the second result 19// does not match, prints the second result and a 'Mismatch' message. 20func.func @check_results(%lhs : i16, %rhs : i16, %res0 : i16, %res1 : i16) -> () { 21 %vec_zero = arith.constant dense<0> : vector<2xi16> 22 %ins0 = vector.insert %lhs, %vec_zero[0] : i16 into vector<2xi16> 23 %operands = vector.insert %rhs, %ins0[1] : i16 into vector<2xi16> 24 vector.print %operands : vector<2xi16> 25 vector.print %res0 : i16 26 %mismatch = arith.cmpi ne, %res0, %res1 : i16 27 scf.if %mismatch -> () { 28 vector.print %res1 : i16 29 vector.print str "Mismatch\n" 30 } 31 return 32} 33 34func.func @xorshift(%i : i16) -> (i16) { 35 %cst8 = arith.constant 8 : i16 36 %shifted = arith.shrui %i, %cst8 : i16 37 %res = arith.xori %i, %shifted : i16 38 return %res : i16 39} 40 41// Returns a hash of the input number. This is used we want to sample a bunch 42// of i16 inputs with close to uniform distribution but without fixed offsets 43// between each sample. 44func.func @xhash(%i : i16) -> (i16) { 45 %pattern = arith.constant 21845 : i16 // Alternating ones and zeros. 46 %prime = arith.constant 25867 : i16 // Large i16 prime. 47 %xi = func.call @xorshift(%i) : (i16) -> (i16) 48 %inner = arith.muli %xi, %pattern : i16 49 %xinner = func.call @xorshift(%inner) : (i16) -> (i16) 50 %res = arith.muli %xinner, %prime : i16 51 return %res : i16 52} 53 54//===----------------------------------------------------------------------===// 55// Test arith.addi 56//===----------------------------------------------------------------------===// 57 58// Ops in this function will be emulated using i8 ops. 59func.func @emulate_addi(%lhs : i16, %rhs : i16) -> (i16) { 60 %res = arith.addi %lhs, %rhs : i16 61 return %res : i16 62} 63 64// Performs both wide and emulated `arith.muli`, and checks that the results 65// match. 66func.func @check_addi(%lhs : i16, %rhs : i16) -> () { 67 %wide = arith.addi %lhs, %rhs : i16 68 %emulated = func.call @emulate_addi(%lhs, %rhs) : (i16, i16) -> (i16) 69 func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> () 70 return 71} 72 73// Checks that `arith.addi` is emulated properly by sampling the input space. 74// In total, this test function checks 500 * 500 = 250k input pairs. 75func.func @test_addi() -> () { 76 %idx0 = arith.constant 0 : index 77 %idx1 = arith.constant 1 : index 78 %idx500 = arith.constant 500 : index 79 80 %cst0 = arith.constant 0 : i16 81 %cst1 = arith.constant 1 : i16 82 83 scf.for %lhs_idx = %idx0 to %idx500 step %idx1 iter_args(%lhs = %cst0) -> (i16) { 84 %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16) 85 86 scf.for %rhs_idx = %idx0 to %idx500 step %idx1 iter_args(%rhs = %cst0) -> (i16) { 87 %arg_rhs = func.call @xhash(%rhs) : (i16) -> (i16) 88 func.call @check_addi(%arg_lhs, %arg_rhs) : (i16, i16) -> () 89 90 %rhs_next = arith.addi %rhs, %cst1 : i16 91 scf.yield %rhs_next : i16 92 } 93 94 %lhs_next = arith.addi %lhs, %cst1 : i16 95 scf.yield %lhs_next : i16 96 } 97 98 return 99} 100 101//===----------------------------------------------------------------------===// 102// Test arith.muli 103//===----------------------------------------------------------------------===// 104 105// Ops in this function will be emulated using i8 ops. 106func.func @emulate_muli(%lhs : i16, %rhs : i16) -> (i16) { 107 %res = arith.muli %lhs, %rhs : i16 108 return %res : i16 109} 110 111// Performs both wide and emulated `arith.muli`, and checks that the results 112// match. 113func.func @check_muli(%lhs : i16, %rhs : i16) -> () { 114 %wide = arith.muli %lhs, %rhs : i16 115 %emulated = func.call @emulate_muli(%lhs, %rhs) : (i16, i16) -> (i16) 116 func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> () 117 return 118} 119 120// Checks that `arith.muli` is emulated properly by sampling the input space. 121// In total, this test function checks 500 * 500 = 250k input pairs. 122func.func @test_muli() -> () { 123 %idx0 = arith.constant 0 : index 124 %idx1 = arith.constant 1 : index 125 %idx500 = arith.constant 500 : index 126 127 %cst0 = arith.constant 0 : i16 128 %cst1 = arith.constant 1 : i16 129 130 scf.for %lhs_idx = %idx0 to %idx500 step %idx1 iter_args(%lhs = %cst0) -> (i16) { 131 %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16) 132 133 scf.for %rhs_idx = %idx0 to %idx500 step %idx1 iter_args(%rhs = %cst0) -> (i16) { 134 %arg_rhs = func.call @xhash(%rhs) : (i16) -> (i16) 135 func.call @check_muli(%arg_lhs, %arg_rhs) : (i16, i16) -> () 136 137 %rhs_next = arith.addi %rhs, %cst1 : i16 138 scf.yield %rhs_next : i16 139 } 140 141 %lhs_next = arith.addi %lhs, %cst1 : i16 142 scf.yield %lhs_next : i16 143 } 144 145 return 146} 147 148//===----------------------------------------------------------------------===// 149// Test arith.shli 150//===----------------------------------------------------------------------===// 151 152// Ops in this function will be emulated using i8 ops. 153func.func @emulate_shli(%lhs : i16, %rhs : i16) -> (i16) { 154 %res = arith.shli %lhs, %rhs : i16 155 return %res : i16 156} 157 158// Performs both wide and emulated `arith.shli`, and checks that the results 159// match. 160func.func @check_shli(%lhs : i16, %rhs : i16) -> () { 161 %wide = arith.shli %lhs, %rhs : i16 162 %emulated = func.call @emulate_shli(%lhs, %rhs) : (i16, i16) -> (i16) 163 func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> () 164 return 165} 166 167// Checks that `arith.shli` is emulated properly by sampling the input space. 168// Checks all valid shift amounts for i16: 0 to 15. 169// In total, this test function checks 100 * 16 = 1.6k input pairs. 170func.func @test_shli() -> () { 171 %idx0 = arith.constant 0 : index 172 %idx1 = arith.constant 1 : index 173 %idx16 = arith.constant 16 : index 174 %idx100 = arith.constant 100 : index 175 176 %cst0 = arith.constant 0 : i16 177 %cst1 = arith.constant 1 : i16 178 179 scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) { 180 %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16) 181 182 scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) { 183 func.call @check_shli(%arg_lhs, %rhs) : (i16, i16) -> () 184 %rhs_next = arith.addi %rhs, %cst1 : i16 185 scf.yield %rhs_next : i16 186 } 187 188 %lhs_next = arith.addi %lhs, %cst1 : i16 189 scf.yield %lhs_next : i16 190 } 191 192 return 193} 194 195//===----------------------------------------------------------------------===// 196// Test arith.shrsi 197//===----------------------------------------------------------------------===// 198 199// Ops in this function will be emulated using i8 ops. 200func.func @emulate_shrsi(%lhs : i16, %rhs : i16) -> (i16) { 201 %res = arith.shrsi %lhs, %rhs : i16 202 return %res : i16 203} 204 205// Performs both wide and emulated `arith.shrsi`, and checks that the results 206// match. 207func.func @check_shrsi(%lhs : i16, %rhs : i16) -> () { 208 %wide = arith.shrsi %lhs, %rhs : i16 209 %emulated = func.call @emulate_shrsi(%lhs, %rhs) : (i16, i16) -> (i16) 210 func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> () 211 return 212} 213 214// Checks that `arith.shrus` is emulated properly by sampling the input space. 215// Checks all valid shift amounts for i16: 0 to 15. 216// In total, this test function checks 100 * 16 = 1.6k input pairs. 217func.func @test_shrsi() -> () { 218 %idx0 = arith.constant 0 : index 219 %idx1 = arith.constant 1 : index 220 %idx16 = arith.constant 16 : index 221 %idx100 = arith.constant 100 : index 222 223 %cst0 = arith.constant 0 : i16 224 %cst1 = arith.constant 1 : i16 225 226 scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) { 227 %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16) 228 229 scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) { 230 func.call @check_shrsi(%arg_lhs, %rhs) : (i16, i16) -> () 231 %rhs_next = arith.addi %rhs, %cst1 : i16 232 scf.yield %rhs_next : i16 233 } 234 235 %lhs_next = arith.addi %lhs, %cst1 : i16 236 scf.yield %lhs_next : i16 237 } 238 239 return 240} 241 242//===----------------------------------------------------------------------===// 243// Test arith.shrui 244//===----------------------------------------------------------------------===// 245 246// Ops in this function will be emulated using i8 ops. 247func.func @emulate_shrui(%lhs : i16, %rhs : i16) -> (i16) { 248 %res = arith.shrui %lhs, %rhs : i16 249 return %res : i16 250} 251 252// Performs both wide and emulated `arith.shrui`, and checks that the results 253// match. 254func.func @check_shrui(%lhs : i16, %rhs : i16) -> () { 255 %wide = arith.shrui %lhs, %rhs : i16 256 %emulated = func.call @emulate_shrui(%lhs, %rhs) : (i16, i16) -> (i16) 257 func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> () 258 return 259} 260 261// Checks that `arith.shrui` is emulated properly by sampling the input space. 262// Checks all valid shift amounts for i16: 0 to 15. 263// In total, this test function checks 100 * 16 = 1.6k input pairs. 264func.func @test_shrui() -> () { 265 %idx0 = arith.constant 0 : index 266 %idx1 = arith.constant 1 : index 267 %idx16 = arith.constant 16 : index 268 %idx100 = arith.constant 100 : index 269 270 %cst0 = arith.constant 0 : i16 271 %cst1 = arith.constant 1 : i16 272 273 scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) { 274 %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16) 275 276 scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) { 277 func.call @check_shrui(%arg_lhs, %rhs) : (i16, i16) -> () 278 %rhs_next = arith.addi %rhs, %cst1 : i16 279 scf.yield %rhs_next : i16 280 } 281 282 %lhs_next = arith.addi %lhs, %cst1 : i16 283 scf.yield %lhs_next : i16 284 } 285 286 return 287} 288 289//===----------------------------------------------------------------------===// 290// Entry Point 291//===----------------------------------------------------------------------===// 292 293func.func @entry() { 294 func.call @test_addi() : () -> () 295 func.call @test_muli() : () -> () 296 func.call @test_shli() : () -> () 297 func.call @test_shrsi() : () -> () 298 func.call @test_shrui() : () -> () 299 return 300} 301