xref: /llvm-project/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp (revision 0441272c4501a72b86cd5c0b8ab96802d1abb8d9)
109349303SJacques Pienaar //===- InferTypeOpInterfaceTest.cpp - Unit Test for type interface --------===//
209349303SJacques Pienaar //
309349303SJacques Pienaar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409349303SJacques Pienaar // See https://llvm.org/LICENSE.txt for license information.
509349303SJacques Pienaar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609349303SJacques Pienaar //
709349303SJacques Pienaar //===----------------------------------------------------------------------===//
809349303SJacques Pienaar 
909349303SJacques Pienaar #include "mlir/Interfaces/InferTypeOpInterface.h"
10abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1123aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1209349303SJacques Pienaar #include "mlir/IR/Builders.h"
1309349303SJacques Pienaar #include "mlir/IR/BuiltinOps.h"
1409349303SJacques Pienaar #include "mlir/IR/Dialect.h"
1509349303SJacques Pienaar #include "mlir/IR/DialectImplementation.h"
1609349303SJacques Pienaar #include "mlir/IR/ImplicitLocOpBuilder.h"
1709349303SJacques Pienaar #include "mlir/IR/OpDefinition.h"
1809349303SJacques Pienaar #include "mlir/IR/OpImplementation.h"
199eaff423SRiver Riddle #include "mlir/Parser/Parser.h"
2009349303SJacques Pienaar 
2109349303SJacques Pienaar #include <gtest/gtest.h>
2209349303SJacques Pienaar 
2309349303SJacques Pienaar using namespace mlir;
2409349303SJacques Pienaar 
2509349303SJacques Pienaar class ValueShapeRangeTest : public testing::Test {
2609349303SJacques Pienaar protected:
SetUp()2709349303SJacques Pienaar   void SetUp() override {
2809349303SJacques Pienaar     const char *ir = R"MLIR(
2963237cddSRiver Riddle       func.func @map(%arg : tensor<1xi64>) {
30a54f4eaeSMogball         %0 = arith.constant dense<[10]> : tensor<1xi64>
31a54f4eaeSMogball         %1 = arith.addi %arg, %0 : tensor<1xi64>
3209349303SJacques Pienaar         return
3309349303SJacques Pienaar       }
3409349303SJacques Pienaar     )MLIR";
3509349303SJacques Pienaar 
36abc362a1SJakub Kuderski     registry.insert<func::FuncDialect, arith::ArithDialect>();
3709349303SJacques Pienaar     ctx.appendDialectRegistry(registry);
38dfaadf6bSChristian Sigg     module = parseSourceString<ModuleOp>(ir, &ctx);
39*0441272cSMehdi Amini     assert(module);
4058ceae95SRiver Riddle     mapFn = cast<func::FuncOp>(module->front());
4109349303SJacques Pienaar   }
4209349303SJacques Pienaar 
43a54f4eaeSMogball   // Create ValueShapeRange on the arith.addi operation.
addiRange()4409349303SJacques Pienaar   ValueShapeRange addiRange() {
45f8d5c73cSRiver Riddle     auto &fnBody = mapFn.getBody();
4609349303SJacques Pienaar     return std::next(fnBody.front().begin())->getOperands();
4709349303SJacques Pienaar   }
4809349303SJacques Pienaar 
4909349303SJacques Pienaar   DialectRegistry registry;
5009349303SJacques Pienaar   MLIRContext ctx;
518f66ab1cSSanjoy Das   OwningOpRef<ModuleOp> module;
5258ceae95SRiver Riddle   func::FuncOp mapFn;
5309349303SJacques Pienaar };
5409349303SJacques Pienaar 
TEST_F(ValueShapeRangeTest,ShapesFromValues)5509349303SJacques Pienaar TEST_F(ValueShapeRangeTest, ShapesFromValues) {
5609349303SJacques Pienaar   ValueShapeRange range = addiRange();
5709349303SJacques Pienaar 
5809349303SJacques Pienaar   EXPECT_FALSE(range.getValueAsShape(0));
5909349303SJacques Pienaar   ASSERT_TRUE(range.getValueAsShape(1));
6009349303SJacques Pienaar   EXPECT_TRUE(range.getValueAsShape(1).hasRank());
6109349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
6209349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
6309349303SJacques Pienaar   EXPECT_EQ(range.getShape(1).getRank(), 1);
6409349303SJacques Pienaar   EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
6509349303SJacques Pienaar }
6609349303SJacques Pienaar 
TEST_F(ValueShapeRangeTest,MapValuesToShapes)6709349303SJacques Pienaar TEST_F(ValueShapeRangeTest, MapValuesToShapes) {
6809349303SJacques Pienaar   ValueShapeRange range = addiRange();
6909349303SJacques Pienaar   ShapedTypeComponents fixed(SmallVector<int64_t>{30});
7009349303SJacques Pienaar   auto mapping = [&](Value val) -> ShapeAdaptor {
7109349303SJacques Pienaar     if (val == mapFn.getArgument(0))
7209349303SJacques Pienaar       return &fixed;
7309349303SJacques Pienaar     return nullptr;
7409349303SJacques Pienaar   };
7509349303SJacques Pienaar   range.setValueToShapeMapping(mapping);
7609349303SJacques Pienaar 
7709349303SJacques Pienaar   ASSERT_TRUE(range.getValueAsShape(0));
7809349303SJacques Pienaar   EXPECT_TRUE(range.getValueAsShape(0).hasRank());
7909349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(0).getRank(), 1);
8009349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(0).getDimSize(0), 30);
8109349303SJacques Pienaar   ASSERT_TRUE(range.getValueAsShape(1));
8209349303SJacques Pienaar   EXPECT_TRUE(range.getValueAsShape(1).hasRank());
8309349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
8409349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
8509349303SJacques Pienaar }
8609349303SJacques Pienaar 
TEST_F(ValueShapeRangeTest,SettingShapes)8709349303SJacques Pienaar TEST_F(ValueShapeRangeTest, SettingShapes) {
8809349303SJacques Pienaar   ShapedTypeComponents shape(SmallVector<int64_t>{10, 20});
8909349303SJacques Pienaar   ValueShapeRange range = addiRange();
9009349303SJacques Pienaar   auto mapping = [&](Value val) -> ShapeAdaptor {
9109349303SJacques Pienaar     if (val == mapFn.getArgument(0))
9209349303SJacques Pienaar       return &shape;
9309349303SJacques Pienaar     return nullptr;
9409349303SJacques Pienaar   };
9509349303SJacques Pienaar   range.setOperandShapeMapping(mapping);
9609349303SJacques Pienaar 
9709349303SJacques Pienaar   ASSERT_TRUE(range.getShape(0));
9809349303SJacques Pienaar   EXPECT_EQ(range.getShape(0).getRank(), 2);
9909349303SJacques Pienaar   EXPECT_EQ(range.getShape(0).getDimSize(0), 10);
10009349303SJacques Pienaar   EXPECT_EQ(range.getShape(0).getDimSize(1), 20);
10109349303SJacques Pienaar   ASSERT_TRUE(range.getShape(1));
10209349303SJacques Pienaar   EXPECT_EQ(range.getShape(1).getRank(), 1);
10309349303SJacques Pienaar   EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
10409349303SJacques Pienaar   EXPECT_FALSE(range.getShape(2));
10509349303SJacques Pienaar }
106