1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 13 14 using namespace mlir; 15 16 namespace mlir { 17 namespace arith { 18 namespace { 19 20 struct AddIOpInterface 21 : public ValueBoundsOpInterface::ExternalModel<AddIOpInterface, AddIOp> { 22 void populateBoundsForIndexValue(Operation *op, Value value, 23 ValueBoundsConstraintSet &cstr) const { 24 auto addIOp = cast<AddIOp>(op); 25 assert(value == addIOp.getResult() && "invalid value"); 26 27 // Note: `getExpr` has a side effect: it may add a new column to the 28 // constraint system. The evaluation order of addition operands is 29 // unspecified in C++. To make sure that all compilers produce the exact 30 // same results (that can be FileCheck'd), it is important that `getExpr` 31 // is called first and assigned to temporary variables, and the addition 32 // is performed afterwards. 33 AffineExpr lhs = cstr.getExpr(addIOp.getLhs()); 34 AffineExpr rhs = cstr.getExpr(addIOp.getRhs()); 35 cstr.bound(value) == lhs + rhs; 36 } 37 }; 38 39 struct ConstantOpInterface 40 : public ValueBoundsOpInterface::ExternalModel<ConstantOpInterface, 41 ConstantOp> { 42 void populateBoundsForIndexValue(Operation *op, Value value, 43 ValueBoundsConstraintSet &cstr) const { 44 auto constantOp = cast<ConstantOp>(op); 45 assert(value == constantOp.getResult() && "invalid value"); 46 47 if (auto attr = llvm::dyn_cast<IntegerAttr>(constantOp.getValue())) 48 cstr.bound(value) == attr.getInt(); 49 } 50 }; 51 52 struct SubIOpInterface 53 : public ValueBoundsOpInterface::ExternalModel<SubIOpInterface, SubIOp> { 54 void populateBoundsForIndexValue(Operation *op, Value value, 55 ValueBoundsConstraintSet &cstr) const { 56 auto subIOp = cast<SubIOp>(op); 57 assert(value == subIOp.getResult() && "invalid value"); 58 59 AffineExpr lhs = cstr.getExpr(subIOp.getLhs()); 60 AffineExpr rhs = cstr.getExpr(subIOp.getRhs()); 61 cstr.bound(value) == lhs - rhs; 62 } 63 }; 64 65 struct MulIOpInterface 66 : public ValueBoundsOpInterface::ExternalModel<MulIOpInterface, MulIOp> { 67 void populateBoundsForIndexValue(Operation *op, Value value, 68 ValueBoundsConstraintSet &cstr) const { 69 auto mulIOp = cast<MulIOp>(op); 70 assert(value == mulIOp.getResult() && "invalid value"); 71 72 AffineExpr lhs = cstr.getExpr(mulIOp.getLhs()); 73 AffineExpr rhs = cstr.getExpr(mulIOp.getRhs()); 74 cstr.bound(value) == lhs *rhs; 75 } 76 }; 77 78 struct SelectOpInterface 79 : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface, 80 SelectOp> { 81 82 static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim, 83 ValueBoundsConstraintSet &cstr) { 84 Value value = selectOp.getResult(); 85 Value condition = selectOp.getCondition(); 86 Value trueValue = selectOp.getTrueValue(); 87 Value falseValue = selectOp.getFalseValue(); 88 89 if (isa<ShapedType>(condition.getType())) { 90 // If the condition is a shaped type, the condition is applied 91 // element-wise. All three operands must have the same shape. 92 cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim); 93 cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim); 94 cstr.bound(value)[*dim] == cstr.getExpr(condition, dim); 95 return; 96 } 97 98 // Populate constraints for the true/false values (and all values on the 99 // backward slice, as long as the current stop condition is not satisfied). 100 cstr.populateConstraints(trueValue, dim); 101 cstr.populateConstraints(falseValue, dim); 102 auto boundsBuilder = cstr.bound(value); 103 if (dim) 104 boundsBuilder[*dim]; 105 106 // Compare yielded values. 107 // If trueValue <= falseValue: 108 // * result <= falseValue 109 // * result >= trueValue 110 if (cstr.populateAndCompare( 111 /*lhs=*/{trueValue, dim}, 112 ValueBoundsConstraintSet::ComparisonOperator::LE, 113 /*rhs=*/{falseValue, dim})) { 114 if (dim) { 115 cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim); 116 cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim); 117 } else { 118 cstr.bound(value) >= trueValue; 119 cstr.bound(value) <= falseValue; 120 } 121 } 122 // If falseValue <= trueValue: 123 // * result <= trueValue 124 // * result >= falseValue 125 if (cstr.populateAndCompare( 126 /*lhs=*/{falseValue, dim}, 127 ValueBoundsConstraintSet::ComparisonOperator::LE, 128 /*rhs=*/{trueValue, dim})) { 129 if (dim) { 130 cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim); 131 cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim); 132 } else { 133 cstr.bound(value) >= falseValue; 134 cstr.bound(value) <= trueValue; 135 } 136 } 137 } 138 139 void populateBoundsForIndexValue(Operation *op, Value value, 140 ValueBoundsConstraintSet &cstr) const { 141 populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr); 142 } 143 144 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, 145 ValueBoundsConstraintSet &cstr) const { 146 populateBounds(cast<SelectOp>(op), dim, cstr); 147 } 148 }; 149 } // namespace 150 } // namespace arith 151 } // namespace mlir 152 153 void mlir::arith::registerValueBoundsOpInterfaceExternalModels( 154 DialectRegistry ®istry) { 155 registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { 156 arith::AddIOp::attachInterface<arith::AddIOpInterface>(*ctx); 157 arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx); 158 arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx); 159 arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx); 160 arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx); 161 }); 162 } 163