1 //===- FunctionExtras.h - Function type erasure utilities -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file provides a collection of function (or more generally, callable) 10 /// type erasure utilities supplementing those provided by the standard library 11 /// in `<function>`. 12 /// 13 /// It provides `unique_function`, which works like `std::function` but supports 14 /// move-only callable objects and const-qualification. 15 /// 16 /// Future plans: 17 /// - Add a `function` that provides ref-qualified support, which doesn't work 18 /// with `std::function`. 19 /// - Provide support for specifying multiple signatures to type erase callable 20 /// objects with an overload set, such as those produced by generic lambdas. 21 /// - Expand to include a copyable utility that directly replaces std::function 22 /// but brings the above improvements. 23 /// 24 /// Note that LLVM's utilities are greatly simplified by not supporting 25 /// allocators. 26 /// 27 /// If the standard library ever begins to provide comparable facilities we can 28 /// consider switching to those. 29 /// 30 //===----------------------------------------------------------------------===// 31 32 #ifndef LLVM_ADT_FUNCTIONEXTRAS_H 33 #define LLVM_ADT_FUNCTIONEXTRAS_H 34 35 #include "llvm/ADT/PointerIntPair.h" 36 #include "llvm/ADT/PointerUnion.h" 37 #include "llvm/ADT/STLForwardCompat.h" 38 #include "llvm/Support/MemAlloc.h" 39 #include "llvm/Support/type_traits.h" 40 #include <memory> 41 #include <type_traits> 42 43 namespace llvm { 44 45 /// unique_function is a type-erasing functor similar to std::function. 46 /// 47 /// It can hold move-only function objects, like lambdas capturing unique_ptrs. 48 /// Accordingly, it is movable but not copyable. 49 /// 50 /// It supports const-qualification: 51 /// - unique_function<int() const> has a const operator(). 52 /// It can only hold functions which themselves have a const operator(). 53 /// - unique_function<int()> has a non-const operator(). 54 /// It can hold functions with a non-const operator(), like mutable lambdas. 55 template <typename FunctionT> class unique_function; 56 57 namespace detail { 58 59 template <typename T> 60 using EnableIfTrivial = 61 std::enable_if_t<llvm::is_trivially_move_constructible<T>::value && 62 std::is_trivially_destructible<T>::value>; 63 template <typename CallableT, typename ThisT> 64 using EnableUnlessSameType = 65 std::enable_if_t<!std::is_same<remove_cvref_t<CallableT>, ThisT>::value>; 66 template <typename CallableT, typename Ret, typename... Params> 67 using EnableIfCallable = 68 std::enable_if_t<std::is_void<Ret>::value || 69 std::is_convertible<decltype(std::declval<CallableT>()( 70 std::declval<Params>()...)), 71 Ret>::value>; 72 73 template <typename ReturnT, typename... ParamTs> class UniqueFunctionBase { 74 protected: 75 static constexpr size_t InlineStorageSize = sizeof(void *) * 3; 76 77 template <typename T, class = void> 78 struct IsSizeLessThanThresholdT : std::false_type {}; 79 80 template <typename T> 81 struct IsSizeLessThanThresholdT< 82 T, std::enable_if_t<sizeof(T) <= 2 * sizeof(void *)>> : std::true_type {}; 83 84 // Provide a type function to map parameters that won't observe extra copies 85 // or moves and which are small enough to likely pass in register to values 86 // and all other types to l-value reference types. We use this to compute the 87 // types used in our erased call utility to minimize copies and moves unless 88 // doing so would force things unnecessarily into memory. 89 // 90 // The heuristic used is related to common ABI register passing conventions. 91 // It doesn't have to be exact though, and in one way it is more strict 92 // because we want to still be able to observe either moves *or* copies. 93 template <typename T> struct AdjustedParamTBase { 94 static_assert(!std::is_reference<T>::value, 95 "references should be handled by template specialization"); 96 using type = typename std::conditional< 97 llvm::is_trivially_copy_constructible<T>::value && 98 llvm::is_trivially_move_constructible<T>::value && 99 IsSizeLessThanThresholdT<T>::value, 100 T, T &>::type; 101 }; 102 103 // This specialization ensures that 'AdjustedParam<V<T>&>' or 104 // 'AdjustedParam<V<T>&&>' does not trigger a compile-time error when 'T' is 105 // an incomplete type and V a templated type. 106 template <typename T> struct AdjustedParamTBase<T &> { using type = T &; }; 107 template <typename T> struct AdjustedParamTBase<T &&> { using type = T &; }; 108 109 template <typename T> 110 using AdjustedParamT = typename AdjustedParamTBase<T>::type; 111 112 // The type of the erased function pointer we use as a callback to dispatch to 113 // the stored callable when it is trivial to move and destroy. 114 using CallPtrT = ReturnT (*)(void *CallableAddr, 115 AdjustedParamT<ParamTs>... Params); 116 using MovePtrT = void (*)(void *LHSCallableAddr, void *RHSCallableAddr); 117 using DestroyPtrT = void (*)(void *CallableAddr); 118 119 /// A struct to hold a single trivial callback with sufficient alignment for 120 /// our bitpacking. 121 struct alignas(8) TrivialCallback { 122 CallPtrT CallPtr; 123 }; 124 125 /// A struct we use to aggregate three callbacks when we need full set of 126 /// operations. 127 struct alignas(8) NonTrivialCallbacks { 128 CallPtrT CallPtr; 129 MovePtrT MovePtr; 130 DestroyPtrT DestroyPtr; 131 }; 132 133 // Create a pointer union between either a pointer to a static trivial call 134 // pointer in a struct or a pointer to a static struct of the call, move, and 135 // destroy pointers. 136 using CallbackPointerUnionT = 137 PointerUnion<TrivialCallback *, NonTrivialCallbacks *>; 138 139 // The main storage buffer. This will either have a pointer to out-of-line 140 // storage or an inline buffer storing the callable. 141 union StorageUnionT { 142 // For out-of-line storage we keep a pointer to the underlying storage and 143 // the size. This is enough to deallocate the memory. 144 struct OutOfLineStorageT { 145 void *StoragePtr; 146 size_t Size; 147 size_t Alignment; 148 } OutOfLineStorage; 149 static_assert( 150 sizeof(OutOfLineStorageT) <= InlineStorageSize, 151 "Should always use all of the out-of-line storage for inline storage!"); 152 153 // For in-line storage, we just provide an aligned character buffer. We 154 // provide three pointers worth of storage here. 155 // This is mutable as an inlined `const unique_function<void() const>` may 156 // still modify its own mutable members. 157 mutable 158 typename std::aligned_storage<InlineStorageSize, alignof(void *)>::type 159 InlineStorage; 160 } StorageUnion; 161 162 // A compressed pointer to either our dispatching callback or our table of 163 // dispatching callbacks and the flag for whether the callable itself is 164 // stored inline or not. 165 PointerIntPair<CallbackPointerUnionT, 1, bool> CallbackAndInlineFlag; 166 167 bool isInlineStorage() const { return CallbackAndInlineFlag.getInt(); } 168 169 bool isTrivialCallback() const { 170 return CallbackAndInlineFlag.getPointer().template is<TrivialCallback *>(); 171 } 172 173 CallPtrT getTrivialCallback() const { 174 return CallbackAndInlineFlag.getPointer().template get<TrivialCallback *>()->CallPtr; 175 } 176 177 NonTrivialCallbacks *getNonTrivialCallbacks() const { 178 return CallbackAndInlineFlag.getPointer() 179 .template get<NonTrivialCallbacks *>(); 180 } 181 182 CallPtrT getCallPtr() const { 183 return isTrivialCallback() ? getTrivialCallback() 184 : getNonTrivialCallbacks()->CallPtr; 185 } 186 187 // These three functions are only const in the narrow sense. They return 188 // mutable pointers to function state. 189 // This allows unique_function<T const>::operator() to be const, even if the 190 // underlying functor may be internally mutable. 191 // 192 // const callers must ensure they're only used in const-correct ways. 193 void *getCalleePtr() const { 194 return isInlineStorage() ? getInlineStorage() : getOutOfLineStorage(); 195 } 196 void *getInlineStorage() const { return &StorageUnion.InlineStorage; } 197 void *getOutOfLineStorage() const { 198 return StorageUnion.OutOfLineStorage.StoragePtr; 199 } 200 201 size_t getOutOfLineStorageSize() const { 202 return StorageUnion.OutOfLineStorage.Size; 203 } 204 size_t getOutOfLineStorageAlignment() const { 205 return StorageUnion.OutOfLineStorage.Alignment; 206 } 207 208 void setOutOfLineStorage(void *Ptr, size_t Size, size_t Alignment) { 209 StorageUnion.OutOfLineStorage = {Ptr, Size, Alignment}; 210 } 211 212 template <typename CalledAsT> 213 static ReturnT CallImpl(void *CallableAddr, 214 AdjustedParamT<ParamTs>... Params) { 215 auto &Func = *reinterpret_cast<CalledAsT *>(CallableAddr); 216 return Func(std::forward<ParamTs>(Params)...); 217 } 218 219 template <typename CallableT> 220 static void MoveImpl(void *LHSCallableAddr, void *RHSCallableAddr) noexcept { 221 new (LHSCallableAddr) 222 CallableT(std::move(*reinterpret_cast<CallableT *>(RHSCallableAddr))); 223 } 224 225 template <typename CallableT> 226 static void DestroyImpl(void *CallableAddr) noexcept { 227 reinterpret_cast<CallableT *>(CallableAddr)->~CallableT(); 228 } 229 230 // The pointers to call/move/destroy functions are determined for each 231 // callable type (and called-as type, which determines the overload chosen). 232 // (definitions are out-of-line). 233 234 // By default, we need an object that contains all the different 235 // type erased behaviors needed. Create a static instance of the struct type 236 // here and each instance will contain a pointer to it. 237 // Wrap in a struct to avoid https://gcc.gnu.org/PR71954 238 template <typename CallableT, typename CalledAs, typename Enable = void> 239 struct CallbacksHolder { 240 static NonTrivialCallbacks Callbacks; 241 }; 242 // See if we can create a trivial callback. We need the callable to be 243 // trivially moved and trivially destroyed so that we don't have to store 244 // type erased callbacks for those operations. 245 template <typename CallableT, typename CalledAs> 246 struct CallbacksHolder<CallableT, CalledAs, EnableIfTrivial<CallableT>> { 247 static TrivialCallback Callbacks; 248 }; 249 250 // A simple tag type so the call-as type to be passed to the constructor. 251 template <typename T> struct CalledAs {}; 252 253 // Essentially the "main" unique_function constructor, but subclasses 254 // provide the qualified type to be used for the call. 255 // (We always store a T, even if the call will use a pointer to const T). 256 template <typename CallableT, typename CalledAsT> 257 UniqueFunctionBase(CallableT Callable, CalledAs<CalledAsT>) { 258 bool IsInlineStorage = true; 259 void *CallableAddr = getInlineStorage(); 260 if (sizeof(CallableT) > InlineStorageSize || 261 alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) { 262 IsInlineStorage = false; 263 // Allocate out-of-line storage. FIXME: Use an explicit alignment 264 // parameter in C++17 mode. 265 auto Size = sizeof(CallableT); 266 auto Alignment = alignof(CallableT); 267 CallableAddr = allocate_buffer(Size, Alignment); 268 setOutOfLineStorage(CallableAddr, Size, Alignment); 269 } 270 271 // Now move into the storage. 272 new (CallableAddr) CallableT(std::move(Callable)); 273 CallbackAndInlineFlag.setPointerAndInt( 274 &CallbacksHolder<CallableT, CalledAsT>::Callbacks, IsInlineStorage); 275 } 276 277 ~UniqueFunctionBase() { 278 if (!CallbackAndInlineFlag.getPointer()) 279 return; 280 281 // Cache this value so we don't re-check it after type-erased operations. 282 bool IsInlineStorage = isInlineStorage(); 283 284 if (!isTrivialCallback()) 285 getNonTrivialCallbacks()->DestroyPtr( 286 IsInlineStorage ? getInlineStorage() : getOutOfLineStorage()); 287 288 if (!IsInlineStorage) 289 deallocate_buffer(getOutOfLineStorage(), getOutOfLineStorageSize(), 290 getOutOfLineStorageAlignment()); 291 } 292 293 UniqueFunctionBase(UniqueFunctionBase &&RHS) noexcept { 294 // Copy the callback and inline flag. 295 CallbackAndInlineFlag = RHS.CallbackAndInlineFlag; 296 297 // If the RHS is empty, just copying the above is sufficient. 298 if (!RHS) 299 return; 300 301 if (!isInlineStorage()) { 302 // The out-of-line case is easiest to move. 303 StorageUnion.OutOfLineStorage = RHS.StorageUnion.OutOfLineStorage; 304 } else if (isTrivialCallback()) { 305 // Move is trivial, just memcpy the bytes across. 306 memcpy(getInlineStorage(), RHS.getInlineStorage(), InlineStorageSize); 307 } else { 308 // Non-trivial move, so dispatch to a type-erased implementation. 309 getNonTrivialCallbacks()->MovePtr(getInlineStorage(), 310 RHS.getInlineStorage()); 311 } 312 313 // Clear the old callback and inline flag to get back to as-if-null. 314 RHS.CallbackAndInlineFlag = {}; 315 316 #ifndef NDEBUG 317 // In debug builds, we also scribble across the rest of the storage. 318 memset(RHS.getInlineStorage(), 0xAD, InlineStorageSize); 319 #endif 320 } 321 322 UniqueFunctionBase &operator=(UniqueFunctionBase &&RHS) noexcept { 323 if (this == &RHS) 324 return *this; 325 326 // Because we don't try to provide any exception safety guarantees we can 327 // implement move assignment very simply by first destroying the current 328 // object and then move-constructing over top of it. 329 this->~UniqueFunctionBase(); 330 new (this) UniqueFunctionBase(std::move(RHS)); 331 return *this; 332 } 333 334 UniqueFunctionBase() = default; 335 336 public: 337 explicit operator bool() const { 338 return (bool)CallbackAndInlineFlag.getPointer(); 339 } 340 }; 341 342 template <typename R, typename... P> 343 template <typename CallableT, typename CalledAsT, typename Enable> 344 typename UniqueFunctionBase<R, P...>::NonTrivialCallbacks UniqueFunctionBase< 345 R, P...>::CallbacksHolder<CallableT, CalledAsT, Enable>::Callbacks = { 346 &CallImpl<CalledAsT>, &MoveImpl<CallableT>, &DestroyImpl<CallableT>}; 347 348 template <typename R, typename... P> 349 template <typename CallableT, typename CalledAsT> 350 typename UniqueFunctionBase<R, P...>::TrivialCallback 351 UniqueFunctionBase<R, P...>::CallbacksHolder< 352 CallableT, CalledAsT, EnableIfTrivial<CallableT>>::Callbacks{ 353 &CallImpl<CalledAsT>}; 354 355 } // namespace detail 356 357 template <typename R, typename... P> 358 class unique_function<R(P...)> : public detail::UniqueFunctionBase<R, P...> { 359 using Base = detail::UniqueFunctionBase<R, P...>; 360 361 public: 362 unique_function() = default; 363 unique_function(std::nullptr_t) {} 364 unique_function(unique_function &&) = default; 365 unique_function(const unique_function &) = delete; 366 unique_function &operator=(unique_function &&) = default; 367 unique_function &operator=(const unique_function &) = delete; 368 369 template <typename CallableT> 370 unique_function( 371 CallableT Callable, 372 detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr, 373 detail::EnableIfCallable<CallableT, R, P...> * = nullptr) 374 : Base(std::forward<CallableT>(Callable), 375 typename Base::template CalledAs<CallableT>{}) {} 376 377 R operator()(P... Params) { 378 return this->getCallPtr()(this->getCalleePtr(), Params...); 379 } 380 }; 381 382 template <typename R, typename... P> 383 class unique_function<R(P...) const> 384 : public detail::UniqueFunctionBase<R, P...> { 385 using Base = detail::UniqueFunctionBase<R, P...>; 386 387 public: 388 unique_function() = default; 389 unique_function(std::nullptr_t) {} 390 unique_function(unique_function &&) = default; 391 unique_function(const unique_function &) = delete; 392 unique_function &operator=(unique_function &&) = default; 393 unique_function &operator=(const unique_function &) = delete; 394 395 template <typename CallableT> 396 unique_function( 397 CallableT Callable, 398 detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr, 399 detail::EnableIfCallable<const CallableT, R, P...> * = nullptr) 400 : Base(std::forward<CallableT>(Callable), 401 typename Base::template CalledAs<const CallableT>{}) {} 402 403 R operator()(P... Params) const { 404 return this->getCallPtr()(this->getCalleePtr(), Params...); 405 } 406 }; 407 408 } // end namespace llvm 409 410 #endif // LLVM_ADT_FUNCTIONEXTRAS_H 411