xref: /llvm-project/clang/test/CodeGenCXX/matrix-type-builtins.cpp (revision 38fffa630ee80163dc65e759392ad29798905679)
1 // RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++17 | FileCheck %s
2 
3 // Tests for the matrix type builtins.
4 
5 template <typename EltTy, unsigned Rows, unsigned Columns>
6 using matrix_t = EltTy __attribute__((matrix_type(Rows, Columns)));
7 
8 template <typename EltTy, unsigned Rows, unsigned Columns>
9 struct MyMatrix {
10   matrix_t<EltTy, Rows, Columns> value;
11 };
12 
13 template <typename T, unsigned R, unsigned C>
14 MyMatrix<T, C, R> transpose(const MyMatrix<T, R, C> &M) {
15   MyMatrix<T, C, R> Res;
16   Res.value = __builtin_matrix_transpose(M.value);
17   return Res;
18 }
19 
20 void test_transpose_template1() {
21   // CHECK-LABEL: define{{.*}} void @_Z24test_transpose_template1v()
22   // CHECK:         call void @_Z9transposeIiLj4ELj10EE8MyMatrixIT_XT1_EXT0_EERKS0_IS1_XT0_EXT1_EE(ptr dead_on_unwind writable sret(%struct.MyMatrix.0) align 4 %M1_t, ptr noundef nonnull align 4 dereferenceable(160) %M1)
23 
24   // CHECK-LABEL: define linkonce_odr void @_Z9transposeIiLj4ELj10EE8MyMatrixIT_XT1_EXT0_EERKS0_IS1_XT0_EXT1_EE(
25   // CHECK:         [[M:%.*]] = load <40 x i32>, ptr {{.*}}, align 4
26   // CHECK-NEXT:    [[M_T:%.*]] = call <40 x i32> @llvm.matrix.transpose.v40i32(<40 x i32> [[M]], i32 4, i32 10)
27 
28   MyMatrix<int, 4, 10> M1;
29   MyMatrix<int, 10, 4> M1_t = transpose(M1);
30 }
31 
32 void test_transpose_template2(MyMatrix<double, 7, 6> &M) {
33   // CHECK-LABEL: define{{.*}} void @_Z24test_transpose_template2R8MyMatrixIdLj7ELj6EE(
34   // CHECK:         call void @_Z9transposeIdLj7ELj6EE8MyMatrixIT_XT1_EXT0_EERKS0_IS1_XT0_EXT1_EE(ptr dead_on_unwind writable sret(%struct.MyMatrix.1) align 8 %ref.tmp1, ptr noundef nonnull align 8 dereferenceable(336) %0)
35   // CHECK-NEXT:    call void @_Z9transposeIdLj6ELj7EE8MyMatrixIT_XT1_EXT0_EERKS0_IS1_XT0_EXT1_EE(ptr dead_on_unwind writable sret(%struct.MyMatrix.2) align 8 %ref.tmp, ptr noundef nonnull align 8 dereferenceable(336) %ref.tmp1)
36   // CHECK-NEXT:    call void @_Z9transposeIdLj7ELj6EE8MyMatrixIT_XT1_EXT0_EERKS0_IS1_XT0_EXT1_EE(ptr dead_on_unwind writable sret(%struct.MyMatrix.1) align 8 %M2_t, ptr noundef nonnull align 8 dereferenceable(336) %ref.tmp)
37 
38   // CHECK-LABEL: define linkonce_odr void @_Z9transposeIdLj7ELj6EE8MyMatrixIT_XT1_EXT0_EERKS0_IS1_XT0_EXT1_EE(
39   // CHECK:         [[M:%.*]] = load <42 x double>, ptr {{.*}}, align 8
40   // CHECK-NEXT:    [[M_T:%.*]] = call <42 x double> @llvm.matrix.transpose.v42f64(<42 x double> [[M]], i32 7, i32 6)
41   // CHECK-NEXT:    [[RES_ADDR:%.*]] = getelementptr inbounds nuw %struct.MyMatrix.1, ptr %agg.result, i32 0, i32 0
42   // CHECK-NEXT:    store <42 x double> [[M_T]], ptr [[RES_ADDR]], align 8
43 
44   // CHECK-LABEL: define linkonce_odr void @_Z9transposeIdLj6ELj7EE8MyMatrixIT_XT1_EXT0_EERKS0_IS1_XT0_EXT1_EE(
45   // CHECK:         [[M:%.*]] = load <42 x double>, ptr {{.*}}, align 8
46   // CHECK-NEXT:    [[M_T:%.*]] = call <42 x double> @llvm.matrix.transpose.v42f64(<42 x double> [[M]], i32 6, i32 7)
47   // CHECK-NEXT:    [[RES_ADDR:%.*]] = getelementptr inbounds nuw %struct.MyMatrix.2, ptr %agg.result, i32 0, i32 0
48   // CHECK-NEXT:    store <42 x double> [[M_T]], ptr [[RES_ADDR]], align 8
49 
50   MyMatrix<double, 6, 7> M2_t = transpose(transpose(transpose(M)));
51 }
52 
53 matrix_t<float, 3, 3> get_matrix();
54 
55 void test_transpose_rvalue() {
56   // CHECK-LABEL: define{{.*}} void @_Z21test_transpose_rvaluev()
57   // CHECK-NEXT:  entry:
58   // CHECK-NEXT:    [[M_T_ADDR:%.*]] = alloca [9 x float], align 4
59   // CHECK-NEXT:    [[CALL_RES:%.*]] = call noundef <9 x float> @_Z10get_matrixv()
60   // CHECK-NEXT:    [[ADD:%.*]] = fadd <9 x float> [[CALL_RES]], splat (float 2.000000e+00)
61   // CHECK-NEXT:    [[M_T:%.*]] = call <9 x float> @llvm.matrix.transpose.v9f32(<9 x float> [[ADD]], i32 3, i32 3)
62   // CHECK-NEXT:    store <9 x float> [[M_T]], ptr [[M_T_ADDR]], align 4
63   matrix_t<float, 3, 3> m_t = __builtin_matrix_transpose(get_matrix() + 2.0);
64 }
65 
66 void test_transpose_const(const matrix_t<float, 3, 3> &m) {
67   // CHECK-LABEL:  define{{.*}} void @_Z20test_transpose_constRKu11matrix_typeILm3ELm3EfE(
68   // CHECK:         [[MATRIX:%.*]] = load <9 x float>, ptr {{.*}}, align 4
69   // CHECK-NEXT:    [[M_T:%.*]] = call <9 x float> @llvm.matrix.transpose.v9f32(<9 x float> [[MATRIX]], i32 3, i32 3)
70   // CHECK-NEXT:    store <9 x float> [[M_T]], ptr %m_t, align 4
71   matrix_t<float, 3, 3> m_t = __builtin_matrix_transpose(m);
72 }
73 
74 // TODO: Enable once initialization support is defined and implemented for
75 //       matrix types.
76 // void test_lvalue_conversion() {
77 //  constexpr double4x4 m = {};
78 //  [] { return __builtin_matrix_transpose(m); }
79 //}
80 
81 template <typename T, unsigned R, unsigned C, unsigned S>
82 matrix_t<T, R, C> column_major_load_with_stride(T *Ptr) {
83   return __builtin_matrix_column_major_load(Ptr, R, C, S);
84 }
85 
86 void test_column_major_load_with_stride_template_double(double *Ptr) {
87   // CHECK-LABEL: define{{.*}} void @_Z50test_column_major_load_with_stride_template_doublePd(ptr noundef %Ptr)
88   // CHECK:         [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
89   // CHECK-NEXT:    call noundef <40 x double> @_Z29column_major_load_with_strideIdLj10ELj4ELj15EEu11matrix_typeIXT0_EXT1_ET_EPS0_(ptr noundef [[PTR]])
90 
91   // CHECK-LABEL:  define linkonce_odr noundef <40 x double> @_Z29column_major_load_with_strideIdLj10ELj4ELj15EEu11matrix_typeIXT0_EXT1_ET_EPS0_(ptr noundef %Ptr)
92   // CHECK:         [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
93   // CHECK-NEXT:    call <40 x double> @llvm.matrix.column.major.load.v40f64.i64(ptr align 8 [[PTR]], i64 15, i1 false, i32 10, i32 4)
94 
95   matrix_t<double, 10, 4> M1 = column_major_load_with_stride<double, 10, 4, 15>(Ptr);
96 }
97 
98 void test_column_major_load_with_stride_template_int(int *Ptr) {
99   // CHECK-LABEL: define{{.*}} void @_Z47test_column_major_load_with_stride_template_intPi(ptr noundef %Ptr) #5 {
100   // CHECK:         [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
101   // CHECK-NEXT:    call noundef <6 x i32> @_Z29column_major_load_with_strideIiLj3ELj2ELj12EEu11matrix_typeIXT0_EXT1_ET_EPS0_(ptr noundef [[PTR]])
102 
103   // CHECK-LABEL: define linkonce_odr noundef <6 x i32> @_Z29column_major_load_with_strideIiLj3ELj2ELj12EEu11matrix_typeIXT0_EXT1_ET_EPS0_(ptr noundef %Ptr)
104   // CHECK:         [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
105   // CHECK-NEXT:    call <6 x i32> @llvm.matrix.column.major.load.v6i32.i64(ptr align 4 [[PTR]], i64 12, i1 false, i32 3, i32 2)
106 
107   matrix_t<int, 3, 2> M1 = column_major_load_with_stride<int, 3, 2, 12>(Ptr);
108 }
109 
110 struct UnsignedWrapper {
111   char x;
112   operator unsigned() {
113     return x;
114   }
115 };
116 
117 void test_column_major_load_stride_wrapper(int *Ptr, UnsignedWrapper &W) {
118   // CHECK-LABEL:  define{{.*}} void @_Z37test_column_major_load_stride_wrapperPiR15UnsignedWrapper(ptr noundef %Ptr, ptr noundef nonnull align 1 dereferenceable(1) %W)
119   // CHECK:         [[W:%.*]] = load ptr, ptr %W.addr, align 8
120   // CHECK-NEXT:    [[STRIDE:%.*]] = call noundef i32 @_ZN15UnsignedWrappercvjEv(ptr {{[^,]*}} [[W]])
121   // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = zext i32 [[STRIDE]] to i64
122   // CHECK-NEXT:    [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
123   // CHECK-NEXT:    call <4 x i32> @llvm.matrix.column.major.load.v4i32.i64(ptr align 4 [[PTR]], i64 [[STRIDE_EXT]], i1 false, i32 2, i32 2)
124   matrix_t<int, 2, 2> M1 = __builtin_matrix_column_major_load(Ptr, 2, 2, W);
125 }
126 
127 constexpr int constexpr3() { return 3; }
128 
129 void test_column_major_load_constexpr_num_rows(int *Ptr) {
130   // CHECK-LABEL: define{{.*}} void @_Z41test_column_major_load_constexpr_num_rowsPi(ptr noundef %Ptr)
131   // CHECK:         [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
132   // CHECK-NEXT:    call <6 x i32> @llvm.matrix.column.major.load.v6i32.i64(ptr align 4 [[PTR]], i64 3, i1 false, i32 3, i32 2)
133 
134   matrix_t<int, 3, 2> M1 = __builtin_matrix_column_major_load(Ptr, constexpr3(), 2, 3);
135 }
136 
137 constexpr int constexpr1() { return 1; }
138 
139 void test_column_major_load_constexpr_num_columns(int *Ptr) {
140   // CHECK-LABEL: define{{.*}} void @_Z44test_column_major_load_constexpr_num_columnsPi(ptr noundef %Ptr)
141   // CHECK:         [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
142   // CHECK-NEXT:    call <2 x i32> @llvm.matrix.column.major.load.v2i32.i64(ptr align 4 [[PTR]], i64 3, i1 false, i32 2, i32 1)
143   matrix_t<int, 2, 1> M1 = __builtin_matrix_column_major_load(Ptr, 2, constexpr1(), 3);
144 }
145 
146 template <unsigned N>
147 constexpr int constexpr_plus1() { return N + 1; }
148 
149 void test_column_major_load_constexpr_num_columns_temp(int *Ptr) {
150   // CHECK-LABEL:  define{{.*}} void @_Z49test_column_major_load_constexpr_num_columns_tempPi(ptr noundef %Ptr)
151   // CHECK:         [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
152   // CHECK-NEXT:    call <10 x i32> @llvm.matrix.column.major.load.v10i32.i64(ptr align 4 [[PTR]], i64 3, i1 false, i32 2, i32 5)
153   matrix_t<int, 2, 5> M1 = __builtin_matrix_column_major_load(Ptr, 2, constexpr_plus1<4>(), 3);
154 }
155 
156 void test_column_major_load_constexpr_stride_constexpr(int *Ptr) {
157   // CHECK-LABEL: define{{.*}} void @_Z49test_column_major_load_constexpr_stride_constexprPi(ptr noundef %Ptr)
158   // CHECK:         [[STRIDE:%.*]] = call noundef i32 @_Z10constexpr3v()
159   // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64
160   // CHECK-NEXT:    [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
161   // CHECK-NEXT:    call <4 x i32> @llvm.matrix.column.major.load.v4i32.i64(ptr align 4 [[PTR]], i64 [[STRIDE_EXT]], i1 false, i32 2, i32 2)
162 
163   matrix_t<int, 2, 2> M1 = __builtin_matrix_column_major_load(Ptr, 2, 2, constexpr3());
164 }
165 
166 template <typename T>
167 struct remove_pointer {
168   typedef T type;
169 };
170 
171 template <typename T>
172 struct remove_pointer<T *> {
173   typedef typename remove_pointer<T>::type type;
174 };
175 
176 // Same as column_major_load_with_stride, but with the PtrT argument itself begin a pointer type.
177 template <typename PtrT, unsigned R, unsigned C, unsigned S>
178 matrix_t<typename remove_pointer<PtrT>::type, R, C> column_major_load_with_stride2(PtrT Ptr) {
179   return __builtin_matrix_column_major_load(Ptr, R, C, S);
180 }
181 
182 void call_column_major_load_with_stride2(float *Ptr) {
183   matrix_t<float, 2, 2> m = column_major_load_with_stride2<float *, 2, 2, 2>(Ptr);
184 }
185 
186 template <typename T, unsigned R, unsigned C, unsigned S>
187 void column_major_store_with_stride(matrix_t<T, R, C> &m, T *Ptr) {
188   __builtin_matrix_column_major_store(m, Ptr, S);
189 }
190 
191 void test_column_major_store_with_stride_template_double(double *Ptr) {
192   // CHECK-LABEL: define{{.*}} void @_Z51test_column_major_store_with_stride_template_doublePd(ptr noundef %Ptr)
193   // CHECK:         [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
194   // CHECK-NEXT:    call void @_Z30column_major_store_with_strideIdLj10ELj4ELj15EEvRu11matrix_typeIXT0_EXT1_ET_EPS0_(ptr noundef nonnull align 8 dereferenceable(320) %M1, ptr noundef [[PTR]])
195 
196   // CHECK-LABEL:  define linkonce_odr void @_Z30column_major_store_with_strideIdLj10ELj4ELj15EEvRu11matrix_typeIXT0_EXT1_ET_EPS0_(ptr noundef nonnull align 8 dereferenceable(320) %m, ptr noundef %Ptr)
197   // CHECK:         [[M:%.*]] = load <40 x double>, ptr {{.*}}, align 8
198   // CHECK-NEXT:    [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
199   // CHECK-NEXT:    call void @llvm.matrix.column.major.store.v40f64.i64(<40 x double> [[M]], ptr align 8 [[PTR]], i64 15, i1 false, i32 10, i32 4)
200 
201   matrix_t<double, 10, 4> M1;
202   column_major_store_with_stride<double, 10, 4, 15>(M1, Ptr);
203 }
204 
205 void test_column_major_store_with_stride_template_int(int *Ptr) {
206   // CHECK-LABEL: define{{.*}} void @_Z48test_column_major_store_with_stride_template_intPi(ptr noundef %Ptr)
207   // CHECK:         [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
208   // CHECK-NEXT:    call void @_Z30column_major_store_with_strideIiLj3ELj2ELj3EEvRu11matrix_typeIXT0_EXT1_ET_EPS0_(ptr noundef nonnull align 4 dereferenceable(24) %M1, ptr noundef [[PTR]])
209 
210   // CHECK-LABEL:  define linkonce_odr void @_Z30column_major_store_with_strideIiLj3ELj2ELj3EEvRu11matrix_typeIXT0_EXT1_ET_EPS0_(ptr noundef nonnull align 4 dereferenceable(24) %m, ptr noundef %Ptr)
211   // CHECK:         [[M:%.*]] = load <6 x i32>, ptr {{.*}}, align 4
212   // CHECK-NEXT:    [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
213   // CHECK-NEXT:    call void @llvm.matrix.column.major.store.v6i32.i64(<6 x i32> [[M]], ptr align 4 [[PTR]], i64 3, i1 false, i32 3, i32 2)
214 
215   matrix_t<int, 3, 2> M1;
216   column_major_store_with_stride<int, 3, 2, 3>(M1, Ptr);
217 }
218 
219 void test_column_major_store_stride_wrapper(int *Ptr, UnsignedWrapper &W) {
220   // CHECK-LABEL: define{{.*}} void @_Z38test_column_major_store_stride_wrapperPiR15UnsignedWrapper(ptr noundef %Ptr, ptr noundef nonnull align 1 dereferenceable(1) %W)
221   // CHECK:         [[M:%.*]] = load <4 x i32>, ptr {{.*}}, align 4
222   // CHECK-NEXT:    [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
223   // CHECK-NEXT:    [[W:%.*]] = load ptr, ptr %W.addr, align 8
224   // CHECK-NEXT:    [[IDX:%.*]] = call noundef i32 @_ZN15UnsignedWrappercvjEv(ptr {{[^,]*}} [[W]])
225   // CHECK-NEXT:    [[IDX_EXT:%.*]] = zext i32 [[IDX]] to i64
226   // CHECK-NEXT:    call void @llvm.matrix.column.major.store.v4i32.i64(<4 x i32> [[M]], ptr align 4 [[PTR]], i64 [[IDX_EXT]], i1 false, i32 2, i32 2)
227 
228   matrix_t<int, 2, 2> M1;
229   __builtin_matrix_column_major_store(M1, Ptr, W);
230 }
231 
232 void test_column_major_store_constexpr_stride_constexpr(int *Ptr) {
233   // CHECK-LABEL: define{{.*}} void @_Z50test_column_major_store_constexpr_stride_constexprPi(ptr noundef %Ptr)
234   // CHECK:         [[M:%.*]] = load <4 x i32>, ptr %M, align 4
235   // CHECK-NEXT:    [[PTR:%.*]] = load ptr, ptr %Ptr.addr, align 8
236   // CHECK-NEXT:    [[IDX:%.*]] = call noundef i32 @_Z10constexpr3v()
237   // CHECK-NEXT:    [[IDX_EXT:%.*]] = sext i32 [[IDX]] to i64
238   // CHECK-NEXT:    call void @llvm.matrix.column.major.store.v4i32.i64(<4 x i32> [[M]], ptr align 4 [[PTR]], i64 [[IDX_EXT]], i1 false, i32 2, i32 2)
239 
240   matrix_t<int, 2, 2> M;
241   __builtin_matrix_column_major_store(M, Ptr, constexpr3());
242 }
243