1 //===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert scf.if ops into emitc ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/EmitC/IR/EmitC.h" 17 #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" 18 #include "mlir/Dialect/SCF/IR/SCF.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/BuiltinOps.h" 21 #include "mlir/IR/IRMapping.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "mlir/Transforms/Passes.h" 26 27 namespace mlir { 28 #define GEN_PASS_DEF_SCFTOEMITC 29 #include "mlir/Conversion/Passes.h.inc" 30 } // namespace mlir 31 32 using namespace mlir; 33 using namespace mlir::scf; 34 35 namespace { 36 37 struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> { 38 void runOnOperation() override; 39 }; 40 41 // Lower scf::for to emitc::for, implementing result values using 42 // emitc::variable's updated within the loop body. 43 struct ForLowering : public OpConversionPattern<ForOp> { 44 using OpConversionPattern<ForOp>::OpConversionPattern; 45 46 LogicalResult 47 matchAndRewrite(ForOp forOp, OpAdaptor adaptor, 48 ConversionPatternRewriter &rewriter) const override; 49 }; 50 51 // Create an uninitialized emitc::variable op for each result of the given op. 52 template <typename T> 53 static LogicalResult 54 createVariablesForResults(T op, const TypeConverter *typeConverter, 55 ConversionPatternRewriter &rewriter, 56 SmallVector<Value> &resultVariables) { 57 if (!op.getNumResults()) 58 return success(); 59 60 Location loc = op->getLoc(); 61 MLIRContext *context = op.getContext(); 62 63 OpBuilder::InsertionGuard guard(rewriter); 64 rewriter.setInsertionPoint(op); 65 66 for (OpResult result : op.getResults()) { 67 Type resultType = typeConverter->convertType(result.getType()); 68 if (!resultType) 69 return rewriter.notifyMatchFailure(op, "result type conversion failed"); 70 Type varType = emitc::LValueType::get(resultType); 71 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); 72 emitc::VariableOp var = 73 rewriter.create<emitc::VariableOp>(loc, varType, noInit); 74 resultVariables.push_back(var); 75 } 76 77 return success(); 78 } 79 80 // Create a series of assign ops assigning given values to given variables at 81 // the current insertion point of given rewriter. 82 static void assignValues(ValueRange values, ValueRange variables, 83 ConversionPatternRewriter &rewriter, Location loc) { 84 for (auto [value, var] : llvm::zip(values, variables)) 85 rewriter.create<emitc::AssignOp>(loc, var, value); 86 } 87 88 SmallVector<Value> loadValues(const SmallVector<Value> &variables, 89 PatternRewriter &rewriter, Location loc) { 90 return llvm::map_to_vector<>(variables, [&](Value var) { 91 Type type = cast<emitc::LValueType>(var.getType()).getValueType(); 92 return rewriter.create<emitc::LoadOp>(loc, type, var).getResult(); 93 }); 94 } 95 96 static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, 97 ConversionPatternRewriter &rewriter, 98 scf::YieldOp yield) { 99 Location loc = yield.getLoc(); 100 101 OpBuilder::InsertionGuard guard(rewriter); 102 rewriter.setInsertionPoint(yield); 103 104 SmallVector<Value> yieldOperands; 105 if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) { 106 return rewriter.notifyMatchFailure(op, "failed to lower yield operands"); 107 } 108 109 assignValues(yieldOperands, resultVariables, rewriter, loc); 110 111 rewriter.create<emitc::YieldOp>(loc); 112 rewriter.eraseOp(yield); 113 114 return success(); 115 } 116 117 // Lower the contents of an scf::if/scf::index_switch regions to an 118 // emitc::if/emitc::switch region. The contents of the lowering region is 119 // moved into the respective lowered region, but the scf::yield is replaced not 120 // only with an emitc::yield, but also with a sequence of emitc::assign ops that 121 // set the yielded values into the result variables. 122 static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables, 123 ConversionPatternRewriter &rewriter, 124 Region ®ion, Region &loweredRegion) { 125 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end()); 126 Operation *terminator = loweredRegion.back().getTerminator(); 127 return lowerYield(op, resultVariables, rewriter, 128 cast<scf::YieldOp>(terminator)); 129 } 130 131 LogicalResult 132 ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, 133 ConversionPatternRewriter &rewriter) const { 134 Location loc = forOp.getLoc(); 135 136 // Create an emitc::variable op for each result. These variables will be 137 // assigned to by emitc::assign ops within the loop body. 138 SmallVector<Value> resultVariables; 139 if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter, 140 resultVariables))) 141 return rewriter.notifyMatchFailure(forOp, 142 "create variables for results failed"); 143 144 assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc); 145 146 emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>( 147 loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); 148 149 Block *loweredBody = loweredFor.getBody(); 150 151 // Erase the auto-generated terminator for the lowered for op. 152 rewriter.eraseOp(loweredBody->getTerminator()); 153 154 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint(); 155 rewriter.setInsertionPointToEnd(loweredBody); 156 157 SmallVector<Value> iterArgsValues = 158 loadValues(resultVariables, rewriter, loc); 159 160 rewriter.restoreInsertionPoint(ip); 161 162 // Convert the original region types into the new types by adding unrealized 163 // casts in the beginning of the loop. This performs the conversion in place. 164 if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), 165 *getTypeConverter(), nullptr))) { 166 return rewriter.notifyMatchFailure(forOp, "region types conversion failed"); 167 } 168 169 // Register the replacements for the block arguments and inline the body of 170 // the scf.for loop into the body of the emitc::for loop. 171 Block *scfBody = &(forOp.getRegion().front()); 172 SmallVector<Value> replacingValues; 173 replacingValues.push_back(loweredFor.getInductionVar()); 174 replacingValues.append(iterArgsValues.begin(), iterArgsValues.end()); 175 rewriter.mergeBlocks(scfBody, loweredBody, replacingValues); 176 177 auto result = lowerYield(forOp, resultVariables, rewriter, 178 cast<scf::YieldOp>(loweredBody->getTerminator())); 179 180 if (failed(result)) { 181 return result; 182 } 183 184 // Load variables into SSA values after the for loop. 185 SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc); 186 187 rewriter.replaceOp(forOp, resultValues); 188 return success(); 189 } 190 191 // Lower scf::if to emitc::if, implementing result values as emitc::variable's 192 // updated within the then and else regions. 193 struct IfLowering : public OpConversionPattern<IfOp> { 194 using OpConversionPattern<IfOp>::OpConversionPattern; 195 196 LogicalResult 197 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, 198 ConversionPatternRewriter &rewriter) const override; 199 }; 200 201 } // namespace 202 203 LogicalResult 204 IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, 205 ConversionPatternRewriter &rewriter) const { 206 Location loc = ifOp.getLoc(); 207 208 // Create an emitc::variable op for each result. These variables will be 209 // assigned to by emitc::assign ops within the then & else regions. 210 SmallVector<Value> resultVariables; 211 if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter, 212 resultVariables))) 213 return rewriter.notifyMatchFailure(ifOp, 214 "create variables for results failed"); 215 216 // Utility function to lower the contents of an scf::if region to an emitc::if 217 // region. The contents of the scf::if regions is moved into the respective 218 // emitc::if regions, but the scf::yield is replaced not only with an 219 // emitc::yield, but also with a sequence of emitc::assign ops that set the 220 // yielded values into the result variables. 221 auto lowerRegion = [&resultVariables, &rewriter, 222 &ifOp](Region ®ion, Region &loweredRegion) { 223 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end()); 224 Operation *terminator = loweredRegion.back().getTerminator(); 225 auto result = lowerYield(ifOp, resultVariables, rewriter, 226 cast<scf::YieldOp>(terminator)); 227 if (failed(result)) { 228 return result; 229 } 230 return success(); 231 }; 232 233 Region &thenRegion = adaptor.getThenRegion(); 234 Region &elseRegion = adaptor.getElseRegion(); 235 236 bool hasElseBlock = !elseRegion.empty(); 237 238 auto loweredIf = 239 rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false); 240 241 Region &loweredThenRegion = loweredIf.getThenRegion(); 242 auto result = lowerRegion(thenRegion, loweredThenRegion); 243 if (failed(result)) { 244 return result; 245 } 246 247 if (hasElseBlock) { 248 Region &loweredElseRegion = loweredIf.getElseRegion(); 249 auto result = lowerRegion(elseRegion, loweredElseRegion); 250 if (failed(result)) { 251 return result; 252 } 253 } 254 255 rewriter.setInsertionPointAfter(ifOp); 256 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc); 257 258 rewriter.replaceOp(ifOp, results); 259 return success(); 260 } 261 262 // Lower scf::index_switch to emitc::switch, implementing result values as 263 // emitc::variable's updated within the case and default regions. 264 struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> { 265 using OpConversionPattern::OpConversionPattern; 266 267 LogicalResult 268 matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, 269 ConversionPatternRewriter &rewriter) const override; 270 }; 271 272 LogicalResult IndexSwitchOpLowering::matchAndRewrite( 273 IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, 274 ConversionPatternRewriter &rewriter) const { 275 Location loc = indexSwitchOp.getLoc(); 276 277 // Create an emitc::variable op for each result. These variables will be 278 // assigned to by emitc::assign ops within the case and default regions. 279 SmallVector<Value> resultVariables; 280 if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(), 281 rewriter, resultVariables))) { 282 return rewriter.notifyMatchFailure(indexSwitchOp, 283 "create variables for results failed"); 284 } 285 286 auto loweredSwitch = rewriter.create<emitc::SwitchOp>( 287 loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases()); 288 289 // Lowering all case regions. 290 for (auto pair : 291 llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) { 292 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter, 293 *std::get<0>(pair), std::get<1>(pair)))) { 294 return failure(); 295 } 296 } 297 298 // Lowering default region. 299 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter, 300 adaptor.getDefaultRegion(), 301 loweredSwitch.getDefaultRegion()))) { 302 return failure(); 303 } 304 305 rewriter.setInsertionPointAfter(indexSwitchOp); 306 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc); 307 308 rewriter.replaceOp(indexSwitchOp, results); 309 return success(); 310 } 311 312 void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, 313 TypeConverter &typeConverter) { 314 patterns.add<ForLowering>(typeConverter, patterns.getContext()); 315 patterns.add<IfLowering>(typeConverter, patterns.getContext()); 316 patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext()); 317 } 318 319 void SCFToEmitCPass::runOnOperation() { 320 RewritePatternSet patterns(&getContext()); 321 TypeConverter typeConverter; 322 // Fallback converter 323 // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter 324 // Type converters are called most to least recently inserted 325 typeConverter.addConversion([](Type t) { return t; }); 326 populateEmitCSizeTTypeConversions(typeConverter); 327 populateSCFToEmitCConversionPatterns(patterns, typeConverter); 328 329 // Configure conversion to lower out SCF operations. 330 ConversionTarget target(getContext()); 331 target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(); 332 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 333 if (failed( 334 applyPartialConversion(getOperation(), target, std::move(patterns)))) 335 signalPassFailure(); 336 } 337