xref: /llvm-project/flang/unittests/Runtime/MatmulTranspose.cpp (revision 8ce1aed55f3dbb71406dc6feaed3f162ac183d21)
1 //===-- flang/unittests/Runtime/MatmulTranspose.cpp -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "gtest/gtest.h"
10 #include "tools.h"
11 #include "flang/Runtime/allocatable.h"
12 #include "flang/Runtime/cpp-type.h"
13 #include "flang/Runtime/descriptor.h"
14 #include "flang/Runtime/matmul-transpose.h"
15 #include "flang/Runtime/type-code.h"
16 
17 using namespace Fortran::runtime;
18 using Fortran::common::TypeCategory;
19 
TEST(MatmulTranspose,Basic)20 TEST(MatmulTranspose, Basic) {
21   // X 0 1     Y 6  9     Z 6  7  8    M 0 0 1 1    V -1 -2
22   //   2 3       7 10       9 10 11      0 1 0 1
23   //   4 5       8 11
24 
25   auto x{MakeArray<TypeCategory::Integer, 4>(
26       std::vector<int>{3, 2}, std::vector<std::int32_t>{0, 2, 4, 1, 3, 5})};
27   auto y{MakeArray<TypeCategory::Integer, 2>(
28       std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})};
29   auto z{MakeArray<TypeCategory::Integer, 2>(
30       std::vector<int>{2, 3}, std::vector<std::int16_t>{6, 9, 7, 10, 8, 11})};
31   auto m{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{2, 4},
32       std::vector<std::int16_t>{0, 0, 0, 1, 1, 0, 1, 1})};
33   auto v{MakeArray<TypeCategory::Integer, 8>(
34       std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
35   // X2  0  1     Y2 -1 -1     Z2  6  7  8
36   //     2  3         6  9         9 10 11
37   //     4  5         7 10        -1 -1 -1
38   //    -1 -1         8 11
39   auto x2{MakeArray<TypeCategory::Integer, 4>(std::vector<int>{4, 2},
40       std::vector<std::int32_t>{0, 2, 4, -1, 1, 3, 5, -1})};
41   auto y2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{4, 2},
42       std::vector<std::int16_t>{-1, 6, 7, 8, -1, 9, 10, 11})};
43   auto z2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{3, 3},
44       std::vector<std::int16_t>{6, 9, -1, 7, 10, -1, 8, 11, -1})};
45 
46   StaticDescriptor<2, true> statDesc;
47   Descriptor &result{statDesc.descriptor()};
48 
49   RTNAME(MatmulTransposeInteger4Integer2)(result, *x, *y, __FILE__, __LINE__);
50   ASSERT_EQ(result.rank(), 2);
51   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
52   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
53   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
54   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
55   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
56   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
57   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
58   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
59   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
60 
61   std::memset(
62       result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
63   result.GetDimension(0).SetLowerBound(0);
64   result.GetDimension(1).SetLowerBound(2);
65   RTNAME(MatmulTransposeDirectInteger4Integer2)
66   (result, *x, *y, __FILE__, __LINE__);
67   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
68   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
69   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
70   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
71   result.Destroy();
72 
73   RTNAME(MatmulTransposeInteger2Integer8)(result, *z, *v, __FILE__, __LINE__);
74   ASSERT_EQ(result.rank(), 1);
75   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
76   EXPECT_EQ(result.GetDimension(0).Extent(), 3);
77   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
78   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
79   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
80   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
81   result.Destroy();
82 
83   RTNAME(MatmulTransposeInteger2Integer2)(result, *m, *z, __FILE__, __LINE__);
84   ASSERT_EQ(result.rank(), 2);
85   ASSERT_EQ(result.GetDimension(0).LowerBound(), 1);
86   ASSERT_EQ(result.GetDimension(0).UpperBound(), 4);
87   ASSERT_EQ(result.GetDimension(1).LowerBound(), 1);
88   ASSERT_EQ(result.GetDimension(1).UpperBound(), 3);
89   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 2}));
90   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(0), 0);
91   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(1), 9);
92   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(2), 6);
93   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(3), 15);
94   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(4), 0);
95   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(5), 10);
96   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(6), 7);
97   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(7), 17);
98   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(8), 0);
99   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(9), 11);
100   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(10), 8);
101   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(11), 19);
102   result.Destroy();
103 
104   // Test non-contiguous sections.
105   static constexpr int sectionRank{2};
106   StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
107   Descriptor &sectionX2{sectionStaticDescriptorX2.descriptor()};
108   sectionX2.Establish(x2->type(), x2->ElementBytes(),
109       /*p=*/nullptr, /*rank=*/sectionRank);
110   static const SubscriptValue lowersX2[]{1, 1}, uppersX2[]{3, 2};
111   // Section of X2:
112   //   +-----+
113   //   | 0  1|
114   //   | 2  3|
115   //   | 4  5|
116   //   +-----+
117   //    -1 -1
118   const auto errorX2{CFI_section(
119       &sectionX2.raw(), &x2->raw(), lowersX2, uppersX2, /*strides=*/nullptr)};
120   ASSERT_EQ(errorX2, 0) << "CFI_section failed for X2: " << errorX2;
121 
122   StaticDescriptor<sectionRank> sectionStaticDescriptorY2;
123   Descriptor &sectionY2{sectionStaticDescriptorY2.descriptor()};
124   sectionY2.Establish(y2->type(), y2->ElementBytes(),
125       /*p=*/nullptr, /*rank=*/sectionRank);
126   static const SubscriptValue lowersY2[]{2, 1};
127   // Section of Y2:
128   //    -1 -1
129   //   +-----+
130   //   | 6  0|
131   //   | 7 10|
132   //   | 8 11|
133   //   +-----+
134   const auto errorY2{CFI_section(&sectionY2.raw(), &y2->raw(), lowersY2,
135       /*uppers=*/nullptr, /*strides=*/nullptr)};
136   ASSERT_EQ(errorY2, 0) << "CFI_section failed for Y2: " << errorY2;
137 
138   StaticDescriptor<sectionRank> sectionStaticDescriptorZ2;
139   Descriptor &sectionZ2{sectionStaticDescriptorZ2.descriptor()};
140   sectionZ2.Establish(z2->type(), z2->ElementBytes(),
141       /*p=*/nullptr, /*rank=*/sectionRank);
142   static const SubscriptValue lowersZ2[]{1, 1}, uppersZ2[]{2, 3};
143   // Section of Z2:
144   //   +--------+
145   //   | 6  7  8|
146   //   | 9 10 11|
147   //   +--------+
148   //    -1 -1 -1
149   const auto errorZ2{CFI_section(
150       &sectionZ2.raw(), &z2->raw(), lowersZ2, uppersZ2, /*strides=*/nullptr)};
151   ASSERT_EQ(errorZ2, 0) << "CFI_section failed for Z2: " << errorZ2;
152 
153   RTNAME(MatmulTransposeInteger4Integer2)
154   (result, sectionX2, *y, __FILE__, __LINE__);
155   ASSERT_EQ(result.rank(), 2);
156   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
157   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
158   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
159   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
160   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
161   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
162   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
163   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
164   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
165   result.Destroy();
166 
167   RTNAME(MatmulTransposeInteger4Integer2)
168   (result, *x, sectionY2, __FILE__, __LINE__);
169   ASSERT_EQ(result.rank(), 2);
170   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
171   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
172   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
173   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
174   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
175   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
176   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
177   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
178   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
179   result.Destroy();
180 
181   RTNAME(MatmulTransposeInteger4Integer2)
182   (result, sectionX2, sectionY2, __FILE__, __LINE__);
183   ASSERT_EQ(result.rank(), 2);
184   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
185   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
186   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
187   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
188   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
189   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
190   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
191   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
192   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
193   result.Destroy();
194 
195   RTNAME(MatmulTransposeInteger2Integer8)
196   (result, sectionZ2, *v, __FILE__, __LINE__);
197   ASSERT_EQ(result.rank(), 1);
198   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
199   EXPECT_EQ(result.GetDimension(0).Extent(), 3);
200   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
201   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
202   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
203   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
204   result.Destroy();
205 
206   // X F F    Y F T    V T F T
207   //   T F      F T
208   //   T T      F F
209   auto xLog{MakeArray<TypeCategory::Logical, 1>(std::vector<int>{3, 2},
210       std::vector<std::uint8_t>{false, true, true, false, false, true})};
211   auto yLog{MakeArray<TypeCategory::Logical, 2>(std::vector<int>{3, 2},
212       std::vector<std::uint16_t>{false, false, false, true, true, false})};
213   auto vLog{MakeArray<TypeCategory::Logical, 1>(
214       std::vector<int>{3}, std::vector<std::uint8_t>{true, false, true})};
215   RTNAME(MatmulTransposeLogical1Logical2)
216   (result, *xLog, *yLog, __FILE__, __LINE__);
217   ASSERT_EQ(result.rank(), 2);
218   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
219   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
220   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
221   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
222   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
223   EXPECT_FALSE(
224       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
225   EXPECT_FALSE(
226       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
227   EXPECT_TRUE(
228       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
229   EXPECT_FALSE(
230       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
231   result.Destroy();
232 
233   RTNAME(MatmulTransposeLogical2Logical1)
234   (result, *yLog, *vLog, __FILE__, __LINE__);
235   ASSERT_EQ(result.rank(), 1);
236   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
237   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
238   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
239   EXPECT_FALSE(
240       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
241   EXPECT_TRUE(
242       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
243   result.Destroy();
244 }
245