1// RUN: mlir-opt %s -test-vector-to-vector-lowering -split-input-file| FileCheck %s 2 3// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> 4// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> 5// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> 6 7// CHECK-LABEL: cast_away_contraction_leading_one_dims 8// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32> 9// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32> 10// CHECK-NEXT: %[[R2:.+]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32> 11// CHECK-NEXT: %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], 12// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} 13// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> 14// CHECK-NEXT: %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32> 15// CHECK-NEXT: return %[[R4]] : vector<1x16x16xf32> 16 17#contraction_accesses0 = [ 18 affine_map<(l, i, j, k) -> (l, i, k)>, 19 affine_map<(l, i, j, k) -> (l, k, j)>, 20 affine_map<(l, i, j, k) -> (l, i, j)> 21] 22#contraction_trait0 = { 23 indexing_maps = #contraction_accesses0, 24 iterator_types = ["parallel", "parallel", "parallel", "reduction"] 25} 26 27func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> { 28 %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> 29 return %0: vector<1x16x16xf32> 30} 31 32// ----- 33// CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> 34// CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> 35// CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 36 37// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_const_mask 38// CHECK: %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1> 39// CHECK: %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32> 40// CHECK: %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32> 41// CHECK: %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32> 42// CHECK: %[[CONTRACT:.*]] = vector.mask %[[MASK]] { 43// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} 44// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> 45// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32> 46// CHECK: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32> 47// CHECK: return %[[RES]] : vector<1x16x16xf32> 48 49#contraction_accesses0 = [ 50 affine_map<(l, i, j, k) -> (l, i, k)>, 51 affine_map<(l, i, j, k) -> (l, k, j)>, 52 affine_map<(l, i, j, k) -> (l, i, j)> 53] 54#contraction_trait0 = { 55 indexing_maps = #contraction_accesses0, 56 iterator_types = ["parallel", "parallel", "parallel", "reduction"] 57} 58 59func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> { 60 %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1> 61 %0 = vector.mask %mask { 62 vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> 63 } : vector<1x16x16x8xi1> -> vector<1x16x16xf32> 64 return %0 : vector<1x16x16xf32> 65} 66 67// ----- 68// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> 69// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> 70// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 71 72// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_mask 73// CHECK: %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32> 74// CHECK: %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32> 75// CHECK: %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32> 76// CHECK: %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1> 77// CHECK: %[[CONTRACT:.*]] = vector.mask %[[M]] { 78// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} 79// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> 80// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32> 81// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32> 82// CHECK-NEXT: return %[[RES]] : vector<1x16x16xf32> 83 84#contraction_accesses0 = [ 85 affine_map<(l, i, j, k) -> (l, i, k)>, 86 affine_map<(l, i, j, k) -> (l, k, j)>, 87 affine_map<(l, i, j, k) -> (l, i, j)> 88] 89#contraction_trait0 = { 90 indexing_maps = #contraction_accesses0, 91 iterator_types = ["parallel", "parallel", "parallel", "reduction"] 92} 93 94func.func @cast_away_contraction_leading_one_dim_under_mask( 95 %arg0: vector<1x16x8xf32>, 96 %arg1: vector<1x8x16xf32>, 97 %arg2: vector<1x16x16xf32>, 98 %mask: vector<1x16x16x8xi1>) -> vector<1x16x16xf32> { 99 %0 = vector.mask %mask { 100 vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> 101 } : vector<1x16x16x8xi1> -> vector<1x16x16xf32> 102 return %0: vector<1x16x16xf32> 103} 104 105// ----- 106 107// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)> 108// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)> 109// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)> 110 111// CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded 112// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32> 113// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32> 114// CHECK-NEXT: %[[R2:.+]] = vector.extract %{{.*}}[0, 0] : vector<16xf32> from vector<1x1x16xf32> 115// CHECK-NEXT: %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], 116// CHECK-SAME: iterator_types = ["parallel", "reduction"], kind = #vector.kind<mul>} 117// CHECK-SAME: %[[R1]], %[[R0]], %[[R2]] : vector<8xf32>, vector<8x16xf32> into vector<16xf32> 118// CHECK-NEXT: %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16xf32> to vector<1x16xf32> 119// CHECK-NEXT: %[[R5:.+]] = vector.broadcast %[[R4]] : vector<1x16xf32> to vector<1x1x16xf32> 120// CHECK-NEXT: return %[[R5]] : vector<1x1x16xf32> 121 122#contraction_accesses1 = [ 123 affine_map<(l, i, j, k) -> (i, l, k)>, 124 affine_map<(l, i, j, k) -> (l, k, j)>, 125 affine_map<(l, i, j, k) -> (l, i, j)> 126] 127#contraction_trait1 = { 128 indexing_maps = #contraction_accesses1, 129 iterator_types = ["parallel", "parallel", "parallel", "reduction"], 130 kind = #vector.kind<mul> 131} 132 133func.func @cast_away_contraction_leading_one_dims_transposeneeded(%arg0: vector<1x1x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x1x16xf32>) -> vector<1x1x16xf32> { 134 %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2 : vector<1x1x8xf32>, vector<1x8x16xf32> into vector<1x1x16xf32> 135 return %0: vector<1x1x16xf32> 136} 137 138// ----- 139// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> 140// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> 141// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> 142 143// CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded2 144// CHECK-NEXT: %[[R0:.+]] = vector.transpose %{{.*}}[1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32> 145// CHECK-NEXT: %[[R1:.+]] = vector.extract %[[R0]][0] : vector<8x16xf32> from vector<1x8x16xf32> 146// CHECK-NEXT: %[[R2:.+]] = vector.transpose %{{.*}}[2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32> 147// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R2]][0] : vector<2x8xf32> from vector<1x2x8xf32> 148// CHECK-NEXT: %[[R4:.+]] = vector.extract %{{.*}}[0] : vector<2x16xf32> from vector<1x2x16xf32> 149// CHECK-NEXT: %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], 150// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} 151// CHECK-SAME: %[[R1]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> 152// CHECK-NEXT: %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32> 153// CHECK-NEXT: return %[[R6]] : vector<1x2x16xf32> 154 155#contraction_accesses2 = [ 156 affine_map<(l, i, j, k) -> (k, l, j)>, 157 affine_map<(l, i, j, k) -> (i, k, l)>, 158 affine_map<(l, i, j, k) -> (l, i, j)> 159] 160#contraction_trait2 = { 161 indexing_maps = #contraction_accesses2, 162 iterator_types = ["parallel", "parallel", "parallel", "reduction"] 163} 164 165 166func.func @cast_away_contraction_leading_one_dims_transposeneeded2(%arg0: vector<8x1x16xf32>, %arg1: vector<2x8x1xf32>, %arg2: vector<1x2x16xf32>) -> vector<1x2x16xf32> { 167 %0 = vector.contract #contraction_trait2 %arg0, %arg1, %arg2 : vector<8x1x16xf32>, vector<2x8x1xf32> into vector<1x2x16xf32> 168 return %0: vector<1x2x16xf32> 169} 170 171// ----- 172// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> 173// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> 174// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> 175 176 177// CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4 178// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<8x1x16xf32> from vector<1x8x1x16xf32> 179// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0] : vector<2x8x1xf32> from vector<1x2x8x1xf32> 180// CHECK-NEXT: %[[R2:.+]] = vector.transpose %[[R0]], [1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32> 181// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R2]][0] : vector<8x16xf32> from vector<1x8x16xf32> 182// CHECK-NEXT: %[[R4:.+]] = vector.transpose %[[R1]], [2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32> 183// CHECK-NEXT: %[[R5:.+]] = vector.extract %[[R4]][0] : vector<2x8xf32> from vector<1x2x8xf32> 184// CHECK-NEXT: %[[R6:.+]] = vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32> 185// CHECK-NEXT: %[[R7:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], 186// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} 187// CHECK-SAME: %[[R3]], %[[R5]], %[[R6]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> 188// CHECK-NEXT: %[[R8:.+]] = vector.broadcast %[[R7]] : vector<2x16xf32> to vector<1x2x16xf32> 189// CHECK-NEXT: %[[R9:.+]] = vector.broadcast %[[R8]] : vector<1x2x16xf32> to vector<1x1x2x16xf32> 190// CHECK-NEXT: return %[[R9]] : vector<1x1x2x16xf32> 191 192#contraction_accesses2 = [ 193 affine_map<(m, l, i, j, k) -> (m, k, l, j)>, 194 affine_map<(m, l, i, j, k) -> (m, i, k, l)>, 195 affine_map<(m, l, i, j, k) -> (m, l, i, j)> 196] 197#contraction_trait2 = { 198 indexing_maps = #contraction_accesses2, 199 iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"] 200} 201 202 203func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4(%arg0: vector<1x8x1x16xf32>, %arg1: vector<1x2x8x1xf32>, %arg2: vector<1x1x2x16xf32>) -> vector<1x1x2x16xf32> { 204 %0 = vector.contract #contraction_trait2 %arg0, %arg1, %arg2 : vector<1x8x1x16xf32>, vector<1x2x8x1xf32> into vector<1x1x2x16xf32> 205 return %0: vector<1x1x2x16xf32> 206} 207 208// ----- 209// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> 210// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> 211// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> 212 213// CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose 214// CHECK-NEXT: %[[R0:.+]] = vector.transpose %{{.*}}, [2, 0, 1, 3] : vector<1x8x1x16xf32> to vector<1x1x8x16xf32> 215// CHECK-NEXT: %[[R1:.+]] = vector.transpose %{{.*}}, [3, 0, 1, 2] : vector<1x2x8x1xf32> to vector<1x1x2x8xf32> 216// CHECK-NEXT: %[[R2:.+]] = vector.extract %[[R0]][0, 0] : vector<8x16xf32> from vector<1x1x8x16xf32> 217// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R1]][0, 0] : vector<2x8xf32> from vector<1x1x2x8xf32> 218// CHECK-NEXT: %[[R4:.+]] = vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32> 219// CHECK-NEXT: %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], 220// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} 221// CHECK-SAME: %[[R2]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> 222// CHECK-NEXT: %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32> 223// CHECK-NEXT: %[[R7:.+]] = vector.broadcast %[[R6]] : vector<1x2x16xf32> to vector<1x1x2x16xf32> 224// CHECK-NEXT: return %[[R7]] : vector<1x1x2x16xf32> 225 226#contraction_accesses3 = [ 227 affine_map<(m, l, i, j, k) -> (m, k, l, j)>, 228 affine_map<(m, l, i, j, k) -> (m, i, k, l)>, 229 affine_map<(m, l, i, j, k) -> (l, m, i, j)> 230] 231#contraction_trait3 = { 232 indexing_maps = #contraction_accesses3, 233 iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"] 234} 235 236func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose(%arg0: vector<1x8x1x16xf32>, %arg1: vector<1x2x8x1xf32>, %arg2: vector<1x1x2x16xf32>) -> vector<1x1x2x16xf32> { 237 %0 = vector.contract #contraction_trait3 %arg0, %arg1, %arg2 : vector<1x8x1x16xf32>, vector<1x2x8x1xf32> into vector<1x1x2x16xf32> 238 return %0: vector<1x1x2x16xf32> 239} 240 241// ----- 242 243// CHECK-LABEL: func.func @cast_away_contraction_does_not_transpose_leading_unit_dims 244// CHECK-NOT: vector.transpose 245// CHECK: vector.contract 246func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vector<1x1x8xi32>, 247 %rhs: vector<1x8x8xi32>, 248 %acc: vector<1x8xi32>) -> vector<1x8xi32> { 249 %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32> 250 return %result : vector<1x8xi32> 251} 252 253// ----- 254// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims 255func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> { 256 // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16> 257 // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16> 258 %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16> 259 // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16> 260 // CHECK: return %[[RET]] 261 return %0: vector<1x1x8xf16> 262} 263 264// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims_scalable 265func.func @cast_away_extract_strided_slice_leading_one_dims_scalable(%arg0: vector<1x8x[8]xf16>) -> vector<1x1x[8]xf16> { 266 // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16> 267 // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x[8]xf16> to vector<1x[8]xf16> 268 %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x[8]xf16> to vector<1x1x[8]xf16> 269 // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16> 270 // CHECK: return %[[RET]] 271 return %0: vector<1x1x[8]xf16> 272} 273 274// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims 275func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> { 276 // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8xf16> from vector<1x8xf16> 277 // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16> 278 // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16> 279 %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16> 280 // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16> 281 // CHECK: return %[[RET]] 282 return %0: vector<1x8x8xf16> 283} 284 285// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_scalable 286func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vector<1x[8]xf16>, %arg1: vector<1x8x[8]xf16>) -> vector<1x8x[8]xf16> { 287 // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<[8]xf16> from vector<1x[8]xf16> 288 // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16> 289 // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<[8]xf16> into vector<8x[8]xf16> 290 %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[8]xf16> into vector<1x8x[8]xf16> 291 // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16> 292 // CHECK: return %[[RET]] 293 return %0: vector<1x8x[8]xf16> 294} 295 296// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element 297// CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16> 298func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> { 299 // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1xf16> from vector<1x1xf16> 300 // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16> 301 %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16> 302 // CHECK: return %[[B]] 303 return %0: vector<1x1x1xf16> 304} 305 306// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable 307// CHECK-SAME: %[[ARG0:.+]]: vector<1x[1]xf16>, %{{.+}}: vector<1x1x[1]xf16> 308func.func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable(%arg0: vector<1x[1]xf16>, %arg1: vector<1x1x[1]xf16>) -> vector<1x1x[1]xf16> { 309 // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<[1]xf16> from vector<1x[1]xf16> 310 // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<[1]xf16> to vector<1x1x[1]xf16> 311 %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[1]xf16> into vector<1x1x[1]xf16> 312 // CHECK: return %[[B]] 313 return %0: vector<1x1x[1]xf16> 314} 315 316// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims 317func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> { 318 // CHECK: %[[C0:.+]] = arith.constant 0 : index 319 %c0 = arith.constant 0 : index 320 // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16 321 %f0 = arith.constant 0. : f16 322 // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16> 323 // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16> 324 %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16> 325 // CHECK: return %[[CAST]] 326 return %0: vector<1x4xf16> 327} 328 329// CHECK-LABEL: func @cast_away_masked_transfer_read_leading_one_dims 330func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xi1>) -> vector<1x4xf16> { 331 // CHECK: %[[C0:.+]] = arith.constant 0 : index 332 %c0 = arith.constant 0 : index 333 // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16 334 %f0 = arith.constant 0. : f16 335 // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> 336 // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16> 337 // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16> 338 %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16> 339 // CHECK: return %[[CAST]] 340 return %0: vector<1x4xf16> 341} 342 343// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element 344func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> { 345 %c0 = arith.constant 0 : index 346 %f0 = arith.constant 0. : f16 347 // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16> 348 %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16> 349 return %0: vector<1x1xf16> 350} 351 352// ----- 353 354// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)> 355// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_read 356func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16>, %arg1: vector<1x4x1xi1>) -> vector<1x1x4xf16> { 357 // CHECK: %[[C0:.+]] = arith.constant 0 : index 358 %c0 = arith.constant 0 : index 359 // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16 360 %f0 = arith.constant 0. : f16 361 // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1> 362 // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true] 363 // CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16> 364 // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16> 365 %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true], 366 permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16> 367 // CHECK: return %[[CAST]] 368 return %0: vector<1x1x4xf16> 369} 370 371// ----- 372 373// CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask 374// CHECK: %[[MASK:.+]] = vector.constant_mask 375// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]] 376// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] { 377// CHECK-SAME: vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> } 378// CHECK: return %[[RET]] : vector<1x4xf16> 379func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>) -> vector<1x4xf16> { 380 %c0 = arith.constant 0 : index 381 %f0 = arith.constant 0. : f16 382 %mask = vector.constant_mask [1, 3] : vector<1x4xi1> 383 %ret = vector.mask %mask { 384 vector.transfer_read %arg0[%c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x4xf16>, vector<1x4xf16> 385 } : vector<1x4xi1> -> vector<1x4xf16> 386 return %ret: vector<1x4xf16> 387} 388 389// ----- 390 391// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims 392func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) { 393 // CHECK: %[[C0:.+]] = arith.constant 0 : index 394 %c0 = arith.constant 0 : index 395 // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16> 396 // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16> 397 398 vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16> 399 return 400} 401 402// CHECK-LABEL: func @cast_away_masked_transfer_write_leading_one_dims 403func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) { 404 // CHECK: %[[C0:.+]] = arith.constant 0 : index 405 %c0 = arith.constant 0 : index 406 // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16> 407 // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> 408 // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16> 409 410 vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16> 411 return 412} 413 414// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element 415func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) { 416 %c0 = arith.constant 0 : index 417 // CHECK: vector.extract %{{.+}}[0] : vector<1xf16> from vector<1x1xf16> 418 vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16> 419 return 420} 421 422// ----- 423 424// CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask 425// CHECK: %[[MASK:.+]] = vector.constant_mask 426// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]] 427// CHECK: vector.mask %[[CASTED_MASK]] { 428// CHECK-SAME: vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> } 429// CHECK: return 430func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16>, %arg1: vector<1x4xf16>) { 431 %c0 = arith.constant 0 : index 432 %mask = vector.constant_mask [1, 3] : vector<1x4xi1> 433 vector.mask %mask { 434 vector.transfer_write %arg1, %arg0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x1x4xf16> 435 } : vector<1x4xi1> 436 return 437} 438 439// ----- 440 441// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)> 442// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write 443func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) { 444 // CHECK: %[[C0:.+]] = arith.constant 0 : index 445 %c0 = arith.constant 0 : index 446 // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16> 447 // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1> 448 // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true] 449 // CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16> 450 451 vector.transfer_write %arg1, %arg0[%c0, %c0, %c0], %arg2 {in_bounds = [true, true, true], 452 permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : vector<1x1x4xf16>, memref<1x4x8xf16> 453 return 454} 455 456// ----- 457 458// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims 459func.func @cast_away_elementwise_leading_one_dims( 460 %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>, 461 %arg3: vector<1x4xf32>, %arg4: i1) -> 462 (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) { 463 // CHECK: vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32> 464 // CHECK: vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32> 465 // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8xf32> 466 // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> 467 %0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32> 468 // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> 469 // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> 470 // CHECK: arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32> 471 // CHECK: vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1> 472 %1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32> 473 // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> 474 // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> 475 // CHECK: select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32> 476 // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> 477 %2 = arith.select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32> 478 // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> 479 // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> 480 // CHECK: select %arg4, %12, %{{.*}} : vector<4xf32> 481 // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> 482 %3 = arith.select %arg4, %arg3, %arg2 : vector<1x4xf32> 483 return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32> 484} 485 486// ----- 487 488// CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar 489// CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1x4xf32>) 490// CHECK: %[[EXTRACT:.+]] = vector.extract %[[V]][0, 0] : vector<4xf32> from vector<1x1x4xf32> 491// CHECK: %[[INSERT:.+]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<4xf32> 492// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32> 493// CHECK: return %[[BCAST]] 494func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> { 495 %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x4xf32> 496 return %0: vector<1x1x4xf32> 497} 498 499// ----- 500 501// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_scalable( 502// CHECK-SAME: %[[S:.*]]: f32, 503// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { 504func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { 505// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0, 0] : vector<[4]xf32> from vector<1x1x[4]xf32> 506// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<[4]xf32> 507// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32> 508// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32> 509 %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x[4]xf32> 510 return %0: vector<1x1x[4]xf32> 511} 512 513// ----- 514 515// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim( 516// CHECK-SAME: %[[S:.*]]: f32, 517// CHECK-SAME: %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> { 518func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %v: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> { 519// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0] : vector<[1]x4xf32> from vector<1x[1]x4xf32> 520// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0, 0] : f32 into vector<[1]x4xf32> 521// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32> 522// CHECK: return %[[BCAST]] : vector<1x[1]x4xf32> 523 %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x[1]x4xf32> 524 return %0: vector<1x[1]x4xf32> 525} 526 527// ----- 528 529// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1 530// CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>) 531// CHECK: %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32> 532// CHECK: return %[[BCAST]] 533func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> { 534 %0 = vector.insert %s, %v [0, 0] : vector<4xf32> into vector<1x1x4xf32> 535 return %0: vector<1x1x4xf32> 536} 537 538// ----- 539 540// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank1_scalable( 541// CHECK-SAME: %[[S:.*]]: vector<[4]xf32>, 542// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { 543// CHECK: %[[BCAST:.*]] = vector.broadcast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32> 544// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32> 545func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { 546 %0 = vector.insert %s, %v [0, 0] : vector<[4]xf32> into vector<1x1x[4]xf32> 547 return %0: vector<1x1x[4]xf32> 548} 549 550// ----- 551 552// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2 553// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>) 554// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32> 555// CHECK: %[[BCAST:.+]] = vector.broadcast %[[EXTRACT]] : vector<4xf32> to vector<1x1x4xf32> 556// CHECK: return %[[BCAST]] 557func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> { 558 %0 = vector.insert %s, %v [0] : vector<1x4xf32> into vector<1x1x4xf32> 559 return %0: vector<1x1x4xf32> 560} 561 562// ----- 563 564// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_scalable( 565// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>, 566// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { 567// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32> 568// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : vector<[4]xf32> to vector<1x1x[4]xf32> 569// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32> 570func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { 571 %0 = vector.insert %s, %v [0] : vector<1x[4]xf32> into vector<1x1x[4]xf32> 572 return %0: vector<1x1x[4]xf32> 573} 574 575// ----- 576 577// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest 578// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>) 579// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32> 580// CHECK: %[[EXTRACTV:.+]] = vector.extract %[[V]][0] : vector<2x1x4xf32> from vector<1x2x1x4xf32> 581// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<4xf32> into vector<2x1x4xf32> 582// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32> 583// CHECK: return %[[BCAST]] 584func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>, %v: vector<1x2x1x4xf32>) -> vector<1x2x1x4xf32> { 585 %0 = vector.insert %s, %v [0, 1] : vector<1x4xf32> into vector<1x2x1x4xf32> 586 return %0: vector<1x2x1x4xf32> 587} 588 589// ----- 590 591// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable( 592// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>, 593// CHECK-SAME: %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> { 594// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32> 595// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0] : vector<2x1x[4]xf32> from vector<1x2x1x[4]xf32> 596// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32> 597// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32> 598// CHECK: return %[[BCAST]] : vector<1x2x1x[4]xf32> 599func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> { 600 %0 = vector.insert %s, %v [0, 1] : vector<1x[4]xf32> into vector<1x2x1x[4]xf32> 601 return %0: vector<1x2x1x[4]xf32> 602} 603 604// ----- 605 606// CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest 607// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>) 608// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32> 609// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32> 610// CHECK: return %[[INSERT]] 611func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %v: vector<8x1x4xf32>) -> vector<8x1x4xf32> { 612 %0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32> 613 return %0: vector<8x1x4xf32> 614} 615 616// ----- 617 618// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable( 619// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>, 620// CHECK-SAME: %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> { 621// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32> 622// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32> 623// CHECK: return %[[INSERT]] : vector<8x1x[4]xf32> 624func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> { 625 %0 = vector.insert %s, %v [5] : vector<1x[4]xf32> into vector<8x1x[4]xf32> 626 return %0: vector<8x1x[4]xf32> 627} 628 629// ----- 630 631// CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest 632// CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>) 633// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<8xi1> from vector<1x8xi1> 634// CHECK: %[[EXTRACTV:.+]] = vector.extract %[[V]][0, 0] : vector<8x1x8xi1> from vector<1x1x8x1x8xi1> 635// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<8xi1> into vector<8x1x8xi1> 636// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1> 637// CHECK: return %[[BCAST]] 638func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v: vector<1x1x8x1x8xi1>) -> vector<1x1x8x1x8xi1> { 639 %0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1> 640 return %0: vector<1x1x8x1x8xi1> 641} 642 643// ----- 644 645// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable( 646// CHECK-SAME: %[[S:.*]]: vector<1x[8]xi1>, 647// CHECK-SAME: %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> { 648// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[8]xi1> from vector<1x[8]xi1> 649// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0, 0] : vector<8x1x[8]xi1> from vector<1x1x8x1x[8]xi1> 650// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1> 651// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1> 652// CHECK: return %[[BCAST]] : vector<1x1x8x1x[8]xi1> 653func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x[8]xi1>, %v: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> { 654 %0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1> 655 return %0: vector<1x1x8x1x[8]xi1> 656} 657 658// ----- 659 660// CHECK-LABEL: func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> { 661// CHECK: %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1> 662// CHECK: %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1> 663// CHECK: return %[[BCAST]] : vector<1x1x8x2x1xi1> 664func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> { 665 %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1> 666 return %0: vector<1x1x8x2x1xi1> 667} 668 669// ----- 670 671// CHECK-LABEL: func.func @drop_unit_dims_scalar_cond_select( 672// CHECK: arith.select {{.*}} : vector<16xi1> 673func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>, %arg1: vector<1x16xi1>) -> vector<1x16xi1> { 674 %sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1> 675 return %sel : vector<1x16xi1> 676} 677