xref: /llvm-project/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir (revision e4384149b58f7c3d19c5d38bc46038c660b77ca9)
1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2
3// CHECK-LABEL: func @parallel_contract_lowering
4//       CHECK:   %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
5//       CHECK:   %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
6//       CHECK:   %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32>
7//       CHECK:   return %[[F]] : vector<4xf32>
8func.func @parallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
9  %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32>
10  return %0 : vector<4xf32>
11}
12
13// CHECK-LABEL: func @parallel_contract_lowering_broadcast
14//       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32>
15//       CHECK:   %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32>
16//       CHECK:   %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
17//       CHECK:   %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<4xf32> from  vector<1x1x4xf32>
18//       CHECK:   %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32>
19//       CHECK:   return %[[F]] : vector<4xf32>
20func.func @parallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
21  %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32>
22  return %0 : vector<4xf32>
23}
24
25// CHECK-LABEL: func @parallel_contract_lowering
26//       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32>
27//       CHECK:   %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32>
28//       CHECK:   %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32>
29//       CHECK:   %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
30//       CHECK:   %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
31//       CHECK:   %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32>
32//       CHECK:   return %[[F]] : vector<4xf32>
33func.func @parallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
34  %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32>
35  return %0 : vector<4xf32>
36}
37
38// CHECK-LABEL: func @parallel_contract_lowering_scalar
39//       CHECK:   %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : f32 from vector<1x1xf32>
40//       CHECK:   %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : f32 from vector<1x1xf32>
41//       CHECK:   %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32
42//       CHECK:   %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
43//       CHECK:   return %[[A]] : f32
44func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 {
45  %0 = vector.contract {
46    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
47                     affine_map<(d0, d1) -> (d0, d1)>,
48                     affine_map<(d0, d1) -> ()>],
49    iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
50  %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32
51  return %0 : f32
52}
53
54module attributes {transform.with_named_sequence} {
55  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
56    %f = transform.structured.match ops{["func.func"]} in %module_op
57      : (!transform.any_op) -> !transform.any_op
58
59    transform.apply_patterns to %f {
60      transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith"
61    } : !transform.any_op
62    transform.yield
63  }
64}
65