1 //===-- flang/unittests/Runtime/MatmulTranspose.cpp -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "gtest/gtest.h"
10 #include "tools.h"
11 #include "flang/Runtime/allocatable.h"
12 #include "flang/Runtime/cpp-type.h"
13 #include "flang/Runtime/descriptor.h"
14 #include "flang/Runtime/matmul-transpose.h"
15 #include "flang/Runtime/type-code.h"
16
17 using namespace Fortran::runtime;
18 using Fortran::common::TypeCategory;
19
TEST(MatmulTranspose,Basic)20 TEST(MatmulTranspose, Basic) {
21 // X 0 1 Y 6 9 Z 6 7 8 M 0 0 1 1 V -1 -2
22 // 2 3 7 10 9 10 11 0 1 0 1
23 // 4 5 8 11
24
25 auto x{MakeArray<TypeCategory::Integer, 4>(
26 std::vector<int>{3, 2}, std::vector<std::int32_t>{0, 2, 4, 1, 3, 5})};
27 auto y{MakeArray<TypeCategory::Integer, 2>(
28 std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})};
29 auto z{MakeArray<TypeCategory::Integer, 2>(
30 std::vector<int>{2, 3}, std::vector<std::int16_t>{6, 9, 7, 10, 8, 11})};
31 auto m{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{2, 4},
32 std::vector<std::int16_t>{0, 0, 0, 1, 1, 0, 1, 1})};
33 auto v{MakeArray<TypeCategory::Integer, 8>(
34 std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
35 // X2 0 1 Y2 -1 -1 Z2 6 7 8
36 // 2 3 6 9 9 10 11
37 // 4 5 7 10 -1 -1 -1
38 // -1 -1 8 11
39 auto x2{MakeArray<TypeCategory::Integer, 4>(std::vector<int>{4, 2},
40 std::vector<std::int32_t>{0, 2, 4, -1, 1, 3, 5, -1})};
41 auto y2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{4, 2},
42 std::vector<std::int16_t>{-1, 6, 7, 8, -1, 9, 10, 11})};
43 auto z2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{3, 3},
44 std::vector<std::int16_t>{6, 9, -1, 7, 10, -1, 8, 11, -1})};
45
46 StaticDescriptor<2, true> statDesc;
47 Descriptor &result{statDesc.descriptor()};
48
49 RTNAME(MatmulTransposeInteger4Integer2)(result, *x, *y, __FILE__, __LINE__);
50 ASSERT_EQ(result.rank(), 2);
51 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
52 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
53 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
54 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
55 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
56 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
57 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
58 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
59 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
60
61 std::memset(
62 result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
63 result.GetDimension(0).SetLowerBound(0);
64 result.GetDimension(1).SetLowerBound(2);
65 RTNAME(MatmulTransposeDirectInteger4Integer2)
66 (result, *x, *y, __FILE__, __LINE__);
67 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
68 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
69 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
70 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
71 result.Destroy();
72
73 RTNAME(MatmulTransposeInteger2Integer8)(result, *z, *v, __FILE__, __LINE__);
74 ASSERT_EQ(result.rank(), 1);
75 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
76 EXPECT_EQ(result.GetDimension(0).Extent(), 3);
77 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
78 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
79 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
80 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
81 result.Destroy();
82
83 RTNAME(MatmulTransposeInteger2Integer2)(result, *m, *z, __FILE__, __LINE__);
84 ASSERT_EQ(result.rank(), 2);
85 ASSERT_EQ(result.GetDimension(0).LowerBound(), 1);
86 ASSERT_EQ(result.GetDimension(0).UpperBound(), 4);
87 ASSERT_EQ(result.GetDimension(1).LowerBound(), 1);
88 ASSERT_EQ(result.GetDimension(1).UpperBound(), 3);
89 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 2}));
90 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(0), 0);
91 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(1), 9);
92 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(2), 6);
93 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(3), 15);
94 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(4), 0);
95 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(5), 10);
96 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(6), 7);
97 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(7), 17);
98 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(8), 0);
99 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(9), 11);
100 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(10), 8);
101 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(11), 19);
102 result.Destroy();
103
104 // Test non-contiguous sections.
105 static constexpr int sectionRank{2};
106 StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
107 Descriptor §ionX2{sectionStaticDescriptorX2.descriptor()};
108 sectionX2.Establish(x2->type(), x2->ElementBytes(),
109 /*p=*/nullptr, /*rank=*/sectionRank);
110 static const SubscriptValue lowersX2[]{1, 1}, uppersX2[]{3, 2};
111 // Section of X2:
112 // +-----+
113 // | 0 1|
114 // | 2 3|
115 // | 4 5|
116 // +-----+
117 // -1 -1
118 const auto errorX2{CFI_section(
119 §ionX2.raw(), &x2->raw(), lowersX2, uppersX2, /*strides=*/nullptr)};
120 ASSERT_EQ(errorX2, 0) << "CFI_section failed for X2: " << errorX2;
121
122 StaticDescriptor<sectionRank> sectionStaticDescriptorY2;
123 Descriptor §ionY2{sectionStaticDescriptorY2.descriptor()};
124 sectionY2.Establish(y2->type(), y2->ElementBytes(),
125 /*p=*/nullptr, /*rank=*/sectionRank);
126 static const SubscriptValue lowersY2[]{2, 1};
127 // Section of Y2:
128 // -1 -1
129 // +-----+
130 // | 6 0|
131 // | 7 10|
132 // | 8 11|
133 // +-----+
134 const auto errorY2{CFI_section(§ionY2.raw(), &y2->raw(), lowersY2,
135 /*uppers=*/nullptr, /*strides=*/nullptr)};
136 ASSERT_EQ(errorY2, 0) << "CFI_section failed for Y2: " << errorY2;
137
138 StaticDescriptor<sectionRank> sectionStaticDescriptorZ2;
139 Descriptor §ionZ2{sectionStaticDescriptorZ2.descriptor()};
140 sectionZ2.Establish(z2->type(), z2->ElementBytes(),
141 /*p=*/nullptr, /*rank=*/sectionRank);
142 static const SubscriptValue lowersZ2[]{1, 1}, uppersZ2[]{2, 3};
143 // Section of Z2:
144 // +--------+
145 // | 6 7 8|
146 // | 9 10 11|
147 // +--------+
148 // -1 -1 -1
149 const auto errorZ2{CFI_section(
150 §ionZ2.raw(), &z2->raw(), lowersZ2, uppersZ2, /*strides=*/nullptr)};
151 ASSERT_EQ(errorZ2, 0) << "CFI_section failed for Z2: " << errorZ2;
152
153 RTNAME(MatmulTransposeInteger4Integer2)
154 (result, sectionX2, *y, __FILE__, __LINE__);
155 ASSERT_EQ(result.rank(), 2);
156 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
157 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
158 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
159 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
160 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
161 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
162 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
163 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
164 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
165 result.Destroy();
166
167 RTNAME(MatmulTransposeInteger4Integer2)
168 (result, *x, sectionY2, __FILE__, __LINE__);
169 ASSERT_EQ(result.rank(), 2);
170 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
171 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
172 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
173 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
174 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
175 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
176 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
177 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
178 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
179 result.Destroy();
180
181 RTNAME(MatmulTransposeInteger4Integer2)
182 (result, sectionX2, sectionY2, __FILE__, __LINE__);
183 ASSERT_EQ(result.rank(), 2);
184 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
185 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
186 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
187 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
188 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
189 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
190 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
191 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
192 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
193 result.Destroy();
194
195 RTNAME(MatmulTransposeInteger2Integer8)
196 (result, sectionZ2, *v, __FILE__, __LINE__);
197 ASSERT_EQ(result.rank(), 1);
198 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
199 EXPECT_EQ(result.GetDimension(0).Extent(), 3);
200 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
201 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
202 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
203 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
204 result.Destroy();
205
206 // X F F Y F T V T F T
207 // T F F T
208 // T T F F
209 auto xLog{MakeArray<TypeCategory::Logical, 1>(std::vector<int>{3, 2},
210 std::vector<std::uint8_t>{false, true, true, false, false, true})};
211 auto yLog{MakeArray<TypeCategory::Logical, 2>(std::vector<int>{3, 2},
212 std::vector<std::uint16_t>{false, false, false, true, true, false})};
213 auto vLog{MakeArray<TypeCategory::Logical, 1>(
214 std::vector<int>{3}, std::vector<std::uint8_t>{true, false, true})};
215 RTNAME(MatmulTransposeLogical1Logical2)
216 (result, *xLog, *yLog, __FILE__, __LINE__);
217 ASSERT_EQ(result.rank(), 2);
218 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
219 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
220 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
221 EXPECT_EQ(result.GetDimension(1).Extent(), 2);
222 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
223 EXPECT_FALSE(
224 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
225 EXPECT_FALSE(
226 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
227 EXPECT_TRUE(
228 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
229 EXPECT_FALSE(
230 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
231 result.Destroy();
232
233 RTNAME(MatmulTransposeLogical2Logical1)
234 (result, *yLog, *vLog, __FILE__, __LINE__);
235 ASSERT_EQ(result.rank(), 1);
236 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
237 EXPECT_EQ(result.GetDimension(0).Extent(), 2);
238 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
239 EXPECT_FALSE(
240 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
241 EXPECT_TRUE(
242 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
243 result.Destroy();
244 }
245