xref: /llvm-project/mlir/test/Dialect/GPU/shuffle-rewrite.mlir (revision 330a232ae76139c3970df5ccaf1b51640cbd4d66)
1// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s
2
3module {
4  // CHECK-LABEL: func.func @shuffleF64
5  // CHECK-SAME: (%[[SZ:.*]]: index, %[[VALUE:.*]]: f64, %[[OFF:.*]]: i32, %[[WIDTH:.*]]: i32, %[[MEM:.*]]: memref<f64, 1>) {
6  func.func @shuffleF64(%sz : index, %value: f64, %offset: i32, %width: i32, %mem: memref<f64, 1>) {
7    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
8               threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
9      // CHECK: %[[INTVAL:.*]] = arith.bitcast %[[VALUE]] : f64 to i64
10      // CHECK-NEXT: %[[LO:.*]] = arith.trunci %[[INTVAL]] : i64 to i32
11      // CHECK-NEXT: %[[HI64:.*]] = arith.shrui %[[INTVAL]], %[[C32:.*]] : i64
12      // CHECK-NEXT: %[[HI:.*]] = arith.trunci %[[HI64]] : i64 to i32
13      // CHECK-NEXT: %[[SH1:.*]], %[[V1:.*]] = gpu.shuffle  xor %[[LO]], %[[OFF]], %[[WIDTH]] : i32
14      // CHECK-NEXT: %[[SH2:.*]], %[[V2:.*]] = gpu.shuffle  xor %[[HI]], %[[OFF]], %[[WIDTH]] : i32
15      // CHECK-NEXT: %[[LOSH:.*]] = arith.extui %[[SH1]] : i32 to i64
16      // CHECK-NEXT: %[[HISHTMP:.*]] = arith.extui %[[SH2]] : i32 to i64
17      // CHECK-NEXT: %[[HISH:.*]] = arith.shli %[[HISHTMP]], %[[C32]] : i64
18      // CHECK-NEXT: %[[SHFLINT:.*]] = arith.ori %[[HISH]], %[[LOSH]] : i64
19      // CHECK-NEXT:  = arith.bitcast %[[SHFLINT]] : i64 to f64
20      %shfl, %pred = gpu.shuffle xor %value, %offset, %width : f64
21      memref.store %shfl, %mem[]  : memref<f64, 1>
22      gpu.terminator
23    }
24    return
25  }
26}
27
28// -----
29
30module {
31  // CHECK-LABEL: func.func @shuffleI64
32  // CHECK-SAME: (%[[SZ:.*]]: index, %[[VALUE:.*]]: i64, %[[OFF:.*]]: i32, %[[WIDTH:.*]]: i32, %[[MEM:.*]]: memref<i64, 1>) {
33  func.func @shuffleI64(%sz : index, %value: i64, %offset: i32, %width: i32, %mem: memref<i64, 1>) {
34    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
35               threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
36      // CHECK: %[[LO:.*]] = arith.trunci %[[VALUE]] : i64 to i32
37      // CHECK-NEXT: %[[HI64:.*]] = arith.shrui %[[VALUE]], %[[C32:.*]] : i64
38      // CHECK-NEXT: %[[HI:.*]] = arith.trunci %[[HI64]] : i64 to i32
39      // CHECK-NEXT: %[[SH1:.*]], %[[V1:.*]] = gpu.shuffle  xor %[[LO]], %[[OFF]], %[[WIDTH]] : i32
40      // CHECK-NEXT: %[[SH2:.*]], %[[V2:.*]] = gpu.shuffle  xor %[[HI]], %[[OFF]], %[[WIDTH]] : i32
41      // CHECK-NEXT: %[[LOSH:.*]] = arith.extui %[[SH1]] : i32 to i64
42      // CHECK-NEXT: %[[HISHTMP:.*]] = arith.extui %[[SH2]] : i32 to i64
43      // CHECK-NEXT: %[[HISH:.*]] = arith.shli %[[HISHTMP]], %[[C32]] : i64
44      // CHECK-NEXT: %[[SHFLINT:.*]] = arith.ori %[[HISH]], %[[LOSH]] : i64
45      %shfl, %pred = gpu.shuffle xor %value, %offset, %width : i64
46      memref.store %shfl, %mem[]  : memref<i64, 1>
47      gpu.terminator
48    }
49    return
50  }
51}
52