xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (revision 9df63b2651b2435c02a7d825953ca2ddc65c778e)
1 //===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/IR/SCF.h"
10 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
11 #include "mlir/Transforms/DialectConversion.h"
12 #include <optional>
13 
14 using namespace mlir;
15 using namespace mlir::scf;
16 
17 namespace {
18 
19 /// Flatten the given value ranges into a single vector of values.
20 static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
21   SmallVector<Value> result;
22   for (const auto &vals : values)
23     llvm::append_range(result, vals);
24   return result;
25 }
26 
27 /// Assert that the given value range contains a single value and return it.
28 static Value getSingleValue(ValueRange values) {
29   assert(values.size() == 1 && "expected single value");
30   return values.front();
31 }
32 
33 // CRTP
34 // A base class that takes care of 1:N type conversion, which maps the converted
35 // op results (computed by the derived class) and materializes 1:N conversion.
36 template <typename SourceOp, typename ConcretePattern>
37 class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
38 public:
39   using OpConversionPattern<SourceOp>::typeConverter;
40   using OpConversionPattern<SourceOp>::OpConversionPattern;
41   using OneToNOpAdaptor =
42       typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
43 
44   //
45   // Derived classes should provide the following method which performs the
46   // actual conversion. It should return std::nullopt upon conversion failure
47   // and return the converted operation upon success.
48   //
49   // std::optional<SourceOp> convertSourceOp(
50   //     SourceOp op, OneToNOpAdaptor adaptor,
51   //     ConversionPatternRewriter &rewriter,
52   //     TypeRange dstTypes) const;
53 
54   LogicalResult
55   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
56                   ConversionPatternRewriter &rewriter) const override {
57     SmallVector<Type> dstTypes;
58     SmallVector<unsigned> offsets;
59     offsets.push_back(0);
60     // Do the type conversion and record the offsets.
61     for (Type type : op.getResultTypes()) {
62       if (failed(typeConverter->convertTypes(type, dstTypes)))
63         return rewriter.notifyMatchFailure(op, "could not convert result type");
64       offsets.push_back(dstTypes.size());
65     }
66 
67     // Calls the actual converter implementation to convert the operation.
68     std::optional<SourceOp> newOp =
69         static_cast<const ConcretePattern *>(this)->convertSourceOp(
70             op, adaptor, rewriter, dstTypes);
71 
72     if (!newOp)
73       return rewriter.notifyMatchFailure(op, "could not convert operation");
74 
75     // Packs the return value.
76     SmallVector<ValueRange> packedRets;
77     for (unsigned i = 1, e = offsets.size(); i < e; i++) {
78       unsigned start = offsets[i - 1], end = offsets[i];
79       unsigned len = end - start;
80       ValueRange mappedValue = newOp->getResults().slice(start, len);
81       packedRets.push_back(mappedValue);
82     }
83 
84     rewriter.replaceOpWithMultiple(op, packedRets);
85     return success();
86   }
87 };
88 
89 class ConvertForOpTypes
90     : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> {
91 public:
92   using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
93 
94   // The callback required by CRTP.
95   std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
96                                        ConversionPatternRewriter &rewriter,
97                                        TypeRange dstTypes) const {
98     // Create a empty new op and inline the regions from the old op.
99     //
100     // This is a little bit tricky. We have two concerns here:
101     //
102     // 1. We cannot update the op in place because the dialect conversion
103     // framework does not track type changes for ops updated in place, so it
104     // won't insert appropriate materializations on the changed result types.
105     // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
106     // to clone the op.
107     //
108     // 2. We need to resue the original region instead of cloning it, otherwise
109     // the dialect conversion framework thinks that we just inserted all the
110     // cloned child ops. But what we want is to "take" the child regions and let
111     // the dialect conversion framework continue recursively into ops inside
112     // those regions (which are already in its worklist; inlining them into the
113     // new op's regions doesn't remove the child ops from the worklist).
114 
115     // convertRegionTypes already takes care of 1:N conversion.
116     if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
117       return std::nullopt;
118 
119     // We can not do clone as the number of result types after conversion
120     // might be different.
121     ForOp newOp = rewriter.create<ForOp>(
122         op.getLoc(), getSingleValue(adaptor.getLowerBound()),
123         getSingleValue(adaptor.getUpperBound()),
124         getSingleValue(adaptor.getStep()),
125         flattenValues(adaptor.getInitArgs()));
126 
127     // Reserve whatever attributes in the original op.
128     newOp->setAttrs(op->getAttrs());
129 
130     // We do not need the empty block created by rewriter.
131     rewriter.eraseBlock(newOp.getBody(0));
132     // Inline the type converted region from the original operation.
133     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
134                                 newOp.getRegion().end());
135 
136     return newOp;
137   }
138 };
139 } // namespace
140 
141 namespace {
142 class ConvertIfOpTypes
143     : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> {
144 public:
145   using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
146 
147   std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
148                                       ConversionPatternRewriter &rewriter,
149                                       TypeRange dstTypes) const {
150 
151     IfOp newOp = rewriter.create<IfOp>(
152         op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
153     newOp->setAttrs(op->getAttrs());
154 
155     // We do not need the empty blocks created by rewriter.
156     rewriter.eraseBlock(newOp.elseBlock());
157     rewriter.eraseBlock(newOp.thenBlock());
158 
159     // Inlines block from the original operation.
160     rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
161                                 newOp.getThenRegion().end());
162     rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
163                                 newOp.getElseRegion().end());
164 
165     return newOp;
166   }
167 };
168 } // namespace
169 
170 namespace {
171 class ConvertWhileOpTypes
172     : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> {
173 public:
174   using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
175 
176   std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
177                                          ConversionPatternRewriter &rewriter,
178                                          TypeRange dstTypes) const {
179     auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
180                                           flattenValues(adaptor.getOperands()));
181 
182     for (auto i : {0u, 1u}) {
183       if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
184         return std::nullopt;
185       auto &dstRegion = newOp.getRegion(i);
186       rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
187     }
188     return newOp;
189   }
190 };
191 } // namespace
192 
193 namespace {
194 // When the result types of a ForOp/IfOp get changed, the operand types of the
195 // corresponding yield op need to be changed. In order to trigger the
196 // appropriate type conversions / materializations, we need a dummy pattern.
197 class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
198 public:
199   using OpConversionPattern::OpConversionPattern;
200   LogicalResult
201   matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
202                   ConversionPatternRewriter &rewriter) const override {
203     rewriter.replaceOpWithNewOp<scf::YieldOp>(
204         op, flattenValues(adaptor.getOperands()));
205     return success();
206   }
207 };
208 } // namespace
209 
210 namespace {
211 class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
212 public:
213   using OpConversionPattern<ConditionOp>::OpConversionPattern;
214   LogicalResult
215   matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
216                   ConversionPatternRewriter &rewriter) const override {
217     rewriter.modifyOpInPlace(
218         op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
219     return success();
220   }
221 };
222 } // namespace
223 
224 void mlir::scf::populateSCFStructuralTypeConversions(
225     const TypeConverter &typeConverter, RewritePatternSet &patterns) {
226   patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
227                ConvertWhileOpTypes, ConvertConditionOpTypes>(
228       typeConverter, patterns.getContext());
229 }
230 
231 void mlir::scf::populateSCFStructuralTypeConversionTarget(
232     const TypeConverter &typeConverter, ConversionTarget &target) {
233   target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
234     return typeConverter.isLegal(op->getResultTypes());
235   });
236   target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
237     // We only have conversions for a subset of ops that use scf.yield
238     // terminators.
239     if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
240       return true;
241     return typeConverter.isLegal(op.getOperandTypes());
242   });
243   target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
244       [&](Operation *op) { return typeConverter.isLegal(op); });
245 }
246 
247 void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
248     const TypeConverter &typeConverter, RewritePatternSet &patterns,
249     ConversionTarget &target) {
250   populateSCFStructuralTypeConversions(typeConverter, patterns);
251   populateSCFStructuralTypeConversionTarget(typeConverter, target);
252 }
253