18c885658SMatthias Springer //===- TestReifyValueBounds.cpp - Test value bounds reification -----------===// 28c885658SMatthias Springer // 38c885658SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48c885658SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 58c885658SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68c885658SMatthias Springer // 78c885658SMatthias Springer //===----------------------------------------------------------------------===// 88c885658SMatthias Springer 9f8d314f0SMatthias Springer #include "TestDialect.h" 10e95e94adSJeff Niu #include "TestOps.h" 118c885658SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 12ebaf8d49SMatthias Springer #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" 138c885658SMatthias Springer #include "mlir/Dialect/Affine/Transforms/Transforms.h" 14d5b63124SMatthias Springer #include "mlir/Dialect/Arith/Transforms/Transforms.h" 158c885658SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h" 16c3f5fd76SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 17c3f5fd76SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 182861856bSBenjamin Maxwell #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" 198c885658SMatthias Springer #include "mlir/IR/PatternMatch.h" 200aa831e0SKrzysztof Drewniak #include "mlir/Interfaces/FunctionInterfaces.h" 218c885658SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h" 228c885658SMatthias Springer #include "mlir/Pass/Pass.h" 238c885658SMatthias Springer 248c885658SMatthias Springer #define PASS_NAME "test-affine-reify-value-bounds" 258c885658SMatthias Springer 268c885658SMatthias Springer using namespace mlir; 274c48f016SMatthias Springer using namespace mlir::affine; 28041bc485SMatthias Springer using mlir::presburger::BoundType; 298c885658SMatthias Springer 308c885658SMatthias Springer namespace { 318c885658SMatthias Springer 328c885658SMatthias Springer /// This pass applies the permutation on the first maximal perfect nest. 338c885658SMatthias Springer struct TestReifyValueBounds 340aa831e0SKrzysztof Drewniak : public PassWrapper<TestReifyValueBounds, 350aa831e0SKrzysztof Drewniak InterfacePass<FunctionOpInterface>> { 368c885658SMatthias Springer MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds) 378c885658SMatthias Springer 388c885658SMatthias Springer StringRef getArgument() const final { return PASS_NAME; } 398c885658SMatthias Springer StringRef getDescription() const final { 408c885658SMatthias Springer return "Tests ValueBoundsOpInterface with affine dialect reification"; 418c885658SMatthias Springer } 428c885658SMatthias Springer TestReifyValueBounds() = default; 438c885658SMatthias Springer TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){}; 448c885658SMatthias Springer 453fecf9a5SMatthias Springer void getDependentDialects(DialectRegistry ®istry) const override { 464c48f016SMatthias Springer registry.insert<affine::AffineDialect, tensor::TensorDialect, 474c48f016SMatthias Springer memref::MemRefDialect>(); 483fecf9a5SMatthias Springer } 493fecf9a5SMatthias Springer 508c885658SMatthias Springer void runOnOperation() override; 518c885658SMatthias Springer 528c885658SMatthias Springer private: 538c885658SMatthias Springer Option<bool> reifyToFuncArgs{ 548c885658SMatthias Springer *this, "reify-to-func-args", 558c885658SMatthias Springer llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)}; 56d5b63124SMatthias Springer 57d5b63124SMatthias Springer Option<bool> useArithOps{*this, "use-arith-ops", 58d5b63124SMatthias Springer llvm::cl::desc("Reify with arith dialect ops"), 59d5b63124SMatthias Springer llvm::cl::init(false)}; 608c885658SMatthias Springer }; 618c885658SMatthias Springer 628c885658SMatthias Springer } // namespace 638c885658SMatthias Springer 64297eca98SMatthias Springer static ValueBoundsConstraintSet::ComparisonOperator 65297eca98SMatthias Springer invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) { 66297eca98SMatthias Springer if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT) 67297eca98SMatthias Springer return ValueBoundsConstraintSet::ComparisonOperator::GE; 68297eca98SMatthias Springer if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LE) 69297eca98SMatthias Springer return ValueBoundsConstraintSet::ComparisonOperator::GT; 70297eca98SMatthias Springer if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GT) 71297eca98SMatthias Springer return ValueBoundsConstraintSet::ComparisonOperator::LE; 72297eca98SMatthias Springer if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GE) 73297eca98SMatthias Springer return ValueBoundsConstraintSet::ComparisonOperator::LT; 74297eca98SMatthias Springer llvm_unreachable("unsupported comparison operator"); 75297eca98SMatthias Springer } 76297eca98SMatthias Springer 778c885658SMatthias Springer /// Look for "test.reify_bound" ops in the input and replace their results with 788c885658SMatthias Springer /// the reified values. 790aa831e0SKrzysztof Drewniak static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp, 80d5b63124SMatthias Springer bool reifyToFuncArgs, 81d5b63124SMatthias Springer bool useArithOps) { 828c885658SMatthias Springer IRRewriter rewriter(funcOp.getContext()); 83f8d314f0SMatthias Springer WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) { 84f8d314f0SMatthias Springer auto boundType = op.getBoundType(); 85f8d314f0SMatthias Springer Value value = op.getVar(); 86f8d314f0SMatthias Springer std::optional<int64_t> dim = op.getDim(); 87f8d314f0SMatthias Springer bool constant = op.getConstant(); 88f8d314f0SMatthias Springer bool scalable = op.getScalable(); 892861856bSBenjamin Maxwell 900dc9087aSMatthias Springer // Prepare stop condition. By default, reify in terms of the op's 910dc9087aSMatthias Springer // operands. No stop condition is used when a constant was requested. 925e4a4438SMatthias Springer std::function<bool(Value, std::optional<int64_t>, 935e4a4438SMatthias Springer ValueBoundsConstraintSet & cstr)> 945e4a4438SMatthias Springer stopCondition = [&](Value v, std::optional<int64_t> d, 955e4a4438SMatthias Springer ValueBoundsConstraintSet &cstr) { 960dc9087aSMatthias Springer // Reify in terms of SSA values that are different from `value`. 970dc9087aSMatthias Springer return v != value; 980dc9087aSMatthias Springer }; 990dc9087aSMatthias Springer if (reifyToFuncArgs) { 1008c885658SMatthias Springer // Reify in terms of function block arguments. 1015e4a4438SMatthias Springer stopCondition = [](Value v, std::optional<int64_t> d, 1025e4a4438SMatthias Springer ValueBoundsConstraintSet &cstr) { 1035550c821STres Popp auto bbArg = dyn_cast<BlockArgument>(v); 1048c885658SMatthias Springer if (!bbArg) 1058c885658SMatthias Springer return false; 106f8d314f0SMatthias Springer return isa<FunctionOpInterface>(bbArg.getParentBlock()->getParentOp()); 1078c885658SMatthias Springer }; 1080dc9087aSMatthias Springer } 1090dc9087aSMatthias Springer 1100dc9087aSMatthias Springer // Reify value bound 1110dc9087aSMatthias Springer rewriter.setInsertionPointAfter(op); 1120dc9087aSMatthias Springer FailureOr<OpFoldResult> reified = failure(); 1130dc9087aSMatthias Springer if (constant) { 1140dc9087aSMatthias Springer auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound( 11540dd3aa9SMatthias Springer boundType, {value, dim}, /*stopCondition=*/nullptr); 1160dc9087aSMatthias Springer if (succeeded(reifiedConst)) 117f8d314f0SMatthias Springer reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst)); 1182861856bSBenjamin Maxwell } else if (scalable) { 1192861856bSBenjamin Maxwell auto loc = op->getLoc(); 1202861856bSBenjamin Maxwell auto reifiedScalable = 1212861856bSBenjamin Maxwell vector::ScalableValueBoundsConstraintSet::computeScalableBound( 122f8d314f0SMatthias Springer value, dim, *op.getVscaleMin(), *op.getVscaleMax(), boundType); 1232861856bSBenjamin Maxwell if (succeeded(reifiedScalable)) { 124f8d314f0SMatthias Springer SmallVector<std::pair<Value, std::optional<int64_t>>, 1> vscaleOperand; 1252861856bSBenjamin Maxwell if (reifiedScalable->map.getNumInputs() == 1) { 1262861856bSBenjamin Maxwell // The only possible input to the bound is vscale. 1272861856bSBenjamin Maxwell vscaleOperand.push_back(std::make_pair( 1282861856bSBenjamin Maxwell rewriter.create<vector::VectorScaleOp>(loc), std::nullopt)); 1292861856bSBenjamin Maxwell } 1302861856bSBenjamin Maxwell reified = affine::materializeComputedBound( 1312861856bSBenjamin Maxwell rewriter, loc, reifiedScalable->map, vscaleOperand); 1322861856bSBenjamin Maxwell } 1330dc9087aSMatthias Springer } else { 134d5b63124SMatthias Springer if (useArithOps) { 13540dd3aa9SMatthias Springer reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType, 13640dd3aa9SMatthias Springer op.getVariable(), stopCondition); 137d5b63124SMatthias Springer } else { 13840dd3aa9SMatthias Springer reified = reifyValueBound(rewriter, op->getLoc(), boundType, 13940dd3aa9SMatthias Springer op.getVariable(), stopCondition); 1408c885658SMatthias Springer } 141d5b63124SMatthias Springer } 1428c885658SMatthias Springer if (failed(reified)) { 1438c885658SMatthias Springer op->emitOpError("could not reify bound"); 1448c885658SMatthias Springer return WalkResult::interrupt(); 1458c885658SMatthias Springer } 1468c885658SMatthias Springer 1478c885658SMatthias Springer // Replace the op with the reified bound. 14868f58812STres Popp if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) { 1498c885658SMatthias Springer rewriter.replaceOp(op, val); 1508c885658SMatthias Springer return WalkResult::skip(); 1518c885658SMatthias Springer } 1528c885658SMatthias Springer Value constOp = rewriter.create<arith::ConstantIndexOp>( 153*35e89897SKazu Hirata op->getLoc(), cast<IntegerAttr>(cast<Attribute>(*reified)).getInt()); 1548c885658SMatthias Springer rewriter.replaceOp(op, constOp); 1558c885658SMatthias Springer return WalkResult::skip(); 1568c885658SMatthias Springer }); 1578c885658SMatthias Springer return failure(result.wasInterrupted()); 1588c885658SMatthias Springer } 1598c885658SMatthias Springer 160297eca98SMatthias Springer /// Look for "test.compare" ops and emit errors/remarks. 1610aa831e0SKrzysztof Drewniak static LogicalResult testEquality(FunctionOpInterface funcOp) { 162ff930645SMatthias Springer IRRewriter rewriter(funcOp.getContext()); 163f8d314f0SMatthias Springer WalkResult result = funcOp.walk([&](test::CompareOp op) { 164f8d314f0SMatthias Springer auto cmpType = op.getComparisonOperator(); 165f8d314f0SMatthias Springer if (op.getCompose()) { 166297eca98SMatthias Springer if (cmpType != ValueBoundsConstraintSet::EQ) { 167297eca98SMatthias Springer op->emitOpError( 168297eca98SMatthias Springer "comparison operator must be EQ when 'composed' is specified"); 169297eca98SMatthias Springer return WalkResult::interrupt(); 170297eca98SMatthias Springer } 17197c9f9a2SLei Zhang FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta( 1723049ac44SLei Zhang op->getOperand(0), op->getOperand(1)); 17397c9f9a2SLei Zhang if (failed(delta)) { 1743049ac44SLei Zhang op->emitError("could not determine equality"); 17597c9f9a2SLei Zhang } else if (*delta == 0) { 1763049ac44SLei Zhang op->emitRemark("equal"); 177ebaf8d49SMatthias Springer } else { 1783049ac44SLei Zhang op->emitRemark("different"); 179ebaf8d49SMatthias Springer } 180297eca98SMatthias Springer return WalkResult::advance(); 181ff930645SMatthias Springer } 182297eca98SMatthias Springer 183297eca98SMatthias Springer auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) { 18440dd3aa9SMatthias Springer return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs()); 185297eca98SMatthias Springer }; 186f8d314f0SMatthias Springer if (compare(cmpType)) { 187297eca98SMatthias Springer op->emitRemark("true"); 188f8d314f0SMatthias Springer } else if (cmpType != ValueBoundsConstraintSet::EQ && 189f8d314f0SMatthias Springer compare(invertComparisonOperator(cmpType))) { 190297eca98SMatthias Springer op->emitRemark("false"); 191f8d314f0SMatthias Springer } else if (cmpType == ValueBoundsConstraintSet::EQ && 192297eca98SMatthias Springer (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) || 193297eca98SMatthias Springer compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) { 194297eca98SMatthias Springer op->emitRemark("false"); 195297eca98SMatthias Springer } else { 196297eca98SMatthias Springer op->emitError("unknown"); 197ff930645SMatthias Springer } 198ff930645SMatthias Springer return WalkResult::advance(); 199ff930645SMatthias Springer }); 200ff930645SMatthias Springer return failure(result.wasInterrupted()); 201ff930645SMatthias Springer } 202ff930645SMatthias Springer 2038c885658SMatthias Springer void TestReifyValueBounds::runOnOperation() { 204d5b63124SMatthias Springer if (failed( 205d5b63124SMatthias Springer testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps))) 2068c885658SMatthias Springer signalPassFailure(); 207ff930645SMatthias Springer if (failed(testEquality(getOperation()))) 208ff930645SMatthias Springer signalPassFailure(); 2098c885658SMatthias Springer } 2108c885658SMatthias Springer 2118c885658SMatthias Springer namespace mlir { 2128c885658SMatthias Springer void registerTestAffineReifyValueBoundsPass() { 2138c885658SMatthias Springer PassRegistration<TestReifyValueBounds>(); 2148c885658SMatthias Springer } 2158c885658SMatthias Springer } // namespace mlir 216