1// RUN: mlir-opt %s -test-vector-contraction-prepare-for-mmt-lowering | FileCheck %s 2 3// CHECK-LABEL: func.func @not_matmul 4// CHECK-SAME: ([[ARG0:%.+]]: vector<4xf32>, [[ARG1:%.+]]: vector<4xf32>, [[ARG2:%.+]]: f32) 5// CHECK-NEXT: vector.contract 6// CHECK-NEXT: return 7func.func @not_matmul(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { 8 %0 = vector.contract {indexing_maps = [affine_map<(d0) -> (d0)>, 9 affine_map<(d0) -> (d0)>, 10 affine_map<(d0) -> ()>], 11 iterator_types = ["reduction"], 12 kind = #vector.kind<add>} %arg0, %arg1, %arg2 : 13 vector<4xf32>, vector<4xf32> into f32 14 return %0 : f32 15} 16 17// This contraction is already in the canonical form. 18// CHECK-LABEL: func.func @matmul_mk_nk_mn_4x4xi32 19// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) 20// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[ARG1]], [[ARG2]] 21// CHECK-NEXT: return [[RES]] 22func.func @matmul_mk_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { 23 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 24 affine_map<(d0, d1, d2) -> (d1, d2)>, 25 affine_map<(d0, d1, d2) -> (d0, d1)>], 26 iterator_types = ["parallel", "parallel", "reduction"], 27 kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> 28 return %res : vector<4x4xi32> 29} 30 31// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi32 32// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) 33// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> 34// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]] 35// CHECK-NEXT: return [[RES]] 36func.func @matmul_mk_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { 37 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 38 affine_map<(d0, d1, d2) -> (d2, d1)>, 39 affine_map<(d0, d1, d2) -> (d0, d1)>], 40 iterator_types = ["parallel", "parallel", "reduction"], 41 kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> 42 return %res : vector<4x4xi32> 43} 44 45// CHECK-LABEL: func.func @matmul_mk_kn_mn_8x16xi8_extsi_i32 46// CHECK-SAME: ([[ARG0:%.+]]: vector<8x4xi8>, [[ARG1:%.+]]: vector<4x16xi8>, [[ARG2:%.+]]: vector<8x16xi32>) 47// CHECK-NEXT: [[LHS:%.+]] = arith.extsi [[ARG0]] : vector<8x4xi8> to vector<8x4xi32> 48// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x16xi8> to vector<16x4xi8> 49// CHECK-NEXT: [[RHS:%.+]] = arith.extsi [[TRANS]] : vector<16x4xi8> to vector<16x4xi32> 50// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]] 51// CHECK-NEXT: return [[RES]] 52func.func @matmul_mk_kn_mn_8x16xi8_extsi_i32(%arg0: vector<8x4xi8>, %arg1: vector<4x16xi8>, %arg2: vector<8x16xi32>) -> vector<8x16xi32> { 53 %lhs = arith.extsi %arg0: vector<8x4xi8> to vector<8x4xi32> 54 %rhs = arith.extsi %arg1: vector<4x16xi8> to vector<4x16xi32> 55 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 56 affine_map<(d0, d1, d2) -> (d2, d1)>, 57 affine_map<(d0, d1, d2) -> (d0, d1)>], 58 iterator_types = ["parallel", "parallel", "reduction"], 59 kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<8x4xi32>, vector<4x16xi32> into vector<8x16xi32> 60 return %res : vector<8x16xi32> 61} 62 63// Check that non-square shapes are also handled. 64// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x16xi32 65// CHECK-SAME: ([[ARG0:%.+]]: vector<4x16xi32>, [[ARG1:%.+]]: vector<16x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) 66// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<16x4xi32> to vector<4x16xi32> 67// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]] 68// CHECK-NEXT: return [[RES]] 69func.func @matmul_mk_kn_mn_4x16xi32(%arg0: vector<4x16xi32>, %arg1: vector<16x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { 70 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 71 affine_map<(d0, d1, d2) -> (d2, d1)>, 72 affine_map<(d0, d1, d2) -> (d0, d1)>], 73 iterator_types = ["parallel", "parallel", "reduction"], 74 kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x16xi32>, vector<16x4xi32> into vector<4x4xi32> 75 return %res : vector<4x4xi32> 76} 77 78// CHECK-LABEL: func.func @matmul_mk_kn_mn_8x16xi8_extui_i32 79// CHECK-SAME: ([[ARG0:%.+]]: vector<8x4xi8>, [[ARG1:%.+]]: vector<4x16xi8>, [[ARG2:%.+]]: vector<8x16xi32>) 80// CHECK-NEXT: [[LHS:%.+]] = arith.extui [[ARG0]] : vector<8x4xi8> to vector<8x4xi32> 81// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x16xi8> to vector<16x4xi8> 82// CHECK-NEXT: [[RHS:%.+]] = arith.extui [[TRANS]] : vector<16x4xi8> to vector<16x4xi32> 83// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]] 84// CHECK-NEXT: return [[RES]] 85func.func @matmul_mk_kn_mn_8x16xi8_extui_i32(%arg0: vector<8x4xi8>, %arg1: vector<4x16xi8>, %arg2: vector<8x16xi32>) -> vector<8x16xi32> { 86 %lhs = arith.extui %arg0: vector<8x4xi8> to vector<8x4xi32> 87 %rhs = arith.extui %arg1: vector<4x16xi8> to vector<4x16xi32> 88 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 89 affine_map<(d0, d1, d2) -> (d2, d1)>, 90 affine_map<(d0, d1, d2) -> (d0, d1)>], 91 iterator_types = ["parallel", "parallel", "reduction"], 92 kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<8x4xi32>, vector<4x16xi32> into vector<8x16xi32> 93 return %res : vector<8x16xi32> 94} 95 96// CHECK-LABEL: func.func @matmul_km_nk_mn_4x4xi32 97// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) 98// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> 99// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[TRANS]], [[ARG1]], [[ARG2]] 100// CHECK-NEXT: return [[RES]] 101func.func @matmul_km_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { 102 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, 103 affine_map<(d0, d1, d2) -> (d1, d2)>, 104 affine_map<(d0, d1, d2) -> (d0, d1)>], 105 iterator_types = ["parallel", "parallel", "reduction"], 106 kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> 107 return %res : vector<4x4xi32> 108} 109 110// CHECK-LABEL: func.func @matmul_km_kn_mn_4x4xi32 111// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) 112// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> 113// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> 114// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]] 115// CHECK-NEXT: return [[RES]] 116func.func @matmul_km_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { 117 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, 118 affine_map<(d0, d1, d2) -> (d2, d1)>, 119 affine_map<(d0, d1, d2) -> (d0, d1)>], 120 iterator_types = ["parallel", "parallel", "reduction"], 121 kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> 122 return %res : vector<4x4xi32> 123} 124 125// CHECK-LABEL: func.func @matmul_km_kn_mn_8x16xi8_mixed_ext_i32 126// CHECK-SAME: ([[ARG0:%.+]]: vector<4x8xi8>, [[ARG1:%.+]]: vector<4x16xi8>, [[ARG2:%.+]]: vector<8x16xi32>) 127// CHECK-DAG: [[LHST:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x8xi8> to vector<8x4xi8> 128// CHECK-DAG: [[LHS:%.+]] = arith.extsi [[LHST]] : vector<8x4xi8> to vector<8x4xi32> 129// CHECK-DAG: [[RHST:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x16xi8> to vector<16x4xi8> 130// CHECK-DAG: [[RHS:%.+]] = arith.extui [[RHST]] : vector<16x4xi8> to vector<16x4xi32> 131// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]] 132// CHECK-NEXT: return [[RES]] 133func.func @matmul_km_kn_mn_8x16xi8_mixed_ext_i32(%arg0: vector<4x8xi8>, %arg1: vector<4x16xi8>, %arg2: vector<8x16xi32>) -> vector<8x16xi32> { 134 %lhs = arith.extsi %arg0 : vector<4x8xi8> to vector<4x8xi32> 135 %rhs = arith.extui %arg1 : vector<4x16xi8> to vector<4x16xi32> 136 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, 137 affine_map<(d0, d1, d2) -> (d2, d1)>, 138 affine_map<(d0, d1, d2) -> (d0, d1)>], 139 iterator_types = ["parallel", "parallel", "reduction"], 140 kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<4x8xi32>, vector<4x16xi32> into vector<8x16xi32> 141 return %res : vector<8x16xi32> 142} 143 144// CHECK-LABEL: func.func @matmul_mk_nk_nm_4x4xi32 145// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) 146// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG1]], [[ARG0]], [[ARG2]] 147// CHECK-NEXT: return [[RES]] 148func.func @matmul_mk_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { 149 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 150 affine_map<(d0, d1, d2) -> (d1, d2)>, 151 affine_map<(d0, d1, d2) -> (d1, d0)>], 152 iterator_types = ["parallel", "parallel", "reduction"], 153 kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> 154 return %res : vector<4x4xi32> 155} 156 157// CHECK-LABEL: func.func @matmul_km_kn_nm_4x4xi32 158// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) 159// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> 160// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> 161// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[LHS]], [[ARG2]] 162// CHECK-NEXT: return [[RES]] 163func.func @matmul_km_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { 164 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, 165 affine_map<(d0, d1, d2) -> (d2, d1)>, 166 affine_map<(d0, d1, d2) -> (d1, d0)>], 167 iterator_types = ["parallel", "parallel", "reduction"], 168 kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> 169 return %res : vector<4x4xi32> 170} 171 172// CHECK-LABEL: func.func @matmul_mk_kn_nm_4x4xi32 173// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) 174// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> 175// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[ARG0]], [[ARG2]] 176// CHECK-NEXT: return [[RES]] 177func.func @matmul_mk_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { 178 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 179 affine_map<(d0, d1, d2) -> (d2, d1)>, 180 affine_map<(d0, d1, d2) -> (d1, d0)>], 181 iterator_types = ["parallel", "parallel", "reduction"], 182 kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> 183 return %res : vector<4x4xi32> 184} 185 186// CHECK-LABEL: func.func @matmul_km_nk_nm_4x4xi32 187// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) 188// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> 189// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG1]], [[LHS]], [[ARG2]] 190// CHECK-NEXT: return [[RES]] 191func.func @matmul_km_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { 192 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, 193 affine_map<(d0, d1, d2) -> (d1, d2)>, 194 affine_map<(d0, d1, d2) -> (d1, d0)>], 195 iterator_types = ["parallel", "parallel", "reduction"], 196 kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> 197 return %res : vector<4x4xi32> 198} 199