xref: /llvm-project/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp (revision d622b66a820a0e5e61c131e9ae5b4db35292aa14)
1 //===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.if ops into emitc ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/EmitC/IR/EmitC.h"
17 #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/Passes.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_SCFTOEMITC
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::scf;
34 
35 namespace {
36 
37 struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
38   void runOnOperation() override;
39 };
40 
41 // Lower scf::for to emitc::for, implementing result values using
42 // emitc::variable's updated within the loop body.
43 struct ForLowering : public OpConversionPattern<ForOp> {
44   using OpConversionPattern<ForOp>::OpConversionPattern;
45 
46   LogicalResult
47   matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
48                   ConversionPatternRewriter &rewriter) const override;
49 };
50 
51 // Create an uninitialized emitc::variable op for each result of the given op.
52 template <typename T>
53 static LogicalResult
54 createVariablesForResults(T op, const TypeConverter *typeConverter,
55                           ConversionPatternRewriter &rewriter,
56                           SmallVector<Value> &resultVariables) {
57   if (!op.getNumResults())
58     return success();
59 
60   Location loc = op->getLoc();
61   MLIRContext *context = op.getContext();
62 
63   OpBuilder::InsertionGuard guard(rewriter);
64   rewriter.setInsertionPoint(op);
65 
66   for (OpResult result : op.getResults()) {
67     Type resultType = typeConverter->convertType(result.getType());
68     if (!resultType)
69       return rewriter.notifyMatchFailure(op, "result type conversion failed");
70     Type varType = emitc::LValueType::get(resultType);
71     emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
72     emitc::VariableOp var =
73         rewriter.create<emitc::VariableOp>(loc, varType, noInit);
74     resultVariables.push_back(var);
75   }
76 
77   return success();
78 }
79 
80 // Create a series of assign ops assigning given values to given variables at
81 // the current insertion point of given rewriter.
82 static void assignValues(ValueRange values, ValueRange variables,
83                          ConversionPatternRewriter &rewriter, Location loc) {
84   for (auto [value, var] : llvm::zip(values, variables))
85     rewriter.create<emitc::AssignOp>(loc, var, value);
86 }
87 
88 SmallVector<Value> loadValues(const SmallVector<Value> &variables,
89                               PatternRewriter &rewriter, Location loc) {
90   return llvm::map_to_vector<>(variables, [&](Value var) {
91     Type type = cast<emitc::LValueType>(var.getType()).getValueType();
92     return rewriter.create<emitc::LoadOp>(loc, type, var).getResult();
93   });
94 }
95 
96 static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
97                                 ConversionPatternRewriter &rewriter,
98                                 scf::YieldOp yield) {
99   Location loc = yield.getLoc();
100 
101   OpBuilder::InsertionGuard guard(rewriter);
102   rewriter.setInsertionPoint(yield);
103 
104   SmallVector<Value> yieldOperands;
105   if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
106     return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
107   }
108 
109   assignValues(yieldOperands, resultVariables, rewriter, loc);
110 
111   rewriter.create<emitc::YieldOp>(loc);
112   rewriter.eraseOp(yield);
113 
114   return success();
115 }
116 
117 // Lower the contents of an scf::if/scf::index_switch regions to an
118 // emitc::if/emitc::switch region. The contents of the lowering region is
119 // moved into the respective lowered region, but the scf::yield is replaced not
120 // only with an emitc::yield, but also with a sequence of emitc::assign ops that
121 // set the yielded values into the result variables.
122 static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
123                                  ConversionPatternRewriter &rewriter,
124                                  Region &region, Region &loweredRegion) {
125   rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
126   Operation *terminator = loweredRegion.back().getTerminator();
127   return lowerYield(op, resultVariables, rewriter,
128                     cast<scf::YieldOp>(terminator));
129 }
130 
131 LogicalResult
132 ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
133                              ConversionPatternRewriter &rewriter) const {
134   Location loc = forOp.getLoc();
135 
136   // Create an emitc::variable op for each result. These variables will be
137   // assigned to by emitc::assign ops within the loop body.
138   SmallVector<Value> resultVariables;
139   if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
140                                        resultVariables)))
141     return rewriter.notifyMatchFailure(forOp,
142                                        "create variables for results failed");
143 
144   assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
145 
146   emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
147       loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
148 
149   Block *loweredBody = loweredFor.getBody();
150 
151   // Erase the auto-generated terminator for the lowered for op.
152   rewriter.eraseOp(loweredBody->getTerminator());
153 
154   IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
155   rewriter.setInsertionPointToEnd(loweredBody);
156 
157   SmallVector<Value> iterArgsValues =
158       loadValues(resultVariables, rewriter, loc);
159 
160   rewriter.restoreInsertionPoint(ip);
161 
162   // Convert the original region types into the new types by adding unrealized
163   // casts in the beginning of the loop. This performs the conversion in place.
164   if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
165                                          *getTypeConverter(), nullptr))) {
166     return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
167   }
168 
169   // Register the replacements for the block arguments and inline the body of
170   // the scf.for loop into the body of the emitc::for loop.
171   Block *scfBody = &(forOp.getRegion().front());
172   SmallVector<Value> replacingValues;
173   replacingValues.push_back(loweredFor.getInductionVar());
174   replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
175   rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
176 
177   auto result = lowerYield(forOp, resultVariables, rewriter,
178                            cast<scf::YieldOp>(loweredBody->getTerminator()));
179 
180   if (failed(result)) {
181     return result;
182   }
183 
184   // Load variables into SSA values after the for loop.
185   SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
186 
187   rewriter.replaceOp(forOp, resultValues);
188   return success();
189 }
190 
191 // Lower scf::if to emitc::if, implementing result values as emitc::variable's
192 // updated within the then and else regions.
193 struct IfLowering : public OpConversionPattern<IfOp> {
194   using OpConversionPattern<IfOp>::OpConversionPattern;
195 
196   LogicalResult
197   matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
198                   ConversionPatternRewriter &rewriter) const override;
199 };
200 
201 } // namespace
202 
203 LogicalResult
204 IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
205                             ConversionPatternRewriter &rewriter) const {
206   Location loc = ifOp.getLoc();
207 
208   // Create an emitc::variable op for each result. These variables will be
209   // assigned to by emitc::assign ops within the then & else regions.
210   SmallVector<Value> resultVariables;
211   if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
212                                        resultVariables)))
213     return rewriter.notifyMatchFailure(ifOp,
214                                        "create variables for results failed");
215 
216   // Utility function to lower the contents of an scf::if region to an emitc::if
217   // region. The contents of the scf::if regions is moved into the respective
218   // emitc::if regions, but the scf::yield is replaced not only with an
219   // emitc::yield, but also with a sequence of emitc::assign ops that set the
220   // yielded values into the result variables.
221   auto lowerRegion = [&resultVariables, &rewriter,
222                       &ifOp](Region &region, Region &loweredRegion) {
223     rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
224     Operation *terminator = loweredRegion.back().getTerminator();
225     auto result = lowerYield(ifOp, resultVariables, rewriter,
226                              cast<scf::YieldOp>(terminator));
227     if (failed(result)) {
228       return result;
229     }
230     return success();
231   };
232 
233   Region &thenRegion = adaptor.getThenRegion();
234   Region &elseRegion = adaptor.getElseRegion();
235 
236   bool hasElseBlock = !elseRegion.empty();
237 
238   auto loweredIf =
239       rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);
240 
241   Region &loweredThenRegion = loweredIf.getThenRegion();
242   auto result = lowerRegion(thenRegion, loweredThenRegion);
243   if (failed(result)) {
244     return result;
245   }
246 
247   if (hasElseBlock) {
248     Region &loweredElseRegion = loweredIf.getElseRegion();
249     auto result = lowerRegion(elseRegion, loweredElseRegion);
250     if (failed(result)) {
251       return result;
252     }
253   }
254 
255   rewriter.setInsertionPointAfter(ifOp);
256   SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
257 
258   rewriter.replaceOp(ifOp, results);
259   return success();
260 }
261 
262 // Lower scf::index_switch to emitc::switch, implementing result values as
263 // emitc::variable's updated within the case and default regions.
264 struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
265   using OpConversionPattern::OpConversionPattern;
266 
267   LogicalResult
268   matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
269                   ConversionPatternRewriter &rewriter) const override;
270 };
271 
272 LogicalResult IndexSwitchOpLowering::matchAndRewrite(
273     IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
274     ConversionPatternRewriter &rewriter) const {
275   Location loc = indexSwitchOp.getLoc();
276 
277   // Create an emitc::variable op for each result. These variables will be
278   // assigned to by emitc::assign ops within the case and default regions.
279   SmallVector<Value> resultVariables;
280   if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
281                                        rewriter, resultVariables))) {
282     return rewriter.notifyMatchFailure(indexSwitchOp,
283                                        "create variables for results failed");
284   }
285 
286   auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
287       loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());
288 
289   // Lowering all case regions.
290   for (auto pair :
291        llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
292     if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
293                            *std::get<0>(pair), std::get<1>(pair)))) {
294       return failure();
295     }
296   }
297 
298   // Lowering default region.
299   if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
300                          adaptor.getDefaultRegion(),
301                          loweredSwitch.getDefaultRegion()))) {
302     return failure();
303   }
304 
305   rewriter.setInsertionPointAfter(indexSwitchOp);
306   SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
307 
308   rewriter.replaceOp(indexSwitchOp, results);
309   return success();
310 }
311 
312 void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
313                                                 TypeConverter &typeConverter) {
314   patterns.add<ForLowering>(typeConverter, patterns.getContext());
315   patterns.add<IfLowering>(typeConverter, patterns.getContext());
316   patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
317 }
318 
319 void SCFToEmitCPass::runOnOperation() {
320   RewritePatternSet patterns(&getContext());
321   TypeConverter typeConverter;
322   // Fallback converter
323   // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
324   // Type converters are called most to least recently inserted
325   typeConverter.addConversion([](Type t) { return t; });
326   populateEmitCSizeTTypeConversions(typeConverter);
327   populateSCFToEmitCConversionPatterns(patterns, typeConverter);
328 
329   // Configure conversion to lower out SCF operations.
330   ConversionTarget target(getContext());
331   target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
332   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
333   if (failed(
334           applyPartialConversion(getOperation(), target, std::move(patterns))))
335     signalPassFailure();
336 }
337