xref: /llvm-project/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (revision 35e89897a4086f5adbab10b4b90aa63ef5b35514)
18c885658SMatthias Springer //===- TestReifyValueBounds.cpp - Test value bounds reification -----------===//
28c885658SMatthias Springer //
38c885658SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48c885658SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
58c885658SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68c885658SMatthias Springer //
78c885658SMatthias Springer //===----------------------------------------------------------------------===//
88c885658SMatthias Springer 
9f8d314f0SMatthias Springer #include "TestDialect.h"
10e95e94adSJeff Niu #include "TestOps.h"
118c885658SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
12ebaf8d49SMatthias Springer #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
138c885658SMatthias Springer #include "mlir/Dialect/Affine/Transforms/Transforms.h"
14d5b63124SMatthias Springer #include "mlir/Dialect/Arith/Transforms/Transforms.h"
158c885658SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
16c3f5fd76SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
17c3f5fd76SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
182861856bSBenjamin Maxwell #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
198c885658SMatthias Springer #include "mlir/IR/PatternMatch.h"
200aa831e0SKrzysztof Drewniak #include "mlir/Interfaces/FunctionInterfaces.h"
218c885658SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h"
228c885658SMatthias Springer #include "mlir/Pass/Pass.h"
238c885658SMatthias Springer 
248c885658SMatthias Springer #define PASS_NAME "test-affine-reify-value-bounds"
258c885658SMatthias Springer 
268c885658SMatthias Springer using namespace mlir;
274c48f016SMatthias Springer using namespace mlir::affine;
28041bc485SMatthias Springer using mlir::presburger::BoundType;
298c885658SMatthias Springer 
308c885658SMatthias Springer namespace {
318c885658SMatthias Springer 
328c885658SMatthias Springer /// This pass applies the permutation on the first maximal perfect nest.
338c885658SMatthias Springer struct TestReifyValueBounds
340aa831e0SKrzysztof Drewniak     : public PassWrapper<TestReifyValueBounds,
350aa831e0SKrzysztof Drewniak                          InterfacePass<FunctionOpInterface>> {
368c885658SMatthias Springer   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds)
378c885658SMatthias Springer 
388c885658SMatthias Springer   StringRef getArgument() const final { return PASS_NAME; }
398c885658SMatthias Springer   StringRef getDescription() const final {
408c885658SMatthias Springer     return "Tests ValueBoundsOpInterface with affine dialect reification";
418c885658SMatthias Springer   }
428c885658SMatthias Springer   TestReifyValueBounds() = default;
438c885658SMatthias Springer   TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){};
448c885658SMatthias Springer 
453fecf9a5SMatthias Springer   void getDependentDialects(DialectRegistry &registry) const override {
464c48f016SMatthias Springer     registry.insert<affine::AffineDialect, tensor::TensorDialect,
474c48f016SMatthias Springer                     memref::MemRefDialect>();
483fecf9a5SMatthias Springer   }
493fecf9a5SMatthias Springer 
508c885658SMatthias Springer   void runOnOperation() override;
518c885658SMatthias Springer 
528c885658SMatthias Springer private:
538c885658SMatthias Springer   Option<bool> reifyToFuncArgs{
548c885658SMatthias Springer       *this, "reify-to-func-args",
558c885658SMatthias Springer       llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)};
56d5b63124SMatthias Springer 
57d5b63124SMatthias Springer   Option<bool> useArithOps{*this, "use-arith-ops",
58d5b63124SMatthias Springer                            llvm::cl::desc("Reify with arith dialect ops"),
59d5b63124SMatthias Springer                            llvm::cl::init(false)};
608c885658SMatthias Springer };
618c885658SMatthias Springer 
628c885658SMatthias Springer } // namespace
638c885658SMatthias Springer 
64297eca98SMatthias Springer static ValueBoundsConstraintSet::ComparisonOperator
65297eca98SMatthias Springer invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
66297eca98SMatthias Springer   if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT)
67297eca98SMatthias Springer     return ValueBoundsConstraintSet::ComparisonOperator::GE;
68297eca98SMatthias Springer   if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LE)
69297eca98SMatthias Springer     return ValueBoundsConstraintSet::ComparisonOperator::GT;
70297eca98SMatthias Springer   if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GT)
71297eca98SMatthias Springer     return ValueBoundsConstraintSet::ComparisonOperator::LE;
72297eca98SMatthias Springer   if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GE)
73297eca98SMatthias Springer     return ValueBoundsConstraintSet::ComparisonOperator::LT;
74297eca98SMatthias Springer   llvm_unreachable("unsupported comparison operator");
75297eca98SMatthias Springer }
76297eca98SMatthias Springer 
778c885658SMatthias Springer /// Look for "test.reify_bound" ops in the input and replace their results with
788c885658SMatthias Springer /// the reified values.
790aa831e0SKrzysztof Drewniak static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp,
80d5b63124SMatthias Springer                                           bool reifyToFuncArgs,
81d5b63124SMatthias Springer                                           bool useArithOps) {
828c885658SMatthias Springer   IRRewriter rewriter(funcOp.getContext());
83f8d314f0SMatthias Springer   WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) {
84f8d314f0SMatthias Springer     auto boundType = op.getBoundType();
85f8d314f0SMatthias Springer     Value value = op.getVar();
86f8d314f0SMatthias Springer     std::optional<int64_t> dim = op.getDim();
87f8d314f0SMatthias Springer     bool constant = op.getConstant();
88f8d314f0SMatthias Springer     bool scalable = op.getScalable();
892861856bSBenjamin Maxwell 
900dc9087aSMatthias Springer     // Prepare stop condition. By default, reify in terms of the op's
910dc9087aSMatthias Springer     // operands. No stop condition is used when a constant was requested.
925e4a4438SMatthias Springer     std::function<bool(Value, std::optional<int64_t>,
935e4a4438SMatthias Springer                        ValueBoundsConstraintSet & cstr)>
945e4a4438SMatthias Springer         stopCondition = [&](Value v, std::optional<int64_t> d,
955e4a4438SMatthias Springer                             ValueBoundsConstraintSet &cstr) {
960dc9087aSMatthias Springer           // Reify in terms of SSA values that are different from `value`.
970dc9087aSMatthias Springer           return v != value;
980dc9087aSMatthias Springer         };
990dc9087aSMatthias Springer     if (reifyToFuncArgs) {
1008c885658SMatthias Springer       // Reify in terms of function block arguments.
1015e4a4438SMatthias Springer       stopCondition = [](Value v, std::optional<int64_t> d,
1025e4a4438SMatthias Springer                          ValueBoundsConstraintSet &cstr) {
1035550c821STres Popp         auto bbArg = dyn_cast<BlockArgument>(v);
1048c885658SMatthias Springer         if (!bbArg)
1058c885658SMatthias Springer           return false;
106f8d314f0SMatthias Springer         return isa<FunctionOpInterface>(bbArg.getParentBlock()->getParentOp());
1078c885658SMatthias Springer       };
1080dc9087aSMatthias Springer     }
1090dc9087aSMatthias Springer 
1100dc9087aSMatthias Springer     // Reify value bound
1110dc9087aSMatthias Springer     rewriter.setInsertionPointAfter(op);
1120dc9087aSMatthias Springer     FailureOr<OpFoldResult> reified = failure();
1130dc9087aSMatthias Springer     if (constant) {
1140dc9087aSMatthias Springer       auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
11540dd3aa9SMatthias Springer           boundType, {value, dim}, /*stopCondition=*/nullptr);
1160dc9087aSMatthias Springer       if (succeeded(reifiedConst))
117f8d314f0SMatthias Springer         reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
1182861856bSBenjamin Maxwell     } else if (scalable) {
1192861856bSBenjamin Maxwell       auto loc = op->getLoc();
1202861856bSBenjamin Maxwell       auto reifiedScalable =
1212861856bSBenjamin Maxwell           vector::ScalableValueBoundsConstraintSet::computeScalableBound(
122f8d314f0SMatthias Springer               value, dim, *op.getVscaleMin(), *op.getVscaleMax(), boundType);
1232861856bSBenjamin Maxwell       if (succeeded(reifiedScalable)) {
124f8d314f0SMatthias Springer         SmallVector<std::pair<Value, std::optional<int64_t>>, 1> vscaleOperand;
1252861856bSBenjamin Maxwell         if (reifiedScalable->map.getNumInputs() == 1) {
1262861856bSBenjamin Maxwell           // The only possible input to the bound is vscale.
1272861856bSBenjamin Maxwell           vscaleOperand.push_back(std::make_pair(
1282861856bSBenjamin Maxwell               rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
1292861856bSBenjamin Maxwell         }
1302861856bSBenjamin Maxwell         reified = affine::materializeComputedBound(
1312861856bSBenjamin Maxwell             rewriter, loc, reifiedScalable->map, vscaleOperand);
1322861856bSBenjamin Maxwell       }
1330dc9087aSMatthias Springer     } else {
134d5b63124SMatthias Springer       if (useArithOps) {
13540dd3aa9SMatthias Springer         reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType,
13640dd3aa9SMatthias Springer                                          op.getVariable(), stopCondition);
137d5b63124SMatthias Springer       } else {
13840dd3aa9SMatthias Springer         reified = reifyValueBound(rewriter, op->getLoc(), boundType,
13940dd3aa9SMatthias Springer                                   op.getVariable(), stopCondition);
1408c885658SMatthias Springer       }
141d5b63124SMatthias Springer     }
1428c885658SMatthias Springer     if (failed(reified)) {
1438c885658SMatthias Springer       op->emitOpError("could not reify bound");
1448c885658SMatthias Springer       return WalkResult::interrupt();
1458c885658SMatthias Springer     }
1468c885658SMatthias Springer 
1478c885658SMatthias Springer     // Replace the op with the reified bound.
14868f58812STres Popp     if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
1498c885658SMatthias Springer       rewriter.replaceOp(op, val);
1508c885658SMatthias Springer       return WalkResult::skip();
1518c885658SMatthias Springer     }
1528c885658SMatthias Springer     Value constOp = rewriter.create<arith::ConstantIndexOp>(
153*35e89897SKazu Hirata         op->getLoc(), cast<IntegerAttr>(cast<Attribute>(*reified)).getInt());
1548c885658SMatthias Springer     rewriter.replaceOp(op, constOp);
1558c885658SMatthias Springer     return WalkResult::skip();
1568c885658SMatthias Springer   });
1578c885658SMatthias Springer   return failure(result.wasInterrupted());
1588c885658SMatthias Springer }
1598c885658SMatthias Springer 
160297eca98SMatthias Springer /// Look for "test.compare" ops and emit errors/remarks.
1610aa831e0SKrzysztof Drewniak static LogicalResult testEquality(FunctionOpInterface funcOp) {
162ff930645SMatthias Springer   IRRewriter rewriter(funcOp.getContext());
163f8d314f0SMatthias Springer   WalkResult result = funcOp.walk([&](test::CompareOp op) {
164f8d314f0SMatthias Springer     auto cmpType = op.getComparisonOperator();
165f8d314f0SMatthias Springer     if (op.getCompose()) {
166297eca98SMatthias Springer       if (cmpType != ValueBoundsConstraintSet::EQ) {
167297eca98SMatthias Springer         op->emitOpError(
168297eca98SMatthias Springer             "comparison operator must be EQ when 'composed' is specified");
169297eca98SMatthias Springer         return WalkResult::interrupt();
170297eca98SMatthias Springer       }
17197c9f9a2SLei Zhang       FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
1723049ac44SLei Zhang           op->getOperand(0), op->getOperand(1));
17397c9f9a2SLei Zhang       if (failed(delta)) {
1743049ac44SLei Zhang         op->emitError("could not determine equality");
17597c9f9a2SLei Zhang       } else if (*delta == 0) {
1763049ac44SLei Zhang         op->emitRemark("equal");
177ebaf8d49SMatthias Springer       } else {
1783049ac44SLei Zhang         op->emitRemark("different");
179ebaf8d49SMatthias Springer       }
180297eca98SMatthias Springer       return WalkResult::advance();
181ff930645SMatthias Springer     }
182297eca98SMatthias Springer 
183297eca98SMatthias Springer     auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
18440dd3aa9SMatthias Springer       return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs());
185297eca98SMatthias Springer     };
186f8d314f0SMatthias Springer     if (compare(cmpType)) {
187297eca98SMatthias Springer       op->emitRemark("true");
188f8d314f0SMatthias Springer     } else if (cmpType != ValueBoundsConstraintSet::EQ &&
189f8d314f0SMatthias Springer                compare(invertComparisonOperator(cmpType))) {
190297eca98SMatthias Springer       op->emitRemark("false");
191f8d314f0SMatthias Springer     } else if (cmpType == ValueBoundsConstraintSet::EQ &&
192297eca98SMatthias Springer                (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
193297eca98SMatthias Springer                 compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
194297eca98SMatthias Springer       op->emitRemark("false");
195297eca98SMatthias Springer     } else {
196297eca98SMatthias Springer       op->emitError("unknown");
197ff930645SMatthias Springer     }
198ff930645SMatthias Springer     return WalkResult::advance();
199ff930645SMatthias Springer   });
200ff930645SMatthias Springer   return failure(result.wasInterrupted());
201ff930645SMatthias Springer }
202ff930645SMatthias Springer 
2038c885658SMatthias Springer void TestReifyValueBounds::runOnOperation() {
204d5b63124SMatthias Springer   if (failed(
205d5b63124SMatthias Springer           testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps)))
2068c885658SMatthias Springer     signalPassFailure();
207ff930645SMatthias Springer   if (failed(testEquality(getOperation())))
208ff930645SMatthias Springer     signalPassFailure();
2098c885658SMatthias Springer }
2108c885658SMatthias Springer 
2118c885658SMatthias Springer namespace mlir {
2128c885658SMatthias Springer void registerTestAffineReifyValueBoundsPass() {
2138c885658SMatthias Springer   PassRegistration<TestReifyValueBounds>();
2148c885658SMatthias Springer }
2158c885658SMatthias Springer } // namespace mlir
216