xref: /llvm-project/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp (revision 8d5c1b4562f880a61c9d9a2bddad73f584cdf311)
1 //===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/ExecutionEngine/CRunnerUtils.h"
10 #include "llvm/ADT/SmallVector.h"
11 
12 #include "gmock/gmock.h"
13 
14 using namespace ::mlir;
15 using namespace ::testing;
16 
TEST(DynamicMemRef,rankZero)17 TEST(DynamicMemRef, rankZero) {
18   int data = 57;
19 
20   StridedMemRefType<int, 0> memRef;
21   memRef.basePtr = &data;
22   memRef.data = &data;
23   memRef.offset = 0;
24 
25   DynamicMemRefType<int> dynamicMemRef(memRef);
26 
27   llvm::SmallVector<int, 1> values(dynamicMemRef.begin(), dynamicMemRef.end());
28   EXPECT_THAT(values, ElementsAre(57));
29 }
30 
TEST(DynamicMemRef,rankOne)31 TEST(DynamicMemRef, rankOne) {
32   std::array<int, 3> data;
33 
34   for (size_t i = 0; i < data.size(); ++i) {
35     data[i] = i;
36   }
37 
38   StridedMemRefType<int, 1> memRef;
39   memRef.basePtr = data.data();
40   memRef.data = data.data();
41   memRef.offset = 0;
42   memRef.sizes[0] = 3;
43   memRef.strides[0] = 1;
44 
45   DynamicMemRefType<int> dynamicMemRef(memRef);
46 
47   llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
48   EXPECT_THAT(values, ElementsAreArray(data));
49 
50   for (int64_t i = 0; i < 3; ++i) {
51     EXPECT_EQ(*dynamicMemRef[i], data[i]);
52   }
53 }
54 
TEST(DynamicMemRef,rankTwo)55 TEST(DynamicMemRef, rankTwo) {
56   std::array<int, 6> data;
57 
58   for (size_t i = 0; i < data.size(); ++i) {
59     data[i] = i;
60   }
61 
62   StridedMemRefType<int, 2> memRef;
63   memRef.basePtr = data.data();
64   memRef.data = data.data();
65   memRef.offset = 0;
66   memRef.sizes[0] = 2;
67   memRef.sizes[1] = 3;
68   memRef.strides[0] = 3;
69   memRef.strides[1] = 1;
70 
71   DynamicMemRefType<int> dynamicMemRef(memRef);
72 
73   llvm::SmallVector<int, 6> values(dynamicMemRef.begin(), dynamicMemRef.end());
74   EXPECT_THAT(values, ElementsAreArray(data));
75 }
76 
TEST(DynamicMemRef,rankThree)77 TEST(DynamicMemRef, rankThree) {
78   std::array<int, 24> data;
79 
80   for (size_t i = 0; i < data.size(); ++i) {
81     data[i] = i;
82   }
83 
84   StridedMemRefType<int, 3> memRef;
85   memRef.basePtr = data.data();
86   memRef.data = data.data();
87   memRef.offset = 0;
88   memRef.sizes[0] = 2;
89   memRef.sizes[1] = 3;
90   memRef.sizes[2] = 4;
91   memRef.strides[0] = 12;
92   memRef.strides[1] = 4;
93   memRef.strides[2] = 1;
94 
95   DynamicMemRefType<int> dynamicMemRef(memRef);
96 
97   llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end());
98   EXPECT_THAT(values, ElementsAreArray(data));
99 }
100 
TEST(DynamicMemRef,rankOneWithOffset)101 TEST(DynamicMemRef, rankOneWithOffset) {
102   constexpr int offset = 4;
103   std::array<int, 3 + offset> buffer;
104 
105   for (size_t i = 0; i < buffer.size(); ++i) {
106     buffer[i] = i;
107   }
108 
109   StridedMemRefType<int, 1> memRef;
110   memRef.basePtr = buffer.data();
111   memRef.data = buffer.data();
112   memRef.offset = offset;
113   memRef.sizes[0] = 3;
114   memRef.strides[0] = 1;
115 
116   DynamicMemRefType<int> dynamicMemRef(memRef);
117 
118   llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
119 
120   for (int64_t i = 0; i < 3; ++i) {
121     EXPECT_EQ(values[i], buffer[offset + i]);
122     EXPECT_EQ(*dynamicMemRef[i], buffer[offset + i]);
123   }
124 }
125