1// RUN: mlir-opt -test-tensor-transform-patterns=test-tracking-listener \ 2// RUN: -split-input-file -verify-diagnostics %s 3 4func.func @replace_op_with_op_of_same_type() { 5 %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) 6 // expected-remark @below {{replacement found}} 7 %1 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>) 8 return 9} 10 11// ----- 12 13func.func @replace_op_with_op_of_different_type() { 14 // expected-error @below {{listener could not find replacement op}} 15 %0 = tensor.empty() {replaced} : tensor<5xf32> 16 %1 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>) 17 return 18} 19 20// ----- 21 22func.func @multi_result_replacement() { 23 %0:2 = "test.foo"() {replaced} : () -> (tensor<5xf32>, tensor<6xf32>) 24 // expected-remark @below {{replacement found}} 25 %1:2 = "test.foo"() {replacement_0 = 0, replacement_1 = 1} 26 : () -> (tensor<5xf32>, tensor<6xf32>) 27 return 28} 29 30// ----- 31 32func.func @multi_result_replacement_with_multiple_ops() { 33 // expected-error @below {{listener could not find replacement op}} 34 %0:2 = "test.foo"() {replaced} : () -> (tensor<5xf32>, tensor<6xf32>) 35 %1:2 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>, tensor<6xf32>) 36 %2:2 = "test.foo"() {replacement_1 = 1} : () -> (tensor<5xf32>, tensor<6xf32>) 37 return 38} 39 40// ----- 41 42func.func @replacement_wrapped_in_cast() { 43 %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) 44 // expected-remark @below {{replacement found}} 45 %1 = "test.foo"() : () -> (tensor<?xf32>) 46 %2 = tensor.cast %1 {replacement_0 = 0} : tensor<?xf32> to tensor<5xf32> 47 return 48} 49 50// ----- 51 52func.func @replacement_wrapped_in_chain_of_casts() { 53 %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) 54 // expected-remark @below {{replacement found}} 55 %1 = "test.foo"() : () -> (tensor<?xf32>) 56 %2 = tensor.cast %1 : tensor<?xf32> to tensor<5xf32> 57 %3 = tensor.cast %2 : tensor<5xf32> to tensor<?xf32> 58 %4 = tensor.cast %3 {replacement_0 = 0} : tensor<?xf32> to tensor<5xf32> 59 return 60} 61 62// ----- 63 64func.func @cast_like_insert_slice(%t: tensor<1x5xf32>) { 65 %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) 66 // expected-remark @below {{replacement found}} 67 %1 = "test.foo"() : () -> (tensor<5xf32>) 68 %2 = tensor.insert_slice %1 into %t[0, 0][1, 5][1, 1] {replacement_0 = 0} 69 : tensor<5xf32> into tensor<1x5xf32> 70 return 71} 72 73// ----- 74 75func.func @non_cast_like_insert_slice(%t: tensor<7xf32>) { 76 // expected-error @below {{listener could not find replacement op}} 77 %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) 78 %1 = "test.foo"() : () -> (tensor<5xf32>) 79 // This is not a cast-like insert_slice op because elements from %t are 80 // contained in %2. 81 %2 = tensor.insert_slice %1 into %t[0][5][1] {replacement_0 = 0} 82 : tensor<5xf32> into tensor<7xf32> 83 return 84} 85 86// ----- 87 88func.func @cast_like_insert_slice_dynamic( 89 %t: tensor<1x?x1xf32>, %f: f32, %pos: index) { 90 %c0 = arith.constant 0 : index 91 %0 = tensor.insert %f into %t[%c0, %pos, %c0] {replaced} : tensor<1x?x1xf32> 92 93 // Rank reduction 94 %c1 = arith.constant 1 : index 95 %dim1 = tensor.dim %t, %c1 : tensor<1x?x1xf32> 96 %1 = tensor.extract_slice %t[0, 0, 0][1, %dim1, 1][1, 1, 1] 97 : tensor<1x?x1xf32> to tensor<?xf32> 98 // expected-remark @below {{replacement found}} 99 %2 = tensor.insert %f into %1[%c0] : tensor<?xf32> 100 // Rank expansion 101 // Throw in a wrench: Do not use %dim1 directly, but another SSA value that 102 // has the same runtime value. 103 %dim1b = tensor.dim %1, %c0 : tensor<?xf32> 104 %3 = tensor.insert_slice %2 into %t[0, 0, 0][1, %dim1b, 1][1, 1, 1] 105 {replacement_0 = 0} : tensor<?xf32> into tensor<1x?x1xf32> 106 return 107} 108 109// ----- 110 111func.func @cast_like_extract_slice() { 112 %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) 113 // expected-remark @below {{replacement found}} 114 %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>) 115 %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 5, 1, 1][1, 1, 1, 1] 116 {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<5xf32> 117 return 118} 119 120// ----- 121 122func.func @cast_like_extract_slice_dynamic() { 123 %0 = "test.foo"() {replaced} : () -> (tensor<?xf32>) 124 // expected-remark @below {{replacement found}} 125 %1 = "test.foo"() : () -> (tensor<1x?x1x1xf32>) 126 %c1 = arith.constant 1 : index 127 %dim = tensor.dim %1, %c1 : tensor<1x?x1x1xf32> 128 %2 = tensor.extract_slice %1[0, 0, 0, 0][1, %dim, 1, 1][1, 1, 1, 1] 129 {replacement_0 = 0} : tensor<1x?x1x1xf32> to tensor<?xf32> 130 return 131} 132 133// ----- 134 135func.func @non_cast_like_extract_slice() { 136 // expected-error @below {{listener could not find replacement op}} 137 %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) 138 %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>) 139 %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 3, 1, 1][1, 1, 1, 1] 140 {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<3xf32> 141 return 142} 143 144// ----- 145 146func.func @non_cast_like_extract_slice_drop_non_unit_dim() { 147 // expected-error @below {{listener could not find replacement op}} 148 %0 = "test.foo"() {replaced} : () -> (tensor<f32>) 149 %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>) 150 %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 1, 1, 1][1, 1, 1, 1] 151 {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<f32> 152 return 153} 154