1; NOTE: Assertions have been autogenerated by utils/update_test_checks.py 2; REQUIRES: aarch64-registered-target 3 4; RUN: opt -passes='lower-matrix-intrinsics' -matrix-print-after-transpose-opt -disable-output %s 2>&1 | FileCheck %s 5 6target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" 7target triple = "aarch64-apple-ios" 8 9; k * A^T 10define void @kat(ptr %Aptr, double %k, ptr %C) { 11; CHECK-LABEL: @kat( 12; CHECK-NEXT: entry: 13; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128 14; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0 15; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer 16; CHECK-NEXT: [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3) 17; CHECK-NEXT: [[MUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[SPLAT]], <9 x double> [[AT]], i32 3, i32 3, i32 3) 18; CHECK-NEXT: store <9 x double> [[MUL]], ptr [[C:%.*]], align 128 19; CHECK-NEXT: ret void 20; 21entry: 22 %a = load <9 x double>, ptr %Aptr 23 %veck = insertelement <9 x double> poison, double %k, i64 0 24 %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer 25 %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) 26 %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %at, i32 3, i32 3, i32 3) 27 store <9 x double> %mul, ptr %C 28 ret void 29} 30 31; (k * A)^T -> A^T * k 32define void @ka_t(ptr %Aptr, double %k, ptr %C) { 33; CHECK-LABEL: @ka_t( 34; CHECK-NEXT: entry: 35; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128 36; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0 37; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer 38; CHECK-NEXT: [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3) 39; CHECK-NEXT: [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[A_T]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3) 40; CHECK-NEXT: store <9 x double> [[MMUL]], ptr [[C:%.*]], align 128 41; CHECK-NEXT: ret void 42; 43entry: 44 %a = load <9 x double>, ptr %Aptr 45 %veck = insertelement <9 x double> poison, double %k, i64 0 46 %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer 47 %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %a, i32 3, i32 3, i32 3) 48 %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3) 49 store <9 x double> %t, ptr %C 50 ret void 51} 52 53; (k * A)^T -> A^T * k with fmul 54define void @ka_t_fmul(ptr %Aptr, double %k, ptr %C) { 55; CHECK-LABEL: @ka_t_fmul( 56; CHECK-NEXT: entry: 57; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128 58; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0 59; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer 60; CHECK-NEXT: [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3) 61; CHECK-NEXT: [[MMUL:%.*]] = fmul <9 x double> [[SPLAT]], [[A_T]] 62; CHECK-NEXT: store <9 x double> [[MMUL]], ptr [[C:%.*]], align 128 63; CHECK-NEXT: ret void 64; 65entry: 66 %a = load <9 x double>, ptr %Aptr 67 %veck = insertelement <9 x double> poison, double %k, i64 0 68 %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer 69 %mul = fmul <9 x double> %splat, %a 70 %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3) 71 store <9 x double> %t, ptr %C 72 ret void 73} 74 75; (k * A)^T -> A^T * k with mul (non-fp types) 76define void @ka_t_mul(ptr %Aptr, i32 %k, ptr %C) { 77; CHECK-LABEL: @ka_t_mul( 78; CHECK-NEXT: entry: 79; CHECK-NEXT: [[A:%.*]] = load <9 x i32>, ptr [[APTR:%.*]], align 64 80; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x i32> poison, i32 [[K:%.*]], i64 0 81; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x i32> [[VECK]], <9 x i32> poison, <9 x i32> zeroinitializer 82; CHECK-NEXT: [[A_T:%.*]] = call <9 x i32> @llvm.matrix.transpose.v9i32(<9 x i32> [[A]], i32 3, i32 3) 83; CHECK-NEXT: [[MMUL:%.*]] = mul <9 x i32> [[SPLAT]], [[A_T]] 84; CHECK-NEXT: store <9 x i32> [[MMUL]], ptr [[C:%.*]], align 64 85; CHECK-NEXT: ret void 86; 87entry: 88 %a = load <9 x i32>, ptr %Aptr 89 %veck = insertelement <9 x i32> poison, i32 %k, i64 0 90 %splat = shufflevector <9 x i32> %veck, <9 x i32> poison, <9 x i32> zeroinitializer 91 %mul = mul <9 x i32> %splat, %a 92 %t = call <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32> %mul, i32 3, i32 3) 93 store <9 x i32> %t, ptr %C 94 ret void 95} 96 97; A^T + B^T -> (A + B)^T 98define void @at_plus_bt(ptr %Aptr, ptr %Bptr, ptr %C) { 99; CHECK-LABEL: @at_plus_bt( 100; CHECK-NEXT: entry: 101; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128 102; CHECK-NEXT: [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128 103; CHECK-NEXT: [[MFADD:%.*]] = fadd <9 x double> [[A]], [[B]] 104; CHECK-NEXT: [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3) 105; CHECK-NEXT: store <9 x double> [[MFADD_T]], ptr [[C:%.*]], align 128 106; CHECK-NEXT: ret void 107; 108entry: 109 %a = load <9 x double>, ptr %Aptr 110 %b = load <9 x double>, ptr %Bptr 111 %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) 112 %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) 113 %fadd = fadd <9 x double> %at, %bt 114 store <9 x double> %fadd, ptr %C 115 ret void 116} 117 118; (A + B)^T -> A^T + B^T -> (A + B)^T 119define void @a_plus_b_t(ptr %Aptr, ptr %Bptr, ptr %C) { 120; CHECK-LABEL: @a_plus_b_t( 121; CHECK-NEXT: entry: 122; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128 123; CHECK-NEXT: [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128 124; CHECK-NEXT: [[MFADD:%.*]] = fadd <9 x double> [[A]], [[B]] 125; CHECK-NEXT: [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3) 126; CHECK-NEXT: store <9 x double> [[MFADD_T]], ptr [[C:%.*]], align 128 127; CHECK-NEXT: ret void 128; 129entry: 130 %a = load <9 x double>, ptr %Aptr 131 %b = load <9 x double>, ptr %Bptr 132 %fadd = fadd <9 x double> %a, %b 133 %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3) 134 store <9 x double> %t, ptr %C 135 ret void 136} 137 138; A^T * B^T + C^T * D^T -> (B * A + D * C)^T 139define void @atbt_plus_ctdt(ptr %Aptr, ptr %Bptr, ptr %Cptr, ptr %Dptr, ptr %E) { 140; CHECK-LABEL: @atbt_plus_ctdt( 141; CHECK-NEXT: entry: 142; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128 143; CHECK-NEXT: [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128 144; CHECK-NEXT: [[C:%.*]] = load <9 x double>, ptr [[CPTR:%.*]], align 128 145; CHECK-NEXT: [[D:%.*]] = load <9 x double>, ptr [[DPTR:%.*]], align 128 146; CHECK-NEXT: [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3) 147; CHECK-NEXT: [[TMP1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[C]], i32 3, i32 3, i32 3) 148; CHECK-NEXT: [[MFADD:%.*]] = fadd <9 x double> [[TMP0]], [[TMP1]] 149; CHECK-NEXT: [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3) 150; CHECK-NEXT: store <9 x double> [[MFADD_T]], ptr [[E:%.*]], align 128 151; CHECK-NEXT: ret void 152; 153entry: 154 %a = load <9 x double>, ptr %Aptr 155 %b = load <9 x double>, ptr %Bptr 156 %c = load <9 x double>, ptr %Cptr 157 %d = load <9 x double>, ptr %Dptr 158 %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) 159 %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) 160 %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3) 161 %dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3) 162 %atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3) 163 %ctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %ct, <9 x double> %dt, i32 3, i32 3, i32 3) 164 %fadd = fadd <9 x double> %atbt, %ctdt 165 store <9 x double> %fadd, ptr %E 166 ret void 167} 168 169; -(A^T) + B^T 170define void @negat_plus_bt(ptr %Aptr, ptr %Bptr, ptr %C) { 171; CHECK-LABEL: @negat_plus_bt( 172; CHECK-NEXT: entry: 173; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128 174; CHECK-NEXT: [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128 175; CHECK-NEXT: [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3) 176; CHECK-NEXT: [[NEGAT:%.*]] = fneg <9 x double> [[AT]] 177; CHECK-NEXT: [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3) 178; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[NEGAT]], [[BT]] 179; CHECK-NEXT: store <9 x double> [[FADD]], ptr [[C:%.*]], align 128 180; CHECK-NEXT: ret void 181; 182entry: 183 %a = load <9 x double>, ptr %Aptr 184 %b = load <9 x double>, ptr %Bptr 185 %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) 186 %negat = fneg <9 x double> %at 187 %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) 188 %fadd = fadd <9 x double> %negat, %bt 189 store <9 x double> %fadd, ptr %C 190 ret void 191} 192 193; (A^T * B^T + k * C^T * D^T)^T -> (B * A) + (D * C * k) 194define void @atbt_plus_kctdt_t(ptr %Aptr, ptr %Bptr, ptr %Cptr, ptr %Dptr, double %k, ptr %E) { 195; CHECK-LABEL: @atbt_plus_kctdt_t( 196; CHECK-NEXT: entry: 197; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128 198; CHECK-NEXT: [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128 199; CHECK-NEXT: [[C:%.*]] = load <9 x double>, ptr [[CPTR:%.*]], align 128 200; CHECK-NEXT: [[D:%.*]] = load <9 x double>, ptr [[DPTR:%.*]], align 128 201; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0 202; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer 203; CHECK-NEXT: [[MMUL2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3) 204; CHECK-NEXT: [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[C]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3) 205; CHECK-NEXT: [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[MMUL1]], i32 3, i32 3, i32 3) 206; CHECK-NEXT: [[MADD:%.*]] = fadd <9 x double> [[MMUL2]], [[MMUL]] 207; CHECK-NEXT: store <9 x double> [[MADD]], ptr [[E:%.*]], align 128 208; CHECK-NEXT: ret void 209; 210entry: 211 %a = load <9 x double>, ptr %Aptr 212 %b = load <9 x double>, ptr %Bptr 213 %c = load <9 x double>, ptr %Cptr 214 %d = load <9 x double>, ptr %Dptr 215 %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) 216 %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) 217 %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3) 218 %dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3) 219 %atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3) 220 %veck = insertelement <9 x double> poison, double %k, i64 0 221 %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer 222 %kct = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %ct, i32 3, i32 3, i32 3) 223 %kctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %kct, <9 x double> %dt, i32 3, i32 3, i32 3) 224 %fadd = fadd <9 x double> %atbt, %kctdt 225 %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3) 226 store <9 x double> %t, ptr %E 227 ret void 228} 229 230; (A^T * (k * B^T))^T => (B * k) * A 231define void @atkbt_t(ptr %Aptr, ptr %Bptr, double %k, ptr %C) { 232; CHECK-LABEL: @atkbt_t( 233; CHECK-NEXT: entry: 234; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128 235; CHECK-NEXT: [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128 236; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0 237; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer 238; CHECK-NEXT: [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3) 239; CHECK-NEXT: [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[MMUL1]], <9 x double> [[A]], i32 3, i32 3, i32 3) 240; CHECK-NEXT: store <9 x double> [[MMUL]], ptr [[C:%.*]], align 128 241; CHECK-NEXT: ret void 242; 243entry: 244 %a = load <9 x double>, ptr %Aptr 245 %b = load <9 x double>, ptr %Bptr 246 %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) 247 %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) 248 %veck = insertelement <9 x double> poison, double %k, i64 0 249 %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer 250 %kbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %bt, i32 3, i32 3, i32 3) 251 %atkbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %kbt, i32 3, i32 3, i32 3) 252 %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %atkbt, i32 3, i32 3) 253 store <9 x double> %t, ptr %C 254 ret void 255} 256 257declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32) 258declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32) 259declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32) 260 261 262; (a * b + c)^T -> (a * b)^T + b^T with integer types. 263define noundef <4 x i32> @mul_add_transpose_int(<4 x i32> noundef %a, <4 x i32> noundef %b, <4 x i32> noundef %c) { 264; CHECK-LABEL: @mul_add_transpose_int( 265; CHECK-NEXT: entry: 266; CHECK-NEXT: [[TMP0:%.*]] = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], i32 2, i32 2, i32 2) 267; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[TMP0]], i32 2, i32 2) 268; CHECK-NEXT: [[C_T:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[C:%.*]], i32 2, i32 2) 269; CHECK-NEXT: [[MADD:%.*]] = add <4 x i32> [[TMP1]], [[C_T]] 270; CHECK-NEXT: ret <4 x i32> [[MADD]] 271; 272entry: 273 %mul = tail call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b, i32 2, i32 2, i32 2) 274 %add = add <4 x i32> %mul, %c 275 %t = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %add, i32 2, i32 2) 276 ret <4 x i32> %t 277} 278 279declare <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32 immarg, i32 immarg, i32 immarg) 280 281declare <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32>, i32 immarg, i32 immarg) 282