1 //===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/SCF/IR/SCF.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/OwningOpRef.h"
14 #include "gtest/gtest.h"
15
16 using namespace mlir;
17 using namespace mlir::scf;
18
19 //===----------------------------------------------------------------------===//
20 // Test Fixture
21 //===----------------------------------------------------------------------===//
22
23 class SCFLoopLikeTest : public ::testing::Test {
24 protected:
SCFLoopLikeTest()25 SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
26 context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
27 }
28
checkUnidimensional(LoopLikeOpInterface loopLikeOp)29 void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
30 std::optional<OpFoldResult> maybeSingleLb =
31 loopLikeOp.getSingleLowerBound();
32 EXPECT_TRUE(maybeSingleLb.has_value());
33 std::optional<OpFoldResult> maybeSingleUb =
34 loopLikeOp.getSingleUpperBound();
35 EXPECT_TRUE(maybeSingleUb.has_value());
36 std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
37 EXPECT_TRUE(maybeSingleStep.has_value());
38 std::optional<OpFoldResult> maybeSingleIndVar =
39 loopLikeOp.getSingleInductionVar();
40 EXPECT_TRUE(maybeSingleIndVar.has_value());
41
42 std::optional<SmallVector<OpFoldResult>> maybeLb =
43 loopLikeOp.getLoopLowerBounds();
44 ASSERT_TRUE(maybeLb.has_value());
45 EXPECT_EQ((*maybeLb).size(), 1u);
46 std::optional<SmallVector<OpFoldResult>> maybeUb =
47 loopLikeOp.getLoopUpperBounds();
48 ASSERT_TRUE(maybeUb.has_value());
49 EXPECT_EQ((*maybeUb).size(), 1u);
50 std::optional<SmallVector<OpFoldResult>> maybeStep =
51 loopLikeOp.getLoopSteps();
52 ASSERT_TRUE(maybeStep.has_value());
53 EXPECT_EQ((*maybeStep).size(), 1u);
54 std::optional<SmallVector<Value>> maybeInductionVars =
55 loopLikeOp.getLoopInductionVars();
56 ASSERT_TRUE(maybeInductionVars.has_value());
57 EXPECT_EQ((*maybeInductionVars).size(), 1u);
58 }
59
checkMultidimensional(LoopLikeOpInterface loopLikeOp)60 void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
61 std::optional<OpFoldResult> maybeSingleLb =
62 loopLikeOp.getSingleLowerBound();
63 EXPECT_FALSE(maybeSingleLb.has_value());
64 std::optional<OpFoldResult> maybeSingleUb =
65 loopLikeOp.getSingleUpperBound();
66 EXPECT_FALSE(maybeSingleUb.has_value());
67 std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
68 EXPECT_FALSE(maybeSingleStep.has_value());
69 std::optional<OpFoldResult> maybeSingleIndVar =
70 loopLikeOp.getSingleInductionVar();
71 EXPECT_FALSE(maybeSingleIndVar.has_value());
72
73 std::optional<SmallVector<OpFoldResult>> maybeLb =
74 loopLikeOp.getLoopLowerBounds();
75 ASSERT_TRUE(maybeLb.has_value());
76 EXPECT_EQ((*maybeLb).size(), 2u);
77 std::optional<SmallVector<OpFoldResult>> maybeUb =
78 loopLikeOp.getLoopUpperBounds();
79 ASSERT_TRUE(maybeUb.has_value());
80 EXPECT_EQ((*maybeUb).size(), 2u);
81 std::optional<SmallVector<OpFoldResult>> maybeStep =
82 loopLikeOp.getLoopSteps();
83 ASSERT_TRUE(maybeStep.has_value());
84 EXPECT_EQ((*maybeStep).size(), 2u);
85 std::optional<SmallVector<Value>> maybeInductionVars =
86 loopLikeOp.getLoopInductionVars();
87 ASSERT_TRUE(maybeInductionVars.has_value());
88 EXPECT_EQ((*maybeInductionVars).size(), 2u);
89 }
90
91 MLIRContext context;
92 OpBuilder b;
93 Location loc;
94 };
95
TEST_F(SCFLoopLikeTest,queryUnidimensionalLooplikes)96 TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) {
97 OwningOpRef<arith::ConstantIndexOp> lb =
98 b.create<arith::ConstantIndexOp>(loc, 0);
99 OwningOpRef<arith::ConstantIndexOp> ub =
100 b.create<arith::ConstantIndexOp>(loc, 10);
101 OwningOpRef<arith::ConstantIndexOp> step =
102 b.create<arith::ConstantIndexOp>(loc, 2);
103
104 OwningOpRef<scf::ForOp> forOp =
105 b.create<scf::ForOp>(loc, lb.get(), ub.get(), step.get());
106 checkUnidimensional(forOp.get());
107
108 OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
109 loc, ArrayRef<OpFoldResult>(lb->getResult()),
110 ArrayRef<OpFoldResult>(ub->getResult()),
111 ArrayRef<OpFoldResult>(step->getResult()), ValueRange(), std::nullopt);
112 checkUnidimensional(forallOp.get());
113
114 OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
115 loc, ValueRange(lb->getResult()), ValueRange(ub->getResult()),
116 ValueRange(step->getResult()), ValueRange());
117 checkUnidimensional(parallelOp.get());
118 }
119
TEST_F(SCFLoopLikeTest,queryMultidimensionalLooplikes)120 TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
121 OwningOpRef<arith::ConstantIndexOp> lb =
122 b.create<arith::ConstantIndexOp>(loc, 0);
123 OwningOpRef<arith::ConstantIndexOp> ub =
124 b.create<arith::ConstantIndexOp>(loc, 10);
125 OwningOpRef<arith::ConstantIndexOp> step =
126 b.create<arith::ConstantIndexOp>(loc, 2);
127
128 OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
129 loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
130 ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
131 ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
132 ValueRange(), std::nullopt);
133 checkMultidimensional(forallOp.get());
134
135 OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
136 loc, ValueRange({lb->getResult(), lb->getResult()}),
137 ValueRange({ub->getResult(), ub->getResult()}),
138 ValueRange({step->getResult(), step->getResult()}), ValueRange());
139 checkMultidimensional(parallelOp.get());
140 }
141