1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s 2 3#dotp_accesses = [ 4 affine_map<(i) -> (i)>, 5 affine_map<(i) -> (i)>, 6 affine_map<(i) -> ()> 7] 8#dotp_trait = { 9 indexing_maps = #dotp_accesses, 10 iterator_types = ["reduction"] 11} 12 13// CHECK-LABEL: func @extract_contract1 14// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, 15// CHECK-SAME: %[[B:.*1]]: vector<4xf32>, 16// CHECK-SAME: %[[C:.*2]]: f32 17// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> 18// CHECK: %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xf32> into f32 19// CHECK: return %[[R]] : f32 20 21func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { 22 %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 23 : vector<4xf32>, vector<4xf32> into f32 24 return %0 : f32 25} 26 27// CHECK-LABEL: func @masked_extract_contract1 28// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, %[[B:.*1]]: vector<4xf32>, %[[C:.*2]]: f32 29// CHECK-SAME: %[[M:.*]]: vector<4xi1> 30// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> 31// CHECK: %[[R:.*]] = vector.mask %[[M]] { vector.reduction <add>, %0, %arg2 : vector<4xf32> into f32 } : vector<4xi1> -> f32 32// CHECK: return %[[R]] : f32 33 34func.func @masked_extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32, %mask: vector<4xi1>) -> f32 { 35 %0 = vector.mask %mask { vector.contract #dotp_trait %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 } : vector<4xi1> -> f32 36 return %0 : f32 37} 38 39// CHECK-LABEL: func @extract_contract1_int 40// CHECK-SAME: %[[A:.*0]]: vector<4xi32>, 41// CHECK-SAME: %[[B:.*1]]: vector<4xi32>, 42// CHECK-SAME: %[[C:.*2]]: i32 43// CHECK: %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32> 44// CHECK: %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xi32> into i32 45// CHECK: return %[[R]] : i32 46 47func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { 48 %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 49 : vector<4xi32>, vector<4xi32> into i32 50 return %0 : i32 51} 52 53#matvec_accesses = [ 54 affine_map<(i, j) -> (i, j)>, 55 affine_map<(i, j) -> (j)>, 56 affine_map<(i, j) -> (i)> 57] 58#matvec_trait = { 59 indexing_maps = #matvec_accesses, 60 iterator_types = ["parallel", "reduction"] 61} 62 63// CHECK-LABEL: func @extract_contract2 64// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, 65// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, 66// CHECK-SAME: %[[C:.*2]]: vector<2xf32> 67// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> 68// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32> 69// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] : vector<3xf32> 70// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32 71// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> 72// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32> 73// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> 74// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32 75// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> 76// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> 77// CHECK: return %[[T10]] : vector<2xf32> 78 79func.func @extract_contract2(%arg0: vector<2x3xf32>, 80 %arg1: vector<3xf32>, 81 %arg2: vector<2xf32>) -> vector<2xf32> { 82 %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 83 : vector<2x3xf32>, vector<3xf32> into vector<2xf32> 84 return %0 : vector<2xf32> 85} 86 87// CHECK-LABEL: func @extract_contract2_int 88// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, 89// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, 90// CHECK-SAME: %[[C:.*2]]: vector<2xi32> 91// CHECK: %[[R:.*]] = arith.constant dense<0> : vector<2xi32> 92// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xi32> from vector<2x3xi32> 93// CHECK: %[[T2:.*]] = arith.muli %[[T0]], %[[B]] : vector<3xi32> 94// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xi32> into i32 95// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32> 96// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xi32> from vector<2x3xi32> 97// CHECK: %[[T7:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> 98// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xi32> into i32 99// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> 100// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32> 101// CHECK: return %[[T10]] : vector<2xi32> 102func.func @extract_contract2_int(%arg0: vector<2x3xi32>, 103 %arg1: vector<3xi32>, 104 %arg2: vector<2xi32>) -> vector<2xi32> { 105 %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 106 : vector<2x3xi32>, vector<3xi32> into vector<2xi32> 107 return %0 : vector<2xi32> 108} 109 110#vecmat_accesses = [ 111 affine_map<(i, j) -> (j)>, 112 affine_map<(i, j) -> (i, j)>, 113 affine_map<(i, j) -> (i)> 114] 115#vecmat_trait = { 116 indexing_maps = #vecmat_accesses, 117 iterator_types = ["parallel", "reduction"] 118} 119 120// CHECK-LABEL: func @extract_contract3 121// CHECK-SAME: %[[A:.*0]]: vector<3xf32>, 122// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, 123// CHECK-SAME: %[[C:.*2]]: vector<2xf32> 124// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> 125// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<2x3xf32> 126// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[A]] : vector<3xf32> 127// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32 128// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> 129// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<2x3xf32> 130// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[A]] : vector<3xf32> 131// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32 132// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> 133// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> 134// CHECK: return %[[T10]] : vector<2xf32> 135 136func.func @extract_contract3(%arg0: vector<3xf32>, 137 %arg1: vector<2x3xf32>, 138 %arg2: vector<2xf32>) -> vector<2xf32> { 139 %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2 140 : vector<3xf32>, vector<2x3xf32> into vector<2xf32> 141 return %0 : vector<2xf32> 142} 143 144#matmat_accesses = [ 145 affine_map<(i, j, k) -> (i, k)>, 146 affine_map<(i, j, k) -> (k, j)>, 147 affine_map<(i, j, k) -> (i, j)> 148] 149#matmat_trait = { 150 indexing_maps = #matmat_accesses, 151 iterator_types = ["parallel", "parallel", "reduction"] 152} 153 154// CHECK-LABEL: func @extract_contract4 155// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>, 156// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>, 157// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32> 158// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> 159// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32> 160// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> 161// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32> 162// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32> 163// CHECK: %[[T10:.*]] = vector.reduction <add>, %[[T9]] : vector<2xf32> into f32 164// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32> 165// 166// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32> 167// CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32> 168// CHECK: %[[T20:.*]] = vector.reduction <add>, %[[T19]] : vector<2xf32> into f32 169// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32> 170// 171// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> 172// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32> 173// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32> 174// CHECK: %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32 175// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32> 176// 177// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32> 178// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32> 179// CHECK: %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32 180// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32> 181// 182// CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32> 183// CHECK: return %[[T52]] : vector<2x2xf32> 184 185func.func @extract_contract4(%arg0: vector<2x2xf32>, 186 %arg1: vector<2x2xf32>, 187 %arg2: vector<2x2xf32>) -> vector<2x2xf32> { 188 %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 189 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 190 return %0 : vector<2x2xf32> 191} 192 193 194#contraction2d_accesses = [ 195 affine_map<(i, j) -> (i, j)>, 196 affine_map<(i, j) -> (i, j)>, 197 affine_map<(i, j) -> ()> 198] 199#contraction2d_trait = { 200 indexing_maps = #contraction2d_accesses, 201 iterator_types = ["reduction", "reduction"] 202} 203 204// CHECK-LABEL: func @full_contract1 205// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, 206// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, 207// CHECK-SAME: %[[C:.*2]]: f32 208// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32> 209// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<2x3xf32> 210// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32> 211// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[C]] : vector<3xf32> into f32 212// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32> 213// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<2x3xf32> 214// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32> 215// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]], %[[T3]] : vector<3xf32> into f32 216// CHECK: return %[[T8]] : f32 217 218func.func @full_contract1(%arg0: vector<2x3xf32>, 219 %arg1: vector<2x3xf32>, 220 %arg2: f32) -> f32 { 221 %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2 222 : vector<2x3xf32>, vector<2x3xf32> into f32 223 return %0 : f32 224} 225 226#contraction2d_trans_accesses = [ 227 affine_map<(i, j) -> (i, j)>, 228 affine_map<(i, j) -> (j, i)>, 229 affine_map<(i, j) -> ()> 230] 231#contraction2d_trans_trait = { 232 indexing_maps = #contraction2d_trans_accesses, 233 iterator_types = ["reduction", "reduction"] 234} 235 236// CHECK-LABEL: func @full_contract2 237// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, 238// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>, 239// CHECK-SAME: %[[C:.*2]]: f32 240// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3xf32> 241// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32> 242// CHECK: %[[T1:.*]] = vector.extract %[[B]][0, 0] : f32 from vector<3x2xf32> 243// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32> 244// CHECK: %[[T4:.*]] = vector.extract %[[B]][1, 0] : f32 from vector<3x2xf32> 245// CHECK: %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32> 246// CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : f32 from vector<3x2xf32> 247// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> 248// CHECK: %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32> 249// CHECK: %[[T11:.*]] = vector.reduction <add>, %[[T10]], %[[C]] : vector<3xf32> into f32 250// 251// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32> 252// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : f32 from vector<3x2xf32> 253// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32> 254// CHECK: %[[T16:.*]] = vector.extract %[[B]][1, 1] : f32 from vector<3x2xf32> 255// CHECK: %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32> 256// CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : f32 from vector<3x2xf32> 257// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> 258// CHECK: %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32> 259// CHECK: %[[T23:.*]] = vector.reduction <add>, %[[T22]], %[[T11]] : vector<3xf32> into f32 260// CHECK: return %[[T23]] : f32 261 262func.func @full_contract2(%arg0: vector<2x3xf32>, 263 %arg1: vector<3x2xf32>, 264 %arg2: f32) -> f32 { 265 %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2 266 : vector<2x3xf32>, vector<3x2xf32> into f32 267 return %0 : f32 268} 269 270// CHECK-LABEL: @contract_one_sided_unit_reduction_dim 271// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>) 272// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32> 273// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<2xi32> from vector<1x2xi32> 274// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2xi32> from vector<2x2xi32> 275// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32> 276// CHECK: %[[R0:.+]] = vector.reduction <add>, %[[M0]] : vector<2xi32> into i32 277// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32> 278// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2xi32> from vector<2x2xi32> 279// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32> 280// CHECK: %[[R1:.+]] = vector.reduction <add>, %[[M1]] : vector<2xi32> into i32 281// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32> 282// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32> 283// CHECK: return %[[S]] : vector<2xi32> 284 285func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> { 286 %res = vector.contract { 287 indexing_maps = [ 288 affine_map<(d0, d1, d2) -> (d0, d2)>, 289 affine_map<(d0, d1, d2) -> (d1, d2)>, 290 affine_map<(d0, d1, d2) -> (d1)> 291 ], 292 iterator_types = ["reduction", "parallel", "reduction"], 293 kind = #vector.kind<add> 294 } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32> 295 return %res : vector<2xi32> 296} 297 298module attributes {transform.with_named_sequence} { 299 transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { 300 %f = transform.structured.match ops{["func.func"]} in %module_op 301 : (!transform.any_op) -> !transform.any_op 302 303 transform.apply_patterns to %f { 304 transform.apply_patterns.vector.lower_contraction lowering_strategy = "dot" 305 } : !transform.any_op 306 transform.yield 307 } 308} 309