18a10ee75SMichele Scuttari //===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===//
28a10ee75SMichele Scuttari //
38a10ee75SMichele Scuttari // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
48a10ee75SMichele Scuttari // See https://llvm.org/LICENSE.txt for license information.
58a10ee75SMichele Scuttari // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a10ee75SMichele Scuttari //
78a10ee75SMichele Scuttari //===----------------------------------------------------------------------===//
88a10ee75SMichele Scuttari
98a10ee75SMichele Scuttari #include "mlir/ExecutionEngine/CRunnerUtils.h"
108a10ee75SMichele Scuttari #include "llvm/ADT/SmallVector.h"
118a10ee75SMichele Scuttari
128a10ee75SMichele Scuttari #include "gmock/gmock.h"
138a10ee75SMichele Scuttari
148a10ee75SMichele Scuttari using namespace ::mlir;
158a10ee75SMichele Scuttari using namespace ::testing;
168a10ee75SMichele Scuttari
TEST(DynamicMemRef,rankZero)178a10ee75SMichele Scuttari TEST(DynamicMemRef, rankZero) {
188a10ee75SMichele Scuttari int data = 57;
198a10ee75SMichele Scuttari
208a10ee75SMichele Scuttari StridedMemRefType<int, 0> memRef;
218a10ee75SMichele Scuttari memRef.basePtr = &data;
228a10ee75SMichele Scuttari memRef.data = &data;
238a10ee75SMichele Scuttari memRef.offset = 0;
248a10ee75SMichele Scuttari
258a10ee75SMichele Scuttari DynamicMemRefType<int> dynamicMemRef(memRef);
268a10ee75SMichele Scuttari
278a10ee75SMichele Scuttari llvm::SmallVector<int, 1> values(dynamicMemRef.begin(), dynamicMemRef.end());
288a10ee75SMichele Scuttari EXPECT_THAT(values, ElementsAre(57));
298a10ee75SMichele Scuttari }
308a10ee75SMichele Scuttari
TEST(DynamicMemRef,rankOne)318a10ee75SMichele Scuttari TEST(DynamicMemRef, rankOne) {
328a10ee75SMichele Scuttari std::array<int, 3> data;
338a10ee75SMichele Scuttari
348a10ee75SMichele Scuttari for (size_t i = 0; i < data.size(); ++i) {
358a10ee75SMichele Scuttari data[i] = i;
368a10ee75SMichele Scuttari }
378a10ee75SMichele Scuttari
388a10ee75SMichele Scuttari StridedMemRefType<int, 1> memRef;
398a10ee75SMichele Scuttari memRef.basePtr = data.data();
408a10ee75SMichele Scuttari memRef.data = data.data();
418a10ee75SMichele Scuttari memRef.offset = 0;
428a10ee75SMichele Scuttari memRef.sizes[0] = 3;
438a10ee75SMichele Scuttari memRef.strides[0] = 1;
448a10ee75SMichele Scuttari
458a10ee75SMichele Scuttari DynamicMemRefType<int> dynamicMemRef(memRef);
468a10ee75SMichele Scuttari
478a10ee75SMichele Scuttari llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
488a10ee75SMichele Scuttari EXPECT_THAT(values, ElementsAreArray(data));
498a10ee75SMichele Scuttari
508a10ee75SMichele Scuttari for (int64_t i = 0; i < 3; ++i) {
518a10ee75SMichele Scuttari EXPECT_EQ(*dynamicMemRef[i], data[i]);
528a10ee75SMichele Scuttari }
538a10ee75SMichele Scuttari }
548a10ee75SMichele Scuttari
TEST(DynamicMemRef,rankTwo)558a10ee75SMichele Scuttari TEST(DynamicMemRef, rankTwo) {
568a10ee75SMichele Scuttari std::array<int, 6> data;
578a10ee75SMichele Scuttari
588a10ee75SMichele Scuttari for (size_t i = 0; i < data.size(); ++i) {
598a10ee75SMichele Scuttari data[i] = i;
608a10ee75SMichele Scuttari }
618a10ee75SMichele Scuttari
628a10ee75SMichele Scuttari StridedMemRefType<int, 2> memRef;
638a10ee75SMichele Scuttari memRef.basePtr = data.data();
648a10ee75SMichele Scuttari memRef.data = data.data();
658a10ee75SMichele Scuttari memRef.offset = 0;
668a10ee75SMichele Scuttari memRef.sizes[0] = 2;
678a10ee75SMichele Scuttari memRef.sizes[1] = 3;
688a10ee75SMichele Scuttari memRef.strides[0] = 3;
698a10ee75SMichele Scuttari memRef.strides[1] = 1;
708a10ee75SMichele Scuttari
718a10ee75SMichele Scuttari DynamicMemRefType<int> dynamicMemRef(memRef);
728a10ee75SMichele Scuttari
738a10ee75SMichele Scuttari llvm::SmallVector<int, 6> values(dynamicMemRef.begin(), dynamicMemRef.end());
748a10ee75SMichele Scuttari EXPECT_THAT(values, ElementsAreArray(data));
758a10ee75SMichele Scuttari }
768a10ee75SMichele Scuttari
TEST(DynamicMemRef,rankThree)778a10ee75SMichele Scuttari TEST(DynamicMemRef, rankThree) {
788a10ee75SMichele Scuttari std::array<int, 24> data;
798a10ee75SMichele Scuttari
808a10ee75SMichele Scuttari for (size_t i = 0; i < data.size(); ++i) {
818a10ee75SMichele Scuttari data[i] = i;
828a10ee75SMichele Scuttari }
838a10ee75SMichele Scuttari
848a10ee75SMichele Scuttari StridedMemRefType<int, 3> memRef;
858a10ee75SMichele Scuttari memRef.basePtr = data.data();
868a10ee75SMichele Scuttari memRef.data = data.data();
878a10ee75SMichele Scuttari memRef.offset = 0;
888a10ee75SMichele Scuttari memRef.sizes[0] = 2;
898a10ee75SMichele Scuttari memRef.sizes[1] = 3;
908a10ee75SMichele Scuttari memRef.sizes[2] = 4;
918a10ee75SMichele Scuttari memRef.strides[0] = 12;
928a10ee75SMichele Scuttari memRef.strides[1] = 4;
938a10ee75SMichele Scuttari memRef.strides[2] = 1;
948a10ee75SMichele Scuttari
958a10ee75SMichele Scuttari DynamicMemRefType<int> dynamicMemRef(memRef);
968a10ee75SMichele Scuttari
978a10ee75SMichele Scuttari llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end());
988a10ee75SMichele Scuttari EXPECT_THAT(values, ElementsAreArray(data));
998a10ee75SMichele Scuttari }
100*8d5c1b45SFelix Schneider
TEST(DynamicMemRef,rankOneWithOffset)101*8d5c1b45SFelix Schneider TEST(DynamicMemRef, rankOneWithOffset) {
102*8d5c1b45SFelix Schneider constexpr int offset = 4;
103*8d5c1b45SFelix Schneider std::array<int, 3 + offset> buffer;
104*8d5c1b45SFelix Schneider
105*8d5c1b45SFelix Schneider for (size_t i = 0; i < buffer.size(); ++i) {
106*8d5c1b45SFelix Schneider buffer[i] = i;
107*8d5c1b45SFelix Schneider }
108*8d5c1b45SFelix Schneider
109*8d5c1b45SFelix Schneider StridedMemRefType<int, 1> memRef;
110*8d5c1b45SFelix Schneider memRef.basePtr = buffer.data();
111*8d5c1b45SFelix Schneider memRef.data = buffer.data();
112*8d5c1b45SFelix Schneider memRef.offset = offset;
113*8d5c1b45SFelix Schneider memRef.sizes[0] = 3;
114*8d5c1b45SFelix Schneider memRef.strides[0] = 1;
115*8d5c1b45SFelix Schneider
116*8d5c1b45SFelix Schneider DynamicMemRefType<int> dynamicMemRef(memRef);
117*8d5c1b45SFelix Schneider
118*8d5c1b45SFelix Schneider llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
119*8d5c1b45SFelix Schneider
120*8d5c1b45SFelix Schneider for (int64_t i = 0; i < 3; ++i) {
121*8d5c1b45SFelix Schneider EXPECT_EQ(values[i], buffer[offset + i]);
122*8d5c1b45SFelix Schneider EXPECT_EQ(*dynamicMemRef[i], buffer[offset + i]);
123*8d5c1b45SFelix Schneider }
124*8d5c1b45SFelix Schneider }
125