xref: /llvm-project/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
10 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
11 #include "mlir/Dialect/Affine/Analysis/Utils.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
14 #include "mlir/Dialect/Affine/LoopUtils.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 
19 using namespace mlir;
20 using namespace mlir::affine;
21 using namespace mlir::transform;
22 
23 //===----------------------------------------------------------------------===//
24 // SimplifyBoundedAffineOpsOp
25 //===----------------------------------------------------------------------===//
26 
27 LogicalResult SimplifyBoundedAffineOpsOp::verify() {
28   if (getLowerBounds().size() != getBoundedValues().size())
29     return emitOpError() << "incorrect number of lower bounds, expected "
30                          << getBoundedValues().size() << " but found "
31                          << getLowerBounds().size();
32   if (getUpperBounds().size() != getBoundedValues().size())
33     return emitOpError() << "incorrect number of upper bounds, expected "
34                          << getBoundedValues().size() << " but found "
35                          << getUpperBounds().size();
36   return success();
37 }
38 
39 namespace {
40 /// Simplify affine.min / affine.max ops with the given constraints. They are
41 /// either rewritten to affine.apply or left unchanged.
42 template <typename OpTy>
43 struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
44   using OpRewritePattern<OpTy>::OpRewritePattern;
45   SimplifyAffineMinMaxOp(MLIRContext *ctx,
46                          const FlatAffineValueConstraints &constraints,
47                          PatternBenefit benefit = 1)
48       : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
49 
50   LogicalResult matchAndRewrite(OpTy op,
51                                 PatternRewriter &rewriter) const override {
52     FailureOr<AffineValueMap> simplified =
53         simplifyConstrainedMinMaxOp(op, constraints);
54     if (failed(simplified))
55       return failure();
56     rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
57                                                simplified->getOperands());
58     return success();
59   }
60 
61   const FlatAffineValueConstraints &constraints;
62 };
63 } // namespace
64 
65 DiagnosedSilenceableFailure
66 SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
67                                   TransformResults &results,
68                                   TransformState &state) {
69   // Get constraints for bounded values.
70   SmallVector<int64_t> lbs;
71   SmallVector<int64_t> ubs;
72   SmallVector<Value> boundedValues;
73   DenseSet<Operation *> boundedOps;
74   for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
75                                         getUpperBounds())) {
76     Value handle = std::get<0>(it);
77     for (Operation *op : state.getPayloadOps(handle)) {
78       if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
79         auto diag =
80             emitDefiniteFailure()
81             << "expected bounded value handle to point to one or multiple "
82                "single-result index-typed ops";
83         diag.attachNote(op->getLoc()) << "multiple/non-index result";
84         return diag;
85       }
86       boundedValues.push_back(op->getResult(0));
87       boundedOps.insert(op);
88       lbs.push_back(std::get<1>(it));
89       ubs.push_back(std::get<2>(it));
90     }
91   }
92 
93   // Build constraint set.
94   FlatAffineValueConstraints cstr;
95   for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
96     unsigned pos;
97     if (!cstr.findVar(std::get<0>(it), &pos))
98       pos = cstr.appendSymbolVar(std::get<0>(it));
99     cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it));
100     // Note: addBound bounds are inclusive, but specified UB is exclusive.
101     cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1);
102   }
103 
104   // Transform all targets.
105   SmallVector<Operation *> targets;
106   for (Operation *target : state.getPayloadOps(getTarget())) {
107     if (!isa<AffineMinOp, AffineMaxOp>(target)) {
108       auto diag = emitDefiniteFailure()
109                   << "target must be affine.min or affine.max";
110       diag.attachNote(target->getLoc()) << "target op";
111       return diag;
112     }
113     if (boundedOps.contains(target)) {
114       auto diag = emitDefiniteFailure()
115                   << "target op result must not be constrainted";
116       diag.attachNote(target->getLoc()) << "target/constrained op";
117       return diag;
118     }
119     targets.push_back(target);
120   }
121   SmallVector<Operation *> transformed;
122   RewritePatternSet patterns(getContext());
123   // Canonicalization patterns are needed so that affine.apply ops are composed
124   // with the remaining affine.min/max ops.
125   AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
126   AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
127   patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
128                   SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
129   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
130   GreedyRewriteConfig config;
131   config.listener =
132       static_cast<RewriterBase::Listener *>(rewriter.getListener());
133   config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
134   // Apply the simplification pattern to a fixpoint.
135   if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
136     auto diag = emitDefiniteFailure()
137                 << "affine.min/max simplification did not converge";
138     return diag;
139   }
140   return DiagnosedSilenceableFailure::success();
141 }
142 
143 void SimplifyBoundedAffineOpsOp::getEffects(
144     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
145   consumesHandle(getTargetMutable(), effects);
146   for (OpOperand &operand : getBoundedValuesMutable())
147     onlyReadsHandle(operand, effects);
148   modifiesPayload(effects);
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // Transform op registration
153 //===----------------------------------------------------------------------===//
154 
155 namespace {
156 class AffineTransformDialectExtension
157     : public transform::TransformDialectExtension<
158           AffineTransformDialectExtension> {
159 public:
160   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
161 
162   using Base::Base;
163 
164   void init() {
165     declareGeneratedDialect<AffineDialect>();
166 
167     registerTransformOps<
168 #define GET_OP_LIST
169 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
170         >();
171   }
172 };
173 } // namespace
174 
175 #define GET_OP_CLASSES
176 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
177 
178 void mlir::affine::registerTransformDialectExtension(
179     DialectRegistry &registry) {
180   registry.addExtensions<AffineTransformDialectExtension>();
181 }
182