1// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-drop-redundant-insert-slice-rank-expansion %s | FileCheck %s 2 3// CHECK-LABEL: func @test_drop_rank_expansion( 4// CHECK-SAME: %[[src:.*]]: tensor<128x480xf32>, 5// CHECK: %[[extract:.*]] = tensor.extract_slice %[[src]][0, 0] [123, 456] [1, 1] : tensor<128x480xf32> to tensor<123x456xf32> 6// CHECK: return %[[extract]] 7func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1x128x480xf32>) -> tensor<123x456xf32> { 8 %inserted_slice = tensor.insert_slice %src into %dest[0, 0, 0, 0] [1, 1, 128, 480] [1, 1, 1, 1] : tensor<128x480xf32> into tensor<1x1x128x480xf32> 9 %extracted_slice = tensor.extract_slice %inserted_slice[0, 0, 0, 0] [1, 1, 123, 456] [1, 1, 1, 1] : tensor<1x1x128x480xf32> to tensor<123x456xf32> 10 return %extracted_slice : tensor<123x456xf32> 11} 12 13// ----- 14 15func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> { 16 %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32> 17 %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32> 18 return %inserted_slice : tensor<8x1x8xf32> 19} 20// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice( 21// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32> 22// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] 23// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<8x1x8xf32> 24// CHECK: return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32> 25 26// ----- 27 28func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> { 29 %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32> 30 %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32> 31 return %inserted_slice : tensor<1x4x8xf32> 32} 33// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice( 34// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32> 35// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] 36// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<1x4x8xf32> 37// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32> 38 39// ----- 40 41func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> { 42 %extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32> 43 %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32> 44 return %inserted_slice : tensor<1x1x8x8xf32> 45} 46// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice( 47// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x8xf32> 48// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]] 49// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]] 50// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32> 51 52// ----- 53 54func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> { 55 %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32> 56 %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32> 57 return %inserted_slice : tensor<1x4x4xf32> 58} 59// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice( 60// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32> 61// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]] 62// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]] 63// CHECK: return %[[INSERTED_SLICE]] : tensor<1x4x4xf32> 64 65// ----- 66 67func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> { 68 %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32> 69 %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32> 70 return %inserted_slice : tensor<2x8x8xf32> 71} 72// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice( 73// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x8x8xf32> 74// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]] 75// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]] 76// CHECK: return %[[INSERTED_SLICE]] : tensor<2x8x8xf32> 77