1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s 2 3/// Tests for `vector.contract` -> `vector.outerproduct` transformations for 4/// Matvec operations: 5/// b += A * x. 6/// (b and x are 1-d vectors, A is a 2-d matrix). ATM three different variants 7/// are tested: 8/// * plain (no mask, fixed-wdith vectors), 9/// * masked (fixed-width vectors, 10/// * scalable (mask + scalable vectors). 11/// 12/// TODO: These tests were extracted from 2 different files. If you find the 13/// formatting inconsistent, please update accordingly. 14 15#matvec_accesses_1 = [ 16 affine_map<(m, k) -> (m, k)>, 17 affine_map<(m, k) -> (k)>, 18 affine_map<(m, k) -> (m)> 19] 20#matvec_trait_1 = { 21 indexing_maps = #matvec_accesses_1, 22 iterator_types = ["parallel", "reduction"] 23} 24 25#matvecmax_trait = { 26 indexing_maps = #matvec_accesses_1, 27 iterator_types = ["parallel", "reduction"], 28 kind = #vector.kind<maxnumf> 29} 30 31#matvec_accesses_2 = [ 32 affine_map<(m, k) -> (k, m)>, 33 affine_map<(m, k) -> (k)>, 34 affine_map<(m, k) -> (m)> 35] 36#matvec_trait_2 = { 37 indexing_maps = #matvec_accesses_2, 38 iterator_types = ["parallel", "reduction"] 39} 40 41#matvec_accesses_3 = [ 42 affine_map<(m, k) -> (k)>, 43 affine_map<(m, k) -> (m, k)>, 44 affine_map<(m, k) -> (m)> 45] 46#matvec_trait_3 = { 47 indexing_maps = #matvec_accesses_3, 48 iterator_types = ["parallel", "reduction"] 49} 50 51#matvec_accesses_4 = [ 52 affine_map<(m, k) -> (k)>, 53 affine_map<(m, k) -> (k, m)>, 54 affine_map<(m, k) -> (m)> 55] 56#matvec_trait_4 = { 57 indexing_maps = #matvec_accesses_4, 58 iterator_types = ["parallel", "reduction"] 59} 60 61#matvec_accesses_5 = [ 62 affine_map<(k, m) -> (m, k)>, 63 affine_map<(k, m) -> (k)>, 64 affine_map<(k, m) -> (m)> 65] 66#matvec_trait_5 = { 67 indexing_maps = #matvec_accesses_5, 68 iterator_types = ["reduction", "parallel"] 69} 70 71#matvec_accesses_6 = [ 72 affine_map<(k, m) -> (k, m)>, 73 affine_map<(k, m) -> (k)>, 74 affine_map<(k, m) -> (m)> 75] 76#matvec_trait_6 = { 77 indexing_maps = #matvec_accesses_6, 78 iterator_types = ["reduction", "parallel"] 79} 80 81#matvec_accesses_7 = [ 82 affine_map<(k, m) -> (k)>, 83 affine_map<(k, m) -> (m, k)>, 84 affine_map<(k, m) -> (m)> 85] 86#matvec_trait_7 = { 87 indexing_maps = #matvec_accesses_7, 88 iterator_types = ["reduction", "parallel"] 89} 90 91#matvec_accesses_8 = [ 92 affine_map<(k, m) -> (k)>, 93 affine_map<(k, m) -> (k, m)>, 94 affine_map<(k, m) -> (m)> 95] 96#matvec_trait_8 = { 97 indexing_maps = #matvec_accesses_8, 98 iterator_types = ["reduction", "parallel"] 99} 100 101// ============================================================================ 102// Matvec 1 (plain + masked + scalable) 103// ============================================================================ 104// CHECK-LABEL: func @matvec_mk_k_m 105// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> 106// CHECK-SAME: %[[X:.*1]]: vector<2xf32> 107// CHECK-SAME: %[[B:.*2]]: vector<2xf32> 108// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> 109// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32> 110// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> 111// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 112// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32> 113// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> 114// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32 115func.func @matvec_mk_k_m(%A: vector<2x2xf32>, 116 %x: vector<2xf32>, 117 %b: vector<2xf32>) -> vector<2xf32> { 118 %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 119 return %0 : vector<2xf32> 120} 121 122// CHECK-LABEL: func.func @masked_matvec_mk_k_m( 123// CHECK-SAME: %{{.*}}: vector<2x3xf32>, 124// CHECK-SAME: %{{.*}}: vector<3xf32>, 125// CHECK-SAME: %{{.*}}: vector<2xf32>, 126// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> 127// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> 128// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1> 129// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> 130 131// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1> 132// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> 133 134// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1> 135// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> 136func.func @masked_matvec_mk_k_m(%A: vector<2x3xf32>, 137 %x: vector<3xf32>, 138 %b: vector<2xf32>, 139 %m: vector<2x3xi1>) -> vector<2xf32> { 140 %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b 141 : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> 142 return %0 : vector<2xf32> 143} 144 145// CHECK-LABEL: func.func @masked_matvec_mk_k_m_scalable_parallel_dim( 146// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>, 147// CHECK-SAME: %{{.*}}: vector<3xf32>, 148// CHECK-SAME: %{{.*}}: vector<[2]xf32>, 149// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32> 150// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1> 151// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1> 152// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> 153 154// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1> 155// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> 156 157// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1> 158// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> 159func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%A: vector<[2]x3xf32>, 160 %x: vector<3xf32>, 161 %b: vector<[2]xf32>, 162 %m: vector<[2]x3xi1>) -> vector<[2]xf32> { 163 %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b 164 : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32> 165 return %0 : vector<[2]xf32> 166} 167 168// ============================================================================ 169// Matvec 1 - max (plain) 170// ============================================================================ 171// CHECK-LABEL: func @matvec_mk_k_m_max 172// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> 173// CHECK-SAME: %[[X:.*1]]: vector<2xf32> 174// CHECK-SAME: %[[B:.*2]]: vector<2xf32> 175// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> 176// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32> 177// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> 178// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32 179// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32> 180// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> 181// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32 182func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>, 183 %x: vector<2xf32>, 184 %b: vector<2xf32>) -> vector<2xf32> { 185 %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 186 return %0 : vector<2xf32> 187} 188 189// CHECK-LABEL: func.func @masked_matvec_mk_k_m_max( 190// CHECK-SAME: %{{.*}}: vector<2x3xf32>, 191// CHECK-SAME: %{{.*}}: vector<3xf32>, 192// CHECK-SAME: %{{.*}}: vector<2xf32>, 193// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> 194// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> 195// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1> 196// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> 197 198// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1> 199// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> 200 201// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1> 202// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> 203func.func @masked_matvec_mk_k_m_max(%A: vector<2x3xf32>, 204 %x: vector<3xf32>, 205 %b: vector<2xf32>, 206 %m: vector<2x3xi1>) -> vector<2xf32> { 207 %0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b 208 : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> 209 return %0 : vector<2xf32> 210} 211 212// CHECK-LABEL: func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim( 213// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>, 214// CHECK-SAME: %{{.*}}: vector<3xf32>, 215// CHECK-SAME: %{{.*}}: vector<[2]xf32>, 216// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32> 217// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1> 218// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1> 219// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> 220 221// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1> 222// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> 223 224// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1> 225// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> 226func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>, 227 %x: vector<3xf32>, 228 %b: vector<[2]xf32>, 229 %m: vector<[2]x3xi1>) -> vector<[2]xf32> { 230 %0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b 231 : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32> 232 return %0 : vector<[2]xf32> 233} 234 235// ============================================================================ 236// Matvec 2 (plain + masked + scalable) 237// ============================================================================ 238// CHECK-LABEL: func @matvec_km_k_m 239// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> 240// CHECK-SAME: %[[X:.*1]]: vector<2xf32> 241// CHECK-SAME: %[[B:.*2]]: vector<2xf32> 242// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> 243// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> 244// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 245// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> 246// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> 247// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 248func.func @matvec_km_k_m(%A: vector<2x2xf32>, 249 %x: vector<2xf32>, 250 %b: vector<2xf32>) -> vector<2xf32> { 251 %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 252 return %0 : vector<2xf32> 253} 254 255// CHECK-LABEL: @masked_matvec_km_k_m 256// CHECK-SAME: %[[A:.+]]: vector<2x4xf32> 257// CHECK-SAME: %[[X:.+]]: vector<2xf32> 258// CHECK-SAME: %[[B:.+]]: vector<4xf32> 259// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1> 260func.func @masked_matvec_km_k_m(%A: vector<2x4xf32>, 261 %x: vector<2xf32>, 262 %b: vector<4xf32>, 263 %mask: vector<4x2xi1>) -> vector<4xf32> { 264 // CHECK: vector.transpose %[[MASK]] 265 // CHECK-NOT: vector.transpose %[[A]] 266 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } 267 %res = vector.mask %mask { 268 vector.contract #matvec_trait_2 %A, %x, %b 269 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32> 270 } : vector<4x2xi1> -> vector<4xf32> 271 return %res : vector<4xf32> 272} 273 274// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim 275// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32> 276// CHECK-SAME: %[[X:.+]]: vector<2xf32> 277// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> 278// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1> 279func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>, 280 %x: vector<2xf32>, 281 %b: vector<[4]xf32>, 282 %mask: vector<[4]x2xi1>) -> vector<[4]xf32> { 283 // CHECK: vector.transpose %[[MASK]] 284 // CHECK-NOT: vector.transpose %[[A]] 285 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } 286 %res = vector.mask %mask { 287 vector.contract #matvec_trait_2 %A, %x, %b 288 : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32> 289 } : vector<[4]x2xi1> -> vector<[4]xf32> 290 return %res : vector<[4]xf32> 291} 292 293// ============================================================================ 294// Matvec 3 (plain + masked + scalable) 295// ============================================================================ 296// CHECK-LABEL: func @matvec_k_mk_m 297// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> 298// CHECK-SAME: %[[X:.*1]]: vector<2xf32> 299// CHECK-SAME: %[[B:.*2]]: vector<2xf32> 300// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> 301// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32> 302// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> 303// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 304// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32> 305// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> 306// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32 307func.func @matvec_k_mk_m(%A: vector<2x2xf32>, 308 %x: vector<2xf32>, 309 %b: vector<2xf32>) -> vector<2xf32> { 310 %0 = vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> 311 return %0 : vector<2xf32> 312} 313 314// CHECK-LABEL: @masked_matvec_k_mk_m 315// CHECK-SAME: %[[A:.+]]: vector<4x2xf32> 316// CHECK-SAME: %[[X:.+]]: vector<2xf32> 317// CHECK-SAME: %[[B:.+]]: vector<4xf32> 318// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1> 319func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>, 320 %x: vector<2xf32>, 321 %b: vector<4xf32>, 322 %mask: vector<4x2xi1>) -> vector<4xf32> { 323 // CHECK: vector.transpose %[[A]] 324 // CHECK: vector.transpose %[[MASK]] 325 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } 326 %res = vector.mask %mask { 327 vector.contract #matvec_trait_3 %x, %A, %b 328 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32> 329 } : vector<4x2xi1> -> vector<4xf32> 330 return %res : vector<4xf32> 331} 332 333// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim 334// CHECK-SAME: %[[A:.+]]: vector<[4]x2xf32> 335// CHECK-SAME: %[[X:.+]]: vector<2xf32> 336// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> 337// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1> 338func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>, 339 %x: vector<2xf32>, 340 %b: vector<[4]xf32>, 341 %mask: vector<[4]x2xi1>) -> vector<[4]xf32> { 342 // CHECK: vector.transpose %[[A]] 343 // CHECK: vector.transpose %[[MASK]] 344 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } 345 %res = vector.mask %mask { 346 vector.contract #matvec_trait_3 %x, %A, %b 347 : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32> 348 } : vector<[4]x2xi1> -> vector<[4]xf32> 349 return %res : vector<[4]xf32> 350} 351 352// ============================================================================ 353// Matvec 4 (plain + masked + scalable) 354// ============================================================================ 355// CHECK-LABEL: func @matvec_k_km_m 356// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> 357// CHECK-SAME: %[[X:.*1]]: vector<2xf32> 358// CHECK-SAME: %[[B:.*2]]: vector<2xf32> 359// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> 360// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> 361// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 362// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> 363// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> 364// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 365func.func @matvec_k_km_m(%A: vector<2x2xf32>, 366 %x: vector<2xf32>, 367 %b: vector<2xf32>) -> vector<2xf32> { 368 %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> 369 return %0 : vector<2xf32> 370} 371 372// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim 373// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32> 374// CHECK-SAME: %[[X:.+]]: vector<2xf32> 375// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> 376// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1> 377func.func @masked_matvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>, 378 %x: vector<2xf32>, 379 %b: vector<[4]xf32>, 380 %mask: vector<[4]x2xi1>) -> vector<[4]xf32> { 381 // CHECK: vector.transpose %[[MASK]] 382 // CHECK-NOT: vector.transpose %[[A]] 383 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } 384 %res = vector.mask %mask { 385 vector.contract #matvec_trait_4 %x, %A, %b 386 : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32> 387 } : vector<[4]x2xi1> -> vector<[4]xf32> 388 return %res : vector<[4]xf32> 389} 390 391// CHECK-LABEL: @masked_matvec_k_km_m 392// CHECK-SAME: %[[A:.+]]: vector<2x4xf32> 393// CHECK-SAME: %[[X:.+]]: vector<2xf32> 394// CHECK-SAME: %[[B:.+]]: vector<4xf32> 395// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1> 396func.func @masked_matvec_k_km_m(%A: vector<2x4xf32>, 397 %x: vector<2xf32>, 398 %b: vector<4xf32>, 399 %mask: vector<4x2xi1>) -> vector<4xf32> { 400 // CHECK: vector.transpose %[[MASK]] 401 // CHECK-NOT: vector.transpose %[[A]] 402 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } 403 %res = vector.mask %mask { 404 vector.contract #matvec_trait_4 %x, %A, %b 405 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32> 406 } : vector<4x2xi1> -> vector<4xf32> 407 return %res : vector<4xf32> 408} 409 410// ============================================================================ 411// Matvec 5 (plain + masked + scalable) 412// ============================================================================ 413// CHECK-LABEL: func.func @tmatvec_mk_k_m( 414// CHECK-SAME: %[[A:.*]]: vector<2x2xf32>, 415// CHECK-SAME: %[[X:.*]]: vector<2xf32>, 416// CHECK-SAME: %[[B:.*]]: vector<2xf32>) -> vector<2xf32> { 417// CHECK: %[[VAL_3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> 418// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : vector<2xf32> from vector<2x2xf32> 419// CHECK: %[[VAL_5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> 420// CHECK: %[[VAL_6:.*]] = vector.outerproduct %[[VAL_4]], %[[VAL_5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 421// CHECK: %[[VAL_7:.*]] = vector.extract %[[VAL_3]][1] : vector<2xf32> from vector<2x2xf32> 422// CHECK: %[[VAL_8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> 423// CHECK: %[[VAL_9:.*]] = vector.outerproduct %[[VAL_7]], %[[VAL_8]], %[[VAL_6]] {kind = #vector.kind<add>} : vector<2xf32>, f32 424func.func @tmatvec_mk_k_m(%A: vector<2x2xf32>, 425 %x: vector<2xf32>, 426 %b: vector<2xf32>) -> vector<2xf32> { 427 %0 = vector.contract #matvec_trait_5 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 428 return %0 : vector<2xf32> 429} 430 431// CHECK-LABEL: @masked_tmatvec_mk_k_m 432// CHECK-SAME: %[[A:.+]]: vector<4x2xf32> 433// CHECK-SAME: %[[X:.+]]: vector<2xf32> 434// CHECK-SAME: %[[B:.+]]: vector<4xf32> 435// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1> 436func.func @masked_tmatvec_mk_k_m(%A: vector<4x2xf32>, 437 %x: vector<2xf32>, 438 %b: vector<4xf32>, 439 %mask: vector<2x4xi1>) -> vector<4xf32> { 440 // CHECK: vector.transpose %[[A]] 441 // CHECK-NOT: vector.transpose %[[MASK]] 442 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } 443 %res = vector.mask %mask { 444 vector.contract #matvec_trait_5 %A, %x, %b 445 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32> 446 } : vector<2x4xi1> -> vector<4xf32> 447 return %res : vector<4xf32> 448} 449 450// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim 451// CHECK-SAME: %[[A:.+]]: vector<[4]x2xf32> 452// CHECK-SAME: %[[X:.+]]: vector<2xf32> 453// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> 454// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1> 455func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%A: vector<[4]x2xf32>, 456 %x: vector<2xf32>, 457 %b: vector<[4]xf32>, 458 %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { 459 // CHECK: vector.transpose %[[A]] 460 // CHECK-NOT: vector.transpose %[[MASK]] 461 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } 462 %res = vector.mask %mask { 463 vector.contract #matvec_trait_5 %A, %x, %b 464 : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32> 465 } : vector<2x[4]xi1> -> vector<[4]xf32> 466 return %res : vector<[4]xf32> 467} 468 469// ============================================================================ 470// Matvec 6 (plain + masked + scalable) 471// ============================================================================ 472// CHECK-LABEL: func.func @tmatvec_km_k_m( 473// CHECK-SAME: %[[A:.*]]: vector<2x2xf32>, 474// CHECK-SAME: %[[X:.*]]: vector<2xf32>, 475// CHECK-SAME: %[[B:.*]]: vector<2xf32>) -> vector<2xf32> { 476// CHECK: %[[VAL_3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> 477// CHECK: %[[VAL_4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> 478// CHECK: %[[VAL_5:.*]] = vector.outerproduct %[[VAL_3]], %[[VAL_4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 479// CHECK: %[[VAL_6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> 480// CHECK: %[[VAL_7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> 481// CHECK: %[[VAL_8:.*]] = vector.outerproduct %[[VAL_6]], %[[VAL_7]], %[[VAL_5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 482func.func @tmatvec_km_k_m(%A: vector<2x2xf32>, 483 %x: vector<2xf32>, 484 %b: vector<2xf32>) -> vector<2xf32> { 485 %0 = vector.contract #matvec_trait_6 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 486 return %0 : vector<2xf32> 487} 488 489// CHECK-LABEL: @masked_tmatvec_km_k_m 490// CHECK-SAME: %[[A:.+]]: vector<2x4xf32> 491// CHECK-SAME: %[[X:.+]]: vector<2xf32> 492// CHECK-SAME: %[[B:.+]]: vector<4xf32> 493// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1> 494func.func @masked_tmatvec_km_k_m(%A: vector<2x4xf32>, 495 %x: vector<2xf32>, 496 %b: vector<4xf32>, 497 %mask: vector<2x4xi1>) -> vector<4xf32> { 498 // CHECK-NOT: vector.transpose %[[A]] 499 // CHECK-NOT: vector.transpose %[[MASK]] 500 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } 501 %res = vector.mask %mask { 502 vector.contract #matvec_trait_6 %A, %x, %b 503 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32> 504 } : vector<2x4xi1> -> vector<4xf32> 505 return %res : vector<4xf32> 506} 507 508// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim 509// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32> 510// CHECK-SAME: %[[X:.+]]: vector<2xf32> 511// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> 512// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1> 513func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>, 514 %x: vector<2xf32>, 515 %b: vector<[4]xf32>, 516 %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { 517 // CHECK-NOT: vector.transpose %[[A]] 518 // CHECK-NOT: vector.transpose %[[MASK]] 519 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } 520 %res = vector.mask %mask { 521 vector.contract #matvec_trait_6 %A, %x, %b 522 : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32> 523 } : vector<2x[4]xi1> -> vector<[4]xf32> 524 return %res : vector<[4]xf32> 525} 526 527// ============================================================================ 528// Matvec 7 (plain + masked + scalable) 529// ============================================================================ 530// CHECK-LABEL: func.func @tmatvec_k_mk_m( 531// CHECK-SAME: %[[A:.*]]: vector<2x2xf32>, 532// CHECK-SAME: %[[X:.*]]: vector<2xf32>, 533// CHECK-SAME: %[[B:.*]]: vector<2xf32>) -> vector<2xf32> { 534// CHECK: %[[VAL_3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> 535// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : vector<2xf32> from vector<2x2xf32> 536// CHECK: %[[VAL_5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> 537// CHECK: %[[VAL_6:.*]] = vector.outerproduct %[[VAL_4]], %[[VAL_5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 538// CHECK: %[[VAL_7:.*]] = vector.extract %[[VAL_3]][1] : vector<2xf32> from vector<2x2xf32> 539// CHECK: %[[VAL_8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> 540// CHECK: %[[VAL_9:.*]] = vector.outerproduct %[[VAL_7]], %[[VAL_8]], %[[VAL_6]] {kind = #vector.kind<add>} : vector<2xf32>, f32 541func.func @tmatvec_k_mk_m(%A: vector<2x2xf32>, 542 %x: vector<2xf32>, 543 %b: vector<2xf32>) -> vector<2xf32> { 544 %0 = vector.contract #matvec_trait_7 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> 545 return %0 : vector<2xf32> 546} 547 548// CHECK-LABEL: @masked_tmatvec_k_mk_m 549// CHECK-SAME: %[[A:.+]]: vector<4x2xf32> 550// CHECK-SAME: %[[X:.+]]: vector<2xf32> 551// CHECK-SAME: %[[B:.+]]: vector<4xf32> 552// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1> 553func.func @masked_tmatvec_k_mk_m(%A: vector<4x2xf32>, 554 %x: vector<2xf32>, 555 %b: vector<4xf32>, 556 %mask: vector<2x4xi1>) -> vector<4xf32> { 557 // CHECK: vector.transpose %[[A]] 558 // CHECK-NOT: vector.transpose %[[MASK]] 559 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } 560 %res = vector.mask %mask { 561 vector.contract #matvec_trait_7 %x, %A, %b 562 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32> 563 } : vector<2x4xi1> -> vector<4xf32> 564 return %res : vector<4xf32> 565} 566 567// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim 568// CHECK-SAME: %[[A:.+]]: vector<[4]x2xf32> 569// CHECK-SAME: %[[X:.+]]: vector<2xf32> 570// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> 571// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1> 572func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>, 573 %x: vector<2xf32>, 574 %b: vector<[4]xf32>, 575 %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { 576 // CHECK: vector.transpose %[[A]] 577 // CHECK-NOT: vector.transpose %[[MASK]] 578 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } 579 %res = vector.mask %mask { 580 vector.contract #matvec_trait_7 %x, %A, %b 581 : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32> 582 } : vector<2x[4]xi1> -> vector<[4]xf32> 583 return %res : vector<[4]xf32> 584} 585 586// ============================================================================ 587// Matvec 8 (plain + masked + scalable) 588// ============================================================================ 589// CHECK-LABEL: func @tmatvec_m_mk_k 590// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> 591// CHECK-SAME: %[[X:.*1]]: vector<2xf32> 592// CHECK-SAME: %[[B:.*2]]: vector<2xf32> 593// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> 594// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> 595// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 596// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> 597// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> 598// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 599func.func @tmatvec_m_mk_k(%A: vector<2x2xf32>, 600 %x: vector<2xf32>, 601 %b: vector<2xf32>) -> vector<2xf32> { 602 %0 = vector.contract #matvec_trait_8 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> 603 return %0 : vector<2xf32> 604} 605 606// CHECK-LABEL: @masked_tmatvec_k_km_m 607// CHECK-SAME: %[[A:.+]]: vector<2x4xf32> 608// CHECK-SAME: %[[X:.+]]: vector<2xf32> 609// CHECK-SAME: %[[B:.+]]: vector<4xf32> 610// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1> 611func.func @masked_tmatvec_k_km_m(%A: vector<2x4xf32>, 612 %x: vector<2xf32>, 613 %b: vector<4xf32>, 614 %mask: vector<2x4xi1>) -> vector<4xf32> { 615 // CHECK-NOT: vector.transpose %[[A]] 616 // CHECK-NOT: vector.transpose %[[MASK]] 617 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } 618 %res = vector.mask %mask { 619 vector.contract #matvec_trait_8 %x, %A, %b 620 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32> 621 } : vector<2x4xi1> -> vector<4xf32> 622 return %res : vector<4xf32> 623} 624 625// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim 626// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32> 627// CHECK-SAME: %[[X:.+]]: vector<2xf32> 628// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> 629// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1> 630func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>, 631 %x: vector<2xf32>, 632 %b: vector<[4]xf32>, 633 %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { 634 // CHECK-NOT: vector.transpose %[[A]] 635 // CHECK-NOT: vector.transpose %[[MASK]] 636 // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } 637 %res = vector.mask %mask { 638 vector.contract #matvec_trait_8 %x, %A, %b 639 : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32> 640 } : vector<2x[4]xi1> -> vector<[4]xf32> 641 return %res : vector<[4]xf32> 642} 643 644// Unrolling scalable reduction dim is not supported - bail out 645// CHECK-LABEL: @masked_extract_contract2_scalable_reduction_dim( 646// CHECK: vector.contract {{.*}} : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> 647func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>, 648 %arg1: vector<[3]xf32>, 649 %arg2: vector<[2]xf32>, 650 %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> { 651 %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2 652 : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32> 653 return %0 : vector<[2]xf32> 654} 655 656// ============================================================================ 657// TD sequence 658// ============================================================================ 659module attributes {transform.with_named_sequence} { 660 transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { 661 %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> 662 transform.apply_patterns to %func_op { 663 transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" 664 } : !transform.op<"func.func"> 665 transform.yield 666 } 667} 668