//===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" #include "gtest/gtest.h" using namespace mlir; using namespace mlir::scf; //===----------------------------------------------------------------------===// // Test Fixture //===----------------------------------------------------------------------===// class SCFLoopLikeTest : public ::testing::Test { protected: SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) { context.loadDialect(); } void checkUnidimensional(LoopLikeOpInterface loopLikeOp) { std::optional maybeSingleLb = loopLikeOp.getSingleLowerBound(); EXPECT_TRUE(maybeSingleLb.has_value()); std::optional maybeSingleUb = loopLikeOp.getSingleUpperBound(); EXPECT_TRUE(maybeSingleUb.has_value()); std::optional maybeSingleStep = loopLikeOp.getSingleStep(); EXPECT_TRUE(maybeSingleStep.has_value()); std::optional maybeSingleIndVar = loopLikeOp.getSingleInductionVar(); EXPECT_TRUE(maybeSingleIndVar.has_value()); std::optional> maybeLb = loopLikeOp.getLoopLowerBounds(); ASSERT_TRUE(maybeLb.has_value()); EXPECT_EQ((*maybeLb).size(), 1u); std::optional> maybeUb = loopLikeOp.getLoopUpperBounds(); ASSERT_TRUE(maybeUb.has_value()); EXPECT_EQ((*maybeUb).size(), 1u); std::optional> maybeStep = loopLikeOp.getLoopSteps(); ASSERT_TRUE(maybeStep.has_value()); EXPECT_EQ((*maybeStep).size(), 1u); std::optional> maybeInductionVars = loopLikeOp.getLoopInductionVars(); ASSERT_TRUE(maybeInductionVars.has_value()); EXPECT_EQ((*maybeInductionVars).size(), 1u); } void checkMultidimensional(LoopLikeOpInterface loopLikeOp) { std::optional maybeSingleLb = loopLikeOp.getSingleLowerBound(); EXPECT_FALSE(maybeSingleLb.has_value()); std::optional maybeSingleUb = loopLikeOp.getSingleUpperBound(); EXPECT_FALSE(maybeSingleUb.has_value()); std::optional maybeSingleStep = loopLikeOp.getSingleStep(); EXPECT_FALSE(maybeSingleStep.has_value()); std::optional maybeSingleIndVar = loopLikeOp.getSingleInductionVar(); EXPECT_FALSE(maybeSingleIndVar.has_value()); std::optional> maybeLb = loopLikeOp.getLoopLowerBounds(); ASSERT_TRUE(maybeLb.has_value()); EXPECT_EQ((*maybeLb).size(), 2u); std::optional> maybeUb = loopLikeOp.getLoopUpperBounds(); ASSERT_TRUE(maybeUb.has_value()); EXPECT_EQ((*maybeUb).size(), 2u); std::optional> maybeStep = loopLikeOp.getLoopSteps(); ASSERT_TRUE(maybeStep.has_value()); EXPECT_EQ((*maybeStep).size(), 2u); std::optional> maybeInductionVars = loopLikeOp.getLoopInductionVars(); ASSERT_TRUE(maybeInductionVars.has_value()); EXPECT_EQ((*maybeInductionVars).size(), 2u); } MLIRContext context; OpBuilder b; Location loc; }; TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) { OwningOpRef lb = b.create(loc, 0); OwningOpRef ub = b.create(loc, 10); OwningOpRef step = b.create(loc, 2); OwningOpRef forOp = b.create(loc, lb.get(), ub.get(), step.get()); checkUnidimensional(forOp.get()); OwningOpRef forallOp = b.create( loc, ArrayRef(lb->getResult()), ArrayRef(ub->getResult()), ArrayRef(step->getResult()), ValueRange(), std::nullopt); checkUnidimensional(forallOp.get()); OwningOpRef parallelOp = b.create( loc, ValueRange(lb->getResult()), ValueRange(ub->getResult()), ValueRange(step->getResult()), ValueRange()); checkUnidimensional(parallelOp.get()); } TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) { OwningOpRef lb = b.create(loc, 0); OwningOpRef ub = b.create(loc, 10); OwningOpRef step = b.create(loc, 2); OwningOpRef forallOp = b.create( loc, ArrayRef({lb->getResult(), lb->getResult()}), ArrayRef({ub->getResult(), ub->getResult()}), ArrayRef({step->getResult(), step->getResult()}), ValueRange(), std::nullopt); checkMultidimensional(forallOp.get()); OwningOpRef parallelOp = b.create( loc, ValueRange({lb->getResult(), lb->getResult()}), ValueRange({ub->getResult(), ub->getResult()}), ValueRange({step->getResult(), step->getResult()}), ValueRange()); checkMultidimensional(parallelOp.get()); }