1// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries test-analysis-only" -split-input-file | FileCheck %s 2 3// CHECK-LABEL: @elementwise_no_conflict 4func.func @elementwise_no_conflict(%a: tensor<5xf32>, 5 %b: tensor<5xf32>) -> tensor<5xf32> { 6 // CHECK: linalg.elemwise_binary 7 // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn<add>} 8 %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} 9 ins(%a, %b : tensor<5xf32>, tensor<5xf32>) 10 outs(%a : tensor<5xf32>) -> tensor<5xf32> 11 return %0 : tensor<5xf32> 12} 13 14// ----- 15 16// CHECK-LABEL: @elementwise_no_conflict_2 17func.func @elementwise_no_conflict_2(%a: tensor<5xf32>) -> tensor<5xf32> { 18 // CHECK: linalg.elemwise_binary 19 // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn<add>} 20 %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} 21 ins(%a, %a : tensor<5xf32>, tensor<5xf32>) 22 outs(%a : tensor<5xf32>) -> tensor<5xf32> 23 return %0 : tensor<5xf32> 24} 25 26// ----- 27 28// CHECK-LABEL: @elementwise_no_conflict_3 29func.func @elementwise_no_conflict_3(%a: tensor<5xf32>) -> tensor<5xf32> { 30 %c0f = arith.constant 1.0 : f32 31 // CHECK: linalg.elemwise_binary 32 // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none", "true"], fun = #linalg.binary_fn<add>} 33 %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} 34 ins(%a, %c0f : tensor<5xf32>, f32) 35 outs(%a : tensor<5xf32>) -> tensor<5xf32> 36 return %0 : tensor<5xf32> 37} 38 39// ----- 40 41func.func @not_elementwise(%a: tensor<5x6xf32>) -> tensor<5x6xf32> { 42 %cst = arith.constant 5.0 : f32 43 // CHECK: tensor.extract_slice 44 // CHECK-SAME: {__inplace_operands_attr__ = ["false"]} 45 %b = tensor.extract_slice %a[0, 0] [1, 6] [1, 1] 46 : tensor<5x6xf32> to tensor<6xf32> 47 // CHECK: linalg.generic 48 // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]} 49 %0 = linalg.generic 50 { iterator_types = ["parallel", "parallel"], 51 indexing_maps = [ affine_map<(d0, d1) -> (d1)>, 52 affine_map<(d0, d1) -> (d0, d1)>] } 53 ins(%b: tensor<6xf32>) outs(%a: tensor<5x6xf32>) { 54 ^bb0(%arg0: f32, %arg1: f32): 55 %r = arith.addf %arg0, %arg1 : f32 56 linalg.yield %r : f32 57 } -> tensor<5x6xf32> 58 return %0 : tensor<5x6xf32> 59} 60 61// ----- 62 63#map = affine_map<(d0, d1) -> (d0, d1)> 64#map1 = affine_map<(d0, d1) -> (d1)> 65 66// CHECK-LABEL: @elementwise_no_conflict_4 67func.func @elementwise_no_conflict_4(%arg0: tensor<8x32x32x32xf32>, %arg1: tensor<32x32x32xf32>) -> tensor<8x32x32x32xf32> { 68 %cst = arith.constant dense<3.000000e-02> : tensor<32x32x32xf32> 69 %cst_0 = arith.constant dense<6.000000e-01> : tensor<32xf32> 70 %cst_1 = arith.constant 0.000000e+00 : f32 71 %r = scf.forall (%arg2, %arg3) in (8, 32) shared_outs(%arg4 = %arg0) -> (tensor<8x32x32x32xf32>) { 72 // CHECK: tensor.extract_slice 73 // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none", "none"]} 74 %extracted_slice = tensor.extract_slice %arg4[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<8x32x32x32xf32> to tensor<32x32xf32> 75 76 // CHECK: linalg.fill 77 // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]} 78 %4 = linalg.fill ins(%cst_1 : f32) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> 79 80 // CHECK: linalg.batch_reduce_matmul 81 // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"]} 82 %5 = linalg.batch_reduce_matmul ins(%arg1, %cst : tensor<32x32x32xf32>, tensor<32x32x32xf32>) outs(%4 : tensor<32x32xf32>) -> tensor<32x32xf32> 83 84 // CHECK: linalg.generic 85 // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"]} 86 // %cst_0 has a non-identity layout may, but %5 and %extracted_slice still 87 // bufferize to element-wise access. 88 %6 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%5, %cst_0 : tensor<32x32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { 89 ^bb0(%in: f32, %in_4: f32, %out: f32): 90 %8 = arith.addf %in, %in_4 : f32 91 linalg.yield %8 : f32 92 } -> tensor<32x32xf32> 93 94 // CHECK: linalg.generic 95 // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]} 96 // They are different SSA values, but %6 and %extract_slice are equivalent. 97 %7 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%6 : tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { 98 ^bb0(%in: f32, %out: f32): 99 %8 = arith.maximumf %in, %cst_1 : f32 100 linalg.yield %8 : f32 101 } -> tensor<32x32xf32> 102 scf.forall.in_parallel { 103 // CHECK: tensor.parallel_insert_slice 104 // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]} 105 tensor.parallel_insert_slice %7 into %arg4[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x32x32x32xf32> 106 } 107 } 108 return %r : tensor<8x32x32x32xf32> 109} 110 111// ----- 112 113// CHECK-LABEL: func @elementwise_access_regression( 114// CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "false"]} 115// CHECK: linalg.map 116// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"]} 117// CHECK: linalg.map 118// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"]} 119func.func private @f(%arg: tensor<32x1xf32>) -> () 120func.func @elementwise_access_regression(%arg0: i32, %arg2: tensor<32x1xf32>, %arg3: tensor<32x1xf32>) { 121 %cst_0 = arith.constant 0.000000e+00 : f32 122 %c0_i32 = arith.constant 0 : i32 123 %c1_i32 = arith.constant 1 : i32 124 %0 = tensor.empty() : tensor<32x1xf32> 125 126 // This op must bufferize out-of-place so that the filled tensor is not 127 // overwritten by the ops inside of the loop. 128 %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<32x1xf32>) -> tensor<32x1xf32> 129 130 scf.for %arg1 = %c0_i32 to %arg0 step %c1_i32 : i32 { 131 %2 = linalg.map { arith.subf } ins(%1, %arg2 : tensor<32x1xf32>, tensor<32x1xf32>) outs(%0 : tensor<32x1xf32>) 132 %3 = tensor.empty() : tensor<32x1xf32> 133 %4 = linalg.map { arith.subf } ins(%2, %arg3 : tensor<32x1xf32>, tensor<32x1xf32>) outs(%3 : tensor<32x1xf32>) 134 func.call @f(%4) : (tensor<32x1xf32>) -> () 135 } 136 return 137} 138