1 //===-- OneToNTypeConversion.cpp - SCF 1:N type conversion ------*- C++ -*-===// 2 // 3 // Licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The patterns in this file are heavily inspired (and copied from) 10 // lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N 11 // type conversions. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 16 17 #include "mlir/Dialect/SCF/IR/SCF.h" 18 #include "mlir/Transforms/OneToNTypeConversion.h" 19 20 using namespace mlir; 21 using namespace mlir::scf; 22 23 class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern<IfOp> { 24 public: 25 using OneToNOpConversionPattern<IfOp>::OneToNOpConversionPattern; 26 27 LogicalResult 28 matchAndRewrite(IfOp op, OpAdaptor adaptor, 29 OneToNPatternRewriter &rewriter) const override { 30 Location loc = op->getLoc(); 31 const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); 32 33 // Nothing to do if there is no non-identity conversion. 34 if (!resultMapping.hasNonIdentityConversion()) 35 return failure(); 36 37 // Create new IfOp. 38 TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); 39 auto newOp = rewriter.create<IfOp>(loc, convertedResultTypes, 40 op.getCondition(), true); 41 newOp->setAttrs(op->getAttrs()); 42 43 // We do not need the empty blocks created by rewriter. 44 rewriter.eraseBlock(newOp.elseBlock()); 45 rewriter.eraseBlock(newOp.thenBlock()); 46 47 // Inlines block from the original operation. 48 rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), 49 newOp.getThenRegion().end()); 50 rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), 51 newOp.getElseRegion().end()); 52 53 rewriter.replaceOp(op, newOp->getResults(), resultMapping); 54 return success(); 55 } 56 }; 57 58 class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern<WhileOp> { 59 public: 60 using OneToNOpConversionPattern<WhileOp>::OneToNOpConversionPattern; 61 62 LogicalResult 63 matchAndRewrite(WhileOp op, OpAdaptor adaptor, 64 OneToNPatternRewriter &rewriter) const override { 65 Location loc = op->getLoc(); 66 67 const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); 68 const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); 69 70 // Nothing to do if the op doesn't have any non-identity conversions for its 71 // operands or results. 72 if (!operandMapping.hasNonIdentityConversion() && 73 !resultMapping.hasNonIdentityConversion()) 74 return failure(); 75 76 // Create new WhileOp. 77 TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); 78 79 auto newOp = rewriter.create<WhileOp>(loc, convertedResultTypes, 80 adaptor.getFlatOperands()); 81 newOp->setAttrs(op->getAttrs()); 82 83 // Update block signatures. 84 std::array<OneToNTypeMapping, 2> blockMappings = {operandMapping, 85 resultMapping}; 86 for (unsigned int i : {0u, 1u}) { 87 Region *region = &op.getRegion(i); 88 Block *block = ®ion->front(); 89 90 rewriter.applySignatureConversion(block, blockMappings[i]); 91 92 // Move updated region to new WhileOp. 93 Region &dstRegion = newOp.getRegion(i); 94 rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); 95 } 96 97 rewriter.replaceOp(op, newOp->getResults(), resultMapping); 98 return success(); 99 } 100 }; 101 102 class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern<YieldOp> { 103 public: 104 using OneToNOpConversionPattern<YieldOp>::OneToNOpConversionPattern; 105 106 LogicalResult 107 matchAndRewrite(YieldOp op, OpAdaptor adaptor, 108 OneToNPatternRewriter &rewriter) const override { 109 // Nothing to do if there is no non-identity conversion. 110 if (!adaptor.getOperandMapping().hasNonIdentityConversion()) 111 return failure(); 112 113 // Convert operands. 114 rewriter.modifyOpInPlace( 115 op, [&] { op->setOperands(adaptor.getFlatOperands()); }); 116 117 return success(); 118 } 119 }; 120 121 class ConvertTypesInSCFConditionOp 122 : public OneToNOpConversionPattern<ConditionOp> { 123 public: 124 using OneToNOpConversionPattern<ConditionOp>::OneToNOpConversionPattern; 125 126 LogicalResult 127 matchAndRewrite(ConditionOp op, OpAdaptor adaptor, 128 OneToNPatternRewriter &rewriter) const override { 129 // Nothing to do if there is no non-identity conversion. 130 if (!adaptor.getOperandMapping().hasNonIdentityConversion()) 131 return failure(); 132 133 // Convert operands. 134 rewriter.modifyOpInPlace( 135 op, [&] { op->setOperands(adaptor.getFlatOperands()); }); 136 137 return success(); 138 } 139 }; 140 141 class ConvertTypesInSCFForOp final : public OneToNOpConversionPattern<ForOp> { 142 public: 143 using OneToNOpConversionPattern<ForOp>::OneToNOpConversionPattern; 144 145 LogicalResult 146 matchAndRewrite(ForOp forOp, OpAdaptor adaptor, 147 OneToNPatternRewriter &rewriter) const override { 148 const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); 149 const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); 150 151 // Nothing to do if there is no non-identity conversion. 152 if (!operandMapping.hasNonIdentityConversion() && 153 !resultMapping.hasNonIdentityConversion()) 154 return failure(); 155 156 // If the lower-bound, upper-bound, or step were expanded, abort the 157 // conversion. This conversion does not know what to do in such cases. 158 ValueRange lbs = adaptor.getLowerBound(); 159 ValueRange ubs = adaptor.getUpperBound(); 160 ValueRange steps = adaptor.getStep(); 161 if (lbs.size() != 1 || ubs.size() != 1 || steps.size() != 1) 162 return rewriter.notifyMatchFailure( 163 forOp, "index operands converted to multiple values"); 164 165 Location loc = forOp.getLoc(); 166 167 Region *region = &forOp.getRegion(); 168 Block *block = ®ion->front(); 169 170 // Construct the new for-op with an empty body. 171 ValueRange newInits = adaptor.getFlatOperands().drop_front(3); 172 auto newOp = 173 rewriter.create<ForOp>(loc, lbs[0], ubs[0], steps[0], newInits); 174 newOp->setAttrs(forOp->getAttrs()); 175 176 // We do not need the empty blocks created by rewriter. 177 rewriter.eraseBlock(newOp.getBody()); 178 179 // Convert the signature of the body region. 180 OneToNTypeMapping bodyTypeMapping(block->getArgumentTypes()); 181 if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), 182 bodyTypeMapping))) 183 return failure(); 184 185 // Perform signature conversion on the body block. 186 rewriter.applySignatureConversion(block, bodyTypeMapping); 187 188 // Splice the old body region into the new for-op. 189 Region &dstRegion = newOp.getBodyRegion(); 190 rewriter.inlineRegionBefore(forOp.getRegion(), dstRegion, dstRegion.end()); 191 192 rewriter.replaceOp(forOp, newOp.getResults(), resultMapping); 193 194 return success(); 195 } 196 }; 197 198 namespace mlir { 199 namespace scf { 200 201 void populateSCFStructuralOneToNTypeConversions( 202 const TypeConverter &typeConverter, RewritePatternSet &patterns) { 203 patterns.add< 204 // clang-format off 205 ConvertTypesInSCFConditionOp, 206 ConvertTypesInSCFForOp, 207 ConvertTypesInSCFIfOp, 208 ConvertTypesInSCFWhileOp, 209 ConvertTypesInSCFYieldOp 210 // clang-format on 211 >(typeConverter, patterns.getContext()); 212 } 213 214 } // namespace scf 215 } // namespace mlir 216