xref: /llvm-project/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir (revision e4384149b58f7c3d19c5d38bc46038c660b77ca9)
1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2
3// CHECK-LABEL: func @outerproduct_noacc
4// CHECK-SAME: %[[A:.*0]]: vector<2xf32>,
5// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
6// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
7// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
8// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
9// CHECK:      %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
10// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
11// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
12// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
13// CHECK:      %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
14// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
15// CHECK:      return %[[T7]] : vector<2x3xf32>
16
17func.func @outerproduct_noacc(%arg0: vector<2xf32>,
18                              %arg1: vector<3xf32>) -> vector<2x3xf32> {
19  %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
20  return %0: vector<2x3xf32>
21}
22
23// CHECK-LABEL: func @outerproduct_acc
24// CHECK-SAME: %[[A:.*0]]: vector<2xf32>,
25// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
26// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
27// CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
28// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
29// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
30// CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
31// CHECK:      %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
32// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
33// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
34// CHECK:      %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
35// CHECK:      %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
36// CHECK:      %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
37// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
38// CHECK:      return %[[T9]] : vector<2x3xf32>
39
40func.func @outerproduct_acc(%arg0: vector<2xf32>,
41                            %arg1: vector<3xf32>,
42                            %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
43  %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32>
44  return %0: vector<2x3xf32>
45}
46
47// CHECK-LABEL: func @outerproduct_noacc_int
48// CHECK-SAME: %[[A:.*0]]: vector<2xi32>,
49// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
50// CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
51// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
52// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
53// CHECK:      %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
54// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
55// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
56// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
57// CHECK:      %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
58// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
59// CHECK:      return %[[T7]] : vector<2x3xi32>
60func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
61                                  %arg1: vector<3xi32>) -> vector<2x3xi32> {
62  %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32>
63  return %0: vector<2x3xi32>
64}
65
66// CHECK-LABEL: func @outerproduct_acc_int
67// CHECK-SAME: %[[A:.*0]]: vector<2xi32>,
68// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
69// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
70// CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
71// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
72// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
73// CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
74// CHECK:      %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
75// CHECK:      %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
76// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
77// CHECK:      %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
78// CHECK:      %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
79// CHECK:      %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
80// CHECK:      %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
81// CHECK:      %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
82// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32>
83// CHECK:      return %[[T11]] : vector<2x3xi32>
84func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
85                                %arg1: vector<3xi32>,
86                                %arg2: vector<2x3xi32>) -> vector<2x3xi32> {
87  %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32>
88  return %0: vector<2x3xi32>
89}
90
91// CHECK-LABEL: func @axpy_fp(
92// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
93// CHECK-SAME: %[[B:.*1]]: f32)
94// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
95// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
96// CHECK: return %[[T1]] : vector<16xf32>
97func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
98  %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32
99  return %0: vector<16xf32>
100}
101
102// CHECK-LABEL: func @axpy_fp_add(
103// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
104// CHECK-SAME: %[[B:.*1]]: f32,
105// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
106// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
107// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
108// CHECK: return %[[T1]] : vector<16xf32>
109func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
110  %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32
111  return %0: vector<16xf32>
112}
113
114// CHECK-LABEL: func @axpy_int(
115// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
116// CHECK-SAME: %[[B:.*1]]: i32)
117// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
118// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
119// CHECK: return %[[T1]] : vector<16xi32>
120func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
121  %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32
122  return %0: vector<16xi32>
123}
124
125// CHECK-LABEL: func @axpy_int_add(
126// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
127// CHECK-SAME: %[[B:.*1]]: i32,
128// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
129// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
130// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
131// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
132// CHECK: return %[[T2]] : vector<16xi32>
133func.func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> {
134  %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32
135  return %0: vector<16xi32>
136}
137
138module attributes {transform.with_named_sequence} {
139  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
140    %f = transform.structured.match ops{["func.func"]} in %module_op
141      : (!transform.any_op) -> !transform.any_op
142
143    transform.apply_patterns to %f {
144      transform.apply_patterns.vector.lower_outerproduct
145    } : !transform.any_op
146
147    transform.apply_patterns to %f {
148      transform.apply_patterns.vector.lower_broadcast
149    } : !transform.any_op
150    transform.yield
151  }
152}
153