xref: /llvm-project/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h (revision 412d784188257f6b8a3748ac9a800002db861181)
1 //===- CRunnerUtils.h - Utils for debugging MLIR execution ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares basic classes and functions to manipulate structured MLIR
10 // types at runtime. Entities in this file must be compliant with C++11 and be
11 // retargetable, including on targets without a C++ runtime.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_EXECUTIONENGINE_CRUNNERUTILS_H
16 #define MLIR_EXECUTIONENGINE_CRUNNERUTILS_H
17 
18 #ifdef _WIN32
19 #ifndef MLIR_CRUNNERUTILS_EXPORT
20 #ifdef mlir_c_runner_utils_EXPORTS
21 // We are building this library
22 #define MLIR_CRUNNERUTILS_EXPORT __declspec(dllexport)
23 #define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
24 #else
25 // We are using this library
26 #define MLIR_CRUNNERUTILS_EXPORT __declspec(dllimport)
27 #endif // mlir_c_runner_utils_EXPORTS
28 #endif // MLIR_CRUNNERUTILS_EXPORT
29 #else  // _WIN32
30 // Non-windows: use visibility attributes.
31 #define MLIR_CRUNNERUTILS_EXPORT __attribute__((visibility("default")))
32 #define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
33 #endif // _WIN32
34 
35 #include <array>
36 #include <cassert>
37 #include <cstdint>
38 #include <initializer_list>
39 #include <vector>
40 
41 //===----------------------------------------------------------------------===//
42 // Codegen-compatible structures for Vector type.
43 //===----------------------------------------------------------------------===//
44 namespace mlir {
45 namespace detail {
46 
isPowerOf2(int n)47 constexpr bool isPowerOf2(int n) { return (!(n & (n - 1))); }
48 
nextPowerOf2(int n)49 constexpr unsigned nextPowerOf2(int n) {
50   return (n <= 1) ? 1 : (isPowerOf2(n) ? n : (2 * nextPowerOf2((n + 1) / 2)));
51 }
52 
53 template <typename T, int Dim, bool IsPowerOf2>
54 struct Vector1D;
55 
56 template <typename T, int Dim>
57 struct Vector1D<T, Dim, /*IsPowerOf2=*/true> {
58   Vector1D() {
59     static_assert(detail::nextPowerOf2(sizeof(T[Dim])) == sizeof(T[Dim]),
60                   "size error");
61   }
62   inline T &operator[](unsigned i) { return vector[i]; }
63   inline const T &operator[](unsigned i) const { return vector[i]; }
64 
65 private:
66   T vector[Dim];
67 };
68 
69 // 1-D vector, padded to the next power of 2 allocation.
70 // Specialization occurs to avoid zero size arrays (which fail in -Werror).
71 template <typename T, int Dim>
72 struct Vector1D<T, Dim, /*IsPowerOf2=*/false> {
73   Vector1D() {
74     static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error");
75     static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]),
76                   "size error");
77   }
78   inline T &operator[](unsigned i) { return vector[i]; }
79   inline const T &operator[](unsigned i) const { return vector[i]; }
80 
81 private:
82   T vector[Dim];
83   char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])];
84 };
85 } // namespace detail
86 } // namespace mlir
87 
88 // N-D vectors recurse down to 1-D.
89 template <typename T, int Dim, int... Dims>
90 struct Vector {
91   inline Vector<T, Dims...> &operator[](unsigned i) { return vector[i]; }
92   inline const Vector<T, Dims...> &operator[](unsigned i) const {
93     return vector[i];
94   }
95 
96 private:
97   Vector<T, Dims...> vector[Dim];
98 };
99 
100 // 1-D vectors in LLVM are automatically padded to the next power of 2.
101 // We insert explicit padding in to account for this.
102 template <typename T, int Dim>
103 struct Vector<T, Dim>
104     : public mlir::detail::Vector1D<T, Dim,
105                                     mlir::detail::isPowerOf2(sizeof(T[Dim]))> {
106 };
107 
108 template <int D1, typename T>
109 using Vector1D = Vector<T, D1>;
110 template <int D1, int D2, typename T>
111 using Vector2D = Vector<T, D1, D2>;
112 template <int D1, int D2, int D3, typename T>
113 using Vector3D = Vector<T, D1, D2, D3>;
114 template <int D1, int D2, int D3, int D4, typename T>
115 using Vector4D = Vector<T, D1, D2, D3, D4>;
116 
117 template <int N>
118 void dropFront(int64_t arr[N], int64_t *res) {
119   for (unsigned i = 1; i < N; ++i)
120     *(res + i - 1) = arr[i];
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // Codegen-compatible structures for StridedMemRef type.
125 //===----------------------------------------------------------------------===//
126 template <typename T, int Rank>
127 class StridedMemrefIterator;
128 
129 /// StridedMemRef descriptor type with static rank.
130 template <typename T, int N>
131 struct StridedMemRefType {
132   T *basePtr;
133   T *data;
134   int64_t offset;
135   int64_t sizes[N];
136   int64_t strides[N];
137 
138   template <typename Range,
139             typename sfinae = decltype(std::declval<Range>().begin())>
140   T &operator[](Range &&indices) {
141     assert(indices.size() == N &&
142            "indices should match rank in memref subscript");
143     int64_t curOffset = offset;
144     for (int dim = N - 1; dim >= 0; --dim) {
145       int64_t currentIndex = *(indices.begin() + dim);
146       assert(currentIndex < sizes[dim] && "Index overflow");
147       curOffset += currentIndex * strides[dim];
148     }
149     return data[curOffset];
150   }
151 
152   StridedMemrefIterator<T, N> begin() { return {*this, offset}; }
153   StridedMemrefIterator<T, N> end() { return {*this, -1}; }
154 
155   // This operator[] is extremely slow and only for sugaring purposes.
156   StridedMemRefType<T, N - 1> operator[](int64_t idx) {
157     StridedMemRefType<T, N - 1> res;
158     res.basePtr = basePtr;
159     res.data = data;
160     res.offset = offset + idx * strides[0];
161     dropFront<N>(sizes, res.sizes);
162     dropFront<N>(strides, res.strides);
163     return res;
164   }
165 };
166 
167 /// StridedMemRef descriptor type specialized for rank 1.
168 template <typename T>
169 struct StridedMemRefType<T, 1> {
170   T *basePtr;
171   T *data;
172   int64_t offset;
173   int64_t sizes[1];
174   int64_t strides[1];
175 
176   template <typename Range,
177             typename sfinae = decltype(std::declval<Range>().begin())>
178   T &operator[](Range indices) {
179     assert(indices.size() == 1 &&
180            "indices should match rank in memref subscript");
181     return (*this)[*indices.begin()];
182   }
183 
184   StridedMemrefIterator<T, 1> begin() { return {*this, offset}; }
185   StridedMemrefIterator<T, 1> end() { return {*this, -1}; }
186 
187   T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); }
188 };
189 
190 /// StridedMemRef descriptor type specialized for rank 0.
191 template <typename T>
192 struct StridedMemRefType<T, 0> {
193   T *basePtr;
194   T *data;
195   int64_t offset;
196 
197   template <typename Range,
198             typename sfinae = decltype(std::declval<Range>().begin())>
199   T &operator[](Range indices) {
200     assert((indices.size() == 0) &&
201            "Expect empty indices for 0-rank memref subscript");
202     return data[offset];
203   }
204 
205   StridedMemrefIterator<T, 0> begin() { return {*this, offset}; }
206   StridedMemrefIterator<T, 0> end() { return {*this, offset + 1}; }
207 };
208 
209 /// Iterate over all elements in a strided memref.
210 template <typename T, int Rank>
211 class StridedMemrefIterator {
212 public:
213   using iterator_category = std::forward_iterator_tag;
214   using value_type = T;
215   using difference_type = std::ptrdiff_t;
216   using pointer = T *;
217   using reference = T &;
218 
219   StridedMemrefIterator(StridedMemRefType<T, Rank> &descriptor,
220                         int64_t offset = 0)
221       : offset(offset), descriptor(&descriptor) {}
222   StridedMemrefIterator<T, Rank> &operator++() {
223     int dim = Rank - 1;
224     while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
225       offset -= indices[dim] * descriptor->strides[dim];
226       indices[dim] = 0;
227       --dim;
228     }
229     if (dim < 0) {
230       offset = -1;
231       return *this;
232     }
233     ++indices[dim];
234     offset += descriptor->strides[dim];
235     return *this;
236   }
237 
238   reference operator*() { return descriptor->data[offset]; }
239   pointer operator->() { return &descriptor->data[offset]; }
240 
241   const std::array<int64_t, Rank> &getIndices() { return indices; }
242 
243   bool operator==(const StridedMemrefIterator &other) const {
244     return other.offset == offset && other.descriptor == descriptor;
245   }
246 
247   bool operator!=(const StridedMemrefIterator &other) const {
248     return !(*this == other);
249   }
250 
251 private:
252   /// Offset in the buffer. This can be derived from the indices and the
253   /// descriptor.
254   int64_t offset = 0;
255 
256   /// Array of indices in the multi-dimensional memref.
257   std::array<int64_t, Rank> indices = {};
258 
259   /// Descriptor for the strided memref.
260   StridedMemRefType<T, Rank> *descriptor;
261 };
262 
263 /// Iterate over all elements in a 0-ranked strided memref.
264 template <typename T>
265 class StridedMemrefIterator<T, 0> {
266 public:
267   using iterator_category = std::forward_iterator_tag;
268   using value_type = T;
269   using difference_type = std::ptrdiff_t;
270   using pointer = T *;
271   using reference = T &;
272 
273   StridedMemrefIterator(StridedMemRefType<T, 0> &descriptor, int64_t offset = 0)
274       : elt(descriptor.data + offset) {}
275 
276   StridedMemrefIterator<T, 0> &operator++() {
277     ++elt;
278     return *this;
279   }
280 
281   reference operator*() { return *elt; }
282   pointer operator->() { return elt; }
283 
284   // There are no indices for a 0-ranked memref, but this API is provided for
285   // consistency with the general case.
286   const std::array<int64_t, 0> &getIndices() {
287     // Since this is a 0-array of indices we can keep a single global const
288     // copy.
289     static const std::array<int64_t, 0> indices = {};
290     return indices;
291   }
292 
293   bool operator==(const StridedMemrefIterator &other) const {
294     return other.elt == elt;
295   }
296 
297   bool operator!=(const StridedMemrefIterator &other) const {
298     return !(*this == other);
299   }
300 
301 private:
302   /// Pointer to the single element in the zero-ranked memref.
303   T *elt;
304 };
305 
306 //===----------------------------------------------------------------------===//
307 // Codegen-compatible structure for UnrankedMemRef type.
308 //===----------------------------------------------------------------------===//
309 // Unranked MemRef
310 template <typename T>
311 struct UnrankedMemRefType {
312   int64_t rank;
313   void *descriptor;
314 };
315 
316 //===----------------------------------------------------------------------===//
317 // DynamicMemRefType type.
318 //===----------------------------------------------------------------------===//
319 template <typename T>
320 class DynamicMemRefIterator;
321 
322 // A reference to one of the StridedMemRef types.
323 template <typename T>
324 class DynamicMemRefType {
325 public:
326   int64_t rank;
327   T *basePtr;
328   T *data;
329   int64_t offset;
330   const int64_t *sizes;
331   const int64_t *strides;
332 
333   explicit DynamicMemRefType(const StridedMemRefType<T, 0> &memRef)
334       : rank(0), basePtr(memRef.basePtr), data(memRef.data),
335         offset(memRef.offset), sizes(nullptr), strides(nullptr) {}
336   template <int N>
337   explicit DynamicMemRefType(const StridedMemRefType<T, N> &memRef)
338       : rank(N), basePtr(memRef.basePtr), data(memRef.data),
339         offset(memRef.offset), sizes(memRef.sizes), strides(memRef.strides) {}
340   explicit DynamicMemRefType(const ::UnrankedMemRefType<T> &memRef)
341       : rank(memRef.rank) {
342     auto *desc = static_cast<StridedMemRefType<T, 1> *>(memRef.descriptor);
343     basePtr = desc->basePtr;
344     data = desc->data;
345     offset = desc->offset;
346     sizes = rank == 0 ? nullptr : desc->sizes;
347     strides = sizes + rank;
348   }
349 
350   template <typename Range,
351             typename sfinae = decltype(std::declval<Range>().begin())>
352   T &operator[](Range &&indices) {
353     assert(indices.size() == rank &&
354            "indices should match rank in memref subscript");
355     if (rank == 0)
356       return data[offset];
357 
358     int64_t curOffset = offset;
359     for (int dim = rank - 1; dim >= 0; --dim) {
360       int64_t currentIndex = *(indices.begin() + dim);
361       assert(currentIndex < sizes[dim] && "Index overflow");
362       curOffset += currentIndex * strides[dim];
363     }
364     return data[curOffset];
365   }
366 
367   DynamicMemRefIterator<T> begin() { return {*this, offset}; }
368   DynamicMemRefIterator<T> end() { return {*this, -1}; }
369 
370   // This operator[] is extremely slow and only for sugaring purposes.
371   DynamicMemRefType<T> operator[](int64_t idx) {
372     assert(rank > 0 && "can't make a subscript of a zero ranked array");
373 
374     DynamicMemRefType<T> res(*this);
375     --res.rank;
376     res.offset += idx * res.strides[0];
377     ++res.sizes;
378     ++res.strides;
379     return res;
380   }
381 
382   // This operator* can be used in conjunction with the previous operator[] in
383   // order to access the underlying value in case of zero-ranked memref.
384   T &operator*() {
385     assert(rank == 0 && "not a zero-ranked memRef");
386     return data[offset];
387   }
388 };
389 
390 /// Iterate over all elements in a dynamic memref.
391 template <typename T>
392 class DynamicMemRefIterator {
393 public:
394   using iterator_category = std::forward_iterator_tag;
395   using value_type = T;
396   using difference_type = std::ptrdiff_t;
397   using pointer = T *;
398   using reference = T &;
399 
400   DynamicMemRefIterator(DynamicMemRefType<T> &descriptor, int64_t offset = 0)
401       : offset(offset), descriptor(&descriptor) {
402     indices.resize(descriptor.rank, 0);
403   }
404 
405   DynamicMemRefIterator<T> &operator++() {
406     if (descriptor->rank == 0) {
407       offset = -1;
408       return *this;
409     }
410 
411     int dim = descriptor->rank - 1;
412 
413     while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
414       offset -= indices[dim] * descriptor->strides[dim];
415       indices[dim] = 0;
416       --dim;
417     }
418 
419     if (dim < 0) {
420       offset = -1;
421       return *this;
422     }
423 
424     ++indices[dim];
425     offset += descriptor->strides[dim];
426     return *this;
427   }
428 
429   reference operator*() { return descriptor->data[offset]; }
430   pointer operator->() { return &descriptor->data[offset]; }
431 
432   const std::vector<int64_t> &getIndices() { return indices; }
433 
434   bool operator==(const DynamicMemRefIterator &other) const {
435     return other.offset == offset && other.descriptor == descriptor;
436   }
437 
438   bool operator!=(const DynamicMemRefIterator &other) const {
439     return !(*this == other);
440   }
441 
442 private:
443   /// Offset in the buffer. This can be derived from the indices and the
444   /// descriptor.
445   int64_t offset = 0;
446 
447   /// Array of indices in the multi-dimensional memref.
448   std::vector<int64_t> indices = {};
449 
450   /// Descriptor for the dynamic memref.
451   DynamicMemRefType<T> *descriptor;
452 };
453 
454 //===----------------------------------------------------------------------===//
455 // Small runtime support library for memref.copy lowering during codegen.
456 //===----------------------------------------------------------------------===//
457 extern "C" MLIR_CRUNNERUTILS_EXPORT void
458 memrefCopy(int64_t elemSize, ::UnrankedMemRefType<char> *src,
459            ::UnrankedMemRefType<char> *dst);
460 
461 //===----------------------------------------------------------------------===//
462 // Small runtime support library for vector.print lowering during codegen.
463 //===----------------------------------------------------------------------===//
464 extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i);
465 extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u);
466 extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f);
467 extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d);
468 extern "C" MLIR_CRUNNERUTILS_EXPORT void printString(char const *s);
469 extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
470 extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
471 extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
472 extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline();
473 
474 //===----------------------------------------------------------------------===//
475 // Small runtime support library for timing execution and printing GFLOPS
476 //===----------------------------------------------------------------------===//
477 extern "C" MLIR_CRUNNERUTILS_EXPORT void printFlops(double flops);
478 extern "C" MLIR_CRUNNERUTILS_EXPORT double rtclock();
479 
480 //===----------------------------------------------------------------------===//
481 // Runtime support library for random number generation.
482 //===----------------------------------------------------------------------===//
483 // Uses a seed to initialize a random generator and returns the generator.
484 extern "C" MLIR_CRUNNERUTILS_EXPORT void *rtsrand(uint64_t s);
485 // Uses a random number generator g and returns a random number
486 // in the range of [0, m).
487 extern "C" MLIR_CRUNNERUTILS_EXPORT uint64_t rtrand(void *g, uint64_t m);
488 // Deletes the random number generator.
489 extern "C" MLIR_CRUNNERUTILS_EXPORT void rtdrand(void *g);
490 // Uses a random number generator g and std::shuffle to modify mref
491 // in place. Memref mref will be a permutation of all numbers
492 // in the range of [0, size of mref).
493 extern "C" MLIR_CRUNNERUTILS_EXPORT void
494 _mlir_ciface_shuffle(StridedMemRefType<uint64_t, 1> *mref, void *g);
495 
496 //===----------------------------------------------------------------------===//
497 // Runtime support library to allow the use of std::sort in MLIR program.
498 //===----------------------------------------------------------------------===//
499 extern "C" MLIR_CRUNNERUTILS_EXPORT void
500 _mlir_ciface_stdSortI64(uint64_t n, StridedMemRefType<int64_t, 1> *vref);
501 extern "C" MLIR_CRUNNERUTILS_EXPORT void
502 _mlir_ciface_stdSortF64(uint64_t n, StridedMemRefType<double, 1> *vref);
503 extern "C" MLIR_CRUNNERUTILS_EXPORT void
504 _mlir_ciface_stdSortF32(uint64_t n, StridedMemRefType<float, 1> *vref);
505 #endif // MLIR_EXECUTIONENGINE_CRUNNERUTILS_H
506