109349303SJacques Pienaar //===- InferTypeOpInterfaceTest.cpp - Unit Test for type interface --------===//
209349303SJacques Pienaar //
309349303SJacques Pienaar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409349303SJacques Pienaar // See https://llvm.org/LICENSE.txt for license information.
509349303SJacques Pienaar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609349303SJacques Pienaar //
709349303SJacques Pienaar //===----------------------------------------------------------------------===//
809349303SJacques Pienaar
909349303SJacques Pienaar #include "mlir/Interfaces/InferTypeOpInterface.h"
10abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1123aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1209349303SJacques Pienaar #include "mlir/IR/Builders.h"
1309349303SJacques Pienaar #include "mlir/IR/BuiltinOps.h"
1409349303SJacques Pienaar #include "mlir/IR/Dialect.h"
1509349303SJacques Pienaar #include "mlir/IR/DialectImplementation.h"
1609349303SJacques Pienaar #include "mlir/IR/ImplicitLocOpBuilder.h"
1709349303SJacques Pienaar #include "mlir/IR/OpDefinition.h"
1809349303SJacques Pienaar #include "mlir/IR/OpImplementation.h"
199eaff423SRiver Riddle #include "mlir/Parser/Parser.h"
2009349303SJacques Pienaar
2109349303SJacques Pienaar #include <gtest/gtest.h>
2209349303SJacques Pienaar
2309349303SJacques Pienaar using namespace mlir;
2409349303SJacques Pienaar
2509349303SJacques Pienaar class ValueShapeRangeTest : public testing::Test {
2609349303SJacques Pienaar protected:
SetUp()2709349303SJacques Pienaar void SetUp() override {
2809349303SJacques Pienaar const char *ir = R"MLIR(
2963237cddSRiver Riddle func.func @map(%arg : tensor<1xi64>) {
30a54f4eaeSMogball %0 = arith.constant dense<[10]> : tensor<1xi64>
31a54f4eaeSMogball %1 = arith.addi %arg, %0 : tensor<1xi64>
3209349303SJacques Pienaar return
3309349303SJacques Pienaar }
3409349303SJacques Pienaar )MLIR";
3509349303SJacques Pienaar
36abc362a1SJakub Kuderski registry.insert<func::FuncDialect, arith::ArithDialect>();
3709349303SJacques Pienaar ctx.appendDialectRegistry(registry);
38dfaadf6bSChristian Sigg module = parseSourceString<ModuleOp>(ir, &ctx);
39*0441272cSMehdi Amini assert(module);
4058ceae95SRiver Riddle mapFn = cast<func::FuncOp>(module->front());
4109349303SJacques Pienaar }
4209349303SJacques Pienaar
43a54f4eaeSMogball // Create ValueShapeRange on the arith.addi operation.
addiRange()4409349303SJacques Pienaar ValueShapeRange addiRange() {
45f8d5c73cSRiver Riddle auto &fnBody = mapFn.getBody();
4609349303SJacques Pienaar return std::next(fnBody.front().begin())->getOperands();
4709349303SJacques Pienaar }
4809349303SJacques Pienaar
4909349303SJacques Pienaar DialectRegistry registry;
5009349303SJacques Pienaar MLIRContext ctx;
518f66ab1cSSanjoy Das OwningOpRef<ModuleOp> module;
5258ceae95SRiver Riddle func::FuncOp mapFn;
5309349303SJacques Pienaar };
5409349303SJacques Pienaar
TEST_F(ValueShapeRangeTest,ShapesFromValues)5509349303SJacques Pienaar TEST_F(ValueShapeRangeTest, ShapesFromValues) {
5609349303SJacques Pienaar ValueShapeRange range = addiRange();
5709349303SJacques Pienaar
5809349303SJacques Pienaar EXPECT_FALSE(range.getValueAsShape(0));
5909349303SJacques Pienaar ASSERT_TRUE(range.getValueAsShape(1));
6009349303SJacques Pienaar EXPECT_TRUE(range.getValueAsShape(1).hasRank());
6109349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
6209349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
6309349303SJacques Pienaar EXPECT_EQ(range.getShape(1).getRank(), 1);
6409349303SJacques Pienaar EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
6509349303SJacques Pienaar }
6609349303SJacques Pienaar
TEST_F(ValueShapeRangeTest,MapValuesToShapes)6709349303SJacques Pienaar TEST_F(ValueShapeRangeTest, MapValuesToShapes) {
6809349303SJacques Pienaar ValueShapeRange range = addiRange();
6909349303SJacques Pienaar ShapedTypeComponents fixed(SmallVector<int64_t>{30});
7009349303SJacques Pienaar auto mapping = [&](Value val) -> ShapeAdaptor {
7109349303SJacques Pienaar if (val == mapFn.getArgument(0))
7209349303SJacques Pienaar return &fixed;
7309349303SJacques Pienaar return nullptr;
7409349303SJacques Pienaar };
7509349303SJacques Pienaar range.setValueToShapeMapping(mapping);
7609349303SJacques Pienaar
7709349303SJacques Pienaar ASSERT_TRUE(range.getValueAsShape(0));
7809349303SJacques Pienaar EXPECT_TRUE(range.getValueAsShape(0).hasRank());
7909349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(0).getRank(), 1);
8009349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(0).getDimSize(0), 30);
8109349303SJacques Pienaar ASSERT_TRUE(range.getValueAsShape(1));
8209349303SJacques Pienaar EXPECT_TRUE(range.getValueAsShape(1).hasRank());
8309349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
8409349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
8509349303SJacques Pienaar }
8609349303SJacques Pienaar
TEST_F(ValueShapeRangeTest,SettingShapes)8709349303SJacques Pienaar TEST_F(ValueShapeRangeTest, SettingShapes) {
8809349303SJacques Pienaar ShapedTypeComponents shape(SmallVector<int64_t>{10, 20});
8909349303SJacques Pienaar ValueShapeRange range = addiRange();
9009349303SJacques Pienaar auto mapping = [&](Value val) -> ShapeAdaptor {
9109349303SJacques Pienaar if (val == mapFn.getArgument(0))
9209349303SJacques Pienaar return &shape;
9309349303SJacques Pienaar return nullptr;
9409349303SJacques Pienaar };
9509349303SJacques Pienaar range.setOperandShapeMapping(mapping);
9609349303SJacques Pienaar
9709349303SJacques Pienaar ASSERT_TRUE(range.getShape(0));
9809349303SJacques Pienaar EXPECT_EQ(range.getShape(0).getRank(), 2);
9909349303SJacques Pienaar EXPECT_EQ(range.getShape(0).getDimSize(0), 10);
10009349303SJacques Pienaar EXPECT_EQ(range.getShape(0).getDimSize(1), 20);
10109349303SJacques Pienaar ASSERT_TRUE(range.getShape(1));
10209349303SJacques Pienaar EXPECT_EQ(range.getShape(1).getRank(), 1);
10309349303SJacques Pienaar EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
10409349303SJacques Pienaar EXPECT_FALSE(range.getShape(2));
10509349303SJacques Pienaar }
106