xref: /llvm-project/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll (revision f10153fe91508966aef062ba062271631f2c0f88)
1; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2; REQUIRES: aarch64-registered-target
3
4; RUN: opt -passes='lower-matrix-intrinsics' -matrix-print-after-transpose-opt -disable-output %s 2>&1 | FileCheck %s
5
6target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
7target triple = "aarch64-apple-ios"
8
9; k * A^T
10define void @kat(ptr %Aptr, double %k, ptr %C) {
11; CHECK-LABEL: @kat(
12; CHECK-NEXT:  entry:
13; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
14; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
15; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
16; CHECK-NEXT:    [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
17; CHECK-NEXT:    [[MUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[SPLAT]], <9 x double> [[AT]], i32 3, i32 3, i32 3)
18; CHECK-NEXT:    store <9 x double> [[MUL]], ptr [[C:%.*]], align 128
19; CHECK-NEXT:    ret void
20;
21entry:
22  %a = load <9 x double>, ptr %Aptr
23  %veck = insertelement <9 x double> poison, double %k, i64 0
24  %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
25  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
26  %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %at, i32 3, i32 3, i32 3)
27  store <9 x double> %mul, ptr %C
28  ret void
29}
30
31; (k * A)^T -> A^T * k
32define void @ka_t(ptr %Aptr, double %k, ptr %C) {
33; CHECK-LABEL: @ka_t(
34; CHECK-NEXT:  entry:
35; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
36; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
37; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
38; CHECK-NEXT:    [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
39; CHECK-NEXT:    [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[A_T]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
40; CHECK-NEXT:    store <9 x double> [[MMUL]], ptr [[C:%.*]], align 128
41; CHECK-NEXT:    ret void
42;
43entry:
44  %a = load <9 x double>, ptr %Aptr
45  %veck = insertelement <9 x double> poison, double %k, i64 0
46  %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
47  %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %a, i32 3, i32 3, i32 3)
48  %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3)
49  store <9 x double> %t, ptr %C
50  ret void
51}
52
53; (k * A)^T -> A^T * k with fmul
54define void @ka_t_fmul(ptr %Aptr, double %k, ptr %C) {
55; CHECK-LABEL: @ka_t_fmul(
56; CHECK-NEXT:  entry:
57; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
58; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
59; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
60; CHECK-NEXT:    [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
61; CHECK-NEXT:    [[MMUL:%.*]] = fmul <9 x double> [[SPLAT]], [[A_T]]
62; CHECK-NEXT:    store <9 x double> [[MMUL]], ptr [[C:%.*]], align 128
63; CHECK-NEXT:    ret void
64;
65entry:
66  %a = load <9 x double>, ptr %Aptr
67  %veck = insertelement <9 x double> poison, double %k, i64 0
68  %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
69  %mul = fmul <9 x double> %splat, %a
70  %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3)
71  store <9 x double> %t, ptr %C
72  ret void
73}
74
75; (k * A)^T -> A^T * k with mul (non-fp types)
76define void @ka_t_mul(ptr %Aptr, i32 %k, ptr %C) {
77; CHECK-LABEL: @ka_t_mul(
78; CHECK-NEXT:  entry:
79; CHECK-NEXT:    [[A:%.*]] = load <9 x i32>, ptr [[APTR:%.*]], align 64
80; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x i32> poison, i32 [[K:%.*]], i64 0
81; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x i32> [[VECK]], <9 x i32> poison, <9 x i32> zeroinitializer
82; CHECK-NEXT:    [[A_T:%.*]] = call <9 x i32> @llvm.matrix.transpose.v9i32(<9 x i32> [[A]], i32 3, i32 3)
83; CHECK-NEXT:    [[MMUL:%.*]] = mul <9 x i32> [[SPLAT]], [[A_T]]
84; CHECK-NEXT:    store <9 x i32> [[MMUL]], ptr [[C:%.*]], align 64
85; CHECK-NEXT:    ret void
86;
87entry:
88  %a = load <9 x i32>, ptr %Aptr
89  %veck = insertelement <9 x i32> poison, i32 %k, i64 0
90  %splat = shufflevector <9 x i32> %veck, <9 x i32> poison, <9 x i32> zeroinitializer
91  %mul = mul <9 x i32> %splat, %a
92  %t = call <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32> %mul, i32 3, i32 3)
93  store <9 x i32> %t, ptr %C
94  ret void
95}
96
97; A^T + B^T -> (A + B)^T
98define void @at_plus_bt(ptr %Aptr, ptr %Bptr, ptr %C) {
99; CHECK-LABEL: @at_plus_bt(
100; CHECK-NEXT:  entry:
101; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
102; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
103; CHECK-NEXT:    [[MFADD:%.*]] = fadd <9 x double> [[A]], [[B]]
104; CHECK-NEXT:    [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3)
105; CHECK-NEXT:    store <9 x double> [[MFADD_T]], ptr [[C:%.*]], align 128
106; CHECK-NEXT:    ret void
107;
108entry:
109  %a = load <9 x double>, ptr %Aptr
110  %b = load <9 x double>, ptr %Bptr
111  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
112  %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
113  %fadd = fadd <9 x double> %at, %bt
114  store <9 x double> %fadd, ptr %C
115  ret void
116}
117
118; (A + B)^T -> A^T + B^T -> (A + B)^T
119define void @a_plus_b_t(ptr %Aptr, ptr %Bptr, ptr %C) {
120; CHECK-LABEL: @a_plus_b_t(
121; CHECK-NEXT:  entry:
122; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
123; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
124; CHECK-NEXT:    [[MFADD:%.*]] = fadd <9 x double> [[A]], [[B]]
125; CHECK-NEXT:    [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3)
126; CHECK-NEXT:    store <9 x double> [[MFADD_T]], ptr [[C:%.*]], align 128
127; CHECK-NEXT:    ret void
128;
129entry:
130  %a = load <9 x double>, ptr %Aptr
131  %b = load <9 x double>, ptr %Bptr
132  %fadd = fadd <9 x double> %a, %b
133  %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3)
134  store <9 x double> %t, ptr %C
135  ret void
136}
137
138; A^T * B^T + C^T * D^T -> (B * A + D * C)^T
139define void @atbt_plus_ctdt(ptr %Aptr, ptr %Bptr, ptr %Cptr, ptr %Dptr, ptr %E) {
140; CHECK-LABEL: @atbt_plus_ctdt(
141; CHECK-NEXT:  entry:
142; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
143; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
144; CHECK-NEXT:    [[C:%.*]] = load <9 x double>, ptr [[CPTR:%.*]], align 128
145; CHECK-NEXT:    [[D:%.*]] = load <9 x double>, ptr [[DPTR:%.*]], align 128
146; CHECK-NEXT:    [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
147; CHECK-NEXT:    [[TMP1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[C]], i32 3, i32 3, i32 3)
148; CHECK-NEXT:    [[MFADD:%.*]] = fadd <9 x double> [[TMP0]], [[TMP1]]
149; CHECK-NEXT:    [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3)
150; CHECK-NEXT:    store <9 x double> [[MFADD_T]], ptr [[E:%.*]], align 128
151; CHECK-NEXT:    ret void
152;
153entry:
154  %a = load <9 x double>, ptr %Aptr
155  %b = load <9 x double>, ptr %Bptr
156  %c = load <9 x double>, ptr %Cptr
157  %d = load <9 x double>, ptr %Dptr
158  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
159  %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
160  %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3)
161  %dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3)
162  %atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3)
163  %ctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %ct, <9 x double> %dt, i32 3, i32 3, i32 3)
164  %fadd = fadd <9 x double> %atbt, %ctdt
165  store <9 x double> %fadd, ptr %E
166  ret void
167}
168
169; -(A^T) + B^T
170define void @negat_plus_bt(ptr %Aptr, ptr %Bptr, ptr %C) {
171; CHECK-LABEL: @negat_plus_bt(
172; CHECK-NEXT:  entry:
173; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
174; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
175; CHECK-NEXT:    [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
176; CHECK-NEXT:    [[NEGAT:%.*]] = fneg <9 x double> [[AT]]
177; CHECK-NEXT:    [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3)
178; CHECK-NEXT:    [[FADD:%.*]] = fadd <9 x double> [[NEGAT]], [[BT]]
179; CHECK-NEXT:    store <9 x double> [[FADD]], ptr [[C:%.*]], align 128
180; CHECK-NEXT:    ret void
181;
182entry:
183  %a = load <9 x double>, ptr %Aptr
184  %b = load <9 x double>, ptr %Bptr
185  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
186  %negat = fneg <9 x double> %at
187  %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
188  %fadd = fadd <9 x double> %negat, %bt
189  store <9 x double> %fadd, ptr %C
190  ret void
191}
192
193; (A^T * B^T + k * C^T * D^T)^T -> (B * A) + (D * C * k)
194define void @atbt_plus_kctdt_t(ptr %Aptr, ptr %Bptr, ptr %Cptr, ptr %Dptr, double %k, ptr %E) {
195; CHECK-LABEL: @atbt_plus_kctdt_t(
196; CHECK-NEXT:  entry:
197; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
198; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
199; CHECK-NEXT:    [[C:%.*]] = load <9 x double>, ptr [[CPTR:%.*]], align 128
200; CHECK-NEXT:    [[D:%.*]] = load <9 x double>, ptr [[DPTR:%.*]], align 128
201; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
202; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
203; CHECK-NEXT:    [[MMUL2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
204; CHECK-NEXT:    [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[C]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
205; CHECK-NEXT:    [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[MMUL1]], i32 3, i32 3, i32 3)
206; CHECK-NEXT:    [[MADD:%.*]] = fadd <9 x double> [[MMUL2]], [[MMUL]]
207; CHECK-NEXT:    store <9 x double> [[MADD]], ptr [[E:%.*]], align 128
208; CHECK-NEXT:    ret void
209;
210entry:
211  %a = load <9 x double>, ptr %Aptr
212  %b = load <9 x double>, ptr %Bptr
213  %c = load <9 x double>, ptr %Cptr
214  %d = load <9 x double>, ptr %Dptr
215  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
216  %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
217  %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3)
218  %dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3)
219  %atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3)
220  %veck = insertelement <9 x double> poison, double %k, i64 0
221  %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
222  %kct = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %ct, i32 3, i32 3, i32 3)
223  %kctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %kct, <9 x double> %dt, i32 3, i32 3, i32 3)
224  %fadd = fadd <9 x double> %atbt, %kctdt
225  %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3)
226  store <9 x double> %t, ptr %E
227  ret void
228}
229
230; (A^T * (k * B^T))^T => (B * k) * A
231define void @atkbt_t(ptr %Aptr, ptr %Bptr, double %k, ptr %C) {
232; CHECK-LABEL: @atkbt_t(
233; CHECK-NEXT:  entry:
234; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
235; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
236; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
237; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
238; CHECK-NEXT:    [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
239; CHECK-NEXT:    [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[MMUL1]], <9 x double> [[A]], i32 3, i32 3, i32 3)
240; CHECK-NEXT:    store <9 x double> [[MMUL]], ptr [[C:%.*]], align 128
241; CHECK-NEXT:    ret void
242;
243entry:
244  %a = load <9 x double>, ptr %Aptr
245  %b = load <9 x double>, ptr %Bptr
246  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
247  %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
248  %veck = insertelement <9 x double> poison, double %k, i64 0
249  %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
250  %kbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %bt, i32 3, i32 3, i32 3)
251  %atkbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %kbt, i32 3, i32 3, i32 3)
252  %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %atkbt, i32 3, i32 3)
253  store <9 x double> %t, ptr %C
254  ret void
255}
256
257declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
258declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32)
259declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32)
260
261
262; (a * b + c)^T -> (a * b)^T + b^T with integer types.
263define noundef <4 x i32> @mul_add_transpose_int(<4 x i32> noundef %a, <4 x i32> noundef %b, <4 x i32> noundef %c) {
264; CHECK-LABEL: @mul_add_transpose_int(
265; CHECK-NEXT:  entry:
266; CHECK-NEXT:    [[TMP0:%.*]] = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], i32 2, i32 2, i32 2)
267; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[TMP0]], i32 2, i32 2)
268; CHECK-NEXT:    [[C_T:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[C:%.*]], i32 2, i32 2)
269; CHECK-NEXT:    [[MADD:%.*]] = add <4 x i32> [[TMP1]], [[C_T]]
270; CHECK-NEXT:    ret <4 x i32> [[MADD]]
271;
272entry:
273  %mul = tail call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b, i32 2, i32 2, i32 2)
274  %add = add <4 x i32> %mul, %c
275  %t = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %add, i32 2, i32 2)
276  ret <4 x i32> %t
277}
278
279declare <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32 immarg, i32 immarg, i32 immarg)
280
281declare <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32>, i32 immarg, i32 immarg)
282