xref: /llvm-project/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (revision 35e89897a4086f5adbab10b4b90aa63ef5b35514)
1 //===- TestReifyValueBounds.cpp - Test value bounds reification -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
13 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
14 #include "mlir/Dialect/Arith/Transforms/Transforms.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Interfaces/FunctionInterfaces.h"
21 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
22 #include "mlir/Pass/Pass.h"
23 
24 #define PASS_NAME "test-affine-reify-value-bounds"
25 
26 using namespace mlir;
27 using namespace mlir::affine;
28 using mlir::presburger::BoundType;
29 
30 namespace {
31 
32 /// This pass applies the permutation on the first maximal perfect nest.
33 struct TestReifyValueBounds
34     : public PassWrapper<TestReifyValueBounds,
35                          InterfacePass<FunctionOpInterface>> {
36   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds)
37 
38   StringRef getArgument() const final { return PASS_NAME; }
39   StringRef getDescription() const final {
40     return "Tests ValueBoundsOpInterface with affine dialect reification";
41   }
42   TestReifyValueBounds() = default;
43   TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){};
44 
45   void getDependentDialects(DialectRegistry &registry) const override {
46     registry.insert<affine::AffineDialect, tensor::TensorDialect,
47                     memref::MemRefDialect>();
48   }
49 
50   void runOnOperation() override;
51 
52 private:
53   Option<bool> reifyToFuncArgs{
54       *this, "reify-to-func-args",
55       llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)};
56 
57   Option<bool> useArithOps{*this, "use-arith-ops",
58                            llvm::cl::desc("Reify with arith dialect ops"),
59                            llvm::cl::init(false)};
60 };
61 
62 } // namespace
63 
64 static ValueBoundsConstraintSet::ComparisonOperator
65 invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
66   if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT)
67     return ValueBoundsConstraintSet::ComparisonOperator::GE;
68   if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LE)
69     return ValueBoundsConstraintSet::ComparisonOperator::GT;
70   if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GT)
71     return ValueBoundsConstraintSet::ComparisonOperator::LE;
72   if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GE)
73     return ValueBoundsConstraintSet::ComparisonOperator::LT;
74   llvm_unreachable("unsupported comparison operator");
75 }
76 
77 /// Look for "test.reify_bound" ops in the input and replace their results with
78 /// the reified values.
79 static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp,
80                                           bool reifyToFuncArgs,
81                                           bool useArithOps) {
82   IRRewriter rewriter(funcOp.getContext());
83   WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) {
84     auto boundType = op.getBoundType();
85     Value value = op.getVar();
86     std::optional<int64_t> dim = op.getDim();
87     bool constant = op.getConstant();
88     bool scalable = op.getScalable();
89 
90     // Prepare stop condition. By default, reify in terms of the op's
91     // operands. No stop condition is used when a constant was requested.
92     std::function<bool(Value, std::optional<int64_t>,
93                        ValueBoundsConstraintSet & cstr)>
94         stopCondition = [&](Value v, std::optional<int64_t> d,
95                             ValueBoundsConstraintSet &cstr) {
96           // Reify in terms of SSA values that are different from `value`.
97           return v != value;
98         };
99     if (reifyToFuncArgs) {
100       // Reify in terms of function block arguments.
101       stopCondition = [](Value v, std::optional<int64_t> d,
102                          ValueBoundsConstraintSet &cstr) {
103         auto bbArg = dyn_cast<BlockArgument>(v);
104         if (!bbArg)
105           return false;
106         return isa<FunctionOpInterface>(bbArg.getParentBlock()->getParentOp());
107       };
108     }
109 
110     // Reify value bound
111     rewriter.setInsertionPointAfter(op);
112     FailureOr<OpFoldResult> reified = failure();
113     if (constant) {
114       auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
115           boundType, {value, dim}, /*stopCondition=*/nullptr);
116       if (succeeded(reifiedConst))
117         reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
118     } else if (scalable) {
119       auto loc = op->getLoc();
120       auto reifiedScalable =
121           vector::ScalableValueBoundsConstraintSet::computeScalableBound(
122               value, dim, *op.getVscaleMin(), *op.getVscaleMax(), boundType);
123       if (succeeded(reifiedScalable)) {
124         SmallVector<std::pair<Value, std::optional<int64_t>>, 1> vscaleOperand;
125         if (reifiedScalable->map.getNumInputs() == 1) {
126           // The only possible input to the bound is vscale.
127           vscaleOperand.push_back(std::make_pair(
128               rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
129         }
130         reified = affine::materializeComputedBound(
131             rewriter, loc, reifiedScalable->map, vscaleOperand);
132       }
133     } else {
134       if (useArithOps) {
135         reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType,
136                                          op.getVariable(), stopCondition);
137       } else {
138         reified = reifyValueBound(rewriter, op->getLoc(), boundType,
139                                   op.getVariable(), stopCondition);
140       }
141     }
142     if (failed(reified)) {
143       op->emitOpError("could not reify bound");
144       return WalkResult::interrupt();
145     }
146 
147     // Replace the op with the reified bound.
148     if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
149       rewriter.replaceOp(op, val);
150       return WalkResult::skip();
151     }
152     Value constOp = rewriter.create<arith::ConstantIndexOp>(
153         op->getLoc(), cast<IntegerAttr>(cast<Attribute>(*reified)).getInt());
154     rewriter.replaceOp(op, constOp);
155     return WalkResult::skip();
156   });
157   return failure(result.wasInterrupted());
158 }
159 
160 /// Look for "test.compare" ops and emit errors/remarks.
161 static LogicalResult testEquality(FunctionOpInterface funcOp) {
162   IRRewriter rewriter(funcOp.getContext());
163   WalkResult result = funcOp.walk([&](test::CompareOp op) {
164     auto cmpType = op.getComparisonOperator();
165     if (op.getCompose()) {
166       if (cmpType != ValueBoundsConstraintSet::EQ) {
167         op->emitOpError(
168             "comparison operator must be EQ when 'composed' is specified");
169         return WalkResult::interrupt();
170       }
171       FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
172           op->getOperand(0), op->getOperand(1));
173       if (failed(delta)) {
174         op->emitError("could not determine equality");
175       } else if (*delta == 0) {
176         op->emitRemark("equal");
177       } else {
178         op->emitRemark("different");
179       }
180       return WalkResult::advance();
181     }
182 
183     auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
184       return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs());
185     };
186     if (compare(cmpType)) {
187       op->emitRemark("true");
188     } else if (cmpType != ValueBoundsConstraintSet::EQ &&
189                compare(invertComparisonOperator(cmpType))) {
190       op->emitRemark("false");
191     } else if (cmpType == ValueBoundsConstraintSet::EQ &&
192                (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
193                 compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
194       op->emitRemark("false");
195     } else {
196       op->emitError("unknown");
197     }
198     return WalkResult::advance();
199   });
200   return failure(result.wasInterrupted());
201 }
202 
203 void TestReifyValueBounds::runOnOperation() {
204   if (failed(
205           testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps)))
206     signalPassFailure();
207   if (failed(testEquality(getOperation())))
208     signalPassFailure();
209 }
210 
211 namespace mlir {
212 void registerTestAffineReifyValueBoundsPass() {
213   PassRegistration<TestReifyValueBounds>();
214 }
215 } // namespace mlir
216