xref: /llvm-project/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp (revision 6b4c12284795a3030e37b17047271a47a69bb587)
1 //===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/SCF/IR/SCF.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/OwningOpRef.h"
14 #include "gtest/gtest.h"
15 
16 using namespace mlir;
17 using namespace mlir::scf;
18 
19 //===----------------------------------------------------------------------===//
20 // Test Fixture
21 //===----------------------------------------------------------------------===//
22 
23 class SCFLoopLikeTest : public ::testing::Test {
24 protected:
SCFLoopLikeTest()25   SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
26     context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
27   }
28 
checkUnidimensional(LoopLikeOpInterface loopLikeOp)29   void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
30     std::optional<OpFoldResult> maybeSingleLb =
31         loopLikeOp.getSingleLowerBound();
32     EXPECT_TRUE(maybeSingleLb.has_value());
33     std::optional<OpFoldResult> maybeSingleUb =
34         loopLikeOp.getSingleUpperBound();
35     EXPECT_TRUE(maybeSingleUb.has_value());
36     std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
37     EXPECT_TRUE(maybeSingleStep.has_value());
38     std::optional<OpFoldResult> maybeSingleIndVar =
39         loopLikeOp.getSingleInductionVar();
40     EXPECT_TRUE(maybeSingleIndVar.has_value());
41 
42     std::optional<SmallVector<OpFoldResult>> maybeLb =
43         loopLikeOp.getLoopLowerBounds();
44     ASSERT_TRUE(maybeLb.has_value());
45     EXPECT_EQ((*maybeLb).size(), 1u);
46     std::optional<SmallVector<OpFoldResult>> maybeUb =
47         loopLikeOp.getLoopUpperBounds();
48     ASSERT_TRUE(maybeUb.has_value());
49     EXPECT_EQ((*maybeUb).size(), 1u);
50     std::optional<SmallVector<OpFoldResult>> maybeStep =
51         loopLikeOp.getLoopSteps();
52     ASSERT_TRUE(maybeStep.has_value());
53     EXPECT_EQ((*maybeStep).size(), 1u);
54     std::optional<SmallVector<Value>> maybeInductionVars =
55         loopLikeOp.getLoopInductionVars();
56     ASSERT_TRUE(maybeInductionVars.has_value());
57     EXPECT_EQ((*maybeInductionVars).size(), 1u);
58   }
59 
checkMultidimensional(LoopLikeOpInterface loopLikeOp)60   void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
61     std::optional<OpFoldResult> maybeSingleLb =
62         loopLikeOp.getSingleLowerBound();
63     EXPECT_FALSE(maybeSingleLb.has_value());
64     std::optional<OpFoldResult> maybeSingleUb =
65         loopLikeOp.getSingleUpperBound();
66     EXPECT_FALSE(maybeSingleUb.has_value());
67     std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
68     EXPECT_FALSE(maybeSingleStep.has_value());
69     std::optional<OpFoldResult> maybeSingleIndVar =
70         loopLikeOp.getSingleInductionVar();
71     EXPECT_FALSE(maybeSingleIndVar.has_value());
72 
73     std::optional<SmallVector<OpFoldResult>> maybeLb =
74         loopLikeOp.getLoopLowerBounds();
75     ASSERT_TRUE(maybeLb.has_value());
76     EXPECT_EQ((*maybeLb).size(), 2u);
77     std::optional<SmallVector<OpFoldResult>> maybeUb =
78         loopLikeOp.getLoopUpperBounds();
79     ASSERT_TRUE(maybeUb.has_value());
80     EXPECT_EQ((*maybeUb).size(), 2u);
81     std::optional<SmallVector<OpFoldResult>> maybeStep =
82         loopLikeOp.getLoopSteps();
83     ASSERT_TRUE(maybeStep.has_value());
84     EXPECT_EQ((*maybeStep).size(), 2u);
85     std::optional<SmallVector<Value>> maybeInductionVars =
86         loopLikeOp.getLoopInductionVars();
87     ASSERT_TRUE(maybeInductionVars.has_value());
88     EXPECT_EQ((*maybeInductionVars).size(), 2u);
89   }
90 
91   MLIRContext context;
92   OpBuilder b;
93   Location loc;
94 };
95 
TEST_F(SCFLoopLikeTest,queryUnidimensionalLooplikes)96 TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) {
97   OwningOpRef<arith::ConstantIndexOp> lb =
98       b.create<arith::ConstantIndexOp>(loc, 0);
99   OwningOpRef<arith::ConstantIndexOp> ub =
100       b.create<arith::ConstantIndexOp>(loc, 10);
101   OwningOpRef<arith::ConstantIndexOp> step =
102       b.create<arith::ConstantIndexOp>(loc, 2);
103 
104   OwningOpRef<scf::ForOp> forOp =
105       b.create<scf::ForOp>(loc, lb.get(), ub.get(), step.get());
106   checkUnidimensional(forOp.get());
107 
108   OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
109       loc, ArrayRef<OpFoldResult>(lb->getResult()),
110       ArrayRef<OpFoldResult>(ub->getResult()),
111       ArrayRef<OpFoldResult>(step->getResult()), ValueRange(), std::nullopt);
112   checkUnidimensional(forallOp.get());
113 
114   OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
115       loc, ValueRange(lb->getResult()), ValueRange(ub->getResult()),
116       ValueRange(step->getResult()), ValueRange());
117   checkUnidimensional(parallelOp.get());
118 }
119 
TEST_F(SCFLoopLikeTest,queryMultidimensionalLooplikes)120 TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
121   OwningOpRef<arith::ConstantIndexOp> lb =
122       b.create<arith::ConstantIndexOp>(loc, 0);
123   OwningOpRef<arith::ConstantIndexOp> ub =
124       b.create<arith::ConstantIndexOp>(loc, 10);
125   OwningOpRef<arith::ConstantIndexOp> step =
126       b.create<arith::ConstantIndexOp>(loc, 2);
127 
128   OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>(
129       loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
130       ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
131       ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
132       ValueRange(), std::nullopt);
133   checkMultidimensional(forallOp.get());
134 
135   OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>(
136       loc, ValueRange({lb->getResult(), lb->getResult()}),
137       ValueRange({ub->getResult(), ub->getResult()}),
138       ValueRange({step->getResult(), step->getResult()}), ValueRange());
139   checkMultidimensional(parallelOp.get());
140 }
141