1 //===- FunctionExtras.h - Function type erasure utilities -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file provides a collection of function (or more generally, callable) 10 /// type erasure utilities supplementing those provided by the standard library 11 /// in `<function>`. 12 /// 13 /// It provides `unique_function`, which works like `std::function` but supports 14 /// move-only callable objects and const-qualification. 15 /// 16 /// Future plans: 17 /// - Add a `function` that provides ref-qualified support, which doesn't work 18 /// with `std::function`. 19 /// - Provide support for specifying multiple signatures to type erase callable 20 /// objects with an overload set, such as those produced by generic lambdas. 21 /// - Expand to include a copyable utility that directly replaces std::function 22 /// but brings the above improvements. 23 /// 24 /// Note that LLVM's utilities are greatly simplified by not supporting 25 /// allocators. 26 /// 27 /// If the standard library ever begins to provide comparable facilities we can 28 /// consider switching to those. 29 /// 30 //===----------------------------------------------------------------------===// 31 32 #ifndef LLVM_ADT_FUNCTIONEXTRAS_H 33 #define LLVM_ADT_FUNCTIONEXTRAS_H 34 35 #include "llvm/ADT/PointerIntPair.h" 36 #include "llvm/ADT/PointerUnion.h" 37 #include "llvm/ADT/STLForwardCompat.h" 38 #include "llvm/Support/Compiler.h" 39 #include "llvm/Support/MemAlloc.h" 40 #include "llvm/Support/type_traits.h" 41 #include <cstring> 42 #include <memory> 43 #include <type_traits> 44 45 namespace llvm { 46 47 /// unique_function is a type-erasing functor similar to std::function. 48 /// 49 /// It can hold move-only function objects, like lambdas capturing unique_ptrs. 50 /// Accordingly, it is movable but not copyable. 51 /// 52 /// It supports const-qualification: 53 /// - unique_function<int() const> has a const operator(). 54 /// It can only hold functions which themselves have a const operator(). 55 /// - unique_function<int()> has a non-const operator(). 56 /// It can hold functions with a non-const operator(), like mutable lambdas. 57 template <typename FunctionT> class unique_function; 58 59 namespace detail { 60 61 template <typename T> 62 using EnableIfTrivial = 63 std::enable_if_t<std::is_trivially_move_constructible<T>::value && 64 std::is_trivially_destructible<T>::value>; 65 template <typename CallableT, typename ThisT> 66 using EnableUnlessSameType = 67 std::enable_if_t<!std::is_same<remove_cvref_t<CallableT>, ThisT>::value>; 68 template <typename CallableT, typename Ret, typename... Params> 69 using EnableIfCallable = std::enable_if_t<std::disjunction< 70 std::is_void<Ret>, 71 std::is_same<decltype(std::declval<CallableT>()(std::declval<Params>()...)), 72 Ret>, 73 std::is_same<const decltype(std::declval<CallableT>()( 74 std::declval<Params>()...)), 75 Ret>, 76 std::is_convertible<decltype(std::declval<CallableT>()( 77 std::declval<Params>()...)), 78 Ret>>::value>; 79 80 template <typename ReturnT, typename... ParamTs> class UniqueFunctionBase { 81 protected: 82 static constexpr size_t InlineStorageSize = sizeof(void *) * 3; 83 static constexpr size_t InlineStorageAlign = alignof(void *); 84 85 template <typename T, class = void> 86 struct IsSizeLessThanThresholdT : std::false_type {}; 87 88 template <typename T> 89 struct IsSizeLessThanThresholdT< 90 T, std::enable_if_t<sizeof(T) <= 2 * sizeof(void *)>> : std::true_type {}; 91 92 // Provide a type function to map parameters that won't observe extra copies 93 // or moves and which are small enough to likely pass in register to values 94 // and all other types to l-value reference types. We use this to compute the 95 // types used in our erased call utility to minimize copies and moves unless 96 // doing so would force things unnecessarily into memory. 97 // 98 // The heuristic used is related to common ABI register passing conventions. 99 // It doesn't have to be exact though, and in one way it is more strict 100 // because we want to still be able to observe either moves *or* copies. 101 template <typename T> struct AdjustedParamTBase { 102 static_assert(!std::is_reference<T>::value, 103 "references should be handled by template specialization"); 104 using type = 105 std::conditional_t<std::is_trivially_copy_constructible<T>::value && 106 std::is_trivially_move_constructible<T>::value && 107 IsSizeLessThanThresholdT<T>::value, 108 T, T &>; 109 }; 110 111 // This specialization ensures that 'AdjustedParam<V<T>&>' or 112 // 'AdjustedParam<V<T>&&>' does not trigger a compile-time error when 'T' is 113 // an incomplete type and V a templated type. 114 template <typename T> struct AdjustedParamTBase<T &> { using type = T &; }; 115 template <typename T> struct AdjustedParamTBase<T &&> { using type = T &; }; 116 117 template <typename T> 118 using AdjustedParamT = typename AdjustedParamTBase<T>::type; 119 120 // The type of the erased function pointer we use as a callback to dispatch to 121 // the stored callable when it is trivial to move and destroy. 122 using CallPtrT = ReturnT (*)(void *CallableAddr, 123 AdjustedParamT<ParamTs>... Params); 124 using MovePtrT = void (*)(void *LHSCallableAddr, void *RHSCallableAddr); 125 using DestroyPtrT = void (*)(void *CallableAddr); 126 127 /// A struct to hold a single trivial callback with sufficient alignment for 128 /// our bitpacking. 129 struct alignas(8) TrivialCallback { 130 CallPtrT CallPtr; 131 }; 132 133 /// A struct we use to aggregate three callbacks when we need full set of 134 /// operations. 135 struct alignas(8) NonTrivialCallbacks { 136 CallPtrT CallPtr; 137 MovePtrT MovePtr; 138 DestroyPtrT DestroyPtr; 139 }; 140 141 // Create a pointer union between either a pointer to a static trivial call 142 // pointer in a struct or a pointer to a static struct of the call, move, and 143 // destroy pointers. 144 using CallbackPointerUnionT = 145 PointerUnion<TrivialCallback *, NonTrivialCallbacks *>; 146 147 // The main storage buffer. This will either have a pointer to out-of-line 148 // storage or an inline buffer storing the callable. 149 union StorageUnionT { 150 // For out-of-line storage we keep a pointer to the underlying storage and 151 // the size. This is enough to deallocate the memory. 152 struct OutOfLineStorageT { 153 void *StoragePtr; 154 size_t Size; 155 size_t Alignment; 156 } OutOfLineStorage; 157 static_assert( 158 sizeof(OutOfLineStorageT) <= InlineStorageSize, 159 "Should always use all of the out-of-line storage for inline storage!"); 160 161 // For in-line storage, we just provide an aligned character buffer. We 162 // provide three pointers worth of storage here. 163 // This is mutable as an inlined `const unique_function<void() const>` may 164 // still modify its own mutable members. 165 alignas(InlineStorageAlign) mutable std::byte 166 InlineStorage[InlineStorageSize]; 167 } StorageUnion; 168 169 // A compressed pointer to either our dispatching callback or our table of 170 // dispatching callbacks and the flag for whether the callable itself is 171 // stored inline or not. 172 PointerIntPair<CallbackPointerUnionT, 1, bool> CallbackAndInlineFlag; 173 174 bool isInlineStorage() const { return CallbackAndInlineFlag.getInt(); } 175 176 bool isTrivialCallback() const { 177 return isa<TrivialCallback *>(CallbackAndInlineFlag.getPointer()); 178 } 179 180 CallPtrT getTrivialCallback() const { 181 return cast<TrivialCallback *>(CallbackAndInlineFlag.getPointer())->CallPtr; 182 } 183 184 NonTrivialCallbacks *getNonTrivialCallbacks() const { 185 return cast<NonTrivialCallbacks *>(CallbackAndInlineFlag.getPointer()); 186 } 187 188 CallPtrT getCallPtr() const { 189 return isTrivialCallback() ? getTrivialCallback() 190 : getNonTrivialCallbacks()->CallPtr; 191 } 192 193 // These three functions are only const in the narrow sense. They return 194 // mutable pointers to function state. 195 // This allows unique_function<T const>::operator() to be const, even if the 196 // underlying functor may be internally mutable. 197 // 198 // const callers must ensure they're only used in const-correct ways. 199 void *getCalleePtr() const { 200 return isInlineStorage() ? getInlineStorage() : getOutOfLineStorage(); 201 } 202 void *getInlineStorage() const { return &StorageUnion.InlineStorage; } 203 void *getOutOfLineStorage() const { 204 return StorageUnion.OutOfLineStorage.StoragePtr; 205 } 206 207 size_t getOutOfLineStorageSize() const { 208 return StorageUnion.OutOfLineStorage.Size; 209 } 210 size_t getOutOfLineStorageAlignment() const { 211 return StorageUnion.OutOfLineStorage.Alignment; 212 } 213 214 void setOutOfLineStorage(void *Ptr, size_t Size, size_t Alignment) { 215 StorageUnion.OutOfLineStorage = {Ptr, Size, Alignment}; 216 } 217 218 template <typename CalledAsT> 219 static ReturnT CallImpl(void *CallableAddr, 220 AdjustedParamT<ParamTs>... Params) { 221 auto &Func = *reinterpret_cast<CalledAsT *>(CallableAddr); 222 return Func(std::forward<ParamTs>(Params)...); 223 } 224 225 template <typename CallableT> 226 static void MoveImpl(void *LHSCallableAddr, void *RHSCallableAddr) noexcept { 227 new (LHSCallableAddr) 228 CallableT(std::move(*reinterpret_cast<CallableT *>(RHSCallableAddr))); 229 } 230 231 template <typename CallableT> 232 static void DestroyImpl(void *CallableAddr) noexcept { 233 reinterpret_cast<CallableT *>(CallableAddr)->~CallableT(); 234 } 235 236 // The pointers to call/move/destroy functions are determined for each 237 // callable type (and called-as type, which determines the overload chosen). 238 // (definitions are out-of-line). 239 240 // By default, we need an object that contains all the different 241 // type erased behaviors needed. Create a static instance of the struct type 242 // here and each instance will contain a pointer to it. 243 // Wrap in a struct to avoid https://gcc.gnu.org/PR71954 244 template <typename CallableT, typename CalledAs, typename Enable = void> 245 struct CallbacksHolder { 246 static NonTrivialCallbacks Callbacks; 247 }; 248 // See if we can create a trivial callback. We need the callable to be 249 // trivially moved and trivially destroyed so that we don't have to store 250 // type erased callbacks for those operations. 251 template <typename CallableT, typename CalledAs> 252 struct CallbacksHolder<CallableT, CalledAs, EnableIfTrivial<CallableT>> { 253 static TrivialCallback Callbacks; 254 }; 255 256 // A simple tag type so the call-as type to be passed to the constructor. 257 template <typename T> struct CalledAs {}; 258 259 // Essentially the "main" unique_function constructor, but subclasses 260 // provide the qualified type to be used for the call. 261 // (We always store a T, even if the call will use a pointer to const T). 262 template <typename CallableT, typename CalledAsT> 263 UniqueFunctionBase(CallableT Callable, CalledAs<CalledAsT>) { 264 bool IsInlineStorage = true; 265 void *CallableAddr = getInlineStorage(); 266 if (sizeof(CallableT) > InlineStorageSize || 267 alignof(CallableT) > InlineStorageAlign) { 268 IsInlineStorage = false; 269 // Allocate out-of-line storage. FIXME: Use an explicit alignment 270 // parameter in C++17 mode. 271 auto Size = sizeof(CallableT); 272 auto Alignment = alignof(CallableT); 273 CallableAddr = allocate_buffer(Size, Alignment); 274 setOutOfLineStorage(CallableAddr, Size, Alignment); 275 } 276 277 // Now move into the storage. 278 new (CallableAddr) CallableT(std::move(Callable)); 279 CallbackAndInlineFlag.setPointerAndInt( 280 &CallbacksHolder<CallableT, CalledAsT>::Callbacks, IsInlineStorage); 281 } 282 283 ~UniqueFunctionBase() { 284 if (!CallbackAndInlineFlag.getPointer()) 285 return; 286 287 // Cache this value so we don't re-check it after type-erased operations. 288 bool IsInlineStorage = isInlineStorage(); 289 290 if (!isTrivialCallback()) 291 getNonTrivialCallbacks()->DestroyPtr( 292 IsInlineStorage ? getInlineStorage() : getOutOfLineStorage()); 293 294 if (!IsInlineStorage) 295 deallocate_buffer(getOutOfLineStorage(), getOutOfLineStorageSize(), 296 getOutOfLineStorageAlignment()); 297 } 298 299 UniqueFunctionBase(UniqueFunctionBase &&RHS) noexcept { 300 // Copy the callback and inline flag. 301 CallbackAndInlineFlag = RHS.CallbackAndInlineFlag; 302 303 // If the RHS is empty, just copying the above is sufficient. 304 if (!RHS) 305 return; 306 307 if (!isInlineStorage()) { 308 // The out-of-line case is easiest to move. 309 StorageUnion.OutOfLineStorage = RHS.StorageUnion.OutOfLineStorage; 310 } else if (isTrivialCallback()) { 311 // Move is trivial, just memcpy the bytes across. 312 memcpy(getInlineStorage(), RHS.getInlineStorage(), InlineStorageSize); 313 } else { 314 // Non-trivial move, so dispatch to a type-erased implementation. 315 getNonTrivialCallbacks()->MovePtr(getInlineStorage(), 316 RHS.getInlineStorage()); 317 getNonTrivialCallbacks()->DestroyPtr(RHS.getInlineStorage()); 318 } 319 320 // Clear the old callback and inline flag to get back to as-if-null. 321 RHS.CallbackAndInlineFlag = {}; 322 323 #if !defined(NDEBUG) && !LLVM_ADDRESS_SANITIZER_BUILD 324 // In debug builds without ASan, we also scribble across the rest of the 325 // storage. Scribbling under AddressSanitizer (ASan) is disabled to prevent 326 // overwriting poisoned objects (e.g., annotated short strings). 327 memset(RHS.getInlineStorage(), 0xAD, InlineStorageSize); 328 #endif 329 } 330 331 UniqueFunctionBase &operator=(UniqueFunctionBase &&RHS) noexcept { 332 if (this == &RHS) 333 return *this; 334 335 // Because we don't try to provide any exception safety guarantees we can 336 // implement move assignment very simply by first destroying the current 337 // object and then move-constructing over top of it. 338 this->~UniqueFunctionBase(); 339 new (this) UniqueFunctionBase(std::move(RHS)); 340 return *this; 341 } 342 343 UniqueFunctionBase() = default; 344 345 public: 346 explicit operator bool() const { 347 return (bool)CallbackAndInlineFlag.getPointer(); 348 } 349 }; 350 351 template <typename R, typename... P> 352 template <typename CallableT, typename CalledAsT, typename Enable> 353 typename UniqueFunctionBase<R, P...>::NonTrivialCallbacks UniqueFunctionBase< 354 R, P...>::CallbacksHolder<CallableT, CalledAsT, Enable>::Callbacks = { 355 &CallImpl<CalledAsT>, &MoveImpl<CallableT>, &DestroyImpl<CallableT>}; 356 357 template <typename R, typename... P> 358 template <typename CallableT, typename CalledAsT> 359 typename UniqueFunctionBase<R, P...>::TrivialCallback 360 UniqueFunctionBase<R, P...>::CallbacksHolder< 361 CallableT, CalledAsT, EnableIfTrivial<CallableT>>::Callbacks{ 362 &CallImpl<CalledAsT>}; 363 364 } // namespace detail 365 366 template <typename R, typename... P> 367 class unique_function<R(P...)> : public detail::UniqueFunctionBase<R, P...> { 368 using Base = detail::UniqueFunctionBase<R, P...>; 369 370 public: 371 unique_function() = default; 372 unique_function(std::nullptr_t) {} 373 unique_function(unique_function &&) = default; 374 unique_function(const unique_function &) = delete; 375 unique_function &operator=(unique_function &&) = default; 376 unique_function &operator=(const unique_function &) = delete; 377 378 template <typename CallableT> 379 unique_function( 380 CallableT Callable, 381 detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr, 382 detail::EnableIfCallable<CallableT, R, P...> * = nullptr) 383 : Base(std::forward<CallableT>(Callable), 384 typename Base::template CalledAs<CallableT>{}) {} 385 386 R operator()(P... Params) { 387 return this->getCallPtr()(this->getCalleePtr(), Params...); 388 } 389 }; 390 391 template <typename R, typename... P> 392 class unique_function<R(P...) const> 393 : public detail::UniqueFunctionBase<R, P...> { 394 using Base = detail::UniqueFunctionBase<R, P...>; 395 396 public: 397 unique_function() = default; 398 unique_function(std::nullptr_t) {} 399 unique_function(unique_function &&) = default; 400 unique_function(const unique_function &) = delete; 401 unique_function &operator=(unique_function &&) = default; 402 unique_function &operator=(const unique_function &) = delete; 403 404 template <typename CallableT> 405 unique_function( 406 CallableT Callable, 407 detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr, 408 detail::EnableIfCallable<const CallableT, R, P...> * = nullptr) 409 : Base(std::forward<CallableT>(Callable), 410 typename Base::template CalledAs<const CallableT>{}) {} 411 412 R operator()(P... Params) const { 413 return this->getCallPtr()(this->getCalleePtr(), Params...); 414 } 415 }; 416 417 } // end namespace llvm 418 419 #endif // LLVM_ADT_FUNCTIONEXTRAS_H 420