1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s 2 3/// Tests for `vector.contract` -> `vector.outerproduct` transformations for 4/// matmul operations: 5/// C += A * B. 6/// (A, B and C are 2-d matrices). ATM three different variants / are tested: 7/// * plain (no mask, fixed-wdith vectors), 8/// * masked (fixed-width vectors, 9/// * scalable (mask + scalable vectors). 10/// In order for the "vector.contract -> vector.outerproduct" patterns to work, 11/// only the non-reduction dimension can be scalable (*). For matmul operations 12/// that is set to be the N dimension (i.e. rows of the output matrix), which 13/// matches how matrix multiplication are normally implemented for e.g. 14/// Arm SVE. However, making the M dimension scalable (i.e. columns of the 15/// output matrix) should work as well. 16/// 17/// (*) The conversion tested in this file unrolls along the reduction 18/// dimension, which is not supported for scalable vectors. 19 20#matmat_accesses_0 = [ 21 affine_map<(m, n, k) -> (m, k)>, 22 affine_map<(m, n, k) -> (k, n)>, 23 affine_map<(m, n, k) -> (m, n)> 24] 25#matmat_trait_0 = { 26 indexing_maps = #matmat_accesses_0, 27 iterator_types = ["parallel", "parallel", "reduction"] 28} 29 30#matmat_accesses_1 = [ 31 affine_map<(m, n, k) -> (m, k)>, 32 affine_map<(m, n, k) -> (n, k)>, 33 affine_map<(m, n, k) -> (m, n)> 34] 35#matmat_trait_1 = { 36 indexing_maps = #matmat_accesses_1, 37 iterator_types = ["parallel", "parallel", "reduction"] 38} 39 40#matmat_accesses_2 = [ 41 affine_map<(m, n, k) -> (k, m)>, 42 affine_map<(m, n, k) -> (k, n)>, 43 affine_map<(m, n, k) -> (m, n)> 44] 45#matmat_trait_2 = { 46 indexing_maps = #matmat_accesses_2, 47 iterator_types = ["parallel", "parallel", "reduction"] 48} 49 50#matmat_accesses_3 = [ 51 affine_map<(m, n, k) -> (k, m)>, 52 affine_map<(m, n, k) -> (n, k)>, 53 affine_map<(m, n, k) -> (m, n)> 54] 55#matmat_trait_3 = { 56 indexing_maps = #matmat_accesses_3, 57 iterator_types = ["parallel", "parallel", "reduction"] 58} 59 60#matmat_accesses_4 = [ 61 affine_map<(m, n, k) -> (m, k)>, 62 affine_map<(m, n, k) -> (k, n)>, 63 affine_map<(m, n, k) -> (n, m)> 64] 65#matmat_trait_4 = { 66 indexing_maps = #matmat_accesses_4, 67 iterator_types = ["parallel", "parallel", "reduction"] 68} 69 70// ============================================================================ 71// Matmul 0 (plain + masked + mixed types) 72// ============================================================================ 73// CHECK-LABEL: func @matmul 74// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, 75// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, 76// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 77// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 78// CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32> 79// 80// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32> 81// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<4x3xf32> 82// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 83// CHECK-SAME: : vector<2xf32>, vector<3xf32> 84// 85// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32> 86// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<4x3xf32> 87// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] 88// CHECK-SAME: : vector<2xf32>, vector<3xf32> 89// 90// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32> 91// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<3xf32> from vector<4x3xf32> 92// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] 93// CHECK-SAME: : vector<2xf32>, vector<3xf32> 94// 95// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32> 96// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<3xf32> from vector<4x3xf32> 97// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] 98// CHECK-SAME: : vector<2xf32>, vector<3xf32> 99// 100// CHECK: return %[[c3]] : vector<2x3xf32> 101func.func @matmul(%A: vector<2x4xf32>, 102 %B: vector<4x3xf32>, 103 %C: vector<2x3xf32>) -> vector<2x3xf32> { 104 %0 = vector.contract #matmat_trait_0 %A, %B, %C 105 : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32> 106 return %0 : vector<2x3xf32> 107} 108 109// CHECK-LABEL: func @matmul_scalable 110// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, 111// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>, 112// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32> 113// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 114// CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32> 115// 116// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32> 117// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32> 118// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 119// CHECK-SAME: : vector<2xf32>, vector<[3]xf32> 120// 121// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32> 122// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32> 123// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] 124// CHECK-SAME: : vector<2xf32>, vector<[3]xf32> 125// 126// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32> 127// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32> 128// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] 129// CHECK-SAME: : vector<2xf32>, vector<[3]xf32> 130// 131// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32> 132// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32> 133// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] 134// CHECK-SAME: : vector<2xf32>, vector<[3]xf32> 135// 136// CHECK: return %[[c3]] : vector<2x[3]xf32> 137func.func @matmul_scalable(%A: vector<2x4xf32>, 138 %B: vector<4x[3]xf32>, 139 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> { 140 %0 = vector.contract #matmat_trait_0 %A, %B, %C 141 : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32> 142 return %0 : vector<2x[3]xf32> 143} 144 145// CHECK-LABEL: func.func @masked_matmul( 146// CHECK-SAME: %{{.*}}: vector<3x5xf32>, 147// CHECK-SAME: %{{.*}}: vector<5x7xf32>, 148// CHECK-SAME: %{{.*}}: vector<3x7xf32>, 149// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> { 150// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1> 151// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1> 152// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> 153// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1> 154// CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> 155// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1> 156// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> 157// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1> 158// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> 159// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1> 160// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> 161 162func.func @masked_matmul(%A: vector<3x5xf32>, 163 %B: vector<5x7xf32>, 164 %C: vector<3x7xf32>, 165 %m : vector<3x7x5xi1>) -> vector<3x7xf32> { 166 %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C 167 : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32> 168 return %0 : vector<3x7xf32> 169} 170 171// CHECK-LABEL: func.func @masked_matmul_scalable( 172// CHECK-SAME: %{{.*}}: vector<3x5xf32>, 173// CHECK-SAME: %{{.*}}: vector<5x[7]xf32>, 174// CHECK-SAME: %{{.*}}: vector<3x[7]xf32>, 175// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> { 176// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1> 177// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1> 178// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32> 179// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1> 180// CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32> 181// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1> 182// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32> 183// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1> 184// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32> 185// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1> 186// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32> 187 188func.func @masked_matmul_scalable(%A: vector<3x5xf32>, 189 %B: vector<5x[7]xf32>, 190 %C: vector<3x[7]xf32>, 191 %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> { 192 %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C 193 : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32> 194 return %0 : vector<3x[7]xf32> 195} 196 197// CHECK-LABEL: func @matmul_mixed 198// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>, 199// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>, 200// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 201// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 202// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16> 203// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf16> from vector<1x3xf16> 204// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32> 205// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32> 206// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]] 207// CHECK: return %[[c0]] : vector<2x3xf32> 208func.func @matmul_mixed(%A: vector<2x1xf16>, 209 %B: vector<1x3xf16>, 210 %C: vector<2x3xf32>) -> vector<2x3xf32> 211{ 212 %0 = vector.contract #matmat_trait_0 %A, %B, %C 213 : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32> 214 return %0 : vector<2x3xf32> 215} 216 217// CHECK-LABEL: func @matmul_mixed_scalable 218// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>, 219// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>, 220// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32> 221// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 222// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16> 223// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16> 224// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32> 225// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32> 226// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]] 227// CHECK: return %[[c0]] : vector<2x[3]xf32> 228func.func @matmul_mixed_scalable(%A: vector<2x1xf16>, 229 %B: vector<1x[3]xf16>, 230 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> 231{ 232 %0 = vector.contract #matmat_trait_0 %A, %B, %C 233 : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32> 234 return %0 : vector<2x[3]xf32> 235} 236 237// ============================================================================ 238// Matmul 1 (plain + scalable) 239// ============================================================================ 240// CHECK-LABEL: func @matmul_1 241// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 242// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, 243// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 244// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 245// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] 246// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32> 247// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32> 248// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 249// CHECK: return %[[c0]] : vector<2x3xf32> 250func.func @matmul_1(%A: vector<2x1xf32>, 251 %B: vector<3x1xf32>, 252 %C: vector<2x3xf32>) -> vector<2x3xf32> 253{ 254 %0 = vector.contract #matmat_trait_1 %A, %B, %C 255 : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32> 256 return %0 : vector<2x3xf32> 257} 258 259// CHECK-LABEL: func @matmul_1_scalable 260// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 261// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>, 262// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32> 263// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 264// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] 265// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32> 266// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32> 267// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 268// CHECK: return %[[c0]] : vector<2x[3]xf32> 269func.func @matmul_1_scalable(%A: vector<2x1xf32>, 270 %B: vector<[3]x1xf32>, 271 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> 272{ 273 %0 = vector.contract #matmat_trait_1 %A, %B, %C 274 : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32> 275 return %0 : vector<2x[3]xf32> 276} 277 278// ============================================================================ 279// Matmul 2 (plain + scalable) 280// ============================================================================ 281// CHECK-LABEL: func @matmul_2 282// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, 283// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 284// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 285// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32> 286// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32> 287// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 288// CHECK: return %[[c0]] : vector<2x3xf32> 289func.func @matmul_2(%A: vector<1x2xf32>, 290 %B: vector<1x3xf32>, 291 %C: vector<2x3xf32>) -> vector<2x3xf32> 292{ 293 %0 = vector.contract #matmat_trait_2 %A, %B, %C 294 : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32> 295 return %0 : vector<2x3xf32> 296} 297 298// CHECK-LABEL: func @matmul_2_scalable 299// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, 300// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>, 301// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32> 302// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32> 303// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32> 304// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 305// CHECK: return %[[c0]] : vector<2x[3]xf32> 306func.func @matmul_2_scalable(%A: vector<1x2xf32>, 307 %B: vector<1x[3]xf32>, 308 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> 309{ 310 %0 = vector.contract #matmat_trait_2 %A, %B, %C 311 : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32> 312 return %0 : vector<2x[3]xf32> 313} 314 315// ============================================================================ 316// Matmul 3 (plain + scalable) 317// ============================================================================ 318// CHECK-LABEL: func @matmul_3 319// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, 320// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, 321// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> 322// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] 323// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32> 324// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32> 325// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 326// CHECK: return %[[c0]] : vector<2x3xf32> 327func.func @matmul_3(%A: vector<1x2xf32>, 328 %B: vector<3x1xf32>, 329 %C: vector<2x3xf32>) -> vector<2x3xf32> 330{ 331 %0 = vector.contract #matmat_trait_3 %A, %B, %C 332 : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32> 333 return %0 : vector<2x3xf32> 334} 335 336// CHECK-LABEL: func @matmul_3_scalable 337// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, 338// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>, 339// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32> 340// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] 341// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32> 342// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32> 343// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] 344// CHECK: return %[[c0]] : vector<2x[3]xf32> 345func.func @matmul_3_scalable(%A: vector<1x2xf32>, 346 %B: vector<[3]x1xf32>, 347 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> 348{ 349 %0 = vector.contract #matmat_trait_3 %A, %B, %C 350 : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32> 351 return %0 : vector<2x[3]xf32> 352} 353 354// ============================================================================ 355// Matmul 4 (plain + scalable) 356// ============================================================================ 357// CHECK-LABEL: func @matmul_4 358// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, 359// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 360// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> 361// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 362// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32> 363// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32> 364// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] 365// CHECK: return %[[c0]] : vector<3x2xf32> 366func.func @matmul_4(%A: vector<2x1xf32>, 367 %B: vector<1x3xf32>, 368 %C: vector<3x2xf32>) -> vector<3x2xf32> 369{ 370 %0 = vector.contract #matmat_trait_4 %A, %B, %C 371 : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> 372 return %0 : vector<3x2xf32> 373} 374 375// CHECK-LABEL: func @matmul_4_scalable 376// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>, 377// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, 378// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32> 379// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] 380// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32> 381// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32> 382// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] 383// CHECK: return %[[c0]] : vector<3x[2]xf32> 384func.func @matmul_4_scalable(%A: vector<[2]x1xf32>, 385 %B: vector<1x3xf32>, 386 %C: vector<3x[2]xf32>) -> vector<3x[2]xf32> 387{ 388 %0 = vector.contract #matmat_trait_4 %A, %B, %C 389 : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32> 390 return %0 : vector<3x[2]xf32> 391} 392 393// ============================================================================ 394// TD sequence 395// ============================================================================ 396module attributes {transform.with_named_sequence} { 397 transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { 398 %f = transform.structured.match ops{["func.func"]} in %module_op 399 : (!transform.any_op) -> !transform.any_op 400 401 transform.apply_patterns to %f { 402 transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" 403 } : !transform.any_op 404 transform.yield 405 } 406} 407