1 //===- FuncConversions.cpp - Function conversions -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/Transforms/DialectConversion.h" 12 13 using namespace mlir; 14 using namespace mlir::func; 15 16 /// Flatten the given value ranges into a single vector of values. 17 static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { 18 SmallVector<Value> result; 19 for (const auto &vals : values) 20 llvm::append_range(result, vals); 21 return result; 22 } 23 24 namespace { 25 /// Converts the operand and result types of the CallOp, used together with the 26 /// FuncOpSignatureConversion. 27 struct CallOpSignatureConversion : public OpConversionPattern<CallOp> { 28 using OpConversionPattern<CallOp>::OpConversionPattern; 29 30 /// Hook for derived classes to implement combined matching and rewriting. 31 LogicalResult 32 matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor, 33 ConversionPatternRewriter &rewriter) const override { 34 // Convert the original function results. Keep track of how many result 35 // types an original result type is converted into. 36 SmallVector<size_t> numResultsReplacments; 37 SmallVector<Type, 1> convertedResults; 38 size_t numFlattenedResults = 0; 39 for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) { 40 if (failed(typeConverter->convertTypes(type, convertedResults))) 41 return failure(); 42 numResultsReplacments.push_back(convertedResults.size() - 43 numFlattenedResults); 44 numFlattenedResults = convertedResults.size(); 45 } 46 47 // Substitute with the new result types from the corresponding FuncType 48 // conversion. 49 auto newCallOp = rewriter.create<CallOp>( 50 callOp.getLoc(), callOp.getCallee(), convertedResults, 51 flattenValues(adaptor.getOperands())); 52 SmallVector<ValueRange> replacements; 53 size_t offset = 0; 54 for (int i = 0, e = callOp->getNumResults(); i < e; ++i) { 55 replacements.push_back( 56 newCallOp->getResults().slice(offset, numResultsReplacments[i])); 57 offset += numResultsReplacments[i]; 58 } 59 assert(offset == convertedResults.size() && 60 "expected that all converted results are used"); 61 rewriter.replaceOpWithMultiple(callOp, replacements); 62 return success(); 63 } 64 }; 65 } // namespace 66 67 void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns, 68 const TypeConverter &converter) { 69 patterns.add<CallOpSignatureConversion>(converter, patterns.getContext()); 70 } 71 72 namespace { 73 /// Only needed to support partial conversion of functions where this pattern 74 /// ensures that the branch operation arguments matches up with the succesor 75 /// block arguments. 76 class BranchOpInterfaceTypeConversion 77 : public OpInterfaceConversionPattern<BranchOpInterface> { 78 public: 79 using OpInterfaceConversionPattern< 80 BranchOpInterface>::OpInterfaceConversionPattern; 81 82 BranchOpInterfaceTypeConversion( 83 const TypeConverter &typeConverter, MLIRContext *ctx, 84 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) 85 : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1), 86 shouldConvertBranchOperand(shouldConvertBranchOperand) {} 87 88 LogicalResult 89 matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands, 90 ConversionPatternRewriter &rewriter) const final { 91 // For a branch operation, only some operands go to the target blocks, so 92 // only rewrite those. 93 SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end()); 94 for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); 95 succIdx < succEnd; ++succIdx) { 96 OperandRange forwardedOperands = 97 op.getSuccessorOperands(succIdx).getForwardedOperands(); 98 if (forwardedOperands.empty()) 99 continue; 100 101 for (int idx = forwardedOperands.getBeginOperandIndex(), 102 eidx = idx + forwardedOperands.size(); 103 idx < eidx; ++idx) { 104 if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx)) 105 newOperands[idx] = operands[idx]; 106 } 107 } 108 rewriter.modifyOpInPlace( 109 op, [newOperands, op]() { op->setOperands(newOperands); }); 110 return success(); 111 } 112 113 private: 114 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand; 115 }; 116 } // namespace 117 118 namespace { 119 /// Only needed to support partial conversion of functions where this pattern 120 /// ensures that the branch operation arguments matches up with the succesor 121 /// block arguments. 122 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { 123 public: 124 using OpConversionPattern<ReturnOp>::OpConversionPattern; 125 126 LogicalResult 127 matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor, 128 ConversionPatternRewriter &rewriter) const final { 129 rewriter.replaceOpWithNewOp<ReturnOp>(op, 130 flattenValues(adaptor.getOperands())); 131 return success(); 132 } 133 }; 134 } // namespace 135 136 void mlir::populateBranchOpInterfaceTypeConversionPattern( 137 RewritePatternSet &patterns, const TypeConverter &typeConverter, 138 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) { 139 patterns.add<BranchOpInterfaceTypeConversion>( 140 typeConverter, patterns.getContext(), shouldConvertBranchOperand); 141 } 142 143 bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern( 144 Operation *op, const TypeConverter &converter) { 145 // All successor operands of branch like operations must be rewritten. 146 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 147 for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { 148 auto successorOperands = branchOp.getSuccessorOperands(p); 149 if (!converter.isLegal( 150 successorOperands.getForwardedOperands().getTypes())) 151 return false; 152 } 153 return true; 154 } 155 156 return false; 157 } 158 159 void mlir::populateReturnOpTypeConversionPattern( 160 RewritePatternSet &patterns, const TypeConverter &typeConverter) { 161 patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext()); 162 } 163 164 bool mlir::isLegalForReturnOpTypeConversionPattern( 165 Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal) { 166 // If this is a `return` and the user pass wants to convert/transform across 167 // function boundaries, then `converter` is invoked to check whether the 168 // `return` op is legal. 169 if (isa<ReturnOp>(op) && !returnOpAlwaysLegal) 170 return converter.isLegal(op); 171 172 // ReturnLike operations have to be legalized with their parent. For 173 // return this is handled, for other ops they remain as is. 174 return op->hasTrait<OpTrait::ReturnLike>(); 175 } 176 177 bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) { 178 // If it is not a terminator, ignore it. 179 if (!op->mightHaveTrait<OpTrait::IsTerminator>()) 180 return true; 181 182 // If it is not the last operation in the block, also ignore it. We do 183 // this to handle unknown operations, as well. 184 Block *block = op->getBlock(); 185 if (!block || &block->back() != op) 186 return true; 187 188 // We don't want to handle terminators in nested regions, assume they are 189 // always legal. 190 if (!isa_and_nonnull<FuncOp>(op->getParentOp())) 191 return true; 192 193 return false; 194 } 195