xref: /llvm-project/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp (revision 8d5c1b4562f880a61c9d9a2bddad73f584cdf311)
18a10ee75SMichele Scuttari //===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===//
28a10ee75SMichele Scuttari //
38a10ee75SMichele Scuttari // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
48a10ee75SMichele Scuttari // See https://llvm.org/LICENSE.txt for license information.
58a10ee75SMichele Scuttari // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a10ee75SMichele Scuttari //
78a10ee75SMichele Scuttari //===----------------------------------------------------------------------===//
88a10ee75SMichele Scuttari 
98a10ee75SMichele Scuttari #include "mlir/ExecutionEngine/CRunnerUtils.h"
108a10ee75SMichele Scuttari #include "llvm/ADT/SmallVector.h"
118a10ee75SMichele Scuttari 
128a10ee75SMichele Scuttari #include "gmock/gmock.h"
138a10ee75SMichele Scuttari 
148a10ee75SMichele Scuttari using namespace ::mlir;
158a10ee75SMichele Scuttari using namespace ::testing;
168a10ee75SMichele Scuttari 
TEST(DynamicMemRef,rankZero)178a10ee75SMichele Scuttari TEST(DynamicMemRef, rankZero) {
188a10ee75SMichele Scuttari   int data = 57;
198a10ee75SMichele Scuttari 
208a10ee75SMichele Scuttari   StridedMemRefType<int, 0> memRef;
218a10ee75SMichele Scuttari   memRef.basePtr = &data;
228a10ee75SMichele Scuttari   memRef.data = &data;
238a10ee75SMichele Scuttari   memRef.offset = 0;
248a10ee75SMichele Scuttari 
258a10ee75SMichele Scuttari   DynamicMemRefType<int> dynamicMemRef(memRef);
268a10ee75SMichele Scuttari 
278a10ee75SMichele Scuttari   llvm::SmallVector<int, 1> values(dynamicMemRef.begin(), dynamicMemRef.end());
288a10ee75SMichele Scuttari   EXPECT_THAT(values, ElementsAre(57));
298a10ee75SMichele Scuttari }
308a10ee75SMichele Scuttari 
TEST(DynamicMemRef,rankOne)318a10ee75SMichele Scuttari TEST(DynamicMemRef, rankOne) {
328a10ee75SMichele Scuttari   std::array<int, 3> data;
338a10ee75SMichele Scuttari 
348a10ee75SMichele Scuttari   for (size_t i = 0; i < data.size(); ++i) {
358a10ee75SMichele Scuttari     data[i] = i;
368a10ee75SMichele Scuttari   }
378a10ee75SMichele Scuttari 
388a10ee75SMichele Scuttari   StridedMemRefType<int, 1> memRef;
398a10ee75SMichele Scuttari   memRef.basePtr = data.data();
408a10ee75SMichele Scuttari   memRef.data = data.data();
418a10ee75SMichele Scuttari   memRef.offset = 0;
428a10ee75SMichele Scuttari   memRef.sizes[0] = 3;
438a10ee75SMichele Scuttari   memRef.strides[0] = 1;
448a10ee75SMichele Scuttari 
458a10ee75SMichele Scuttari   DynamicMemRefType<int> dynamicMemRef(memRef);
468a10ee75SMichele Scuttari 
478a10ee75SMichele Scuttari   llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
488a10ee75SMichele Scuttari   EXPECT_THAT(values, ElementsAreArray(data));
498a10ee75SMichele Scuttari 
508a10ee75SMichele Scuttari   for (int64_t i = 0; i < 3; ++i) {
518a10ee75SMichele Scuttari     EXPECT_EQ(*dynamicMemRef[i], data[i]);
528a10ee75SMichele Scuttari   }
538a10ee75SMichele Scuttari }
548a10ee75SMichele Scuttari 
TEST(DynamicMemRef,rankTwo)558a10ee75SMichele Scuttari TEST(DynamicMemRef, rankTwo) {
568a10ee75SMichele Scuttari   std::array<int, 6> data;
578a10ee75SMichele Scuttari 
588a10ee75SMichele Scuttari   for (size_t i = 0; i < data.size(); ++i) {
598a10ee75SMichele Scuttari     data[i] = i;
608a10ee75SMichele Scuttari   }
618a10ee75SMichele Scuttari 
628a10ee75SMichele Scuttari   StridedMemRefType<int, 2> memRef;
638a10ee75SMichele Scuttari   memRef.basePtr = data.data();
648a10ee75SMichele Scuttari   memRef.data = data.data();
658a10ee75SMichele Scuttari   memRef.offset = 0;
668a10ee75SMichele Scuttari   memRef.sizes[0] = 2;
678a10ee75SMichele Scuttari   memRef.sizes[1] = 3;
688a10ee75SMichele Scuttari   memRef.strides[0] = 3;
698a10ee75SMichele Scuttari   memRef.strides[1] = 1;
708a10ee75SMichele Scuttari 
718a10ee75SMichele Scuttari   DynamicMemRefType<int> dynamicMemRef(memRef);
728a10ee75SMichele Scuttari 
738a10ee75SMichele Scuttari   llvm::SmallVector<int, 6> values(dynamicMemRef.begin(), dynamicMemRef.end());
748a10ee75SMichele Scuttari   EXPECT_THAT(values, ElementsAreArray(data));
758a10ee75SMichele Scuttari }
768a10ee75SMichele Scuttari 
TEST(DynamicMemRef,rankThree)778a10ee75SMichele Scuttari TEST(DynamicMemRef, rankThree) {
788a10ee75SMichele Scuttari   std::array<int, 24> data;
798a10ee75SMichele Scuttari 
808a10ee75SMichele Scuttari   for (size_t i = 0; i < data.size(); ++i) {
818a10ee75SMichele Scuttari     data[i] = i;
828a10ee75SMichele Scuttari   }
838a10ee75SMichele Scuttari 
848a10ee75SMichele Scuttari   StridedMemRefType<int, 3> memRef;
858a10ee75SMichele Scuttari   memRef.basePtr = data.data();
868a10ee75SMichele Scuttari   memRef.data = data.data();
878a10ee75SMichele Scuttari   memRef.offset = 0;
888a10ee75SMichele Scuttari   memRef.sizes[0] = 2;
898a10ee75SMichele Scuttari   memRef.sizes[1] = 3;
908a10ee75SMichele Scuttari   memRef.sizes[2] = 4;
918a10ee75SMichele Scuttari   memRef.strides[0] = 12;
928a10ee75SMichele Scuttari   memRef.strides[1] = 4;
938a10ee75SMichele Scuttari   memRef.strides[2] = 1;
948a10ee75SMichele Scuttari 
958a10ee75SMichele Scuttari   DynamicMemRefType<int> dynamicMemRef(memRef);
968a10ee75SMichele Scuttari 
978a10ee75SMichele Scuttari   llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end());
988a10ee75SMichele Scuttari   EXPECT_THAT(values, ElementsAreArray(data));
998a10ee75SMichele Scuttari }
100*8d5c1b45SFelix Schneider 
TEST(DynamicMemRef,rankOneWithOffset)101*8d5c1b45SFelix Schneider TEST(DynamicMemRef, rankOneWithOffset) {
102*8d5c1b45SFelix Schneider   constexpr int offset = 4;
103*8d5c1b45SFelix Schneider   std::array<int, 3 + offset> buffer;
104*8d5c1b45SFelix Schneider 
105*8d5c1b45SFelix Schneider   for (size_t i = 0; i < buffer.size(); ++i) {
106*8d5c1b45SFelix Schneider     buffer[i] = i;
107*8d5c1b45SFelix Schneider   }
108*8d5c1b45SFelix Schneider 
109*8d5c1b45SFelix Schneider   StridedMemRefType<int, 1> memRef;
110*8d5c1b45SFelix Schneider   memRef.basePtr = buffer.data();
111*8d5c1b45SFelix Schneider   memRef.data = buffer.data();
112*8d5c1b45SFelix Schneider   memRef.offset = offset;
113*8d5c1b45SFelix Schneider   memRef.sizes[0] = 3;
114*8d5c1b45SFelix Schneider   memRef.strides[0] = 1;
115*8d5c1b45SFelix Schneider 
116*8d5c1b45SFelix Schneider   DynamicMemRefType<int> dynamicMemRef(memRef);
117*8d5c1b45SFelix Schneider 
118*8d5c1b45SFelix Schneider   llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
119*8d5c1b45SFelix Schneider 
120*8d5c1b45SFelix Schneider   for (int64_t i = 0; i < 3; ++i) {
121*8d5c1b45SFelix Schneider     EXPECT_EQ(values[i], buffer[offset + i]);
122*8d5c1b45SFelix Schneider     EXPECT_EQ(*dynamicMemRef[i], buffer[offset + i]);
123*8d5c1b45SFelix Schneider   }
124*8d5c1b45SFelix Schneider }
125