1 //===- TestReifyValueBounds.cpp - Test value bounds reification -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestOps.h" 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" 13 #include "mlir/Dialect/Affine/Transforms/Transforms.h" 14 #include "mlir/Dialect/Arith/Transforms/Transforms.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Interfaces/FunctionInterfaces.h" 21 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 22 #include "mlir/Pass/Pass.h" 23 24 #define PASS_NAME "test-affine-reify-value-bounds" 25 26 using namespace mlir; 27 using namespace mlir::affine; 28 using mlir::presburger::BoundType; 29 30 namespace { 31 32 /// This pass applies the permutation on the first maximal perfect nest. 33 struct TestReifyValueBounds 34 : public PassWrapper<TestReifyValueBounds, 35 InterfacePass<FunctionOpInterface>> { 36 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds) 37 38 StringRef getArgument() const final { return PASS_NAME; } 39 StringRef getDescription() const final { 40 return "Tests ValueBoundsOpInterface with affine dialect reification"; 41 } 42 TestReifyValueBounds() = default; 43 TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){}; 44 45 void getDependentDialects(DialectRegistry ®istry) const override { 46 registry.insert<affine::AffineDialect, tensor::TensorDialect, 47 memref::MemRefDialect>(); 48 } 49 50 void runOnOperation() override; 51 52 private: 53 Option<bool> reifyToFuncArgs{ 54 *this, "reify-to-func-args", 55 llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)}; 56 57 Option<bool> useArithOps{*this, "use-arith-ops", 58 llvm::cl::desc("Reify with arith dialect ops"), 59 llvm::cl::init(false)}; 60 }; 61 62 } // namespace 63 64 static ValueBoundsConstraintSet::ComparisonOperator 65 invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) { 66 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT) 67 return ValueBoundsConstraintSet::ComparisonOperator::GE; 68 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LE) 69 return ValueBoundsConstraintSet::ComparisonOperator::GT; 70 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GT) 71 return ValueBoundsConstraintSet::ComparisonOperator::LE; 72 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GE) 73 return ValueBoundsConstraintSet::ComparisonOperator::LT; 74 llvm_unreachable("unsupported comparison operator"); 75 } 76 77 /// Look for "test.reify_bound" ops in the input and replace their results with 78 /// the reified values. 79 static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp, 80 bool reifyToFuncArgs, 81 bool useArithOps) { 82 IRRewriter rewriter(funcOp.getContext()); 83 WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) { 84 auto boundType = op.getBoundType(); 85 Value value = op.getVar(); 86 std::optional<int64_t> dim = op.getDim(); 87 bool constant = op.getConstant(); 88 bool scalable = op.getScalable(); 89 90 // Prepare stop condition. By default, reify in terms of the op's 91 // operands. No stop condition is used when a constant was requested. 92 std::function<bool(Value, std::optional<int64_t>, 93 ValueBoundsConstraintSet & cstr)> 94 stopCondition = [&](Value v, std::optional<int64_t> d, 95 ValueBoundsConstraintSet &cstr) { 96 // Reify in terms of SSA values that are different from `value`. 97 return v != value; 98 }; 99 if (reifyToFuncArgs) { 100 // Reify in terms of function block arguments. 101 stopCondition = [](Value v, std::optional<int64_t> d, 102 ValueBoundsConstraintSet &cstr) { 103 auto bbArg = dyn_cast<BlockArgument>(v); 104 if (!bbArg) 105 return false; 106 return isa<FunctionOpInterface>(bbArg.getParentBlock()->getParentOp()); 107 }; 108 } 109 110 // Reify value bound 111 rewriter.setInsertionPointAfter(op); 112 FailureOr<OpFoldResult> reified = failure(); 113 if (constant) { 114 auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound( 115 boundType, {value, dim}, /*stopCondition=*/nullptr); 116 if (succeeded(reifiedConst)) 117 reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst)); 118 } else if (scalable) { 119 auto loc = op->getLoc(); 120 auto reifiedScalable = 121 vector::ScalableValueBoundsConstraintSet::computeScalableBound( 122 value, dim, *op.getVscaleMin(), *op.getVscaleMax(), boundType); 123 if (succeeded(reifiedScalable)) { 124 SmallVector<std::pair<Value, std::optional<int64_t>>, 1> vscaleOperand; 125 if (reifiedScalable->map.getNumInputs() == 1) { 126 // The only possible input to the bound is vscale. 127 vscaleOperand.push_back(std::make_pair( 128 rewriter.create<vector::VectorScaleOp>(loc), std::nullopt)); 129 } 130 reified = affine::materializeComputedBound( 131 rewriter, loc, reifiedScalable->map, vscaleOperand); 132 } 133 } else { 134 if (useArithOps) { 135 reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType, 136 op.getVariable(), stopCondition); 137 } else { 138 reified = reifyValueBound(rewriter, op->getLoc(), boundType, 139 op.getVariable(), stopCondition); 140 } 141 } 142 if (failed(reified)) { 143 op->emitOpError("could not reify bound"); 144 return WalkResult::interrupt(); 145 } 146 147 // Replace the op with the reified bound. 148 if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) { 149 rewriter.replaceOp(op, val); 150 return WalkResult::skip(); 151 } 152 Value constOp = rewriter.create<arith::ConstantIndexOp>( 153 op->getLoc(), cast<IntegerAttr>(cast<Attribute>(*reified)).getInt()); 154 rewriter.replaceOp(op, constOp); 155 return WalkResult::skip(); 156 }); 157 return failure(result.wasInterrupted()); 158 } 159 160 /// Look for "test.compare" ops and emit errors/remarks. 161 static LogicalResult testEquality(FunctionOpInterface funcOp) { 162 IRRewriter rewriter(funcOp.getContext()); 163 WalkResult result = funcOp.walk([&](test::CompareOp op) { 164 auto cmpType = op.getComparisonOperator(); 165 if (op.getCompose()) { 166 if (cmpType != ValueBoundsConstraintSet::EQ) { 167 op->emitOpError( 168 "comparison operator must be EQ when 'composed' is specified"); 169 return WalkResult::interrupt(); 170 } 171 FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta( 172 op->getOperand(0), op->getOperand(1)); 173 if (failed(delta)) { 174 op->emitError("could not determine equality"); 175 } else if (*delta == 0) { 176 op->emitRemark("equal"); 177 } else { 178 op->emitRemark("different"); 179 } 180 return WalkResult::advance(); 181 } 182 183 auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) { 184 return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs()); 185 }; 186 if (compare(cmpType)) { 187 op->emitRemark("true"); 188 } else if (cmpType != ValueBoundsConstraintSet::EQ && 189 compare(invertComparisonOperator(cmpType))) { 190 op->emitRemark("false"); 191 } else if (cmpType == ValueBoundsConstraintSet::EQ && 192 (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) || 193 compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) { 194 op->emitRemark("false"); 195 } else { 196 op->emitError("unknown"); 197 } 198 return WalkResult::advance(); 199 }); 200 return failure(result.wasInterrupted()); 201 } 202 203 void TestReifyValueBounds::runOnOperation() { 204 if (failed( 205 testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps))) 206 signalPassFailure(); 207 if (failed(testEquality(getOperation()))) 208 signalPassFailure(); 209 } 210 211 namespace mlir { 212 void registerTestAffineReifyValueBoundsPass() { 213 PassRegistration<TestReifyValueBounds>(); 214 } 215 } // namespace mlir 216