1// RUN: mlir-opt -outline-shape-computation -test-print-shape-mapping -split-input-file %s 2>%t | FileCheck %s 2// RUN: cat %t | FileCheck %s --check-prefix SHAPE 3 4// Two dynamic shapes: one of direct shape.shape_of(arg) and the other. 5func.func @two_dynamic_one_direct_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> { 6 // SHAPE-DAG: Shape for {{.*}} = "test.abs"({{.*}}> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0) 7 // SHAPE-DAG: Shape for {{.*}} = "test.concat"({{.*}}> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0) 8 %c2 = arith.constant 2 : index 9 %c0 = arith.constant 0 : index 10 %c4 = arith.constant 4 : index 11 %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex> 12 %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index 13 %2 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> 14 %3 = shape.with_shape %2, %0 : tensor<?x4x?xf32>, tensor<3xindex> 15 %4 = shape.value_of %3 : tensor<?x4x?xf32> 16 %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> 17 %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index 18 %7 = arith.addi %6, %c2 : index 19 %8 = shape.from_extents %7, %c4, %1 : index, index, index 20 %9 = shape.with_shape %5, %8 : tensor<?x4x?xf32>, !shape.shape 21 %10 = shape.value_of %9 : tensor<?x4x?xf32> 22 return %10 : tensor<?x4x?xf32> 23} 24 25// CHECK-LABEL: func.func @two_dynamic_one_direct_shape 26// CHECK-NEXT: %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> 27// CHECK-NEXT: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> 28// CHECK-NEXT: return %1 : tensor<?x4x?xf32> 29 30// CHECK: shape.func private @shape_cal_1(%arg0: tensor<?x4x?xf32>) -> !shape.shape { 31// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex> 32// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index 33// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index 34// CHECK-DAG: %[[V3:.*]] = arith.addi %[[V2]], %c2 : index 35// CHECK-DAG: %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index 36// CHECK-DAG: return %[[V4]] : !shape.shape 37 38// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> tensor<3xindex> { 39// CHECK-DAG: %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex> 40// CHECK-DAG: return %0 : tensor<3xindex> 41 42// ----- 43 44// Two dynamic shapes and they share the same shape.func 45func.func @two_dynamic_share_same_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> { 46 %c2 = arith.constant 2 : index 47 %c0 = arith.constant 0 : index 48 %c4 = arith.constant 4 : index 49 %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex> 50 %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index 51 %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> 52 %3 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index 53 %4 = arith.addi %3, %c2 : index 54 %5 = shape.from_extents %4, %c4, %1 : index, index, index 55 %6 = shape.with_shape %2, %5 : tensor<?x4x?xf32>, !shape.shape 56 %7 = shape.value_of %6 : tensor<?x4x?xf32> 57 %8 = "test.abs"(%7) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> 58 %9 = shape.with_shape %8, %5 : tensor<?x4x?xf32>, !shape.shape 59 %10 = shape.value_of %9 : tensor<?x4x?xf32> 60 return %10 : tensor<?x4x?xf32> 61} 62// CHECK-LABEL: func.func @two_dynamic_share_same_shape 63// CHECK-NEXT: %0 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> 64// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> 65// CHECK-NEXT: return %1 : tensor<?x4x?xf32> 66 67// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> !shape.shape { 68// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex> 69// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index 70// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index 71// CHECK-DAG: %[[V3:.*]] = arith.addi %[[V2]], %c2 : index 72// CHECK-DAG: %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index 73// CHECK-DAG: return %4 : !shape.shape 74// CHECK-NOT: shape_cal_1 75 76// ----- 77 78// There's an internal dynamic shape source, and two other dynamic shapes shares it 79func.func @internal_dynamic_shape_source_shared(%arg0: tensor<?x4xf32>) -> tensor<?xi32> { 80 %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32> 81 %1 = shape.shape_of %0 : tensor<?xi32> -> tensor<1xindex> 82 %2 = shape.with_shape %0, %1 : tensor<?xi32>, tensor<1xindex> 83 %3 = shape.value_of %2 : tensor<?xi32> 84 %4 = "test.abs"(%3) : (tensor<?xi32>) -> tensor<?xi32> 85 %5 = shape.with_shape %4, %1 : tensor<?xi32>, tensor<1xindex> 86 %6 = shape.value_of %5 : tensor<?xi32> 87 %7 = "test.negate"(%6) : (tensor<?xi32>) -> tensor<?xi32> 88 %8 = shape.with_shape %7, %1 : tensor<?xi32>, tensor<1xindex> 89 %9 = shape.value_of %8 : tensor<?xi32> 90 return %9 : tensor<?xi32> 91} 92// CHECK-LABEL: func.func @internal_dynamic_shape_source_shared 93// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32> 94// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor<?xi32>) -> tensor<?xi32> 95// CHECK-NEXT: %2 = "test.negate"(%1) : (tensor<?xi32>) -> tensor<?xi32> 96// CHECK-NEXT: return %2 : tensor<?xi32> 97 98// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?xi32>) -> tensor<1xindex> { 99// CHECK-NEXT: %0 = shape_of %arg0 : tensor<?xi32> -> tensor<1xindex> 100// CHECK-NEXT: return %0 : tensor<1xindex> 101// CHECK-NOT: shape_cal_1 102 103// ----- 104 105// There's only a return op in the constructed shape.func 106func.func @only_return_of_constructed_shape(%arg0: tensor<?x4xf32>, %arg1: tensor<1xindex>) -> tensor<?xi32> { 107 %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32> 108 %1 = shape.with_shape %0, %arg1 : tensor<?xi32>, tensor<1xindex> 109 %2 = shape.value_of %1 : tensor<?xi32> 110 return %2 : tensor<?xi32> 111} 112// CHECK-LABEL: func.func @only_return_of_constructed_shape(%arg0: tensor<?x4xf32>, %arg1: tensor<1xindex>) -> tensor<?xi32> { 113// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32> 114// CHECK-NEXT: return %0 : tensor<?xi32> 115 116// CHECK: shape.func private @shape_cal_0(%arg0: tensor<1xindex>) -> tensor<1xindex> { 117// CHECK-NEXT: return %arg0 : tensor<1xindex> 118 119// ----- 120 121// Shape computation part interleaves with general computation. 122func.func @interleaved_shape_computation(%arg0: tensor<?x4x5xf32>, %arg1: tensor<?x4x5xf32>, %arg2: tensor<?x4x5xf32>) -> (tensor<?x4x5xf32>, index) { 123 %c0 = arith.constant 0 : index 124 %c4 = arith.constant 4 : index 125 %c5 = arith.constant 5 : index 126 %0 = shape.shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex> 127 %1 = shape.shape_of %arg1 : tensor<?x4x5xf32> -> tensor<3xindex> 128 %2 = shape.shape_of %arg2 : tensor<?x4x5xf32> -> tensor<3xindex> 129 %3 = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor<?x4x5xf32>, tensor<?x4x5xf32>, tensor<?x4x5xf32>) -> tensor<?x4x5xf32> 130 %4 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index 131 %5 = shape.get_extent %1, %c0 : tensor<3xindex>, index -> index 132 %6 = shape.get_extent %2, %c0 : tensor<3xindex>, index -> index 133 %7 = arith.addi %4, %5 : index 134 %8 = arith.addi %7, %6 : index 135 %9 = shape.from_extents %8, %c4, %c5 : index, index, index 136 %10 = shape.with_shape %3, %9 : tensor<?x4x5xf32>, !shape.shape 137 %11 = shape.value_of %10 : tensor<?x4x5xf32> 138 return %11, %7 : tensor<?x4x5xf32>, index 139} 140// CHECK-LABEL: func.func @interleaved_shape_computation 141// CHECK-DAG: %[[V0:.*]] = shape.shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex> 142// CHECK-DAG: %[[V1:.*]] = shape.shape_of %arg1 : tensor<?x4x5xf32> -> tensor<3xindex> 143// CHECK-DAG: %[[V2:.*]] = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor<?x4x5xf32>, tensor<?x4x5xf32>, tensor<?x4x5xf32>) -> tensor<?x4x5xf32> 144// CHECK-DAG: %[[V3:.*]] = shape.get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index 145// CHECK-DAG: %[[V4:.*]] = shape.get_extent %[[V1]], %c0 : tensor<3xindex>, index -> index 146// CHECK-DAG: %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : index 147// CHECK-DAG: return %[[V2]], %[[V5]] : tensor<?x4x5xf32>, index 148 149// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4x5xf32>, %arg1: index, %arg2: index) -> !shape.shape { 150// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex> 151// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %arg1 : tensor<3xindex>, index -> index 152// CHECK-DAG: %[[V2:.*]] = arith.addi %arg2, %[[V1]] : index 153// CHECK-DAG: %[[V3:.*]] = from_extents %[[V2]], %c4, %c5 : index, index, index 154// CHECK-DAG: return %[[V3]] : !shape.shape 155 156// ----- 157 158// There're multiple reused shape computations. 159func.func @multiple_reused(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> (tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>) { 160 %c0 = arith.constant 0 : index 161 %c4 = arith.constant 4 : index 162 %0 = shape.shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex> 163 %1 = shape.shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex> 164 %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32> 165 %3 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32> 166 %4 = shape.get_extent %0, %c0 : tensor<2xindex>, index -> index 167 %5 = shape.get_extent %1, %c0 : tensor<2xindex>, index -> index 168 %6 = arith.addi %4, %5 : index 169 %7 = shape.from_extents %6, %c4 : index, index 170 %8 = shape.with_shape %2, %7 : tensor<?x4xf32>, !shape.shape 171 %9 = shape.with_shape %3, %7 : tensor<?x4xf32>, !shape.shape 172 %10 = shape.value_of %8 : tensor<?x4xf32> 173 %11 = shape.value_of %9 : tensor<?x4xf32> 174 %12 = "test.concat"(%arg0, %2) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32> 175 %13 = "test.concat"(%arg0, %3) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32> 176 %14 = arith.addi %6, %4 : index 177 %15 = shape.from_extents %14, %c4 : index, index 178 %16 = shape.with_shape %12, %15 : tensor<?x4xf32>, !shape.shape 179 %17 = shape.with_shape %13, %15 : tensor<?x4xf32>, !shape.shape 180 %18 = shape.value_of %16 : tensor<?x4xf32> 181 %19 = shape.value_of %17 : tensor<?x4xf32> 182 return %10, %11, %18, %19 : tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32> 183} 184// CHECK-LABEL: func.func @multiple_reused 185// CHECK-DAG: %[[V0:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32> 186// CHECK-DAG: %[[V1:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32> 187// CHECK-DAG: %[[V2:.*]] = "test.concat"(%arg0, %[[V0]]) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32> 188// CHECK-DAG: %[[V3:.*]] = "test.concat"(%arg0, %[[V1]]) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32> 189// CHECK-DAG: return %[[V0]], %[[V1]], %[[V2]], %[[V3]] : tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32> 190 191// CHECK: shape.func private @shape_cal_1(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> !shape.shape { 192// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex> 193// CHECK-DAG: %[[V1:.*]] = shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex> 194// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index 195// CHECK-DAG: %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index 196// CHECK-DAG: %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index 197// CHECK-DAG: %[[V5:.*]] = arith.addi %[[V4]], %[[V2]] : index 198// CHECK-DAG: %[[V6:.*]] = from_extents %[[V5]], %c4 : index, index 199// CHECK-DAG: return %[[V6]] : !shape.shape 200 201// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> !shape.shape { 202// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex> 203// CHECK-DAG: %[[V1:.*]] = shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex> 204// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index 205// CHECK-DAG: %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index 206// CHECK-DAG: %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index 207// CHECK-DAG: %[[V5:.*]] = from_extents %[[V4]], %c4 : index, index 208// CHECK-DAG: return %[[V5]] : !shape.shape 209 210// Make sure redundant with_shape is removed when with_shape input is !shape.value_shape. 211func.func @value_shape_with_shape(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> tensor<?xf32> { 212 %1 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape 213 %2 = shape.with_shape %arg1, %1 : !shape.value_shape, !shape.shape 214 %3 = shape.value_of %2 : tensor<?xf32> 215 return %3 : tensor<?xf32> 216} 217// CHECK-LABEL:func.func @value_shape_with_shape 218// CHECK-NEXT:%0 = shape.value_of %arg1 : tensor<?xf32> 219// CHECK-NEXT:return %0 : tensor<?xf32> 220