xref: /llvm-project/mlir/include/mlir/ExecutionEngine/RunnerUtils.h (revision 753dc0a01ccc3cbe87d5ee0fe0ec7f8db340966f)
1 //===- RunnerUtils.h - Utils for debugging MLIR execution -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares basic classes and functions to debug structured MLIR
10 // types at runtime. Entities in this file may not be compatible with targets
11 // without a C++ runtime. These may be progressively migrated to CRunnerUtils.h
12 // over time.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_EXECUTIONENGINE_RUNNERUTILS_H
17 #define MLIR_EXECUTIONENGINE_RUNNERUTILS_H
18 
19 #ifdef _WIN32
20 #ifndef MLIR_RUNNERUTILS_EXPORT
21 #ifdef mlir_runner_utils_EXPORTS
22 // We are building this library
23 #define MLIR_RUNNERUTILS_EXPORT __declspec(dllexport)
24 #else
25 // We are using this library
26 #define MLIR_RUNNERUTILS_EXPORT __declspec(dllimport)
27 #endif // mlir_runner_utils_EXPORTS
28 #endif // MLIR_RUNNERUTILS_EXPORT
29 #else
30 // Non-windows: use visibility attributes.
31 #define MLIR_RUNNERUTILS_EXPORT __attribute__((visibility("default")))
32 #endif // _WIN32
33 
34 #include <assert.h>
35 #include <cmath>
36 #include <complex>
37 #include <iomanip>
38 #include <iostream>
39 
40 #include "mlir/ExecutionEngine/CRunnerUtils.h"
41 #include "mlir/ExecutionEngine/Float16bits.h"
42 
43 template <typename T, typename StreamType>
printMemRefMetaData(StreamType & os,const DynamicMemRefType<T> & v)44 void printMemRefMetaData(StreamType &os, const DynamicMemRefType<T> &v) {
45   // Make the printed pointer format platform independent by casting it to an
46   // integer and manually formatting it to a hex with prefix as tests expect.
47   os << "base@ = " << std::hex << std::showbase
48      << reinterpret_cast<std::intptr_t>(v.data) << std::dec << std::noshowbase
49      << " rank = " << v.rank << " offset = " << v.offset;
50   auto print = [&](const int64_t *ptr) {
51     if (v.rank == 0)
52       return;
53     os << ptr[0];
54     for (int64_t i = 1; i < v.rank; ++i)
55       os << ", " << ptr[i];
56   };
57   os << " sizes = [";
58   print(v.sizes);
59   os << "] strides = [";
60   print(v.strides);
61   os << "]";
62 }
63 
64 template <typename StreamType, typename T, int N>
printMemRefMetaData(StreamType & os,StridedMemRefType<T,N> & v)65 void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &v) {
66   static_assert(N >= 0, "Expected N > 0");
67   os << "MemRef ";
68   printMemRefMetaData(os, DynamicMemRefType<T>(v));
69 }
70 
71 template <typename StreamType, typename T>
printUnrankedMemRefMetaData(StreamType & os,UnrankedMemRefType<T> & v)72 void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType<T> &v) {
73   os << "Unranked MemRef ";
74   printMemRefMetaData(os, DynamicMemRefType<T>(v));
75 }
76 
77 ////////////////////////////////////////////////////////////////////////////////
78 // Templated instantiation follows.
79 ////////////////////////////////////////////////////////////////////////////////
80 namespace impl {
81 using index_type = uint64_t;
82 using complex64 = std::complex<double>;
83 using complex32 = std::complex<float>;
84 
85 template <typename T, int M, int... Dims>
86 std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v);
87 
88 template <int... Dims>
89 struct StaticSizeMult {
90   static constexpr int value = 1;
91 };
92 
93 template <int N, int... Dims>
94 struct StaticSizeMult<N, Dims...> {
95   static constexpr int value = N * StaticSizeMult<Dims...>::value;
96 };
97 
98 static inline void printSpace(std::ostream &os, int count) {
99   for (int i = 0; i < count; ++i) {
100     os << ' ';
101   }
102 }
103 
104 template <typename T, int M, int... Dims>
105 struct VectorDataPrinter {
106   static void print(std::ostream &os, const Vector<T, M, Dims...> &val);
107 };
108 
109 template <typename T, int M, int... Dims>
110 void VectorDataPrinter<T, M, Dims...>::print(std::ostream &os,
111                                              const Vector<T, M, Dims...> &val) {
112   static_assert(M > 0, "0 dimensioned tensor");
113   static_assert(sizeof(val) == M * StaticSizeMult<Dims...>::value * sizeof(T),
114                 "Incorrect vector size!");
115   // First
116   os << "(" << val[0];
117   if (M > 1)
118     os << ", ";
119   if (sizeof...(Dims) > 1)
120     os << "\n";
121   // Kernel
122   for (unsigned i = 1; i + 1 < M; ++i) {
123     printSpace(os, 2 * sizeof...(Dims));
124     os << val[i] << ", ";
125     if (sizeof...(Dims) > 1)
126       os << "\n";
127   }
128   // Last
129   if (M > 1) {
130     printSpace(os, sizeof...(Dims));
131     os << val[M - 1];
132   }
133   os << ")";
134 }
135 
136 template <typename T, int M, int... Dims>
137 std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v) {
138   VectorDataPrinter<T, M, Dims...>::print(os, v);
139   return os;
140 }
141 
142 template <typename T>
143 struct MemRefDataPrinter {
144   static void print(std::ostream &os, T *base, int64_t dim, int64_t rank,
145                     int64_t offset, const int64_t *sizes,
146                     const int64_t *strides);
147   static void printFirst(std::ostream &os, T *base, int64_t dim, int64_t rank,
148                          int64_t offset, const int64_t *sizes,
149                          const int64_t *strides);
150   static void printLast(std::ostream &os, T *base, int64_t dim, int64_t rank,
151                         int64_t offset, const int64_t *sizes,
152                         const int64_t *strides);
153 };
154 
155 template <typename T>
156 void MemRefDataPrinter<T>::printFirst(std::ostream &os, T *base, int64_t dim,
157                                       int64_t rank, int64_t offset,
158                                       const int64_t *sizes,
159                                       const int64_t *strides) {
160   os << "[";
161   print(os, base, dim - 1, rank, offset, sizes + 1, strides + 1);
162   // If single element, close square bracket and return early.
163   if (sizes[0] <= 1) {
164     os << "]";
165     return;
166   }
167   os << ", ";
168   if (dim > 1)
169     os << "\n";
170 }
171 
172 template <typename T>
173 void MemRefDataPrinter<T>::print(std::ostream &os, T *base, int64_t dim,
174                                  int64_t rank, int64_t offset,
175                                  const int64_t *sizes, const int64_t *strides) {
176   if (dim == 0) {
177     os << base[offset];
178     return;
179   }
180   printFirst(os, base, dim, rank, offset, sizes, strides);
181   for (unsigned i = 1; i + 1 < sizes[0]; ++i) {
182     printSpace(os, rank - dim + 1);
183     print(os, base, dim - 1, rank, offset + i * strides[0], sizes + 1,
184           strides + 1);
185     os << ", ";
186     if (dim > 1)
187       os << "\n";
188   }
189   if (sizes[0] <= 1)
190     return;
191   printLast(os, base, dim, rank, offset, sizes, strides);
192 }
193 
194 template <typename T>
195 void MemRefDataPrinter<T>::printLast(std::ostream &os, T *base, int64_t dim,
196                                      int64_t rank, int64_t offset,
197                                      const int64_t *sizes,
198                                      const int64_t *strides) {
199   printSpace(os, rank - dim + 1);
200   print(os, base, dim - 1, rank, offset + (sizes[0] - 1) * (*strides),
201         sizes + 1, strides + 1);
202   os << "]";
203 }
204 
205 template <typename T, int N>
206 void printMemRefShape(StridedMemRefType<T, N> &m) {
207   std::cout << "Memref ";
208   printMemRefMetaData(std::cout, DynamicMemRefType<T>(m));
209 }
210 
211 template <typename T>
212 void printMemRefShape(UnrankedMemRefType<T> &m) {
213   std::cout << "Unranked Memref ";
214   printMemRefMetaData(std::cout, DynamicMemRefType<T>(m));
215 }
216 
217 template <typename T>
218 void printMemRef(const DynamicMemRefType<T> &m) {
219   printMemRefMetaData(std::cout, m);
220   std::cout << " data = \n";
221   if (m.rank == 0)
222     std::cout << "[";
223   MemRefDataPrinter<T>::print(std::cout, m.data, m.rank, m.rank, m.offset,
224                               m.sizes, m.strides);
225   if (m.rank == 0)
226     std::cout << "]";
227   std::cout << '\n' << std::flush;
228 }
229 
230 template <typename T, int N>
231 void printMemRef(StridedMemRefType<T, N> &m) {
232   std::cout << "Memref ";
233   printMemRef(DynamicMemRefType<T>(m));
234 }
235 
236 template <typename T>
237 void printMemRef(UnrankedMemRefType<T> &m) {
238   std::cout << "Unranked Memref ";
239   printMemRef(DynamicMemRefType<T>(m));
240 }
241 
242 /// Verify the result of two computations are equivalent up to a small
243 /// numerical error and return the number of errors.
244 template <typename T>
245 struct MemRefDataVerifier {
246   /// Maximum number of errors printed by the verifier.
247   static constexpr int printLimit = 10;
248 
249   /// Verify the relative difference of the values is smaller than epsilon.
250   static bool verifyRelErrorSmallerThan(T actual, T expected, T epsilon);
251 
252   /// Verify the values are equivalent (integers) or are close (floating-point).
253   static bool verifyElem(T actual, T expected);
254 
255   /// Verify the data element-by-element and return the number of errors.
256   static int64_t verify(std::ostream &os, T *actualBasePtr, T *expectedBasePtr,
257                         int64_t dim, int64_t offset, const int64_t *sizes,
258                         const int64_t *strides, int64_t &printCounter);
259 };
260 
261 template <typename T>
262 bool MemRefDataVerifier<T>::verifyRelErrorSmallerThan(T actual, T expected,
263                                                       T epsilon) {
264   // Return an error if one of the values is infinite or NaN.
265   if (!std::isfinite(actual) || !std::isfinite(expected))
266     return false;
267   // Return true if the relative error is smaller than epsilon.
268   T delta = std::abs(actual - expected);
269   return (delta <= epsilon * std::abs(expected));
270 }
271 
272 template <typename T>
273 bool MemRefDataVerifier<T>::verifyElem(T actual, T expected) {
274   return actual == expected;
275 }
276 
277 template <>
278 inline bool MemRefDataVerifier<double>::verifyElem(double actual,
279                                                    double expected) {
280   return verifyRelErrorSmallerThan(actual, expected, 1e-12);
281 }
282 
283 template <>
284 inline bool MemRefDataVerifier<float>::verifyElem(float actual,
285                                                   float expected) {
286   return verifyRelErrorSmallerThan(actual, expected, 1e-6f);
287 }
288 
289 template <typename T>
290 int64_t MemRefDataVerifier<T>::verify(std::ostream &os, T *actualBasePtr,
291                                       T *expectedBasePtr, int64_t dim,
292                                       int64_t offset, const int64_t *sizes,
293                                       const int64_t *strides,
294                                       int64_t &printCounter) {
295   int64_t errors = 0;
296   // Verify the elements at the current offset.
297   if (dim == 0) {
298     if (!verifyElem(actualBasePtr[offset], expectedBasePtr[offset])) {
299       if (printCounter < printLimit) {
300         os << actualBasePtr[offset] << " != " << expectedBasePtr[offset]
301            << " offset = " << offset << "\n";
302         printCounter++;
303       }
304       errors++;
305     }
306   } else {
307     // Iterate the current dimension and verify recursively.
308     for (int64_t i = 0; i < sizes[0]; ++i) {
309       errors +=
310           verify(os, actualBasePtr, expectedBasePtr, dim - 1,
311                  offset + i * strides[0], sizes + 1, strides + 1, printCounter);
312     }
313   }
314   return errors;
315 }
316 
317 /// Verify the equivalence of two dynamic memrefs and return the number of
318 /// errors or -1 if the shape of the memrefs do not match.
319 template <typename T>
320 int64_t verifyMemRef(const DynamicMemRefType<T> &actual,
321                      const DynamicMemRefType<T> &expected) {
322   // Check if the memref shapes match.
323   for (int64_t i = 0; i < actual.rank; ++i) {
324     if (expected.rank != actual.rank || actual.offset != expected.offset ||
325         actual.sizes[i] != expected.sizes[i] ||
326         actual.strides[i] != expected.strides[i]) {
327       printMemRefMetaData(std::cerr, actual);
328       printMemRefMetaData(std::cerr, expected);
329       return -1;
330     }
331   }
332   // Return the number of errors.
333   int64_t printCounter = 0;
334   return MemRefDataVerifier<T>::verify(std::cerr, actual.data, expected.data,
335                                        actual.rank, actual.offset, actual.sizes,
336                                        actual.strides, printCounter);
337 }
338 
339 /// Verify the equivalence of two unranked memrefs and return the number of
340 /// errors or -1 if the shape of the memrefs do not match.
341 template <typename T>
342 int64_t verifyMemRef(UnrankedMemRefType<T> &actual,
343                      UnrankedMemRefType<T> &expected) {
344   return verifyMemRef(DynamicMemRefType<T>(actual),
345                       DynamicMemRefType<T>(expected));
346 }
347 
348 } // namespace impl
349 
350 ////////////////////////////////////////////////////////////////////////////////
351 // Currently exposed C API.
352 ////////////////////////////////////////////////////////////////////////////////
353 extern "C" MLIR_RUNNERUTILS_EXPORT void
354 _mlir_ciface_printMemrefShapeI8(UnrankedMemRefType<int8_t> *m);
355 extern "C" MLIR_RUNNERUTILS_EXPORT void
356 _mlir_ciface_printMemrefShapeI32(UnrankedMemRefType<int32_t> *m);
357 extern "C" MLIR_RUNNERUTILS_EXPORT void
358 _mlir_ciface_printMemrefShapeI64(UnrankedMemRefType<int64_t> *m);
359 extern "C" MLIR_RUNNERUTILS_EXPORT void
360 _mlir_ciface_printMemrefShapeF32(UnrankedMemRefType<float> *m);
361 extern "C" MLIR_RUNNERUTILS_EXPORT void
362 _mlir_ciface_printMemrefShapeF64(UnrankedMemRefType<double> *m);
363 extern "C" MLIR_RUNNERUTILS_EXPORT void
364 _mlir_ciface_printMemrefShapeInd(UnrankedMemRefType<impl::index_type> *m);
365 extern "C" MLIR_RUNNERUTILS_EXPORT void
366 _mlir_ciface_printMemrefShapeC32(UnrankedMemRefType<impl::complex32> *m);
367 extern "C" MLIR_RUNNERUTILS_EXPORT void
368 _mlir_ciface_printMemrefShapeC64(UnrankedMemRefType<impl::complex64> *m);
369 
370 extern "C" MLIR_RUNNERUTILS_EXPORT void
371 _mlir_ciface_printMemrefI8(UnrankedMemRefType<int8_t> *m);
372 extern "C" MLIR_RUNNERUTILS_EXPORT void
373 _mlir_ciface_printMemrefI16(UnrankedMemRefType<int16_t> *m);
374 extern "C" MLIR_RUNNERUTILS_EXPORT void
375 _mlir_ciface_printMemrefI32(UnrankedMemRefType<int32_t> *m);
376 extern "C" MLIR_RUNNERUTILS_EXPORT void
377 _mlir_ciface_printMemrefI64(UnrankedMemRefType<int64_t> *m);
378 extern "C" MLIR_RUNNERUTILS_EXPORT void
379 _mlir_ciface_printMemrefF16(UnrankedMemRefType<f16> *m);
380 extern "C" MLIR_RUNNERUTILS_EXPORT void
381 _mlir_ciface_printMemrefBF16(UnrankedMemRefType<bf16> *m);
382 extern "C" MLIR_RUNNERUTILS_EXPORT void
383 _mlir_ciface_printMemrefF32(UnrankedMemRefType<float> *m);
384 extern "C" MLIR_RUNNERUTILS_EXPORT void
385 _mlir_ciface_printMemrefF64(UnrankedMemRefType<double> *m);
386 extern "C" MLIR_RUNNERUTILS_EXPORT void
387 _mlir_ciface_printMemrefInd(UnrankedMemRefType<impl::index_type> *m);
388 extern "C" MLIR_RUNNERUTILS_EXPORT void
389 _mlir_ciface_printMemrefC32(UnrankedMemRefType<impl::complex32> *m);
390 extern "C" MLIR_RUNNERUTILS_EXPORT void
391 _mlir_ciface_printMemrefC64(UnrankedMemRefType<impl::complex64> *m);
392 
393 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_nanoTime();
394 
395 extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefI32(int64_t rank, void *ptr);
396 extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefI64(int64_t rank, void *ptr);
397 extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefF32(int64_t rank, void *ptr);
398 extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefF64(int64_t rank, void *ptr);
399 extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefInd(int64_t rank, void *ptr);
400 extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefC32(int64_t rank, void *ptr);
401 extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefC64(int64_t rank, void *ptr);
402 
403 extern "C" MLIR_RUNNERUTILS_EXPORT void
404 _mlir_ciface_printMemref0dF32(StridedMemRefType<float, 0> *m);
405 extern "C" MLIR_RUNNERUTILS_EXPORT void
406 _mlir_ciface_printMemref1dF32(StridedMemRefType<float, 1> *m);
407 extern "C" MLIR_RUNNERUTILS_EXPORT void
408 _mlir_ciface_printMemref2dF32(StridedMemRefType<float, 2> *m);
409 extern "C" MLIR_RUNNERUTILS_EXPORT void
410 _mlir_ciface_printMemref3dF32(StridedMemRefType<float, 3> *m);
411 extern "C" MLIR_RUNNERUTILS_EXPORT void
412 _mlir_ciface_printMemref4dF32(StridedMemRefType<float, 4> *m);
413 
414 extern "C" MLIR_RUNNERUTILS_EXPORT void
415 _mlir_ciface_printMemref1dI8(StridedMemRefType<int8_t, 1> *m);
416 extern "C" MLIR_RUNNERUTILS_EXPORT void
417 _mlir_ciface_printMemref1dI32(StridedMemRefType<int32_t, 1> *m);
418 extern "C" MLIR_RUNNERUTILS_EXPORT void
419 _mlir_ciface_printMemref1dI64(StridedMemRefType<int64_t, 1> *m);
420 extern "C" MLIR_RUNNERUTILS_EXPORT void
421 _mlir_ciface_printMemref1dF64(StridedMemRefType<double, 1> *m);
422 extern "C" MLIR_RUNNERUTILS_EXPORT void
423 _mlir_ciface_printMemref1dInd(StridedMemRefType<impl::index_type, 1> *m);
424 extern "C" MLIR_RUNNERUTILS_EXPORT void
425 _mlir_ciface_printMemref1dC32(StridedMemRefType<impl::complex32, 1> *m);
426 extern "C" MLIR_RUNNERUTILS_EXPORT void
427 _mlir_ciface_printMemref1dC64(StridedMemRefType<impl::complex64, 1> *m);
428 
429 extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefVector4x4xf32(
430     StridedMemRefType<Vector2D<4, 4, float>, 2> *m);
431 
432 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI8(
433     UnrankedMemRefType<int8_t> *actual, UnrankedMemRefType<int8_t> *expected);
434 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI16(
435     UnrankedMemRefType<int16_t> *actual, UnrankedMemRefType<int16_t> *expected);
436 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI32(
437     UnrankedMemRefType<int32_t> *actual, UnrankedMemRefType<int32_t> *expected);
438 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI64(
439     UnrankedMemRefType<int64_t> *actual, UnrankedMemRefType<int64_t> *expected);
440 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefBF16(
441     UnrankedMemRefType<bf16> *actual, UnrankedMemRefType<bf16> *expected);
442 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefF16(
443     UnrankedMemRefType<f16> *actual, UnrankedMemRefType<f16> *expected);
444 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefF32(
445     UnrankedMemRefType<float> *actual, UnrankedMemRefType<float> *expected);
446 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefF64(
447     UnrankedMemRefType<double> *actual, UnrankedMemRefType<double> *expected);
448 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t
449 _mlir_ciface_verifyMemRefInd(UnrankedMemRefType<impl::index_type> *actual,
450                              UnrankedMemRefType<impl::index_type> *expected);
451 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t
452 _mlir_ciface_verifyMemRefC32(UnrankedMemRefType<impl::complex32> *actual,
453                              UnrankedMemRefType<impl::complex32> *expected);
454 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t
455 _mlir_ciface_verifyMemRefC64(UnrankedMemRefType<impl::complex64> *actual,
456                              UnrankedMemRefType<impl::complex64> *expected);
457 
458 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefI32(int64_t rank,
459                                                            void *actualPtr,
460                                                            void *expectedPtr);
461 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefF32(int64_t rank,
462                                                            void *actualPtr,
463                                                            void *expectedPtr);
464 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefF64(int64_t rank,
465                                                            void *actualPtr,
466                                                            void *expectedPtr);
467 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefInd(int64_t rank,
468                                                            void *actualPtr,
469                                                            void *expectedPtr);
470 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefC32(int64_t rank,
471                                                            void *actualPtr,
472                                                            void *expectedPtr);
473 extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefC64(int64_t rank,
474                                                            void *actualPtr,
475                                                            void *expectedPtr);
476 
477 #endif // MLIR_EXECUTIONENGINE_RUNNERUTILS_H
478