123aa5a74SRiver Riddle //===- FuncConversions.cpp - Function conversions -------------------------===// 223aa5a74SRiver Riddle // 323aa5a74SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 423aa5a74SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 523aa5a74SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 623aa5a74SRiver Riddle // 723aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 823aa5a74SRiver Riddle 923aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 1023aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 1123aa5a74SRiver Riddle #include "mlir/Transforms/DialectConversion.h" 1223aa5a74SRiver Riddle 1323aa5a74SRiver Riddle using namespace mlir; 1423aa5a74SRiver Riddle using namespace mlir::func; 1523aa5a74SRiver Riddle 169df63b26SMatthias Springer /// Flatten the given value ranges into a single vector of values. 179df63b26SMatthias Springer static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { 189df63b26SMatthias Springer SmallVector<Value> result; 199df63b26SMatthias Springer for (const auto &vals : values) 209df63b26SMatthias Springer llvm::append_range(result, vals); 219df63b26SMatthias Springer return result; 229df63b26SMatthias Springer } 239df63b26SMatthias Springer 2423aa5a74SRiver Riddle namespace { 2523aa5a74SRiver Riddle /// Converts the operand and result types of the CallOp, used together with the 2623aa5a74SRiver Riddle /// FuncOpSignatureConversion. 2723aa5a74SRiver Riddle struct CallOpSignatureConversion : public OpConversionPattern<CallOp> { 2823aa5a74SRiver Riddle using OpConversionPattern<CallOp>::OpConversionPattern; 2923aa5a74SRiver Riddle 3023aa5a74SRiver Riddle /// Hook for derived classes to implement combined matching and rewriting. 3123aa5a74SRiver Riddle LogicalResult 329df63b26SMatthias Springer matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor, 3323aa5a74SRiver Riddle ConversionPatternRewriter &rewriter) const override { 3408e6566dSMatthias Springer // Convert the original function results. Keep track of how many result 3508e6566dSMatthias Springer // types an original result type is converted into. 3608e6566dSMatthias Springer SmallVector<size_t> numResultsReplacments; 3723aa5a74SRiver Riddle SmallVector<Type, 1> convertedResults; 3808e6566dSMatthias Springer size_t numFlattenedResults = 0; 3908e6566dSMatthias Springer for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) { 4008e6566dSMatthias Springer if (failed(typeConverter->convertTypes(type, convertedResults))) 4123aa5a74SRiver Riddle return failure(); 4208e6566dSMatthias Springer numResultsReplacments.push_back(convertedResults.size() - 4308e6566dSMatthias Springer numFlattenedResults); 4408e6566dSMatthias Springer numFlattenedResults = convertedResults.size(); 4508e6566dSMatthias Springer } 4688a3dc0eSAlex Zinenko 4723aa5a74SRiver Riddle // Substitute with the new result types from the corresponding FuncType 4823aa5a74SRiver Riddle // conversion. 499df63b26SMatthias Springer auto newCallOp = rewriter.create<CallOp>( 509df63b26SMatthias Springer callOp.getLoc(), callOp.getCallee(), convertedResults, 519df63b26SMatthias Springer flattenValues(adaptor.getOperands())); 5208e6566dSMatthias Springer SmallVector<ValueRange> replacements; 5308e6566dSMatthias Springer size_t offset = 0; 5408e6566dSMatthias Springer for (int i = 0, e = callOp->getNumResults(); i < e; ++i) { 5508e6566dSMatthias Springer replacements.push_back( 5608e6566dSMatthias Springer newCallOp->getResults().slice(offset, numResultsReplacments[i])); 5708e6566dSMatthias Springer offset += numResultsReplacments[i]; 5808e6566dSMatthias Springer } 5908e6566dSMatthias Springer assert(offset == convertedResults.size() && 6008e6566dSMatthias Springer "expected that all converted results are used"); 6108e6566dSMatthias Springer rewriter.replaceOpWithMultiple(callOp, replacements); 6223aa5a74SRiver Riddle return success(); 6323aa5a74SRiver Riddle } 6423aa5a74SRiver Riddle }; 6523aa5a74SRiver Riddle } // namespace 6623aa5a74SRiver Riddle 6723aa5a74SRiver Riddle void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns, 68206fad0eSMatthias Springer const TypeConverter &converter) { 6923aa5a74SRiver Riddle patterns.add<CallOpSignatureConversion>(converter, patterns.getContext()); 7023aa5a74SRiver Riddle } 7123aa5a74SRiver Riddle 7223aa5a74SRiver Riddle namespace { 7323aa5a74SRiver Riddle /// Only needed to support partial conversion of functions where this pattern 7423aa5a74SRiver Riddle /// ensures that the branch operation arguments matches up with the succesor 7523aa5a74SRiver Riddle /// block arguments. 7623aa5a74SRiver Riddle class BranchOpInterfaceTypeConversion 7723aa5a74SRiver Riddle : public OpInterfaceConversionPattern<BranchOpInterface> { 7823aa5a74SRiver Riddle public: 7923aa5a74SRiver Riddle using OpInterfaceConversionPattern< 8023aa5a74SRiver Riddle BranchOpInterface>::OpInterfaceConversionPattern; 8123aa5a74SRiver Riddle 8223aa5a74SRiver Riddle BranchOpInterfaceTypeConversion( 83206fad0eSMatthias Springer const TypeConverter &typeConverter, MLIRContext *ctx, 8423aa5a74SRiver Riddle function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) 8523aa5a74SRiver Riddle : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1), 8623aa5a74SRiver Riddle shouldConvertBranchOperand(shouldConvertBranchOperand) {} 8723aa5a74SRiver Riddle 8823aa5a74SRiver Riddle LogicalResult 8923aa5a74SRiver Riddle matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands, 9023aa5a74SRiver Riddle ConversionPatternRewriter &rewriter) const final { 9123aa5a74SRiver Riddle // For a branch operation, only some operands go to the target blocks, so 9223aa5a74SRiver Riddle // only rewrite those. 9323aa5a74SRiver Riddle SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end()); 9423aa5a74SRiver Riddle for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); 9523aa5a74SRiver Riddle succIdx < succEnd; ++succIdx) { 960c789db5SMarkus Böck OperandRange forwardedOperands = 970c789db5SMarkus Böck op.getSuccessorOperands(succIdx).getForwardedOperands(); 980c789db5SMarkus Böck if (forwardedOperands.empty()) 9923aa5a74SRiver Riddle continue; 10023aa5a74SRiver Riddle 1010c789db5SMarkus Böck for (int idx = forwardedOperands.getBeginOperandIndex(), 1020c789db5SMarkus Böck eidx = idx + forwardedOperands.size(); 10323aa5a74SRiver Riddle idx < eidx; ++idx) { 10423aa5a74SRiver Riddle if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx)) 10523aa5a74SRiver Riddle newOperands[idx] = operands[idx]; 10623aa5a74SRiver Riddle } 10723aa5a74SRiver Riddle } 1085fcf907bSMatthias Springer rewriter.modifyOpInPlace( 10923aa5a74SRiver Riddle op, [newOperands, op]() { op->setOperands(newOperands); }); 11023aa5a74SRiver Riddle return success(); 11123aa5a74SRiver Riddle } 11223aa5a74SRiver Riddle 11323aa5a74SRiver Riddle private: 11423aa5a74SRiver Riddle function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand; 11523aa5a74SRiver Riddle }; 11623aa5a74SRiver Riddle } // namespace 11723aa5a74SRiver Riddle 11823aa5a74SRiver Riddle namespace { 11923aa5a74SRiver Riddle /// Only needed to support partial conversion of functions where this pattern 12023aa5a74SRiver Riddle /// ensures that the branch operation arguments matches up with the succesor 12123aa5a74SRiver Riddle /// block arguments. 12223aa5a74SRiver Riddle class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { 12323aa5a74SRiver Riddle public: 12423aa5a74SRiver Riddle using OpConversionPattern<ReturnOp>::OpConversionPattern; 12523aa5a74SRiver Riddle 12623aa5a74SRiver Riddle LogicalResult 127*7267c859SMatthias Springer matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor, 12823aa5a74SRiver Riddle ConversionPatternRewriter &rewriter) const final { 129*7267c859SMatthias Springer rewriter.replaceOpWithNewOp<ReturnOp>(op, 130*7267c859SMatthias Springer flattenValues(adaptor.getOperands())); 13123aa5a74SRiver Riddle return success(); 13223aa5a74SRiver Riddle } 13323aa5a74SRiver Riddle }; 13423aa5a74SRiver Riddle } // namespace 13523aa5a74SRiver Riddle 13623aa5a74SRiver Riddle void mlir::populateBranchOpInterfaceTypeConversionPattern( 137206fad0eSMatthias Springer RewritePatternSet &patterns, const TypeConverter &typeConverter, 13823aa5a74SRiver Riddle function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) { 13923aa5a74SRiver Riddle patterns.add<BranchOpInterfaceTypeConversion>( 14023aa5a74SRiver Riddle typeConverter, patterns.getContext(), shouldConvertBranchOperand); 14123aa5a74SRiver Riddle } 14223aa5a74SRiver Riddle 14323aa5a74SRiver Riddle bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern( 144206fad0eSMatthias Springer Operation *op, const TypeConverter &converter) { 14523aa5a74SRiver Riddle // All successor operands of branch like operations must be rewritten. 14623aa5a74SRiver Riddle if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 14723aa5a74SRiver Riddle for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { 14823aa5a74SRiver Riddle auto successorOperands = branchOp.getSuccessorOperands(p); 1490c789db5SMarkus Böck if (!converter.isLegal( 1500c789db5SMarkus Böck successorOperands.getForwardedOperands().getTypes())) 15123aa5a74SRiver Riddle return false; 15223aa5a74SRiver Riddle } 15323aa5a74SRiver Riddle return true; 15423aa5a74SRiver Riddle } 15523aa5a74SRiver Riddle 15623aa5a74SRiver Riddle return false; 15723aa5a74SRiver Riddle } 15823aa5a74SRiver Riddle 159206fad0eSMatthias Springer void mlir::populateReturnOpTypeConversionPattern( 160206fad0eSMatthias Springer RewritePatternSet &patterns, const TypeConverter &typeConverter) { 16123aa5a74SRiver Riddle patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext()); 16223aa5a74SRiver Riddle } 16323aa5a74SRiver Riddle 164206fad0eSMatthias Springer bool mlir::isLegalForReturnOpTypeConversionPattern( 165206fad0eSMatthias Springer Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal) { 16623aa5a74SRiver Riddle // If this is a `return` and the user pass wants to convert/transform across 1677557530fSFangrui Song // function boundaries, then `converter` is invoked to check whether the 16823aa5a74SRiver Riddle // `return` op is legal. 169e4c39501SRahul Joshi if (isa<ReturnOp>(op) && !returnOpAlwaysLegal) 17023aa5a74SRiver Riddle return converter.isLegal(op); 17123aa5a74SRiver Riddle 17223aa5a74SRiver Riddle // ReturnLike operations have to be legalized with their parent. For 17323aa5a74SRiver Riddle // return this is handled, for other ops they remain as is. 17423aa5a74SRiver Riddle return op->hasTrait<OpTrait::ReturnLike>(); 17523aa5a74SRiver Riddle } 17623aa5a74SRiver Riddle 17723aa5a74SRiver Riddle bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) { 17823aa5a74SRiver Riddle // If it is not a terminator, ignore it. 17923aa5a74SRiver Riddle if (!op->mightHaveTrait<OpTrait::IsTerminator>()) 18023aa5a74SRiver Riddle return true; 18123aa5a74SRiver Riddle 18223aa5a74SRiver Riddle // If it is not the last operation in the block, also ignore it. We do 18323aa5a74SRiver Riddle // this to handle unknown operations, as well. 18423aa5a74SRiver Riddle Block *block = op->getBlock(); 18523aa5a74SRiver Riddle if (!block || &block->back() != op) 18623aa5a74SRiver Riddle return true; 18723aa5a74SRiver Riddle 18823aa5a74SRiver Riddle // We don't want to handle terminators in nested regions, assume they are 18923aa5a74SRiver Riddle // always legal. 19023aa5a74SRiver Riddle if (!isa_and_nonnull<FuncOp>(op->getParentOp())) 19123aa5a74SRiver Riddle return true; 19223aa5a74SRiver Riddle 19323aa5a74SRiver Riddle return false; 19423aa5a74SRiver Riddle } 195