1858be166SMatthias Springer //===- InferShapeTest.cpp - unit tests for shape inference ----------------===//
2858be166SMatthias Springer //
3858be166SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4858be166SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5858be166SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6858be166SMatthias Springer //
7858be166SMatthias Springer //===----------------------------------------------------------------------===//
8858be166SMatthias Springer
9858be166SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
10858be166SMatthias Springer #include "mlir/IR/AffineMap.h"
11858be166SMatthias Springer #include "mlir/IR/Builders.h"
12858be166SMatthias Springer #include "mlir/IR/BuiltinTypes.h"
13858be166SMatthias Springer #include "gtest/gtest.h"
14858be166SMatthias Springer
15858be166SMatthias Springer using namespace mlir;
16858be166SMatthias Springer using namespace mlir::memref;
17858be166SMatthias Springer
18858be166SMatthias Springer // Source memref has identity layout.
TEST(InferShapeTest,inferRankReducedShapeIdentity)19858be166SMatthias Springer TEST(InferShapeTest, inferRankReducedShapeIdentity) {
20858be166SMatthias Springer MLIRContext ctx;
21858be166SMatthias Springer OpBuilder b(&ctx);
22858be166SMatthias Springer auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType());
23858be166SMatthias Springer auto reducedType = SubViewOp::inferRankReducedResultType(
246c3c5f80SMatthias Springer /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
25*2791162bSAlex Zinenko auto expectedType = MemRefType::get(
26*2791162bSAlex Zinenko {2}, b.getIndexType(),
27*2791162bSAlex Zinenko StridedLayoutAttr::get(&ctx, /*offset=*/13, /*strides=*/{1}));
28858be166SMatthias Springer EXPECT_EQ(reducedType, expectedType);
29858be166SMatthias Springer }
30858be166SMatthias Springer
31858be166SMatthias Springer // Source memref has non-identity layout.
TEST(InferShapeTest,inferRankReducedShapeNonIdentity)32858be166SMatthias Springer TEST(InferShapeTest, inferRankReducedShapeNonIdentity) {
33858be166SMatthias Springer MLIRContext ctx;
34858be166SMatthias Springer OpBuilder b(&ctx);
35858be166SMatthias Springer AffineExpr dim0, dim1;
36858be166SMatthias Springer bindDims(&ctx, dim0, dim1);
37858be166SMatthias Springer auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
38858be166SMatthias Springer AffineMap::get(2, 0, 1000 * dim0 + dim1));
39858be166SMatthias Springer auto reducedType = SubViewOp::inferRankReducedResultType(
406c3c5f80SMatthias Springer /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
41*2791162bSAlex Zinenko auto expectedType = MemRefType::get(
42*2791162bSAlex Zinenko {2}, b.getIndexType(),
43*2791162bSAlex Zinenko StridedLayoutAttr::get(&ctx, /*offset=*/2003, /*strides=*/{1}));
44858be166SMatthias Springer EXPECT_EQ(reducedType, expectedType);
45858be166SMatthias Springer }
46858be166SMatthias Springer
TEST(InferShapeTest,inferRankReducedShapeToScalar)47858be166SMatthias Springer TEST(InferShapeTest, inferRankReducedShapeToScalar) {
48858be166SMatthias Springer MLIRContext ctx;
49858be166SMatthias Springer OpBuilder b(&ctx);
50858be166SMatthias Springer AffineExpr dim0, dim1;
51858be166SMatthias Springer bindDims(&ctx, dim0, dim1);
52858be166SMatthias Springer auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
53858be166SMatthias Springer AffineMap::get(2, 0, 1000 * dim0 + dim1));
54858be166SMatthias Springer auto reducedType = SubViewOp::inferRankReducedResultType(
556c3c5f80SMatthias Springer /*resultShape=*/{}, sourceMemref, {2, 3}, {1, 1}, {1, 1});
56*2791162bSAlex Zinenko auto expectedType = MemRefType::get(
57*2791162bSAlex Zinenko {}, b.getIndexType(),
58*2791162bSAlex Zinenko StridedLayoutAttr::get(&ctx, /*offset=*/2003, /*strides=*/{}));
59858be166SMatthias Springer EXPECT_EQ(reducedType, expectedType);
60858be166SMatthias Springer }
61