xref: /llvm-project/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir (revision 9e8200c7184431e0dd0e235b70cabfbe8bfe351d)
1// RUN: mlir-opt %s -affine-expand-index-ops -split-input-file | FileCheck %s
2
3// CHECK-LABEL: @delinearize_static_basis
4//  CHECK-SAME:    (%[[IDX:.+]]: index)
5//   CHECK-DAG:   %[[C224:.+]] = arith.constant 224 : index
6//   CHECK-DAG:   %[[C50176:.+]] = arith.constant 50176 : index
7//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
8//       CHECK:   %[[N:.+]] = arith.floordivsi %[[IDX]], %[[C50176]]
9//   CHECK-DAG:   %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[C50176]]
10//   CHECK-DAG:   %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]]
11//   CHECK-DAG:   %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[C50176]]
12//   CHECK-DAG:   %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]]
13//       CHECK:   %[[P:.+]] = arith.divsi %[[P_MOD]], %[[C224]]
14//   CHECK-DAG:   %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[C224]]
15//   CHECK-DAG:   %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]]
16//   CHECK-DAG:   %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[C224]]
17//       CHECK:   %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]]
18//       CHECK:   return %[[N]], %[[P]], %[[Q]]
19func.func @delinearize_static_basis(%linear_index: index) -> (index, index, index) {
20  %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
21  return %1#0, %1#1, %1#2 : index, index, index
22}
23
24// -----
25
26// CHECK-LABEL: @delinearize_dynamic_basis
27//  CHECK-SAME:    (%[[IDX:.+]]: index, %[[MEMREF:.+]]: memref
28//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
29//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
30//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
31//       CHECK:  %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] :
32//       CHECK:  %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] :
33//       CHECK:  %[[STRIDE1:.+]] = arith.muli %[[DIM2]], %[[DIM1]]
34//       CHECK:  %[[N:.+]] = arith.floordivsi %[[IDX]], %[[STRIDE1]]
35//   CHECK-DAG:  %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[STRIDE1]]
36//   CHECK-DAG:  %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]]
37//   CHECK-DAG:  %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[STRIDE1]]
38//   CHECK-DAG:  %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]]
39//       CHECK:  %[[P:.+]] = arith.divsi %[[P_MOD]], %[[DIM2]]
40//   CHECK-DAG:  %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[DIM2]]
41//   CHECK-DAG:  %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]]
42//   CHECK-DAG:  %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[DIM2]]
43//       CHECK:  %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]]
44//       CHECK:   return %[[N]], %[[P]], %[[Q]]
45func.func @delinearize_dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (index, index, index) {
46  %c1 = arith.constant 1 : index
47  %c2 = arith.constant 2 : index
48  %b1 = memref.dim %src, %c1 : memref<?x?x?xf32>
49  %b2 = memref.dim %src, %c2 : memref<?x?x?xf32>
50  // Note: no outer bound.
51  %1:3 = affine.delinearize_index %linear_index into (%b1, %b2) : index, index, index
52  return %1#0, %1#1, %1#2 : index, index, index
53}
54
55// -----
56
57// CHECK-LABEL: @linearize_static
58// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
59// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
60// CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index
61// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[C15]]
62// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[C5]]
63// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]]
64// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]]
65// CHECK: return %[[val_1]]
66func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
67  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
68  func.return %0 : index
69}
70
71// -----
72
73// CHECK-LABEL: @linearize_dynamic
74// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index)
75// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]]
76// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]]
77// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]]
78// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]]
79// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]]
80// CHECK: return %[[val_1]]
81func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index {
82  // Note: no outer bounds
83  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
84  func.return %0 : index
85}
86
87// -----
88
89// CHECK-LABEL: @linearize_sort_adds
90// CHECK-SAME: (%[[arg0:.+]]: memref<?xi32>, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
91// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
92// CHECK: scf.for %[[arg3:.+]] = %{{.*}} to %[[arg2]] step %{{.*}} {
93// CHECK: scf.for %[[arg4:.+]] = %{{.*}} to %[[C4]] step %{{.*}} {
94// CHECK: %[[stride_0:.+]] = arith.muli %[[arg2]], %[[C4]]
95// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg1]], %[[stride_0]]
96// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg4]], %[[arg2]]
97// Note: even though %arg3 has a lower stride, we add it first
98// CHECK: %[[val_0_2:.+]] = arith.addi %[[scaled_0]], %[[arg3]]
99// CHECK: %[[val_1:.+]] = arith.addi %[[val_0_2]], %[[scaled_1]]
100// CHECK: memref.store %{{.*}}, %[[arg0]][%[[val_1]]]
101func.func @linearize_sort_adds(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
102  %c0 = arith.constant 0 : index
103  %c1 = arith.constant 1 : index
104  %c4 = arith.constant 4 : index
105  %c0_i32 = arith.constant 0 : i32
106  scf.for %arg3 = %c0 to %arg2 step %c1 {
107    scf.for %arg4 = %c0 to %c4 step %c1 {
108      %idx = affine.linearize_index disjoint [%arg1, %arg4, %arg3] by (4, %arg2) : index
109      memref.store %c0_i32, %arg0[%idx] : memref<?xi32>
110    }
111  }
112  return
113}
114