xref: /llvm-project/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir (revision 6bf043e7433680c6f4e36393734ef83699b30f14)
1// RUN: mlir-opt %s -test-tensor-copy-insertion=allow-return-allocs-from-loops -allow-unregistered-dialect -split-input-file | FileCheck %s
2// RUN: mlir-opt %s -test-tensor-copy-insertion="allow-return-allocs-from-loops bufferize-function-boundaries" -split-input-file | FileCheck %s --check-prefix=CHECK-FUNC
3
4// CHECK-LABEL: func @scf_for(
5//  CHECK-SAME:     %[[A:.*]]: tensor<?xf32>, %[[B:.*]]: tensor<?xf32>
6func.func @scf_for(%A : tensor<?xf32>, %B : tensor<?xf32>,
7                   %lb : index, %ub : index, %step : index)
8  -> (tensor<?xf32>, tensor<?xf32>)
9{
10  // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) : tensor<?xf32>
11  // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) : tensor<?xf32>
12  // CHECK:   %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[iter1:.*]] = %[[A_copy]], %[[iter2:.*]] = %[[B_copy]])
13  %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B)
14      -> (tensor<?xf32>, tensor<?xf32>)
15  {
16    // CHECK: scf.yield %[[iter1]], %[[iter2]]
17    scf.yield %tA, %tB : tensor<?xf32>, tensor<?xf32>
18  }
19
20  return %r0#0, %r0#1 : tensor<?xf32>, tensor<?xf32>
21}
22
23// -----
24
25// CHECK-LABEL: func @scf_for_swapping_yields(
26//  CHECK-SAME:     %[[A:.*]]: tensor<?xf32>, %[[B:.*]]: tensor<?xf32>
27func.func @scf_for_swapping_yields(%A : tensor<?xf32>, %B : tensor<?xf32>,
28                                   %lb : index, %ub : index, %step : index)
29  -> (tensor<?xf32>, tensor<?xf32>)
30{
31  // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) : tensor<?xf32>
32  // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) : tensor<?xf32>
33  // CHECK:   %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[iter1:.*]] = %[[A_copy]], %[[iter2:.*]] = %[[B_copy]])
34  %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B)
35      -> (tensor<?xf32>, tensor<?xf32>)
36  {
37    // Yield tensors in different order.
38    // CHECK-DAG: %[[yield1:.*]] = bufferization.alloc_tensor() copy(%[[iter2]]) : tensor<?xf32>
39    // CHECK-DAG: %[[yield2:.*]] = bufferization.alloc_tensor() copy(%[[iter1]]) : tensor<?xf32>
40    // CHECK: scf.yield %[[yield1]], %[[yield2]]
41    scf.yield %tB, %tA : tensor<?xf32>, tensor<?xf32>
42  }
43
44  return %r0#0, %r0#1 : tensor<?xf32>, tensor<?xf32>
45}
46
47// -----
48
49// CHECK-LABEL: func @scf_while(
50//  CHECK-SAME:     %[[A:.*]]: tensor<5xi1>, %[[B:.*]]: tensor<5xi1>
51func.func @scf_while(%A: tensor<5xi1>, %B: tensor<5xi1>, %idx: index)
52  -> (tensor<5xi1>, tensor<5xi1>)
53{
54  // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) : tensor<5xi1>
55  // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) : tensor<5xi1>
56  // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[A_copy]], %[[w1:.*]] = %[[B_copy]]) {{.*}} {
57  %r0, %r1 = scf.while (%w0 = %A, %w1 = %B)
58      : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
59    // CHECK: %[[condition:.*]] = tensor.extract %[[w0]]
60    %condition = tensor.extract %w0[%idx] : tensor<5xi1>
61    // Yield tensors in different order.
62    // CHECK: scf.condition(%[[condition]]) %[[w0]], %[[w1]]
63    scf.condition(%condition) %w0, %w1 : tensor<5xi1>, tensor<5xi1>
64  } do {
65  ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
66    // CHECK: } do {
67    // CHECK: ^bb0(%[[b0:.*]]: tensor<5xi1>, %[[b1:.*]]: tensor<5xi1>):
68    // CHECK: scf.yield %[[b0]], %[[b1]]
69    // CHECK: }
70    scf.yield %b0, %b1 : tensor<5xi1>, tensor<5xi1>
71  }
72
73  return %r0, %r1 : tensor<5xi1>, tensor<5xi1>
74}
75
76// -----
77
78// CHECK-LABEL: func @scf_while_non_equiv_condition_and_body(
79//  CHECK-SAME:     %[[A:.*]]: tensor<5xi1>, %[[B:.*]]: tensor<5xi1>
80func.func @scf_while_non_equiv_condition_and_body(%A: tensor<5xi1>,
81                                                  %B: tensor<5xi1>,
82                                                  %idx: index)
83  -> (tensor<5xi1>, tensor<5xi1>)
84{
85  // CHECK: %[[A_copy:.*]] = bufferization.alloc_tensor() copy(%[[A]]) : tensor<5xi1>
86  // CHECK: %[[B_copy:.*]] = bufferization.alloc_tensor() copy(%[[B]]) : tensor<5xi1>
87  // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[A_copy]], %[[w1:.*]] = %[[B_copy]]) {{.*}} {
88  %r0, %r1 = scf.while (%w0 = %A, %w1 = %B)
89      : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
90    // CHECK: %[[condition:.*]] = tensor.extract %[[w0]]
91    %condition = tensor.extract %w0[%idx] : tensor<5xi1>
92    // Yield tensors in different order.
93    // CHECK-DAG: %[[yield0:.*]] = bufferization.alloc_tensor() copy(%[[w1]]) : tensor<5xi1>
94    // CHECK-DAG: %[[yield1:.*]] = bufferization.alloc_tensor() copy(%[[w0]]) : tensor<5xi1>
95    // CHECK: scf.condition(%[[condition]]) %[[yield0]], %[[yield1]]
96    scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
97  } do {
98  ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
99    // CHECK: } do {
100    // CHECK: ^bb0(%[[b0:.*]]: tensor<5xi1>, %[[b1:.*]]: tensor<5xi1>):
101    // CHECK: scf.yield %[[b1]], %[[b0]]
102    // CHECK: }
103    scf.yield %b1, %b0 : tensor<5xi1>, tensor<5xi1>
104  }
105
106  return %r0, %r1 : tensor<5xi1>, tensor<5xi1>
107}
108
109// -----
110
111// CHECK-LABEL: func @scf_forall_out_of_place(
112//  CHECK-SAME:     %[[arg0:.*]]: tensor<100xf32>, %[[arg1:.*]]: tensor<100xf32>
113// CHECK-FUNC-LABEL: func @scf_forall_out_of_place(
114func.func @scf_forall_out_of_place(%in: tensor<100xf32>,
115                                           %out: tensor<100xf32>) {
116  %c1 = arith.constant 1 : index
117  %num_threads = arith.constant 100 : index
118
119  // CHECK-FUNC-NOT: alloc_tensor
120  // CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[arg1]]) : tensor<100xf32>
121  // CHECK: scf.forall {{.*}} shared_outs(%[[o:.*]] = %[[alloc]])
122  %result = scf.forall (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> tensor<100xf32> {
123      // CHECK: tensor.extract_slice
124      // CHECK: scf.forall.in_parallel
125      // CHECK: tensor.parallel_insert_slice %{{.*}} into %[[o]]
126      %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
127      scf.forall.in_parallel {
128        tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
129          tensor<1xf32> into tensor<100xf32>
130      }
131  // CHECK: } {mapping = [#gpu.thread<x>]}
132  } {mapping = [#gpu.thread<x>]}
133  return
134}
135