xref: /llvm-project/mlir/test/Dialect/SCF/bufferize.mlir (revision ced2fc7819d5ddea616ec330f18e08ff284c1868)
1// RUN: mlir-opt %s -one-shot-bufferize="dialect-filter=scf,bufferization copy-before-write unknown-type-conversion=identity-layout-map" -split-input-file | FileCheck %s
2
3// CHECK-LABEL:   func @if(
4// CHECK-SAME:             %[[PRED:.*]]: i1,
5// CHECK-SAME:             %[[TRUE_TENSOR:.*]]: tensor<?xf32>,
6// CHECK-SAME:             %[[FALSE_TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
7// CHECK-DAG:       %[[TRUE_MEMREF:.*]] = bufferization.to_memref %[[TRUE_TENSOR]] : tensor<?xf32> to memref<?xf32>
8// CHECK-DAG:       %[[FALSE_MEMREF:.*]] = bufferization.to_memref %[[FALSE_TENSOR]] : tensor<?xf32> to memref<?xf32>
9// CHECK:           %[[RESULT_MEMREF:.*]] = scf.if %[[PRED]] -> (memref<?xf32>) {
10// CHECK:             scf.yield %[[TRUE_MEMREF]] : memref<?xf32>
11// CHECK:           } else {
12// CHECK:             scf.yield %[[FALSE_MEMREF]] : memref<?xf32>
13// CHECK:           }
14// CHECK:           %[[RESULT_TENSOR:.*]] = bufferization.to_tensor %[[RESULT_MEMREF:.*]] : memref<?xf32>
15// CHECK:           return %[[RESULT_TENSOR]] : tensor<?xf32>
16// CHECK:         }
17func.func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
18  %0 = scf.if %pred -> (tensor<?xf32>) {
19    scf.yield %true_val : tensor<?xf32>
20  } else {
21    scf.yield %false_val : tensor<?xf32>
22  }
23  return %0 : tensor<?xf32>
24}
25
26// -----
27
28// CHECK-LABEL:   func @for(
29// CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
30// CHECK-SAME:              %[[LB:.*]]: index, %[[UB:.*]]: index,
31// CHECK-SAME:              %[[STEP:.*]]: index) -> tensor<f32> {
32// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<f32> to memref<f32>
33// Note: scf.for iter_args always bufferize to a memory write. This could be
34// optimized by analyzing the loop body.
35// CHECK:           %[[MEMREF_COPY:.*]] = memref.alloc()
36// CHECK:           memref.copy %[[MEMREF]], %[[MEMREF_COPY]]
37// CHECK:           %[[RESULT_MEMREF:.*]] = scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF_COPY]]) -> (memref<f32>) {
38// CHECK:             scf.yield %[[ITER]] : memref<f32>
39// CHECK:           } {some_attr}
40// CHECK:           %[[VAL_8:.*]] = bufferization.to_tensor %[[RESULT_MEMREF]] : memref<f32>
41// CHECK:           return %[[VAL_8]] : tensor<f32>
42// CHECK:         }
43func.func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f32> {
44  %ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor<f32> {
45    scf.yield %iter : tensor<f32>
46  } {some_attr}
47  return %ret : tensor<f32>
48}
49
50// -----
51
52// Check whether this converts at all.
53//
54// It would previously fail altogether.
55// CHECK-LABEL:   func @if_correct_recursive_legalization_behavior
56// CHECK: "test.munge_tensor"
57func.func @if_correct_recursive_legalization_behavior(%pred: i1, %tensor: tensor<f32>) -> tensor<f32> {
58  %0 = scf.if %pred -> (tensor<f32>) {
59    %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>)
60    scf.yield %1: tensor<f32>
61  } else {
62    %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>)
63    scf.yield %1 : tensor<f32>
64  }
65  return %0 : tensor<f32>
66}
67
68// -----
69
70// CHECK-LABEL:   func @for_correct_recursive_legalization_behavior(
71// CHECK-SAME:                                                      %[[TENSOR:.*]]: tensor<f32>,
72// CHECK-SAME:                                                      %[[INDEX:.*]]: index) -> tensor<f32> {
73// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<f32> to memref<f32>
74// Note: scf.for iter_args always bufferize to a memory write. This could be
75// optimized by analyzing the loop body.
76// CHECK:           %[[MEMREF_COPY:.*]] = memref.alloc()
77// CHECK:           memref.copy %[[MEMREF]], %[[MEMREF_COPY]]
78// CHECK:           %[[RESULT:.*]] = scf.for %{{.*}} = %[[INDEX]] to %[[INDEX]] step %[[INDEX]] iter_args(%[[MEMREF_ITER:.*]] = %[[MEMREF_COPY]]) -> (memref<f32>) {
79// CHECK:             %[[TENSOR_ITER:.*]] = bufferization.to_tensor %[[MEMREF_ITER]] : memref<f32>
80// CHECK:             %[[TENSOR_MUNGED:.*]] = "test.munge_tensor"(%[[TENSOR_ITER]]) : (tensor<f32>) -> tensor<f32>
81// CHECK:             %[[MEMREF_MUNGED:.*]] = bufferization.to_memref %[[TENSOR_MUNGED]] : tensor<f32> to memref<f32>
82// CHECK:             scf.yield %[[MEMREF_MUNGED]] : memref<f32>
83// CHECK:           }
84// CHECK:           %[[TENSOR:.*]] = bufferization.to_tensor %[[RESULT]] : memref<f32>
85// CHECK:           return %[[TENSOR]] : tensor<f32>
86// CHECK:         }
87func.func @for_correct_recursive_legalization_behavior(%arg0: tensor<f32>, %index: index) -> tensor<f32> {
88  %ret = scf.for %iv = %index to %index step %index iter_args(%iter = %arg0) -> tensor<f32> {
89    %0 = "test.munge_tensor"(%iter) : (tensor<f32>) -> (tensor<f32>)
90    scf.yield %0 : tensor<f32>
91  }
92  return %ret : tensor<f32>
93}
94
95// -----
96
97// CHECK-LABEL:   func @bufferize_while(
98// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64, %[[ARG2:.*]]: tensor<f32>
99// CHECK: %[[M:.*]] = bufferization.to_memref %[[ARG2]] : tensor<f32> to memref<f32>
100// Note: scf.while iter_args always bufferize to a memory write. This could be
101// optimized by analyzing the loop body.
102// CHECK:           %[[MEMREF_COPY:.*]] = memref.alloc()
103// CHECK:           memref.copy %[[M]], %[[MEMREF_COPY]]
104// CHECK: %[[RES1:.*]]:3 = scf.while (%{{.*}} = %[[ARG0]], %[[ITER:.*]] = %[[MEMREF_COPY]]) : (i64, memref<f32>) -> (i64, i64, memref<f32>)
105// CHECK: scf.condition(%{{.*}}) %{{.*}}, %{{.*}}, %[[ITER]] : i64, i64, memref<f32>
106// CHECK: ^bb0(%{{.*}}: i64, %{{.*}}: i64, %{{.*}}: memref<f32>):
107// CHECK: scf.yield %{{.*}}, %{{.*}} : i64, memref<f32>
108// CHECK:  %[[RES2:.*]] = bufferization.to_tensor %[[RES1]]#2 : memref<f32>
109// CHECK:  return %[[RES1]]#1, %[[RES2]] : i64, tensor<f32>
110func.func @bufferize_while(%arg0: i64, %arg1: i64, %arg2: tensor<f32>) -> (i64, tensor<f32>) {
111  %c2_i64 = arith.constant 2 : i64
112  %0:3 = scf.while (%arg3 = %arg0, %arg4 = %arg2) : (i64, tensor<f32>) -> (i64, i64, tensor<f32>) {
113    %1 = arith.cmpi slt, %arg3, %arg1 : i64
114    scf.condition(%1) %arg3, %arg3, %arg4 : i64, i64, tensor<f32>
115  } do {
116  ^bb0(%arg5: i64, %arg6: i64, %arg7: tensor<f32>):
117    %1 = arith.muli %arg6, %c2_i64 : i64
118    scf.yield %1, %arg7 : i64, tensor<f32>
119  }
120  return %0#1, %0#2 : i64, tensor<f32>
121}
122