1aa0208d1SFelix Schneider //===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===//
2aa0208d1SFelix Schneider //
3aa0208d1SFelix Schneider // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4aa0208d1SFelix Schneider // See https://llvm.org/LICENSE.txt for license information.
5aa0208d1SFelix Schneider // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6aa0208d1SFelix Schneider //
7aa0208d1SFelix Schneider //===----------------------------------------------------------------------===//
8aa0208d1SFelix Schneider
9aa0208d1SFelix Schneider #include "mlir/Dialect/Arith/IR/Arith.h"
10aa0208d1SFelix Schneider #include "mlir/Dialect/SCF/IR/SCF.h"
11aa0208d1SFelix Schneider #include "mlir/IR/Diagnostics.h"
12aa0208d1SFelix Schneider #include "mlir/IR/MLIRContext.h"
1340ba0ca5SAdrian Kuegel #include "mlir/IR/OwningOpRef.h"
14aa0208d1SFelix Schneider #include "gtest/gtest.h"
15aa0208d1SFelix Schneider
16aa0208d1SFelix Schneider using namespace mlir;
17aa0208d1SFelix Schneider using namespace mlir::scf;
18aa0208d1SFelix Schneider
19aa0208d1SFelix Schneider //===----------------------------------------------------------------------===//
20aa0208d1SFelix Schneider // Test Fixture
21aa0208d1SFelix Schneider //===----------------------------------------------------------------------===//
22aa0208d1SFelix Schneider
23aa0208d1SFelix Schneider class SCFLoopLikeTest : public ::testing::Test {
24aa0208d1SFelix Schneider protected:
SCFLoopLikeTest()25aa0208d1SFelix Schneider SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
26aa0208d1SFelix Schneider context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
27aa0208d1SFelix Schneider }
28aa0208d1SFelix Schneider
checkUnidimensional(LoopLikeOpInterface loopLikeOp)29aa0208d1SFelix Schneider void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
30*6b4c1228Ssrcarroll std::optional<OpFoldResult> maybeSingleLb =
31*6b4c1228Ssrcarroll loopLikeOp.getSingleLowerBound();
32*6b4c1228Ssrcarroll EXPECT_TRUE(maybeSingleLb.has_value());
33*6b4c1228Ssrcarroll std::optional<OpFoldResult> maybeSingleUb =
34*6b4c1228Ssrcarroll loopLikeOp.getSingleUpperBound();
35*6b4c1228Ssrcarroll EXPECT_TRUE(maybeSingleUb.has_value());
36*6b4c1228Ssrcarroll std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
37*6b4c1228Ssrcarroll EXPECT_TRUE(maybeSingleStep.has_value());
38*6b4c1228Ssrcarroll std::optional<OpFoldResult> maybeSingleIndVar =
39aa0208d1SFelix Schneider loopLikeOp.getSingleInductionVar();
40*6b4c1228Ssrcarroll EXPECT_TRUE(maybeSingleIndVar.has_value());
41*6b4c1228Ssrcarroll
42*6b4c1228Ssrcarroll std::optional<SmallVector<OpFoldResult>> maybeLb =
43*6b4c1228Ssrcarroll loopLikeOp.getLoopLowerBounds();
44*6b4c1228Ssrcarroll ASSERT_TRUE(maybeLb.has_value());
45*6b4c1228Ssrcarroll EXPECT_EQ((*maybeLb).size(), 1u);
46*6b4c1228Ssrcarroll std::optional<SmallVector<OpFoldResult>> maybeUb =
47*6b4c1228Ssrcarroll loopLikeOp.getLoopUpperBounds();
48*6b4c1228Ssrcarroll ASSERT_TRUE(maybeUb.has_value());
49*6b4c1228Ssrcarroll EXPECT_EQ((*maybeUb).size(), 1u);
50*6b4c1228Ssrcarroll std::optional<SmallVector<OpFoldResult>> maybeStep =
51*6b4c1228Ssrcarroll loopLikeOp.getLoopSteps();
52*6b4c1228Ssrcarroll ASSERT_TRUE(maybeStep.has_value());
53*6b4c1228Ssrcarroll EXPECT_EQ((*maybeStep).size(), 1u);
54*6b4c1228Ssrcarroll std::optional<SmallVector<Value>> maybeInductionVars =
55*6b4c1228Ssrcarroll loopLikeOp.getLoopInductionVars();
56*6b4c1228Ssrcarroll ASSERT_TRUE(maybeInductionVars.has_value());
57*6b4c1228Ssrcarroll EXPECT_EQ((*maybeInductionVars).size(), 1u);
58aa0208d1SFelix Schneider }
59aa0208d1SFelix Schneider
checkMultidimensional(LoopLikeOpInterface loopLikeOp)60aa0208d1SFelix Schneider void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
61*6b4c1228Ssrcarroll std::optional<OpFoldResult> maybeSingleLb =
62*6b4c1228Ssrcarroll loopLikeOp.getSingleLowerBound();
63*6b4c1228Ssrcarroll EXPECT_FALSE(maybeSingleLb.has_value());
64*6b4c1228Ssrcarroll std::optional<OpFoldResult> maybeSingleUb =
65*6b4c1228Ssrcarroll loopLikeOp.getSingleUpperBound();
66*6b4c1228Ssrcarroll EXPECT_FALSE(maybeSingleUb.has_value());
67*6b4c1228Ssrcarroll std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
68*6b4c1228Ssrcarroll EXPECT_FALSE(maybeSingleStep.has_value());
69*6b4c1228Ssrcarroll std::optional<OpFoldResult> maybeSingleIndVar =
70aa0208d1SFelix Schneider loopLikeOp.getSingleInductionVar();
71*6b4c1228Ssrcarroll EXPECT_FALSE(maybeSingleIndVar.has_value());
72*6b4c1228Ssrcarroll
73*6b4c1228Ssrcarroll std::optional<SmallVector<OpFoldResult>> maybeLb =
74*6b4c1228Ssrcarroll loopLikeOp.getLoopLowerBounds();
75*6b4c1228Ssrcarroll ASSERT_TRUE(maybeLb.has_value());
76*6b4c1228Ssrcarroll EXPECT_EQ((*maybeLb).size(), 2u);
77*6b4c1228Ssrcarroll std::optional<SmallVector<OpFoldResult>> maybeUb =
78*6b4c1228Ssrcarroll loopLikeOp.getLoopUpperBounds();
79*6b4c1228Ssrcarroll ASSERT_TRUE(maybeUb.has_value());
80*6b4c1228Ssrcarroll EXPECT_EQ((*maybeUb).size(), 2u);
81*6b4c1228Ssrcarroll std::optional<SmallVector<OpFoldResult>> maybeStep =
82*6b4c1228Ssrcarroll loopLikeOp.getLoopSteps();
83*6b4c1228Ssrcarroll ASSERT_TRUE(maybeStep.has_value());
84*6b4c1228Ssrcarroll EXPECT_EQ((*maybeStep).size(), 2u);
85*6b4c1228Ssrcarroll std::optional<SmallVector<Value>> maybeInductionVars =
86*6b4c1228Ssrcarroll loopLikeOp.getLoopInductionVars();
87*6b4c1228Ssrcarroll ASSERT_TRUE(maybeInductionVars.has_value());
88*6b4c1228Ssrcarroll EXPECT_EQ((*maybeInductionVars).size(), 2u);
89aa0208d1SFelix Schneider }
90aa0208d1SFelix Schneider
91aa0208d1SFelix Schneider MLIRContext context;
92aa0208d1SFelix Schneider OpBuilder b;
93aa0208d1SFelix Schneider Location loc;
94aa0208d1SFelix Schneider };
95aa0208d1SFelix Schneider
TEST_F(SCFLoopLikeTest,queryUnidimensionalLooplikes)96aa0208d1SFelix Schneider TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) {
9740ba0ca5SAdrian Kuegel OwningOpRef<arith::ConstantIndexOp> lb =
9840ba0ca5SAdrian Kuegel b.create<arith::ConstantIndexOp>(loc, 0);
9940ba0ca5SAdrian Kuegel OwningOpRef<arith::ConstantIndexOp> ub =
10040ba0ca5SAdrian Kuegel b.create<arith::ConstantIndexOp>(loc, 10);
10140ba0ca5SAdrian Kuegel OwningOpRef<arith::ConstantIndexOp> step =
10240ba0ca5SAdrian Kuegel b.create<arith::ConstantIndexOp>(loc, 2);
103aa0208d1SFelix Schneider
10440ba0ca5SAdrian Kuegel OwningOpRef<scf::ForOp> forOp =
10540ba0ca5SAdrian Kuegel b.create<scf::ForOp>(loc, lb.get(), ub.get(), step.get());
10640ba0ca5SAdrian Kuegel checkUnidimensional(forOp.get());
107aa0208d1SFelix Schneider
10840ba0ca5SAdrian Kuegel OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
1092fe9a342SAdrian Kuegel loc, ArrayRef<OpFoldResult>(lb->getResult()),
1102fe9a342SAdrian Kuegel ArrayRef<OpFoldResult>(ub->getResult()),
1112fe9a342SAdrian Kuegel ArrayRef<OpFoldResult>(step->getResult()), ValueRange(), std::nullopt);
11240ba0ca5SAdrian Kuegel checkUnidimensional(forallOp.get());
113aa0208d1SFelix Schneider
11435e4a28fSMehdi Amini OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
11535e4a28fSMehdi Amini loc, ValueRange(lb->getResult()), ValueRange(ub->getResult()),
11635e4a28fSMehdi Amini ValueRange(step->getResult()), ValueRange());
11740ba0ca5SAdrian Kuegel checkUnidimensional(parallelOp.get());
118aa0208d1SFelix Schneider }
119aa0208d1SFelix Schneider
TEST_F(SCFLoopLikeTest,queryMultidimensionalLooplikes)120aa0208d1SFelix Schneider TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
12140ba0ca5SAdrian Kuegel OwningOpRef<arith::ConstantIndexOp> lb =
12240ba0ca5SAdrian Kuegel b.create<arith::ConstantIndexOp>(loc, 0);
12340ba0ca5SAdrian Kuegel OwningOpRef<arith::ConstantIndexOp> ub =
12440ba0ca5SAdrian Kuegel b.create<arith::ConstantIndexOp>(loc, 10);
12540ba0ca5SAdrian Kuegel OwningOpRef<arith::ConstantIndexOp> step =
12640ba0ca5SAdrian Kuegel b.create<arith::ConstantIndexOp>(loc, 2);
127aa0208d1SFelix Schneider
1282fe9a342SAdrian Kuegel OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
1292fe9a342SAdrian Kuegel loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
1302fe9a342SAdrian Kuegel ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
1312fe9a342SAdrian Kuegel ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
13240ba0ca5SAdrian Kuegel ValueRange(), std::nullopt);
13340ba0ca5SAdrian Kuegel checkMultidimensional(forallOp.get());
134aa0208d1SFelix Schneider
13540ba0ca5SAdrian Kuegel OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
1362fe9a342SAdrian Kuegel loc, ValueRange({lb->getResult(), lb->getResult()}),
1372fe9a342SAdrian Kuegel ValueRange({ub->getResult(), ub->getResult()}),
1382fe9a342SAdrian Kuegel ValueRange({step->getResult(), step->getResult()}), ValueRange());
13940ba0ca5SAdrian Kuegel checkMultidimensional(parallelOp.get());
140aa0208d1SFelix Schneider }
141