xref: /llvm-project/mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir (revision 13bd41096286305ee603428f6adf161f52981827)
1// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-parallel-loop-tiling{parallel-loop-tile-sizes=1,4 no-min-max-bounds=true}))' -split-input-file | FileCheck %s
2
3func.func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
4                    %arg3 : index, %arg4 : index, %arg5 : index,
5		    %A: memref<?x?xf32>, %B: memref<?x?xf32>,
6                    %C: memref<?x?xf32>, %result: memref<?x?xf32>) {
7  scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
8    %B_elem = memref.load %B[%i0, %i1] : memref<?x?xf32>
9    %C_elem = memref.load %C[%i0, %i1] : memref<?x?xf32>
10    %sum_elem = arith.addf %B_elem, %C_elem : f32
11    memref.store %sum_elem, %result[%i0, %i1] : memref<?x?xf32>
12  }
13  return
14}
15
16// CHECK-LABEL:   func @parallel_loop(
17// CHECK-SAME:                        [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index, [[ARG7:%.*]]: memref<?x?xf32>, [[ARG8:%.*]]: memref<?x?xf32>, [[ARG9:%.*]]: memref<?x?xf32>, [[ARG10:%.*]]: memref<?x?xf32>) {
18// CHECK:           [[C0:%.*]] = arith.constant 0 : index
19// CHECK:           [[C1:%.*]] = arith.constant 1 : index
20// CHECK:           [[C4:%.*]] = arith.constant 4 : index
21// CHECK:           [[V1:%.*]] = arith.muli [[ARG5]], [[C1]] : index
22// CHECK:           [[V2:%.*]] = arith.muli [[ARG6]], [[C4]] : index
23// CHECK:           scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[ARG1]], [[ARG2]]) to ([[ARG3]], [[ARG4]]) step ([[V1]], [[V2]]) {
24// CHECK:             scf.parallel ([[V7:%.*]], [[V8:%.*]]) = ([[C0]], [[C0]]) to ([[V1]], [[V2]]) step ([[ARG5]], [[ARG6]]) {
25// CHECK:               [[V9:%.*]] = arith.addi [[V7]], [[V3]] : index
26// CHECK:               [[V10:%.*]] = arith.addi [[V8]], [[V4]] : index
27// CHECK:               %true = arith.constant true
28// CHECK:               [[V11:%.*]] = arith.muli [[V7]], [[ARG5]] : index
29// CHECK:               [[V12:%.*]] = arith.addi [[V11]], [[V3]] : index
30// CHECK:               [[V13:%.*]] = arith.cmpi ult, [[V12]], [[ARG3]] : index
31// CHECK:               [[V14:%.*]] = arith.andi %true, [[V13]] : i1
32// CHECK:               [[V15:%.*]] = arith.muli [[V8]], [[ARG6]] : index
33// CHECK:               [[V16:%.*]] = arith.addi [[V15]], [[V4]] : index
34// CHECK:               [[V17:%.*]] = arith.cmpi ult, [[V16]], [[ARG4]] : index
35// CHECK:               [[V18:%.*]] = arith.andi [[V14]], [[V17]] : i1
36// CHECK:               scf.if [[V18]] {
37// CHECK:                 [[V19:%.*]] = memref.load [[ARG8]]{{\[}}[[V9]], [[V10]]] : memref<?x?xf32>
38// CHECK:                 [[V20:%.*]] = memref.load [[ARG9]]{{\[}}[[V9]], [[V10]]] : memref<?x?xf32>
39// CHECK:                 [[V21:%.*]] = arith.addf [[V19]], [[V20]] : f32
40// CHECK:                 memref.store [[V21]], [[ARG10]]{{\[}}[[V9]], [[V10]]] : memref<?x?xf32>
41// CHECK:               }
42// CHECK:             }
43// CHECK:           }
44// CHECK:           return
45
46// -----
47
48func.func @static_loop_with_step() {
49  %c0 = arith.constant 0 : index
50  %c3 = arith.constant 3 : index
51  %c22 = arith.constant 22 : index
52  %c24 = arith.constant 24 : index
53  scf.parallel (%i0, %i1) = (%c0, %c0) to (%c22, %c24) step (%c3, %c3) {
54  }
55  return
56}
57
58// CHECK-LABEL:   func @static_loop_with_step() {
59// CHECK:           [[C0:%.*]] = arith.constant 0 : index
60// CHECK:           [[C3:%.*]] = arith.constant 3 : index
61// CHECK:           [[C22:%.*]] = arith.constant 22 : index
62// CHECK:           [[C24:%.*]] = arith.constant 24 : index
63// CHECK:           [[C0_1:%.*]] = arith.constant 0 : index
64// CHECK:           [[C1:%.*]] = arith.constant 1 : index
65// CHECK:           [[C4:%.*]] = arith.constant 4 : index
66// CHECK:           [[V1:%.*]] = arith.muli [[C3]], [[C1]] : index
67// CHECK:           [[V2:%.*]] = arith.muli [[C3]], [[C4]] : index
68// CHECK:           scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[C0]], [[C0]]) to ([[C22]], [[C24]]) step ([[V1]], [[V2]]) {
69// CHECK:             scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V1]], [[V2]]) step ([[C3]], [[C3]]) {
70// CHECK-NOT:           scf.if
71// CHECK:               = arith.addi [[V5]], [[V3]] : index
72// CHECK:               = arith.addi [[V6]], [[V4]] : index
73// CHECK:             }
74// CHECK:           }
75// CHECK:           return
76
77// -----
78
79func.func @tile_nested_innermost() {
80  %c2 = arith.constant 2 : index
81  %c0 = arith.constant 0 : index
82  %c1 = arith.constant 1 : index
83  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
84    scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
85    }
86  }
87  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
88  }
89  return
90}
91
92// CHECK-LABEL:   func @tile_nested_innermost() {
93// CHECK:           [[C2:%.*]] = arith.constant 2 : index
94// CHECK:           [[C0:%.*]] = arith.constant 0 : index
95// CHECK:           [[C1:%.*]] = arith.constant 1 : index
96// CHECK:           scf.parallel ([[V1:%.*]], [[V2:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
97// CHECK:             [[C0_1:%.*]] = arith.constant 0 : index
98// CHECK:             [[C1_1:%.*]] = arith.constant 1 : index
99// CHECK:             [[C4:%.*]] = arith.constant 4 : index
100// CHECK:             [[V3:%.*]] = arith.muli [[C1]], [[C1_1]] : index
101// CHECK:             [[V4:%.*]] = arith.muli [[C1]], [[C4]] : index
102// CHECK:             scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V3]], [[V4]]) {
103// CHECK:               scf.parallel ([[V8:%.*]], [[V9:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V3]], [[V4]]) step ([[C1]], [[C1]]) {
104// CHECK:                 = arith.addi [[V8]], [[V5]] : index
105// CHECK:                 = arith.addi [[V9]], [[V6]] : index
106// CHECK:                 scf.if
107// CHECK:               }
108// CHECK:             }
109// CHECK:           }
110// CHECK:           [[C0_2:%.*]] = arith.constant 0 : index
111// CHECK:           [[C1_2:%.*]] = arith.constant 1 : index
112// CHECK:           [[C4_1:%.*]] = arith.constant 4 : index
113// CHECK:           [[V10:%.*]] = arith.muli [[C1]], [[C1_2]] : index
114// CHECK:           [[V11:%.*]] = arith.muli [[C1]], [[C4_1]] : index
115// CHECK:           scf.parallel ([[V12:%.*]], [[V13:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V10]], [[V11]]) {
116// CHECK:             scf.parallel ([[V15:%.*]], [[V16:%.*]]) = ([[C0_2]], [[C0_2]]) to ([[V10]], [[V11]]) step ([[C1]], [[C1]]) {
117// CHECK:               = arith.addi [[V15]], [[V12]] : index
118// CHECK:               = arith.addi [[V16]], [[V13]] : index
119// CHECK:               scf.if
120// CHECK:             }
121// CHECK:           }
122// CHECK:           return
123// CHECK:         }
124
125// -----
126
127func.func @tile_nested_in_non_ploop() {
128  %c0 = arith.constant 0 : index
129  %c1 = arith.constant 1 : index
130  %c2 = arith.constant 2 : index
131  scf.for %i = %c0 to %c2 step %c1 {
132    scf.for %j = %c0 to %c2 step %c1 {
133      scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
134      }
135    }
136  }
137  return
138}
139
140// CHECK-LABEL: func @tile_nested_in_non_ploop
141// CHECK:         scf.for
142// CHECK:           scf.for
143// CHECK:             scf.parallel
144// CHECK:               scf.parallel
145// CHECK:               }
146// CHECK:             }
147// CHECK:           }
148// CHECK:         }
149// CHECK:       }
150