xref: /llvm-project/clang-tools-extra/clang-tidy/mpi/TypeMismatchCheck.cpp (revision 15aa965363df5cf3a021b3841bcafbced3756ea2)
1 //===--- TypeMismatchCheck.cpp - clang-tidy--------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TypeMismatchCheck.h"
10 #include "clang/Lex/Lexer.h"
11 #include "clang/Tooling/FixIt.h"
12 #include "llvm/ADT/StringSet.h"
13 #include <map>
14 
15 using namespace clang::ast_matchers;
16 
17 namespace clang::tidy::mpi {
18 
19 /// Check if a BuiltinType::Kind matches the MPI datatype.
20 ///
21 /// \param MultiMap datatype group
22 /// \param Kind buffer type kind
23 /// \param MPIDatatype name of the MPI datatype
24 ///
25 /// \returns true if the pair matches
26 static bool
isMPITypeMatching(const std::multimap<BuiltinType::Kind,StringRef> & MultiMap,const BuiltinType::Kind Kind,StringRef MPIDatatype)27 isMPITypeMatching(const std::multimap<BuiltinType::Kind, StringRef> &MultiMap,
28                   const BuiltinType::Kind Kind, StringRef MPIDatatype) {
29   auto ItPair = MultiMap.equal_range(Kind);
30   while (ItPair.first != ItPair.second) {
31     if (ItPair.first->second == MPIDatatype)
32       return true;
33     ++ItPair.first;
34   }
35   return false;
36 }
37 
38 /// Check if the MPI datatype is a standard type.
39 ///
40 /// \param MPIDatatype name of the MPI datatype
41 ///
42 /// \returns true if the type is a standard type
isStandardMPIDatatype(StringRef MPIDatatype)43 static bool isStandardMPIDatatype(StringRef MPIDatatype) {
44   static llvm::StringSet<> AllTypes = {"MPI_C_BOOL",
45                                        "MPI_CHAR",
46                                        "MPI_SIGNED_CHAR",
47                                        "MPI_UNSIGNED_CHAR",
48                                        "MPI_WCHAR",
49                                        "MPI_INT",
50                                        "MPI_LONG",
51                                        "MPI_SHORT",
52                                        "MPI_LONG_LONG",
53                                        "MPI_LONG_LONG_INT",
54                                        "MPI_UNSIGNED",
55                                        "MPI_UNSIGNED_SHORT",
56                                        "MPI_UNSIGNED_LONG",
57                                        "MPI_UNSIGNED_LONG_LONG",
58                                        "MPI_FLOAT",
59                                        "MPI_DOUBLE",
60                                        "MPI_LONG_DOUBLE",
61                                        "MPI_C_COMPLEX",
62                                        "MPI_C_FLOAT_COMPLEX",
63                                        "MPI_C_DOUBLE_COMPLEX",
64                                        "MPI_C_LONG_DOUBLE_COMPLEX",
65                                        "MPI_INT8_T",
66                                        "MPI_INT16_T",
67                                        "MPI_INT32_T",
68                                        "MPI_INT64_T",
69                                        "MPI_UINT8_T",
70                                        "MPI_UINT16_T",
71                                        "MPI_UINT32_T",
72                                        "MPI_UINT64_T",
73                                        "MPI_CXX_BOOL",
74                                        "MPI_CXX_FLOAT_COMPLEX",
75                                        "MPI_CXX_DOUBLE_COMPLEX",
76                                        "MPI_CXX_LONG_DOUBLE_COMPLEX"};
77 
78   return AllTypes.contains(MPIDatatype);
79 }
80 
81 /// Check if a BuiltinType matches the MPI datatype.
82 ///
83 /// \param Builtin the builtin type
84 /// \param BufferTypeName buffer type name, gets assigned
85 /// \param MPIDatatype name of the MPI datatype
86 /// \param LO language options
87 ///
88 /// \returns true if the type matches
isBuiltinTypeMatching(const BuiltinType * Builtin,std::string & BufferTypeName,StringRef MPIDatatype,const LangOptions & LO)89 static bool isBuiltinTypeMatching(const BuiltinType *Builtin,
90                                   std::string &BufferTypeName,
91                                   StringRef MPIDatatype,
92                                   const LangOptions &LO) {
93   static std::multimap<BuiltinType::Kind, StringRef> BuiltinMatches = {
94       // On some systems like PPC or ARM, 'char' is unsigned by default which is
95       // why distinct signedness for the buffer and MPI type is tolerated.
96       {BuiltinType::SChar, "MPI_CHAR"},
97       {BuiltinType::SChar, "MPI_SIGNED_CHAR"},
98       {BuiltinType::SChar, "MPI_UNSIGNED_CHAR"},
99       {BuiltinType::Char_S, "MPI_CHAR"},
100       {BuiltinType::Char_S, "MPI_SIGNED_CHAR"},
101       {BuiltinType::Char_S, "MPI_UNSIGNED_CHAR"},
102       {BuiltinType::UChar, "MPI_CHAR"},
103       {BuiltinType::UChar, "MPI_SIGNED_CHAR"},
104       {BuiltinType::UChar, "MPI_UNSIGNED_CHAR"},
105       {BuiltinType::Char_U, "MPI_CHAR"},
106       {BuiltinType::Char_U, "MPI_SIGNED_CHAR"},
107       {BuiltinType::Char_U, "MPI_UNSIGNED_CHAR"},
108       {BuiltinType::WChar_S, "MPI_WCHAR"},
109       {BuiltinType::WChar_U, "MPI_WCHAR"},
110       {BuiltinType::Bool, "MPI_C_BOOL"},
111       {BuiltinType::Bool, "MPI_CXX_BOOL"},
112       {BuiltinType::Short, "MPI_SHORT"},
113       {BuiltinType::Int, "MPI_INT"},
114       {BuiltinType::Long, "MPI_LONG"},
115       {BuiltinType::LongLong, "MPI_LONG_LONG"},
116       {BuiltinType::LongLong, "MPI_LONG_LONG_INT"},
117       {BuiltinType::UShort, "MPI_UNSIGNED_SHORT"},
118       {BuiltinType::UInt, "MPI_UNSIGNED"},
119       {BuiltinType::ULong, "MPI_UNSIGNED_LONG"},
120       {BuiltinType::ULongLong, "MPI_UNSIGNED_LONG_LONG"},
121       {BuiltinType::Float, "MPI_FLOAT"},
122       {BuiltinType::Double, "MPI_DOUBLE"},
123       {BuiltinType::LongDouble, "MPI_LONG_DOUBLE"}};
124 
125   if (!isMPITypeMatching(BuiltinMatches, Builtin->getKind(), MPIDatatype)) {
126     BufferTypeName = std::string(Builtin->getName(LO));
127     return false;
128   }
129 
130   return true;
131 }
132 
133 /// Check if a complex float/double/long double buffer type matches
134 /// the MPI datatype.
135 ///
136 /// \param Complex buffer type
137 /// \param BufferTypeName buffer type name, gets assigned
138 /// \param MPIDatatype name of the MPI datatype
139 /// \param LO language options
140 ///
141 /// \returns true if the type matches or the buffer type is unknown
isCComplexTypeMatching(const ComplexType * const Complex,std::string & BufferTypeName,StringRef MPIDatatype,const LangOptions & LO)142 static bool isCComplexTypeMatching(const ComplexType *const Complex,
143                                    std::string &BufferTypeName,
144                                    StringRef MPIDatatype,
145                                    const LangOptions &LO) {
146   static std::multimap<BuiltinType::Kind, StringRef> ComplexCMatches = {
147       {BuiltinType::Float, "MPI_C_COMPLEX"},
148       {BuiltinType::Float, "MPI_C_FLOAT_COMPLEX"},
149       {BuiltinType::Double, "MPI_C_DOUBLE_COMPLEX"},
150       {BuiltinType::LongDouble, "MPI_C_LONG_DOUBLE_COMPLEX"}};
151 
152   const auto *Builtin =
153       Complex->getElementType().getTypePtr()->getAs<BuiltinType>();
154 
155   if (Builtin &&
156       !isMPITypeMatching(ComplexCMatches, Builtin->getKind(), MPIDatatype)) {
157     BufferTypeName = (llvm::Twine(Builtin->getName(LO)) + " _Complex").str();
158     return false;
159   }
160   return true;
161 }
162 
163 /// Check if a complex<float/double/long double> templated buffer type matches
164 /// the MPI datatype.
165 ///
166 /// \param Template buffer type
167 /// \param BufferTypeName buffer type name, gets assigned
168 /// \param MPIDatatype name of the MPI datatype
169 /// \param LO language options
170 ///
171 /// \returns true if the type matches or the buffer type is unknown
172 static bool
isCXXComplexTypeMatching(const TemplateSpecializationType * const Template,std::string & BufferTypeName,StringRef MPIDatatype,const LangOptions & LO)173 isCXXComplexTypeMatching(const TemplateSpecializationType *const Template,
174                          std::string &BufferTypeName, StringRef MPIDatatype,
175                          const LangOptions &LO) {
176   static std::multimap<BuiltinType::Kind, StringRef> ComplexCXXMatches = {
177       {BuiltinType::Float, "MPI_CXX_FLOAT_COMPLEX"},
178       {BuiltinType::Double, "MPI_CXX_DOUBLE_COMPLEX"},
179       {BuiltinType::LongDouble, "MPI_CXX_LONG_DOUBLE_COMPLEX"}};
180 
181   if (Template->getAsCXXRecordDecl()->getName() != "complex")
182     return true;
183 
184   const auto *Builtin = Template->template_arguments()[0]
185                             .getAsType()
186                             .getTypePtr()
187                             ->getAs<BuiltinType>();
188 
189   if (Builtin &&
190       !isMPITypeMatching(ComplexCXXMatches, Builtin->getKind(), MPIDatatype)) {
191     BufferTypeName =
192         (llvm::Twine("complex<") + Builtin->getName(LO) + ">").str();
193     return false;
194   }
195 
196   return true;
197 }
198 
199 /// Check if a fixed size width buffer type matches the MPI datatype.
200 ///
201 /// \param Typedef buffer type
202 /// \param BufferTypeName buffer type name, gets assigned
203 /// \param MPIDatatype name of the MPI datatype
204 ///
205 /// \returns true if the type matches or the buffer type is unknown
isTypedefTypeMatching(const TypedefType * const Typedef,std::string & BufferTypeName,StringRef MPIDatatype)206 static bool isTypedefTypeMatching(const TypedefType *const Typedef,
207                                   std::string &BufferTypeName,
208                                   StringRef MPIDatatype) {
209   static llvm::StringMap<StringRef> FixedWidthMatches = {
210       {"int8_t", "MPI_INT8_T"},     {"int16_t", "MPI_INT16_T"},
211       {"int32_t", "MPI_INT32_T"},   {"int64_t", "MPI_INT64_T"},
212       {"uint8_t", "MPI_UINT8_T"},   {"uint16_t", "MPI_UINT16_T"},
213       {"uint32_t", "MPI_UINT32_T"}, {"uint64_t", "MPI_UINT64_T"}};
214 
215   const auto It = FixedWidthMatches.find(Typedef->getDecl()->getName());
216   // Check if the typedef is known and not matching the MPI datatype.
217   if (It != FixedWidthMatches.end() && It->getValue() != MPIDatatype) {
218     BufferTypeName = std::string(Typedef->getDecl()->getName());
219     return false;
220   }
221   return true;
222 }
223 
224 /// Get the unqualified, dereferenced type of an argument.
225 ///
226 /// \param CE call expression
227 /// \param Idx argument index
228 ///
229 /// \returns type of the argument
argumentType(const CallExpr * const CE,const size_t Idx)230 static const Type *argumentType(const CallExpr *const CE, const size_t Idx) {
231   const QualType QT = CE->getArg(Idx)->IgnoreImpCasts()->getType();
232   return QT.getTypePtr()->getPointeeOrArrayElementType();
233 }
234 
registerMatchers(MatchFinder * Finder)235 void TypeMismatchCheck::registerMatchers(MatchFinder *Finder) {
236   Finder->addMatcher(callExpr().bind("CE"), this);
237 }
238 
check(const MatchFinder::MatchResult & Result)239 void TypeMismatchCheck::check(const MatchFinder::MatchResult &Result) {
240   const auto *const CE = Result.Nodes.getNodeAs<CallExpr>("CE");
241   if (!CE->getDirectCallee())
242     return;
243 
244   if (!FuncClassifier)
245     FuncClassifier.emplace(*Result.Context);
246 
247   const IdentifierInfo *Identifier = CE->getDirectCallee()->getIdentifier();
248   if (!Identifier || !FuncClassifier->isMPIType(Identifier))
249     return;
250 
251   // These containers are used, to capture buffer, MPI datatype pairs.
252   SmallVector<const Type *, 1> BufferTypes;
253   SmallVector<const Expr *, 1> BufferExprs;
254   SmallVector<StringRef, 1> MPIDatatypes;
255 
256   // Adds a buffer, MPI datatype pair of an MPI call expression to the
257   // containers. For buffers, the type and expression is captured.
258   auto AddPair = [&CE, &Result, &BufferTypes, &BufferExprs, &MPIDatatypes](
259                      const size_t BufferIdx, const size_t DatatypeIdx) {
260     // Skip null pointer constants and in place 'operators'.
261     if (CE->getArg(BufferIdx)->isNullPointerConstant(
262             *Result.Context, Expr::NPC_ValueDependentIsNull) ||
263         tooling::fixit::getText(*CE->getArg(BufferIdx), *Result.Context) ==
264             "MPI_IN_PLACE")
265       return;
266 
267     StringRef MPIDatatype =
268         tooling::fixit::getText(*CE->getArg(DatatypeIdx), *Result.Context);
269 
270     const Type *ArgType = argumentType(CE, BufferIdx);
271     // Skip unknown MPI datatypes and void pointers.
272     if (!isStandardMPIDatatype(MPIDatatype) || ArgType->isVoidType())
273       return;
274 
275     BufferTypes.push_back(ArgType);
276     BufferExprs.push_back(CE->getArg(BufferIdx));
277     MPIDatatypes.push_back(MPIDatatype);
278   };
279 
280   // Collect all buffer, MPI datatype pairs for the inspected call expression.
281   if (FuncClassifier->isPointToPointType(Identifier)) {
282     AddPair(0, 2);
283   } else if (FuncClassifier->isCollectiveType(Identifier)) {
284     if (FuncClassifier->isReduceType(Identifier)) {
285       AddPair(0, 3);
286       AddPair(1, 3);
287     } else if (FuncClassifier->isScatterType(Identifier) ||
288                FuncClassifier->isGatherType(Identifier) ||
289                FuncClassifier->isAlltoallType(Identifier)) {
290       AddPair(0, 2);
291       AddPair(3, 5);
292     } else if (FuncClassifier->isBcastType(Identifier)) {
293       AddPair(0, 2);
294     }
295   }
296   checkArguments(BufferTypes, BufferExprs, MPIDatatypes, getLangOpts());
297 }
298 
checkArguments(ArrayRef<const Type * > BufferTypes,ArrayRef<const Expr * > BufferExprs,ArrayRef<StringRef> MPIDatatypes,const LangOptions & LO)299 void TypeMismatchCheck::checkArguments(ArrayRef<const Type *> BufferTypes,
300                                        ArrayRef<const Expr *> BufferExprs,
301                                        ArrayRef<StringRef> MPIDatatypes,
302                                        const LangOptions &LO) {
303   std::string BufferTypeName;
304 
305   for (size_t I = 0; I < MPIDatatypes.size(); ++I) {
306     const Type *const BT = BufferTypes[I];
307     bool Error = false;
308 
309     if (const auto *Typedef = BT->getAs<TypedefType>()) {
310       Error = !isTypedefTypeMatching(Typedef, BufferTypeName, MPIDatatypes[I]);
311     } else if (const auto *Complex = BT->getAs<ComplexType>()) {
312       Error =
313           !isCComplexTypeMatching(Complex, BufferTypeName, MPIDatatypes[I], LO);
314     } else if (const auto *Template = BT->getAs<TemplateSpecializationType>()) {
315       Error = !isCXXComplexTypeMatching(Template, BufferTypeName,
316                                         MPIDatatypes[I], LO);
317     } else if (const auto *Builtin = BT->getAs<BuiltinType>()) {
318       Error =
319           !isBuiltinTypeMatching(Builtin, BufferTypeName, MPIDatatypes[I], LO);
320     }
321 
322     if (Error) {
323       const auto Loc = BufferExprs[I]->getSourceRange().getBegin();
324       diag(Loc, "buffer type '%0' does not match the MPI datatype '%1'")
325           << BufferTypeName << MPIDatatypes[I];
326     }
327   }
328 }
329 
onEndOfTranslationUnit()330 void TypeMismatchCheck::onEndOfTranslationUnit() { FuncClassifier.reset(); }
331 } // namespace clang::tidy::mpi
332