xref: /llvm-project/flang/unittests/Runtime/Matmul.cpp (revision 8ce1aed55f3dbb71406dc6feaed3f162ac183d21)
1 //===-- flang/unittests/Runtime/Matmul.cpp--------- -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Runtime/matmul.h"
10 #include "gtest/gtest.h"
11 #include "tools.h"
12 #include "flang/Runtime/allocatable.h"
13 #include "flang/Runtime/cpp-type.h"
14 #include "flang/Runtime/descriptor.h"
15 #include "flang/Runtime/type-code.h"
16 
17 using namespace Fortran::runtime;
18 using Fortran::common::TypeCategory;
19 
TEST(Matmul,Basic)20 TEST(Matmul, Basic) {
21   // X 0 2 4   Y 6  9   V -1 -2
22   //   1 3 5     7 10
23   //             8 11
24   auto x{MakeArray<TypeCategory::Integer, 4>(
25       std::vector<int>{2, 3}, std::vector<std::int32_t>{0, 1, 2, 3, 4, 5})};
26   auto y{MakeArray<TypeCategory::Integer, 2>(
27       std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})};
28   auto v{MakeArray<TypeCategory::Integer, 8>(
29       std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
30 
31   // X2  0  2  4  Y2 -1 -1
32   //     1  3  5      6  9
33   //    -1 -1 -1      7 10
34   //                  8 11
35   auto x2{MakeArray<TypeCategory::Integer, 4>(std::vector<int>{3, 3},
36       std::vector<std::int32_t>{0, 1, -1, 2, 3, -1, 4, 5})};
37   auto y2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{4, 2},
38       std::vector<std::int16_t>{-1, 6, 7, 8, -1, 9, 10, 11})};
39 
40   StaticDescriptor<2, true> statDesc;
41   Descriptor &result{statDesc.descriptor()};
42 
43   RTNAME(MatmulInteger4Integer2)(result, *x, *y, __FILE__, __LINE__);
44   ASSERT_EQ(result.rank(), 2);
45   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
46   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
47   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
48   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
49   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
50   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
51   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
52   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
53   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
54 
55   std::memset(
56       result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
57   result.GetDimension(0).SetLowerBound(0);
58   result.GetDimension(1).SetLowerBound(2);
59   RTNAME(MatmulDirectInteger4Integer2)(result, *x, *y, __FILE__, __LINE__);
60   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
61   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
62   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
63   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
64   result.Destroy();
65 
66   RTNAME(MatmulInteger8Integer4)(result, *v, *x, __FILE__, __LINE__);
67   ASSERT_EQ(result.rank(), 1);
68   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
69   EXPECT_EQ(result.GetDimension(0).Extent(), 3);
70   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
71   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
72   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
73   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
74   result.Destroy();
75 
76   RTNAME(MatmulInteger2Integer8)(result, *y, *v, __FILE__, __LINE__);
77   ASSERT_EQ(result.rank(), 1);
78   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
79   EXPECT_EQ(result.GetDimension(0).Extent(), 3);
80   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
81   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
82   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
83   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
84   result.Destroy();
85 
86   // Test non-contiguous sections.
87   static constexpr int sectionRank{2};
88   StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
89   Descriptor &sectionX2{sectionStaticDescriptorX2.descriptor()};
90   sectionX2.Establish(x2->type(), x2->ElementBytes(),
91       /*p=*/nullptr, /*rank=*/sectionRank);
92   static const SubscriptValue lowersX2[]{1, 1}, uppersX2[]{2, 3};
93   // Section of X2:
94   //   +--------+
95   //   | 0  2  4|
96   //   | 1  3  5|
97   //   +--------+
98   //    -1 -1 -1
99   const auto errorX2{CFI_section(
100       &sectionX2.raw(), &x2->raw(), lowersX2, uppersX2, /*strides=*/nullptr)};
101   ASSERT_EQ(errorX2, 0) << "CFI_section failed for X2: " << errorX2;
102 
103   StaticDescriptor<sectionRank> sectionStaticDescriptorY2;
104   Descriptor &sectionY2{sectionStaticDescriptorY2.descriptor()};
105   sectionY2.Establish(y2->type(), y2->ElementBytes(),
106       /*p=*/nullptr, /*rank=*/sectionRank);
107   static const SubscriptValue lowersY2[]{2, 1};
108   // Section of Y2:
109   //    -1 -1
110   //   +-----+
111   //   | 6  9|
112   //   | 7 10|
113   //   | 8 11|
114   //   +-----+
115   const auto errorY2{CFI_section(&sectionY2.raw(), &y2->raw(), lowersY2,
116       /*uppers=*/nullptr, /*strides=*/nullptr)};
117   ASSERT_EQ(errorY2, 0) << "CFI_section failed for Y2: " << errorY2;
118 
119   RTNAME(MatmulInteger4Integer2)(result, sectionX2, *y, __FILE__, __LINE__);
120   ASSERT_EQ(result.rank(), 2);
121   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
122   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
123   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
124   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
125   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
126   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
127   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
128   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
129   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
130   result.Destroy();
131 
132   RTNAME(MatmulInteger4Integer2)(result, *x, sectionY2, __FILE__, __LINE__);
133   ASSERT_EQ(result.rank(), 2);
134   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
135   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
136   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
137   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
138   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
139   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
140   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
141   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
142   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
143   result.Destroy();
144 
145   RTNAME(MatmulInteger4Integer2)
146   (result, sectionX2, sectionY2, __FILE__, __LINE__);
147   ASSERT_EQ(result.rank(), 2);
148   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
149   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
150   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
151   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
152   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
153   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
154   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
155   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
156   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
157   result.Destroy();
158 
159   RTNAME(MatmulInteger8Integer4)(result, *v, sectionX2, __FILE__, __LINE__);
160   ASSERT_EQ(result.rank(), 1);
161   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
162   EXPECT_EQ(result.GetDimension(0).Extent(), 3);
163   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
164   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
165   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
166   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
167   result.Destroy();
168 
169   RTNAME(MatmulInteger2Integer8)(result, sectionY2, *v, __FILE__, __LINE__);
170   ASSERT_EQ(result.rank(), 1);
171   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
172   EXPECT_EQ(result.GetDimension(0).Extent(), 3);
173   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
174   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
175   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
176   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
177   result.Destroy();
178 
179   // X F F T  Y F T
180   //   F T T    F T
181   //            F F
182   auto xLog{MakeArray<TypeCategory::Logical, 1>(std::vector<int>{2, 3},
183       std::vector<std::uint8_t>{false, false, false, true, true, false})};
184   auto yLog{MakeArray<TypeCategory::Logical, 2>(std::vector<int>{3, 2},
185       std::vector<std::uint16_t>{false, false, false, true, true, false})};
186   RTNAME(MatmulLogical1Logical2)(result, *xLog, *yLog, __FILE__, __LINE__);
187   ASSERT_EQ(result.rank(), 2);
188   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
189   EXPECT_EQ(result.GetDimension(0).Extent(), 2);
190   EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
191   EXPECT_EQ(result.GetDimension(1).Extent(), 2);
192   ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
193   EXPECT_FALSE(
194       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
195   EXPECT_FALSE(
196       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
197   EXPECT_FALSE(
198       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
199   EXPECT_TRUE(
200       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
201   result.Destroy();
202 }
203