xref: /llvm-project/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp (revision 2791162b01e3199b24f2d18f7b370157e2c57daf)
1 //===- InferShapeTest.cpp - unit tests for shape inference ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/IR/AffineMap.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "gtest/gtest.h"
14 
15 using namespace mlir;
16 using namespace mlir::memref;
17 
18 // Source memref has identity layout.
TEST(InferShapeTest,inferRankReducedShapeIdentity)19 TEST(InferShapeTest, inferRankReducedShapeIdentity) {
20   MLIRContext ctx;
21   OpBuilder b(&ctx);
22   auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType());
23   auto reducedType = SubViewOp::inferRankReducedResultType(
24       /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
25   auto expectedType = MemRefType::get(
26       {2}, b.getIndexType(),
27       StridedLayoutAttr::get(&ctx, /*offset=*/13, /*strides=*/{1}));
28   EXPECT_EQ(reducedType, expectedType);
29 }
30 
31 // Source memref has non-identity layout.
TEST(InferShapeTest,inferRankReducedShapeNonIdentity)32 TEST(InferShapeTest, inferRankReducedShapeNonIdentity) {
33   MLIRContext ctx;
34   OpBuilder b(&ctx);
35   AffineExpr dim0, dim1;
36   bindDims(&ctx, dim0, dim1);
37   auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
38                                       AffineMap::get(2, 0, 1000 * dim0 + dim1));
39   auto reducedType = SubViewOp::inferRankReducedResultType(
40       /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
41   auto expectedType = MemRefType::get(
42       {2}, b.getIndexType(),
43       StridedLayoutAttr::get(&ctx, /*offset=*/2003, /*strides=*/{1}));
44   EXPECT_EQ(reducedType, expectedType);
45 }
46 
TEST(InferShapeTest,inferRankReducedShapeToScalar)47 TEST(InferShapeTest, inferRankReducedShapeToScalar) {
48   MLIRContext ctx;
49   OpBuilder b(&ctx);
50   AffineExpr dim0, dim1;
51   bindDims(&ctx, dim0, dim1);
52   auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
53                                       AffineMap::get(2, 0, 1000 * dim0 + dim1));
54   auto reducedType = SubViewOp::inferRankReducedResultType(
55       /*resultShape=*/{}, sourceMemref, {2, 3}, {1, 1}, {1, 1});
56   auto expectedType = MemRefType::get(
57       {}, b.getIndexType(),
58       StridedLayoutAttr::get(&ctx, /*offset=*/2003, /*strides=*/{}));
59   EXPECT_EQ(reducedType, expectedType);
60 }
61