xref: /llvm-project/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp (revision 2791162b01e3199b24f2d18f7b370157e2c57daf)
1858be166SMatthias Springer //===- InferShapeTest.cpp - unit tests for shape inference ----------------===//
2858be166SMatthias Springer //
3858be166SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4858be166SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5858be166SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6858be166SMatthias Springer //
7858be166SMatthias Springer //===----------------------------------------------------------------------===//
8858be166SMatthias Springer 
9858be166SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
10858be166SMatthias Springer #include "mlir/IR/AffineMap.h"
11858be166SMatthias Springer #include "mlir/IR/Builders.h"
12858be166SMatthias Springer #include "mlir/IR/BuiltinTypes.h"
13858be166SMatthias Springer #include "gtest/gtest.h"
14858be166SMatthias Springer 
15858be166SMatthias Springer using namespace mlir;
16858be166SMatthias Springer using namespace mlir::memref;
17858be166SMatthias Springer 
18858be166SMatthias Springer // Source memref has identity layout.
TEST(InferShapeTest,inferRankReducedShapeIdentity)19858be166SMatthias Springer TEST(InferShapeTest, inferRankReducedShapeIdentity) {
20858be166SMatthias Springer   MLIRContext ctx;
21858be166SMatthias Springer   OpBuilder b(&ctx);
22858be166SMatthias Springer   auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType());
23858be166SMatthias Springer   auto reducedType = SubViewOp::inferRankReducedResultType(
246c3c5f80SMatthias Springer       /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
25*2791162bSAlex Zinenko   auto expectedType = MemRefType::get(
26*2791162bSAlex Zinenko       {2}, b.getIndexType(),
27*2791162bSAlex Zinenko       StridedLayoutAttr::get(&ctx, /*offset=*/13, /*strides=*/{1}));
28858be166SMatthias Springer   EXPECT_EQ(reducedType, expectedType);
29858be166SMatthias Springer }
30858be166SMatthias Springer 
31858be166SMatthias Springer // Source memref has non-identity layout.
TEST(InferShapeTest,inferRankReducedShapeNonIdentity)32858be166SMatthias Springer TEST(InferShapeTest, inferRankReducedShapeNonIdentity) {
33858be166SMatthias Springer   MLIRContext ctx;
34858be166SMatthias Springer   OpBuilder b(&ctx);
35858be166SMatthias Springer   AffineExpr dim0, dim1;
36858be166SMatthias Springer   bindDims(&ctx, dim0, dim1);
37858be166SMatthias Springer   auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
38858be166SMatthias Springer                                       AffineMap::get(2, 0, 1000 * dim0 + dim1));
39858be166SMatthias Springer   auto reducedType = SubViewOp::inferRankReducedResultType(
406c3c5f80SMatthias Springer       /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
41*2791162bSAlex Zinenko   auto expectedType = MemRefType::get(
42*2791162bSAlex Zinenko       {2}, b.getIndexType(),
43*2791162bSAlex Zinenko       StridedLayoutAttr::get(&ctx, /*offset=*/2003, /*strides=*/{1}));
44858be166SMatthias Springer   EXPECT_EQ(reducedType, expectedType);
45858be166SMatthias Springer }
46858be166SMatthias Springer 
TEST(InferShapeTest,inferRankReducedShapeToScalar)47858be166SMatthias Springer TEST(InferShapeTest, inferRankReducedShapeToScalar) {
48858be166SMatthias Springer   MLIRContext ctx;
49858be166SMatthias Springer   OpBuilder b(&ctx);
50858be166SMatthias Springer   AffineExpr dim0, dim1;
51858be166SMatthias Springer   bindDims(&ctx, dim0, dim1);
52858be166SMatthias Springer   auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
53858be166SMatthias Springer                                       AffineMap::get(2, 0, 1000 * dim0 + dim1));
54858be166SMatthias Springer   auto reducedType = SubViewOp::inferRankReducedResultType(
556c3c5f80SMatthias Springer       /*resultShape=*/{}, sourceMemref, {2, 3}, {1, 1}, {1, 1});
56*2791162bSAlex Zinenko   auto expectedType = MemRefType::get(
57*2791162bSAlex Zinenko       {}, b.getIndexType(),
58*2791162bSAlex Zinenko       StridedLayoutAttr::get(&ctx, /*offset=*/2003, /*strides=*/{}));
59858be166SMatthias Springer   EXPECT_EQ(reducedType, expectedType);
60858be166SMatthias Springer }
61