xref: /llvm-project/flang/include/flang/Common/visit.h (revision 208544fc70d2cfd5b2c13232a267048108da1978)
1 //===-- include/flang/Common/visit.h ----------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // common::visit() is a drop-in replacement for std::visit() that reduces both
10 // compiler build time and compiler execution time modestly, and reduces
11 // compiler build memory requirements significantly (overall & maximum).
12 // It does not require redefinition of std::variant<>.
13 //
14 // The C++ standard mandates that std::visit be O(1), but most variants are
15 // small and O(logN) is faster in practice to compile and execute, avoiding
16 // the need to build a dispatch table.
17 //
18 // Define FLANG_USE_STD_VISIT to avoid this code and make common::visit() an
19 // alias for ::std::visit().
20 
21 #ifndef FORTRAN_COMMON_VISIT_H_
22 #define FORTRAN_COMMON_VISIT_H_
23 
24 #include "variant.h"
25 #include "flang/Common/api-attrs.h"
26 #include <type_traits>
27 
28 namespace Fortran::common {
29 namespace log2visit {
30 
31 template <std::size_t LOW, std::size_t HIGH, typename RESULT, typename VISITOR,
32     typename... VARIANT>
Log2VisitHelper(VISITOR && visitor,std::size_t which,VARIANT &&...u)33 RT_DEVICE_NOINLINE_HOST_INLINE RT_API_ATTRS RESULT Log2VisitHelper(
34     VISITOR &&visitor, std::size_t which, VARIANT &&...u) {
35   if constexpr (LOW + 7 >= HIGH) {
36     switch (which - LOW) {
37 #define VISIT_CASE_N(N) \
38   case N: \
39     if constexpr (LOW + N <= HIGH) { \
40       return visitor(std::get<(LOW + N)>(std::forward<VARIANT>(u))...); \
41     }
42       VISIT_CASE_N(1)
43       [[fallthrough]];
44       VISIT_CASE_N(2)
45       [[fallthrough]];
46       VISIT_CASE_N(3)
47       [[fallthrough]];
48       VISIT_CASE_N(4)
49       [[fallthrough]];
50       VISIT_CASE_N(5)
51       [[fallthrough]];
52       VISIT_CASE_N(6)
53       [[fallthrough]];
54       VISIT_CASE_N(7)
55 #undef VISIT_CASE_N
56     }
57     return visitor(std::get<LOW>(std::forward<VARIANT>(u))...);
58   } else {
59     static constexpr std::size_t mid{(HIGH + LOW) / 2};
60     if (which <= mid) {
61       return Log2VisitHelper<LOW, mid, RESULT>(
62           std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u)...);
63     } else {
64       return Log2VisitHelper<(mid + 1), HIGH, RESULT>(
65           std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u)...);
66     }
67   }
68 }
69 
70 template <typename VISITOR, typename... VARIANT>
71 RT_DEVICE_NOINLINE_HOST_INLINE RT_API_ATTRS auto
72 visit(VISITOR &&visitor, VARIANT &&...u) -> decltype(visitor(std::get<0>(
73                                              std::forward<VARIANT>(u))...)) {
74   using Result = decltype(visitor(std::get<0>(std::forward<VARIANT>(u))...));
75   if constexpr (sizeof...(u) == 1) {
76     static constexpr std::size_t high{
77         (std::variant_size_v<std::decay_t<decltype(u)>> * ...) - 1};
78     return Log2VisitHelper<0, high, Result>(std::forward<VISITOR>(visitor),
79         u.index()..., std::forward<VARIANT>(u)...);
80   } else {
81     // TODO: figure out how to do multiple variant arguments
82     return ::std::visit(
83         std::forward<VISITOR>(visitor), std::forward<VARIANT>(u)...);
84   }
85 }
86 
87 } // namespace log2visit
88 
89 // Some versions of clang have bugs that cause compilation to hang
90 // on these templates.  MSVC and older GCC versions may work but are
91 // not well tested.  So enable only for GCC 9 and better.
92 #if __GNUC__ < 9 && !defined(__clang__)
93 #define FLANG_USE_STD_VISIT
94 #endif
95 
96 #ifdef FLANG_USE_STD_VISIT
97 using ::std::visit;
98 #else
99 using Fortran::common::log2visit::visit;
100 #endif
101 
102 } // namespace Fortran::common
103 #endif // FORTRAN_COMMON_VISIT_H_
104