1 //=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" 10 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 11 #include "mlir/Dialect/Affine/Analysis/Utils.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 14 #include "mlir/Dialect/Affine/LoopUtils.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace mlir::affine; 21 using namespace mlir::transform; 22 23 //===----------------------------------------------------------------------===// 24 // SimplifyBoundedAffineOpsOp 25 //===----------------------------------------------------------------------===// 26 27 LogicalResult SimplifyBoundedAffineOpsOp::verify() { 28 if (getLowerBounds().size() != getBoundedValues().size()) 29 return emitOpError() << "incorrect number of lower bounds, expected " 30 << getBoundedValues().size() << " but found " 31 << getLowerBounds().size(); 32 if (getUpperBounds().size() != getBoundedValues().size()) 33 return emitOpError() << "incorrect number of upper bounds, expected " 34 << getBoundedValues().size() << " but found " 35 << getUpperBounds().size(); 36 return success(); 37 } 38 39 namespace { 40 /// Simplify affine.min / affine.max ops with the given constraints. They are 41 /// either rewritten to affine.apply or left unchanged. 42 template <typename OpTy> 43 struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> { 44 using OpRewritePattern<OpTy>::OpRewritePattern; 45 SimplifyAffineMinMaxOp(MLIRContext *ctx, 46 const FlatAffineValueConstraints &constraints, 47 PatternBenefit benefit = 1) 48 : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {} 49 50 LogicalResult matchAndRewrite(OpTy op, 51 PatternRewriter &rewriter) const override { 52 FailureOr<AffineValueMap> simplified = 53 simplifyConstrainedMinMaxOp(op, constraints); 54 if (failed(simplified)) 55 return failure(); 56 rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(), 57 simplified->getOperands()); 58 return success(); 59 } 60 61 const FlatAffineValueConstraints &constraints; 62 }; 63 } // namespace 64 65 DiagnosedSilenceableFailure 66 SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter, 67 TransformResults &results, 68 TransformState &state) { 69 // Get constraints for bounded values. 70 SmallVector<int64_t> lbs; 71 SmallVector<int64_t> ubs; 72 SmallVector<Value> boundedValues; 73 DenseSet<Operation *> boundedOps; 74 for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(), 75 getUpperBounds())) { 76 Value handle = std::get<0>(it); 77 for (Operation *op : state.getPayloadOps(handle)) { 78 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { 79 auto diag = 80 emitDefiniteFailure() 81 << "expected bounded value handle to point to one or multiple " 82 "single-result index-typed ops"; 83 diag.attachNote(op->getLoc()) << "multiple/non-index result"; 84 return diag; 85 } 86 boundedValues.push_back(op->getResult(0)); 87 boundedOps.insert(op); 88 lbs.push_back(std::get<1>(it)); 89 ubs.push_back(std::get<2>(it)); 90 } 91 } 92 93 // Build constraint set. 94 FlatAffineValueConstraints cstr; 95 for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) { 96 unsigned pos; 97 if (!cstr.findVar(std::get<0>(it), &pos)) 98 pos = cstr.appendSymbolVar(std::get<0>(it)); 99 cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it)); 100 // Note: addBound bounds are inclusive, but specified UB is exclusive. 101 cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1); 102 } 103 104 // Transform all targets. 105 SmallVector<Operation *> targets; 106 for (Operation *target : state.getPayloadOps(getTarget())) { 107 if (!isa<AffineMinOp, AffineMaxOp>(target)) { 108 auto diag = emitDefiniteFailure() 109 << "target must be affine.min or affine.max"; 110 diag.attachNote(target->getLoc()) << "target op"; 111 return diag; 112 } 113 if (boundedOps.contains(target)) { 114 auto diag = emitDefiniteFailure() 115 << "target op result must not be constrainted"; 116 diag.attachNote(target->getLoc()) << "target/constrained op"; 117 return diag; 118 } 119 targets.push_back(target); 120 } 121 SmallVector<Operation *> transformed; 122 RewritePatternSet patterns(getContext()); 123 // Canonicalization patterns are needed so that affine.apply ops are composed 124 // with the remaining affine.min/max ops. 125 AffineMaxOp::getCanonicalizationPatterns(patterns, getContext()); 126 AffineMinOp::getCanonicalizationPatterns(patterns, getContext()); 127 patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>, 128 SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr); 129 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 130 GreedyRewriteConfig config; 131 config.listener = 132 static_cast<RewriterBase::Listener *>(rewriter.getListener()); 133 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 134 // Apply the simplification pattern to a fixpoint. 135 if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) { 136 auto diag = emitDefiniteFailure() 137 << "affine.min/max simplification did not converge"; 138 return diag; 139 } 140 return DiagnosedSilenceableFailure::success(); 141 } 142 143 void SimplifyBoundedAffineOpsOp::getEffects( 144 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 145 consumesHandle(getTargetMutable(), effects); 146 for (OpOperand &operand : getBoundedValuesMutable()) 147 onlyReadsHandle(operand, effects); 148 modifiesPayload(effects); 149 } 150 151 //===----------------------------------------------------------------------===// 152 // Transform op registration 153 //===----------------------------------------------------------------------===// 154 155 namespace { 156 class AffineTransformDialectExtension 157 : public transform::TransformDialectExtension< 158 AffineTransformDialectExtension> { 159 public: 160 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension) 161 162 using Base::Base; 163 164 void init() { 165 declareGeneratedDialect<AffineDialect>(); 166 167 registerTransformOps< 168 #define GET_OP_LIST 169 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc" 170 >(); 171 } 172 }; 173 } // namespace 174 175 #define GET_OP_CLASSES 176 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc" 177 178 void mlir::affine::registerTransformDialectExtension( 179 DialectRegistry ®istry) { 180 registry.addExtensions<AffineTransformDialectExtension>(); 181 } 182