1// RUN: mlir-opt %s -test-vector-to-vector-lowering="unroll" | FileCheck %s 2 3// CHECK-DAG: #[[MAP1:map[0-9]*]] = affine_map<(d0, d1, d2) -> (d1, d2)> 4 5// CHECK-LABEL: func @add4x2 6// CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> 7// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> 8// CHECK-NEXT: %[[A1:.*]] = arith.addf %[[S1]], %[[S2]] : vector<2x2xf32> 9// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[A1]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> 10// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> 11// CHECK-NEXT: %[[S4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> 12// CHECK-NEXT: %[[A2:.*]] = arith.addf %[[S3]], %[[S4]] : vector<2x2xf32> 13// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[A2]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> 14// CHECK-NEXT: return %[[VEC1:.*]] : vector<4x2xf32> 15 16func.func @add4x2(%0: vector<4x2xf32>) -> vector<4x2xf32> { 17 %1 = arith.addf %0, %0: vector<4x2xf32> 18 return %1: vector<4x2xf32> 19} 20 21// Regression test. Previously, this example would trigger 22// CastAwayElementwiseLeadingOneDim as: 23// * `vector<2x[4]x1xf32>`, would be reformulated as 24// * `vector<2x4x1xf32>`. 25// With the updated shape, the conversion pattern would incorrectly assume that 26// some leading dims have been dropped. 27// CHECK-LABEL: func.func @no_change( 28// CHECK-SAME: %[[VAL_0:.*]]: vector<2x[4]x1xf32>, 29// CHECK-SAME: %[[VAL_1:.*]]: vector<2x[4]x1xf32>) 30// CHECK-NEXT: %[[VAL_2:.*]] = arith.mulf %[[VAL_0]], %[[VAL_1]] : vector<2x[4]x1xf32> 31// CHECK-NEXT: return %[[VAL_2]] 32func.func @no_change(%arg0: vector<2x[4]x1xf32>, %arg1: vector<2x[4]x1xf32>) -> vector<2x[4]x1xf32> { 33 %1 = arith.mulf %arg0, %arg1 : vector<2x[4]x1xf32> 34 return %1 : vector<2x[4]x1xf32> 35} 36 37// CHECK-LABEL: func.func @cast_away_leading_one_dim( 38// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<4x1xf32> 39// CHECK: vector.broadcast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32> 40func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4x1xf32>) -> vector<1x4x1xf32> { 41 %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> 42 return %1: vector<1x4x1xf32> 43} 44 45// CHECK-LABEL: func.func @cast_away_leading_one_dim_scalable( 46// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<[4]x1xf32> 47// CHECK: vector.broadcast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32> 48func.func @cast_away_leading_one_dim_scalable(%arg0: vector<1x[4]x1xf32>, %arg1: vector<1x[4]x1xf32>) -> vector<1x[4]x1xf32> { 49 %1 = arith.mulf %arg0, %arg1 : vector<1x[4]x1xf32> 50 return %1: vector<1x[4]x1xf32> 51} 52 53// CHECK-LABEL: func @add4x4 54// CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 55// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 56 57// CHECK-NEXT: %[[A1:.*]] = arith.addf %[[S1]], %[[S2]] : vector<2x2xf32> 58 59// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 60// CHECK-NEXT: %[[S4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 61 62// CHECK-NEXT: %[[A2:.*]] = arith.addf %[[S3]], %[[S4]] : vector<2x2xf32> 63 64// CHECK-NEXT: %[[S5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 65// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 66// CHECK-NEXT: %[[A3:.*]] = arith.addf %[[S5]], %[[S6]] : vector<2x2xf32> 67 68// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 69// CHECK-NEXT: %[[S8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 70// CHECK-NEXT: %[[A4:.*]] = arith.addf %[[S7]], %[[S8]] : vector<2x2xf32> 71 72// CHECK-NEXT: %[[S9:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 73// CHECK-NEXT: %[[A5:.*]] = arith.addf %[[S9]], %[[A1]] : vector<2x2xf32> 74// CHECK-NEXT: %[[R1:.*]] = vector.insert_strided_slice %[[A5]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> 75 76 77// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 78// CHECK-NEXT: %[[A6:.*]] = arith.addf %[[S11]], %[[A2]] : vector<2x2xf32> 79// CHECK-NEXT: %[[R2:.*]] = vector.insert_strided_slice %[[A6]], %[[R1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> 80 81// CHECK-NEXT: %[[S13:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 82// CHECK-NEXT: %[[A7:.*]] = arith.addf %[[S13]], %[[A3]] : vector<2x2xf32> 83// CHECK-NEXT: %[[R3:.*]] = vector.insert_strided_slice %[[A7]], %[[R2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> 84 85// CHECK-NEXT: %[[S15:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> 86// CHECK-NEXT: %[[A8:.*]] = arith.addf %[[S15]], %[[A4]] : vector<2x2xf32> 87// CHECK-NEXT: %[[R4:.*]] = vector.insert_strided_slice %[[A8]], %[[R3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> 88 89// CHECK-NEXT: return %[[R4]] : vector<4x4xf32> 90 91func.func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> { 92 %2 = arith.addf %0, %1: vector<4x4xf32> 93 %3 = arith.addf %1, %2: vector<4x4xf32> 94 return %3: vector<4x4xf32> 95} 96 97// CHECK-LABEL: func @contraction4x4_ikj_xfer_read 98 99// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 100// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 101 102// Check LHS vector.transfer read is split for each user. 103 104// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32> 105// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32> 106 107// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32> 108// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32> 109 110// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> 111// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> 112// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> 113// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> 114 115// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 116// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 117// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 118// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 119 120// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32> 121// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32> 122// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32> 123// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32> 124// CHECK-NEXT: return 125 126#contraction_accesses1 = [ 127 affine_map<(i, k, j) -> (i, k)>, 128 affine_map<(i, k, j) -> (k, j)>, 129 affine_map<(i, k, j) -> (i, j)> 130] 131#contraction_trait1 = { 132 indexing_maps = #contraction_accesses1, 133 iterator_types = ["parallel", "reduction", "parallel"] 134} 135 136func.func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>, 137 %arg1 : memref<2x4xf32>, 138 %arg2 : memref<4x4xf32>) { 139 %c0 = arith.constant 0 : index 140 %cf0 = arith.constant 0.0 : f32 141 142 %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 143 { permutation_map = affine_map<(d0, d1) -> (d0, d1)> } 144 : memref<4x2xf32>, vector<4x2xf32> 145 146 %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 147 { permutation_map = affine_map<(d0, d1) -> (d0, d1)> } 148 : memref<2x4xf32>, vector<2x4xf32> 149 150 %2 = vector.transfer_read %arg2[%c0, %c0], %cf0 151 { permutation_map = affine_map<(d0, d1) -> (d0, d1)> } 152 : memref<4x4xf32>, vector<4x4xf32> 153 154 %3 = vector.contract #contraction_trait1 %0, %1, %2 155 : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32> 156 157 vector.transfer_write %3, %arg2[%c0, %c0] 158 {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} 159 : vector<4x4xf32>, memref<4x4xf32> 160 return 161} 162 163// TODO: Update test with VTR split transform. 164// CHECK-LABEL: func @vector_transfers 165// CHECK-COUNT-8: vector.transfer_read 166// CHECK-COUNT-4: arith.addf 167// CHECK-COUNT-4: vector.transfer_write 168 169func.func @vector_transfers(%arg0: index, %arg1: index) { 170 %cst = arith.constant 0.000000e+00 : f32 171 %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> 172 %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> 173 %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> 174 %cst_0 = arith.constant 1.000000e+00 : f32 175 %cst_1 = arith.constant 2.000000e+00 : f32 176 affine.for %arg2 = 0 to %arg0 step 4 { 177 affine.for %arg3 = 0 to %arg1 step 4 { 178 %4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32> 179 %5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32> 180 %6 = arith.addf %4, %5 : vector<4x4xf32> 181 vector.transfer_write %6, %2[%arg2, %arg3] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : vector<4x4xf32>, memref<?x?xf32> 182 } 183 } 184 return 185} 186 187// CHECK-LABEL: func @elementwise_unroll 188// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>) 189// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 190// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 191// CHECK: %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 192// CHECK: %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 193// CHECK: %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 194// CHECK: %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 195// CHECK: %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 196// CHECK: %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 197// CHECK: %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 198// CHECK: %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 199// CHECK: %[[CMP0:.*]] = arith.cmpf ult, %[[VT0]], %[[VT4]] : vector<2x2xf32> 200// CHECK: %[[CMP1:.*]] = arith.cmpf ult, %[[VT1]], %[[VT5]] : vector<2x2xf32> 201// CHECK: %[[CMP2:.*]] = arith.cmpf ult, %[[VT2]], %[[VT6]] : vector<2x2xf32> 202// CHECK: %[[CMP3:.*]] = arith.cmpf ult, %[[VT3]], %[[VT7]] : vector<2x2xf32> 203// CHECK: %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 204// CHECK: %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 205// CHECK: %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 206// CHECK: %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 207// CHECK: %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 208// CHECK: %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 209// CHECK: %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 210// CHECK: %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> 211// CHECK: %[[SEL0:.*]] = arith.select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32> 212// CHECK: %[[SEL1:.*]] = arith.select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32> 213// CHECK: %[[SEL2:.*]] = arith.select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32> 214// CHECK: %[[SEL3:.*]] = arith.select %[[CMP3]], %[[VT3]], %[[VT7]] : vector<2x2xi1>, vector<2x2xf32> 215// CHECK: vector.transfer_write %[[SEL0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> 216// CHECK: vector.transfer_write %[[SEL1]], %[[ARG0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> 217// CHECK: vector.transfer_write %[[SEL2]], %[[ARG0]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> 218// CHECK: vector.transfer_write %[[SEL3]], %[[ARG0]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> 219func.func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) { 220 %c0 = arith.constant 0 : index 221 %cf0 = arith.constant 0.0 : f32 222 %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> 223 %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> 224 %cond = arith.cmpf ult, %0, %1 : vector<4x4xf32> 225 // Vector transfer split pattern only support single user right now. 226 %2 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> 227 %3 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> 228 %4 = arith.select %cond, %2, %3 : vector<4x4xi1>, vector<4x4xf32> 229 vector.transfer_write %4, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> 230 return 231} 232 233// Check that vector.transfer read/write are split based on contract unrolling. 234// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32> 235// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32> 236 237// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32> 238// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32> 239 240// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> 241// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> 242// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> 243// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> 244 245// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 246// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 247// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 248// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 249 250// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32> 251// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32> 252// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[R2]], %[[VTW1]][%[[C2]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32> 253// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[R3]], %[[VTW2]][%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32> 254// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32> 255 256func.func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>, 257 %arg1 : tensor<2x4xf32>, 258 %arg2 : tensor<4x4xf32>) -> 259 tensor<4x4xf32> { 260 %c0 = arith.constant 0 : index 261 %cf0 = arith.constant 0.0 : f32 262 %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : 263 tensor<4x2xf32>, vector<4x2xf32> 264 %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : 265 tensor<2x4xf32>, vector<2x4xf32> 266 %2 = vector.transfer_read %arg2[%c0, %c0], %cf0 : 267 tensor<4x4xf32>, vector<4x4xf32> 268 %3 = vector.contract #contraction_trait1 %0, %1, %2 269 : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32> 270 %r = vector.transfer_write %3, %arg2[%c0, %c0] 271 : vector<4x4xf32>, tensor<4x4xf32> 272 return %r : tensor<4x4xf32> 273} 274 275// CHECK-LABEL: func @bubble_down_bitcast_in_extract 276// CHECK-SAME: %[[SRC:.+]]: vector<4xf32> 277func.func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) { 278 %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 279 // CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : f32 from vector<4xf32> 280 // CHECK: %[[INSERT1:.+]] = vector.insert %[[EXTRACT1]], %{{.+}} [0] : f32 into vector<1xf32> 281 // CHECK: %[[CAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<1xf32> to vector<2xf16> 282 // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : f16 from vector<2xf16> 283 %1 = vector.extract %0[3] : f16 from vector<8xf16> 284 // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : f32 from vector<4xf32> 285 // CHECK: %[[INSERT3:.+]] = vector.insert %[[EXTRACT3]], %{{.+}} [0] : f32 into vector<1xf32> 286 // CHECK: %[[CAST2:.+]] = vector.bitcast %[[INSERT3]] : vector<1xf32> to vector<2xf16> 287 // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : f16 from vector<2xf16> 288 %2 = vector.extract %0[4] : f16 from vector<8xf16> 289 // CHECK: return %[[EXTRACT2]], %[[EXTRACT4]] 290 return %1, %2: f16, f16 291} 292 293// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract 294// CHECK-SAME: %[[SRC:.+]]: vector<4xf32> 295func.func @bubble_down_bitcast_in_strided_slice_extract(%arg0: vector<4xf32>) -> vector<4xf16> { 296 // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> 297 // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2xf32> to vector<4xf16> 298 %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 299 %0 = vector.extract_strided_slice %cast {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> 300 // CHECK: return %[[CAST]] 301 return %0: vector<4xf16> 302} 303 304// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim 305// CHECK-SAME: %[[SRC:.+]]: vector<4x2xf32> 306func.func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim(%arg0: vector<4x2xf32>) -> vector<2x4xf16> { 307 // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [1], sizes = [2], strides = [1]} : vector<4x2xf32> to vector<2x2xf32> 308 // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2x2xf32> to vector<2x4xf16> 309 %cast = vector.bitcast %arg0: vector<4x2xf32> to vector<4x4xf16> 310 %0 = vector.extract_strided_slice %cast {offsets = [1], sizes = [2], strides = [1]} : vector<4x4xf16> to vector<2x4xf16> 311 // CHECK: return %[[CAST]] 312 return %0: vector<2x4xf16> 313} 314 315// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_offset 316func.func @bubble_down_bitcast_in_strided_slice_extract_odd_offset(%arg0: vector<4xf32>) -> vector<4xf16> { 317 // CHECK: vector.bitcast 318 // CHECK-NEXT: vector.extract_strided_slice 319 %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 320 %0 = vector.extract_strided_slice %cast {offsets = [3], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> 321 return %0: vector<4xf16> 322} 323 324// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_size 325func.func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4xf32>) -> vector<3xf16> { 326 // CHECK: vector.bitcast 327 // CHECK-NEXT: vector.extract_strided_slice 328 %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 329 %0 = vector.extract_strided_slice %cast {offsets = [0], sizes = [3], strides = [1]} : vector<8xf16> to vector<3xf16> 330 return %0: vector<3xf16> 331} 332 333// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_i4_i8( 334// CHECK-SAME: %[[VAL:.*]]: vector<32xi4>, 335// CHECK-SAME: %[[DST:.*]]: vector<8x32xi4>) -> vector<8x16xi8> { 336func.func @bubble_up_bitcast_in_insert_i4_i8(%val: vector<32xi4>, %src: vector<8x32xi4>) -> vector<8x16xi8> { 337// CHECK: %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<32xi4> to vector<16xi8> 338// CHECK: %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x32xi4> to vector<8x16xi8> 339// CHECK: vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<16xi8> into vector<8x16xi8> 340 %0 = vector.insert %val, %src[4] : vector<32xi4> into vector<8x32xi4> 341 %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8> 342 return %1 : vector<8x16xi8> 343} 344 345// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_i8_i4( 346// CHECK-SAME: %[[VAL:.*]]: vector<16xi8>, 347// CHECK-SAME: %[[DST:.*]]: vector<8x16xi8>) -> vector<8x32xi4> { 348func.func @bubble_up_bitcast_in_insert_i8_i4(%val: vector<16xi8>, %src: vector<8x16xi8>) -> vector<8x32xi4> { 349// CHECK: %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<16xi8> to vector<32xi4> 350// CHECK: %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x16xi8> to vector<8x32xi4> 351// CHECK: vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<32xi4> into vector<8x32xi4> 352 %0 = vector.insert %val, %src[4] : vector<16xi8> into vector<8x16xi8> 353 %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4> 354 return %1 : vector<8x32xi4> 355} 356 357// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_i32_f32( 358// CHECK-SAME: %[[VAL:.*]]: vector<16xi32>, 359// CHECK-SAME: %[[DST:.*]]: vector<8x16xi32>) -> vector<8x16xf32> { 360func.func @bubble_up_bitcast_in_insert_i32_f32(%val: vector<16xi32>, %src: vector<8x16xi32>) -> vector<8x16xf32> { 361// CHECK: %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<16xi32> to vector<16xf32> 362// CHECK: %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x16xi32> to vector<8x16xf32> 363// CHECK: vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<16xf32> into vector<8x16xf32> 364 %0 = vector.insert %val, %src[4] : vector<16xi32> into vector<8x16xi32> 365 %1 = vector.bitcast %0 : vector<8x16xi32> to vector<8x16xf32> 366 return %1 : vector<8x16xf32> 367} 368 369// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_scalar( 370func.func @bubble_up_bitcast_in_insert_scalar(%val: i8, %src: vector<8x16xi8>) -> vector<8x32xi4> { 371// CHECK: vector.insert 372// CHECK-NEXT: vector.bitcast 373 %0 = vector.insert %val, %src[4, 8] : i8 into vector<8x16xi8> 374 %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4> 375 return %1 : vector<8x32xi4> 376} 377 378// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert 379// CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>) 380func.func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> { 381 // CHECK-DAG: %[[CAST_SRC1:.+]] = vector.bitcast %[[SRC1]] : vector<4xf16> to vector<2xf32> 382 // CHECK-DAG: %[[CAST_SRC2:.+]] = vector.bitcast %[[SRC2]] : vector<4xf16> to vector<2xf32> 383 // CHECK-DAG: %[[CAST_DST:.+]] = vector.bitcast %[[DST]] : vector<8xf16> to vector<4xf32> 384 // CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST_SRC1]], %[[CAST_DST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 385 // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[CAST_SRC2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> 386 %0 = vector.insert_strided_slice %src1, %dst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 387 %1 = vector.insert_strided_slice %src2, %0 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16> 388 %cast = vector.bitcast %1: vector<8xf16> to vector<4xf32> 389 // CHECK: return %[[INSERT2]] 390 return %cast: vector<4xf32> 391} 392 393// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_offset 394func.func @bubble_up_bitcast_in_strided_slice_insert_odd_offset(%dst: vector<8xf16>, %src: vector<4xf16>) -> vector<4xf32> { 395 // CHECK: vector.insert_strided_slice 396 // CHECK-NEXT: vector.bitcast 397 %0 = vector.insert_strided_slice %src, %dst {offsets = [3], strides = [1]} : vector<4xf16> into vector<8xf16> 398 %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32> 399 return %cast: vector<4xf32> 400} 401 402// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_different_rank 403func.func @bubble_up_bitcast_in_strided_slice_insert_different_rank(%dst: vector<16x4x8xf16>, %src: vector<2x4xf16>) -> vector<16x4x4xf32> { 404 // CHECK: vector.insert_strided_slice 405 // CHECK-NEXT: vector.bitcast 406 %0 = vector.insert_strided_slice %src, %dst {offsets = [0, 0, 2], strides = [1, 1]} : vector<2x4xf16> into vector<16x4x8xf16> 407 %cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32> 408 return %cast: vector<16x4x4xf32> 409} 410 411// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_shape 412func.func @bubble_up_bitcast_in_strided_slice_insert_odd_shape(%dst: vector<2xf16>, %src: vector<1xf16>) -> vector<1xf32> { 413 // CHECK: vector.insert_strided_slice 414 // CHECK-NEXT: vector.bitcast 415 %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<1xf16> into vector<2xf16> 416 %cast = vector.bitcast %0: vector<2xf16> to vector<1xf32> 417 return %cast: vector<1xf32> 418} 419 420// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape 421func.func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape(%dst: vector<8xf16>, %src: vector<3xf16>) -> vector<4xf32> { 422 // CHECK: vector.insert_strided_slice 423 // CHECK-NEXT: vector.bitcast 424 %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<3xf16> into vector<8xf16> 425 %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32> 426 return %cast: vector<4xf32> 427} 428 429// Make sure not crash on 0-D vector. 430// CHECK-LABEL:func.func @vec_0D 431// CHECK-NEXT:vector.bitcast 432func.func @vec_0D(%arg0: vector<f32>) -> vector<i32> { 433 %0 = vector.bitcast %arg0 : vector<f32> to vector<i32> 434 return %0 : vector<i32> 435} 436 437// Make sure not crash on dynamic index `vector.extract`: 438func.func @vector_extract_dynamic_index(%arg0 : vector<4xi32>, %index : index) -> i16 { 439 %0 = vector.bitcast %arg0 : vector<4xi32> to vector<8xi16> 440 %1 = vector.extract %0[%index] : i16 from vector<8xi16> 441 return %1 : i16 442} 443 444// CHECK-LABEL: func.func @vector_extract_dynamic_index 445// CHECK-SAME: (%[[VEC:.+]]: vector<4xi32>, %[[IDX:.+]]: index) -> i16 { 446// CHECK: %[[BC:.+]] = vector.bitcast %[[VEC]] : vector<4xi32> to vector<8xi16> 447// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BC]][%[[IDX]]] : i16 from vector<8xi16> 448// CHECK: return %[[EXTRACT]] 449