xref: /llvm-project/mlir/include/mlir/ExecutionEngine/MemRefUtils.h (revision 586ce6ac9e45c736923f105530a51017c518c82f)
1 //===- MemRefUtils.h - Memref helpers to invoke MLIR JIT code ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utils for MLIR ABI interfacing with frameworks.
10 //
11 // The templated free functions below make it possible to allocate dense
12 // contiguous buffers with shapes that interoperate properly with the MLIR
13 // codegen ABI.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/CRunnerUtils.h"
18 #include "mlir/Support/LLVM.h"
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/STLExtras.h"
21 
22 #include "llvm/Support/raw_ostream.h"
23 
24 #include <algorithm>
25 #include <array>
26 #include <cassert>
27 #include <climits>
28 #include <functional>
29 #include <initializer_list>
30 #include <memory>
31 #include <optional>
32 
33 #ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
34 #define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
35 
36 namespace mlir {
37 using AllocFunType = llvm::function_ref<void *(size_t)>;
38 
39 namespace detail {
40 
41 /// Given a shape with sizes greater than 0 along all dimensions, returns the
42 /// distance, in number of elements, between a slice in a dimension and the next
43 /// slice in the same dimension.
44 ///    e.g. shape[3, 4, 5] -> strides[20, 5, 1]
45 template <size_t N>
makeStrides(ArrayRef<int64_t> shape)46 inline std::array<int64_t, N> makeStrides(ArrayRef<int64_t> shape) {
47   assert(shape.size() == N && "expect shape specification to match rank");
48   std::array<int64_t, N> res;
49   int64_t running = 1;
50   for (int64_t idx = N - 1; idx >= 0; --idx) {
51     assert(shape[idx] && "size must be non-negative for all shape dimensions");
52     res[idx] = running;
53     running *= shape[idx];
54   }
55   return res;
56 }
57 
58 /// Build a `StridedMemRefDescriptor<T, N>` that matches the MLIR ABI.
59 /// This is an implementation detail that is kept in sync with MLIR codegen
60 /// conventions.  Additionally takes a `shapeAlloc` array which
61 /// is used instead of `shape` to allocate "more aligned" data and compute the
62 /// corresponding strides.
63 template <int N, typename T>
64 typename std::enable_if<(N >= 1), StridedMemRefType<T, N>>::type
makeStridedMemRefDescriptor(T * ptr,T * alignedPtr,ArrayRef<int64_t> shape,ArrayRef<int64_t> shapeAlloc)65 makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape,
66                             ArrayRef<int64_t> shapeAlloc) {
67   assert(shape.size() == N);
68   assert(shapeAlloc.size() == N);
69   StridedMemRefType<T, N> descriptor;
70   descriptor.basePtr = static_cast<T *>(ptr);
71   descriptor.data = static_cast<T *>(alignedPtr);
72   descriptor.offset = 0;
73   std::copy(shape.begin(), shape.end(), descriptor.sizes);
74   auto strides = makeStrides<N>(shapeAlloc);
75   std::copy(strides.begin(), strides.end(), descriptor.strides);
76   return descriptor;
77 }
78 
79 /// Build a `StridedMemRefDescriptor<T, 0>` that matches the MLIR ABI.
80 /// This is an implementation detail that is kept in sync with MLIR codegen
81 /// conventions.  Additionally takes a `shapeAlloc` array which
82 /// is used instead of `shape` to allocate "more aligned" data and compute the
83 /// corresponding strides.
84 template <int N, typename T>
85 typename std::enable_if<(N == 0), StridedMemRefType<T, 0>>::type
86 makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape = {},
87                             ArrayRef<int64_t> shapeAlloc = {}) {
88   assert(shape.size() == N);
89   assert(shapeAlloc.size() == N);
90   StridedMemRefType<T, 0> descriptor;
91   descriptor.basePtr = static_cast<T *>(ptr);
92   descriptor.data = static_cast<T *>(alignedPtr);
93   descriptor.offset = 0;
94   return descriptor;
95 }
96 
97 /// Align `nElements` of type T with an optional `alignment`.
98 /// This replaces a portable `posix_memalign`.
99 /// `alignment` must be a power of 2 and greater than the size of T. By default
100 /// the alignment is sizeof(T).
101 template <typename T>
102 std::pair<T *, T *>
103 allocAligned(size_t nElements, AllocFunType allocFun = &::malloc,
104              std::optional<uint64_t> alignment = std::optional<uint64_t>()) {
105   assert(sizeof(T) <= UINT_MAX && "Elemental type overflows");
106   auto size = nElements * sizeof(T);
107   auto desiredAlignment = alignment.value_or(nextPowerOf2(sizeof(T)));
108   assert((desiredAlignment & (desiredAlignment - 1)) == 0);
109   assert(desiredAlignment >= sizeof(T));
110   T *data = reinterpret_cast<T *>(allocFun(size + desiredAlignment));
111   uintptr_t addr = reinterpret_cast<uintptr_t>(data);
112   uintptr_t rem = addr % desiredAlignment;
113   T *alignedData = (rem == 0)
114                        ? data
115                        : reinterpret_cast<T *>(addr + (desiredAlignment - rem));
116   assert(reinterpret_cast<uintptr_t>(alignedData) % desiredAlignment == 0);
117   return std::make_pair(data, alignedData);
118 }
119 
120 } // namespace detail
121 
122 //===----------------------------------------------------------------------===//
123 // Public API
124 //===----------------------------------------------------------------------===//
125 
126 /// Convenient callback to "visit" a memref element by element.
127 /// This takes a reference to an individual element as well as the coordinates.
128 /// It can be used in conjuction with a StridedMemrefIterator.
129 template <typename T>
130 using ElementWiseVisitor = llvm::function_ref<void(T &ptr, ArrayRef<int64_t>)>;
131 
132 /// Owning MemRef type that abstracts over the runtime type for ranked strided
133 /// memref.
134 template <typename T, unsigned Rank>
135 class OwningMemRef {
136 public:
137   using DescriptorType = StridedMemRefType<T, Rank>;
138   using FreeFunType = std::function<void(DescriptorType)>;
139 
140   /// Allocate a new dense StridedMemrefRef with a given `shape`. An optional
141   /// `shapeAlloc` array can be supplied to "pad" every dimension individually.
142   /// If an ElementWiseVisitor is provided, it will be used to initialize the
143   /// data, else the memory will be zero-initialized. The alloc and free method
144   /// used to manage the data allocation can be optionally provided, and default
145   /// to malloc/free.
146   OwningMemRef(
147       ArrayRef<int64_t> shape, ArrayRef<int64_t> shapeAlloc = {},
148       ElementWiseVisitor<T> init = {},
149       std::optional<uint64_t> alignment = std::optional<uint64_t>(),
150       AllocFunType allocFun = &::malloc,
151       std::function<void(StridedMemRefType<T, Rank>)> freeFun =
152           [](StridedMemRefType<T, Rank> descriptor) {
153             ::free(descriptor.data);
154           })
freeFunc(freeFun)155       : freeFunc(freeFun) {
156     if (shapeAlloc.empty())
157       shapeAlloc = shape;
158     assert(shape.size() == Rank);
159     assert(shapeAlloc.size() == Rank);
160     for (unsigned i = 0; i < Rank; ++i)
161       assert(shape[i] <= shapeAlloc[i] &&
162              "shapeAlloc must be greater than or equal to shape");
163     int64_t nElements = 1;
164     for (int64_t s : shapeAlloc)
165       nElements *= s;
166     auto [data, alignedData] =
167         detail::allocAligned<T>(nElements, allocFun, alignment);
168     descriptor = detail::makeStridedMemRefDescriptor<Rank>(data, alignedData,
169                                                            shape, shapeAlloc);
170     if (init) {
171       for (StridedMemrefIterator<T, Rank> it = descriptor.begin(),
172                                           end = descriptor.end();
173            it != end; ++it)
174         init(*it, it.getIndices());
175     } else {
176       memset(descriptor.data, 0,
177              nElements * sizeof(T) +
178                  alignment.value_or(detail::nextPowerOf2(sizeof(T))));
179     }
180   }
181   /// Take ownership of an existing descriptor with a custom deleter.
OwningMemRef(DescriptorType descriptor,FreeFunType freeFunc)182   OwningMemRef(DescriptorType descriptor, FreeFunType freeFunc)
183       : freeFunc(freeFunc), descriptor(descriptor) {}
~OwningMemRef()184   ~OwningMemRef() {
185     if (freeFunc)
186       freeFunc(descriptor);
187   }
188   OwningMemRef(const OwningMemRef &) = delete;
189   OwningMemRef &operator=(const OwningMemRef &) = delete;
190   OwningMemRef &operator=(const OwningMemRef &&other) {
191     freeFunc = other.freeFunc;
192     descriptor = other.descriptor;
193     other.freeFunc = nullptr;
194     memset(&other.descriptor, 0, sizeof(other.descriptor));
195   }
OwningMemRef(OwningMemRef && other)196   OwningMemRef(OwningMemRef &&other) { *this = std::move(other); }
197 
198   DescriptorType &operator*() { return descriptor; }
199   DescriptorType *operator->() { return &descriptor; }
200   T &operator[](std::initializer_list<int64_t> indices) {
201     return descriptor[indices];
202   }
203 
204 private:
205   /// Custom deleter used to release the data buffer manager with the descriptor
206   /// below.
207   FreeFunType freeFunc;
208   /// The descriptor is an instance of StridedMemRefType<T, rank>.
209   DescriptorType descriptor;
210 };
211 
212 } // namespace mlir
213 
214 #endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
215