1// RUN: mlir-opt %s -one-shot-bufferize="dialect-filter=scf,bufferization copy-before-write unknown-type-conversion=identity-layout-map" -split-input-file | FileCheck %s 2 3// CHECK-LABEL: func @if( 4// CHECK-SAME: %[[PRED:.*]]: i1, 5// CHECK-SAME: %[[TRUE_TENSOR:.*]]: tensor<?xf32>, 6// CHECK-SAME: %[[FALSE_TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> { 7// CHECK-DAG: %[[TRUE_MEMREF:.*]] = bufferization.to_memref %[[TRUE_TENSOR]] : tensor<?xf32> to memref<?xf32> 8// CHECK-DAG: %[[FALSE_MEMREF:.*]] = bufferization.to_memref %[[FALSE_TENSOR]] : tensor<?xf32> to memref<?xf32> 9// CHECK: %[[RESULT_MEMREF:.*]] = scf.if %[[PRED]] -> (memref<?xf32>) { 10// CHECK: scf.yield %[[TRUE_MEMREF]] : memref<?xf32> 11// CHECK: } else { 12// CHECK: scf.yield %[[FALSE_MEMREF]] : memref<?xf32> 13// CHECK: } 14// CHECK: %[[RESULT_TENSOR:.*]] = bufferization.to_tensor %[[RESULT_MEMREF:.*]] : memref<?xf32> 15// CHECK: return %[[RESULT_TENSOR]] : tensor<?xf32> 16// CHECK: } 17func.func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> { 18 %0 = scf.if %pred -> (tensor<?xf32>) { 19 scf.yield %true_val : tensor<?xf32> 20 } else { 21 scf.yield %false_val : tensor<?xf32> 22 } 23 return %0 : tensor<?xf32> 24} 25 26// ----- 27 28// CHECK-LABEL: func @for( 29// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>, 30// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, 31// CHECK-SAME: %[[STEP:.*]]: index) -> tensor<f32> { 32// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<f32> to memref<f32> 33// Note: scf.for iter_args always bufferize to a memory write. This could be 34// optimized by analyzing the loop body. 35// CHECK: %[[MEMREF_COPY:.*]] = memref.alloc() 36// CHECK: memref.copy %[[MEMREF]], %[[MEMREF_COPY]] 37// CHECK: %[[RESULT_MEMREF:.*]] = scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF_COPY]]) -> (memref<f32>) { 38// CHECK: scf.yield %[[ITER]] : memref<f32> 39// CHECK: } {some_attr} 40// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[RESULT_MEMREF]] : memref<f32> 41// CHECK: return %[[VAL_8]] : tensor<f32> 42// CHECK: } 43func.func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f32> { 44 %ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor<f32> { 45 scf.yield %iter : tensor<f32> 46 } {some_attr} 47 return %ret : tensor<f32> 48} 49 50// ----- 51 52// Check whether this converts at all. 53// 54// It would previously fail altogether. 55// CHECK-LABEL: func @if_correct_recursive_legalization_behavior 56// CHECK: "test.munge_tensor" 57func.func @if_correct_recursive_legalization_behavior(%pred: i1, %tensor: tensor<f32>) -> tensor<f32> { 58 %0 = scf.if %pred -> (tensor<f32>) { 59 %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>) 60 scf.yield %1: tensor<f32> 61 } else { 62 %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>) 63 scf.yield %1 : tensor<f32> 64 } 65 return %0 : tensor<f32> 66} 67 68// ----- 69 70// CHECK-LABEL: func @for_correct_recursive_legalization_behavior( 71// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>, 72// CHECK-SAME: %[[INDEX:.*]]: index) -> tensor<f32> { 73// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<f32> to memref<f32> 74// Note: scf.for iter_args always bufferize to a memory write. This could be 75// optimized by analyzing the loop body. 76// CHECK: %[[MEMREF_COPY:.*]] = memref.alloc() 77// CHECK: memref.copy %[[MEMREF]], %[[MEMREF_COPY]] 78// CHECK: %[[RESULT:.*]] = scf.for %{{.*}} = %[[INDEX]] to %[[INDEX]] step %[[INDEX]] iter_args(%[[MEMREF_ITER:.*]] = %[[MEMREF_COPY]]) -> (memref<f32>) { 79// CHECK: %[[TENSOR_ITER:.*]] = bufferization.to_tensor %[[MEMREF_ITER]] : memref<f32> 80// CHECK: %[[TENSOR_MUNGED:.*]] = "test.munge_tensor"(%[[TENSOR_ITER]]) : (tensor<f32>) -> tensor<f32> 81// CHECK: %[[MEMREF_MUNGED:.*]] = bufferization.to_memref %[[TENSOR_MUNGED]] : tensor<f32> to memref<f32> 82// CHECK: scf.yield %[[MEMREF_MUNGED]] : memref<f32> 83// CHECK: } 84// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[RESULT]] : memref<f32> 85// CHECK: return %[[TENSOR]] : tensor<f32> 86// CHECK: } 87func.func @for_correct_recursive_legalization_behavior(%arg0: tensor<f32>, %index: index) -> tensor<f32> { 88 %ret = scf.for %iv = %index to %index step %index iter_args(%iter = %arg0) -> tensor<f32> { 89 %0 = "test.munge_tensor"(%iter) : (tensor<f32>) -> (tensor<f32>) 90 scf.yield %0 : tensor<f32> 91 } 92 return %ret : tensor<f32> 93} 94 95// ----- 96 97// CHECK-LABEL: func @bufferize_while( 98// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64, %[[ARG2:.*]]: tensor<f32> 99// CHECK: %[[M:.*]] = bufferization.to_memref %[[ARG2]] : tensor<f32> to memref<f32> 100// Note: scf.while iter_args always bufferize to a memory write. This could be 101// optimized by analyzing the loop body. 102// CHECK: %[[MEMREF_COPY:.*]] = memref.alloc() 103// CHECK: memref.copy %[[M]], %[[MEMREF_COPY]] 104// CHECK: %[[RES1:.*]]:3 = scf.while (%{{.*}} = %[[ARG0]], %[[ITER:.*]] = %[[MEMREF_COPY]]) : (i64, memref<f32>) -> (i64, i64, memref<f32>) 105// CHECK: scf.condition(%{{.*}}) %{{.*}}, %{{.*}}, %[[ITER]] : i64, i64, memref<f32> 106// CHECK: ^bb0(%{{.*}}: i64, %{{.*}}: i64, %{{.*}}: memref<f32>): 107// CHECK: scf.yield %{{.*}}, %{{.*}} : i64, memref<f32> 108// CHECK: %[[RES2:.*]] = bufferization.to_tensor %[[RES1]]#2 : memref<f32> 109// CHECK: return %[[RES1]]#1, %[[RES2]] : i64, tensor<f32> 110func.func @bufferize_while(%arg0: i64, %arg1: i64, %arg2: tensor<f32>) -> (i64, tensor<f32>) { 111 %c2_i64 = arith.constant 2 : i64 112 %0:3 = scf.while (%arg3 = %arg0, %arg4 = %arg2) : (i64, tensor<f32>) -> (i64, i64, tensor<f32>) { 113 %1 = arith.cmpi slt, %arg3, %arg1 : i64 114 scf.condition(%1) %arg3, %arg3, %arg4 : i64, i64, tensor<f32> 115 } do { 116 ^bb0(%arg5: i64, %arg6: i64, %arg7: tensor<f32>): 117 %1 = arith.muli %arg6, %c2_i64 : i64 118 scf.yield %1, %arg7 : i64, tensor<f32> 119 } 120 return %0#1, %0#2 : i64, tensor<f32> 121} 122