1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/SCF/IR/SCF.h" 12 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 13 14 using namespace mlir; 15 16 namespace mlir { 17 namespace scf { 18 namespace { 19 20 struct ForOpInterface 21 : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> { 22 23 /// Populate bounds of values/dimensions for iter_args/OpResults. If the 24 /// value/dimension size does not change in an iteration, we can deduce that 25 /// it the same as the initial value/dimension. 26 /// 27 /// Example 1: 28 /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { 29 /// ... 30 /// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32> 31 /// scf.yield %1 : tensor<?xf32> 32 /// } 33 /// --> bound(%0)[0] == bound(%t)[0] 34 /// --> bound(%arg0)[0] == bound(%t)[0] 35 /// 36 /// Example 2: 37 /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { 38 /// %sz = tensor.dim %arg0 : tensor<?xf32> 39 /// %incr = arith.addi %sz, %c1 : index 40 /// %1 = tensor.empty(%incr) : tensor<?xf32> 41 /// scf.yield %1 : tensor<?xf32> 42 /// } 43 /// --> The yielded tensor dimension size changes with each iteration. Such 44 /// loops are not supported and no constraints are added. 45 static void populateIterArgBounds(scf::ForOp forOp, Value value, 46 std::optional<int64_t> dim, 47 ValueBoundsConstraintSet &cstr) { 48 // `value` is an iter_arg or an OpResult. 49 int64_t iterArgIdx; 50 if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) { 51 iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars(); 52 } else { 53 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber(); 54 } 55 56 Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator()) 57 .getOperand(iterArgIdx); 58 Value iterArg = forOp.getRegionIterArg(iterArgIdx); 59 Value initArg = forOp.getInitArgs()[iterArgIdx]; 60 61 // An EQ constraint can be added if the yielded value (dimension size) 62 // equals the corresponding block argument (dimension size). 63 if (cstr.populateAndCompare( 64 /*lhs=*/{yieldedValue, dim}, 65 ValueBoundsConstraintSet::ComparisonOperator::EQ, 66 /*rhs=*/{iterArg, dim})) { 67 if (dim.has_value()) { 68 cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim); 69 } else { 70 cstr.bound(value) == cstr.getExpr(initArg); 71 } 72 } 73 74 if (dim.has_value() || isa<BlockArgument>(value)) 75 return; 76 77 // `value` is result of `forOp`, we can prove that: 78 // %result == %init_arg + trip_count * (%yielded_value - %iter_arg). 79 // Where trip_count is (ub - lb) / step. 80 AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound()); 81 AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound()); 82 AffineExpr stepExpr = cstr.getExpr(forOp.getStep()); 83 AffineExpr tripCountExpr = 84 AffineExpr(ubExpr - lbExpr).ceilDiv(stepExpr); // (ub - lb) / step 85 AffineExpr oneIterAdvanceExpr = 86 cstr.getExpr(yieldedValue) - cstr.getExpr(iterArg); 87 cstr.bound(value) == 88 cstr.getExpr(initArg) + AffineExpr(tripCountExpr * oneIterAdvanceExpr); 89 } 90 91 void populateBoundsForIndexValue(Operation *op, Value value, 92 ValueBoundsConstraintSet &cstr) const { 93 auto forOp = cast<ForOp>(op); 94 95 if (value == forOp.getInductionVar()) { 96 // TODO: Take into account step size. 97 cstr.bound(value) >= forOp.getLowerBound(); 98 cstr.bound(value) < forOp.getUpperBound(); 99 return; 100 } 101 102 // Handle iter_args and OpResults. 103 populateIterArgBounds(forOp, value, std::nullopt, cstr); 104 } 105 106 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 107 ValueBoundsConstraintSet &cstr) const { 108 auto forOp = cast<ForOp>(op); 109 // Handle iter_args and OpResults. 110 populateIterArgBounds(forOp, value, dim, cstr); 111 } 112 }; 113 114 struct ForallOpInterface 115 : public ValueBoundsOpInterface::ExternalModel<ForallOpInterface, 116 ForallOp> { 117 118 void populateBoundsForIndexValue(Operation *op, Value value, 119 ValueBoundsConstraintSet &cstr) const { 120 auto forallOp = cast<ForallOp>(op); 121 122 // Index values should be induction variables, since the semantics of 123 // tensor::ParallelInsertSliceOp requires forall outputs to be ranked 124 // tensors. 125 auto blockArg = cast<BlockArgument>(value); 126 assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() && 127 "expected index value to be an induction var"); 128 int64_t idx = blockArg.getArgNumber(); 129 // TODO: Take into account step size. 130 AffineExpr lb = cstr.getExpr(forallOp.getMixedLowerBound()[idx]); 131 AffineExpr ub = cstr.getExpr(forallOp.getMixedUpperBound()[idx]); 132 cstr.bound(value) >= lb; 133 cstr.bound(value) < ub; 134 } 135 136 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 137 ValueBoundsConstraintSet &cstr) const { 138 auto forallOp = cast<ForallOp>(op); 139 140 // `value` is an iter_arg or an OpResult. 141 int64_t iterArgIdx; 142 if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) { 143 iterArgIdx = iterArg.getArgNumber() - forallOp.getInductionVars().size(); 144 } else { 145 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber(); 146 } 147 148 // The forall results and output arguments have the same sizes as the output 149 // operands. 150 Value outputOperand = forallOp.getOutputs()[iterArgIdx]; 151 cstr.bound(value)[dim] == cstr.getExpr(outputOperand, dim); 152 } 153 }; 154 155 struct IfOpInterface 156 : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> { 157 158 static void populateBounds(scf::IfOp ifOp, Value value, 159 std::optional<int64_t> dim, 160 ValueBoundsConstraintSet &cstr) { 161 unsigned int resultNum = cast<OpResult>(value).getResultNumber(); 162 Value thenValue = ifOp.thenYield().getResults()[resultNum]; 163 Value elseValue = ifOp.elseYield().getResults()[resultNum]; 164 165 auto boundsBuilder = cstr.bound(value); 166 if (dim) 167 boundsBuilder[*dim]; 168 169 // Compare yielded values. 170 // If thenValue <= elseValue: 171 // * result <= elseValue 172 // * result >= thenValue 173 if (cstr.populateAndCompare( 174 /*lhs=*/{thenValue, dim}, 175 ValueBoundsConstraintSet::ComparisonOperator::LE, 176 /*rhs=*/{elseValue, dim})) { 177 if (dim) { 178 cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim); 179 cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim); 180 } else { 181 cstr.bound(value) >= thenValue; 182 cstr.bound(value) <= elseValue; 183 } 184 } 185 // If elseValue <= thenValue: 186 // * result <= thenValue 187 // * result >= elseValue 188 if (cstr.populateAndCompare( 189 /*lhs=*/{elseValue, dim}, 190 ValueBoundsConstraintSet::ComparisonOperator::LE, 191 /*rhs=*/{thenValue, dim})) { 192 if (dim) { 193 cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim); 194 cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim); 195 } else { 196 cstr.bound(value) >= elseValue; 197 cstr.bound(value) <= thenValue; 198 } 199 } 200 } 201 202 void populateBoundsForIndexValue(Operation *op, Value value, 203 ValueBoundsConstraintSet &cstr) const { 204 populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr); 205 } 206 207 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 208 ValueBoundsConstraintSet &cstr) const { 209 populateBounds(cast<IfOp>(op), value, dim, cstr); 210 } 211 }; 212 213 } // namespace 214 } // namespace scf 215 } // namespace mlir 216 217 void mlir::scf::registerValueBoundsOpInterfaceExternalModels( 218 DialectRegistry ®istry) { 219 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { 220 scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx); 221 scf::ForallOp::attachInterface<scf::ForallOpInterface>(*ctx); 222 scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx); 223 }); 224 } 225