1 //===- StructuralTypeConversions.cpp - scf structural type conversions ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/IR/SCF.h" 10 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 11 #include "mlir/Transforms/DialectConversion.h" 12 #include <optional> 13 14 using namespace mlir; 15 using namespace mlir::scf; 16 17 namespace { 18 19 /// Flatten the given value ranges into a single vector of values. 20 static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { 21 SmallVector<Value> result; 22 for (const auto &vals : values) 23 llvm::append_range(result, vals); 24 return result; 25 } 26 27 /// Assert that the given value range contains a single value and return it. 28 static Value getSingleValue(ValueRange values) { 29 assert(values.size() == 1 && "expected single value"); 30 return values.front(); 31 } 32 33 // CRTP 34 // A base class that takes care of 1:N type conversion, which maps the converted 35 // op results (computed by the derived class) and materializes 1:N conversion. 36 template <typename SourceOp, typename ConcretePattern> 37 class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> { 38 public: 39 using OpConversionPattern<SourceOp>::typeConverter; 40 using OpConversionPattern<SourceOp>::OpConversionPattern; 41 using OneToNOpAdaptor = 42 typename OpConversionPattern<SourceOp>::OneToNOpAdaptor; 43 44 // 45 // Derived classes should provide the following method which performs the 46 // actual conversion. It should return std::nullopt upon conversion failure 47 // and return the converted operation upon success. 48 // 49 // std::optional<SourceOp> convertSourceOp( 50 // SourceOp op, OneToNOpAdaptor adaptor, 51 // ConversionPatternRewriter &rewriter, 52 // TypeRange dstTypes) const; 53 54 LogicalResult 55 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, 56 ConversionPatternRewriter &rewriter) const override { 57 SmallVector<Type> dstTypes; 58 SmallVector<unsigned> offsets; 59 offsets.push_back(0); 60 // Do the type conversion and record the offsets. 61 for (Type type : op.getResultTypes()) { 62 if (failed(typeConverter->convertTypes(type, dstTypes))) 63 return rewriter.notifyMatchFailure(op, "could not convert result type"); 64 offsets.push_back(dstTypes.size()); 65 } 66 67 // Calls the actual converter implementation to convert the operation. 68 std::optional<SourceOp> newOp = 69 static_cast<const ConcretePattern *>(this)->convertSourceOp( 70 op, adaptor, rewriter, dstTypes); 71 72 if (!newOp) 73 return rewriter.notifyMatchFailure(op, "could not convert operation"); 74 75 // Packs the return value. 76 SmallVector<ValueRange> packedRets; 77 for (unsigned i = 1, e = offsets.size(); i < e; i++) { 78 unsigned start = offsets[i - 1], end = offsets[i]; 79 unsigned len = end - start; 80 ValueRange mappedValue = newOp->getResults().slice(start, len); 81 packedRets.push_back(mappedValue); 82 } 83 84 rewriter.replaceOpWithMultiple(op, packedRets); 85 return success(); 86 } 87 }; 88 89 class ConvertForOpTypes 90 : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> { 91 public: 92 using Structural1ToNConversionPattern::Structural1ToNConversionPattern; 93 94 // The callback required by CRTP. 95 std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor, 96 ConversionPatternRewriter &rewriter, 97 TypeRange dstTypes) const { 98 // Create a empty new op and inline the regions from the old op. 99 // 100 // This is a little bit tricky. We have two concerns here: 101 // 102 // 1. We cannot update the op in place because the dialect conversion 103 // framework does not track type changes for ops updated in place, so it 104 // won't insert appropriate materializations on the changed result types. 105 // PR47938 tracks this issue, but it seems hard to fix. Instead, we need 106 // to clone the op. 107 // 108 // 2. We need to resue the original region instead of cloning it, otherwise 109 // the dialect conversion framework thinks that we just inserted all the 110 // cloned child ops. But what we want is to "take" the child regions and let 111 // the dialect conversion framework continue recursively into ops inside 112 // those regions (which are already in its worklist; inlining them into the 113 // new op's regions doesn't remove the child ops from the worklist). 114 115 // convertRegionTypes already takes care of 1:N conversion. 116 if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter))) 117 return std::nullopt; 118 119 // We can not do clone as the number of result types after conversion 120 // might be different. 121 ForOp newOp = rewriter.create<ForOp>( 122 op.getLoc(), getSingleValue(adaptor.getLowerBound()), 123 getSingleValue(adaptor.getUpperBound()), 124 getSingleValue(adaptor.getStep()), 125 flattenValues(adaptor.getInitArgs())); 126 127 // Reserve whatever attributes in the original op. 128 newOp->setAttrs(op->getAttrs()); 129 130 // We do not need the empty block created by rewriter. 131 rewriter.eraseBlock(newOp.getBody(0)); 132 // Inline the type converted region from the original operation. 133 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 134 newOp.getRegion().end()); 135 136 return newOp; 137 } 138 }; 139 } // namespace 140 141 namespace { 142 class ConvertIfOpTypes 143 : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> { 144 public: 145 using Structural1ToNConversionPattern::Structural1ToNConversionPattern; 146 147 std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor, 148 ConversionPatternRewriter &rewriter, 149 TypeRange dstTypes) const { 150 151 IfOp newOp = rewriter.create<IfOp>( 152 op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true); 153 newOp->setAttrs(op->getAttrs()); 154 155 // We do not need the empty blocks created by rewriter. 156 rewriter.eraseBlock(newOp.elseBlock()); 157 rewriter.eraseBlock(newOp.thenBlock()); 158 159 // Inlines block from the original operation. 160 rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), 161 newOp.getThenRegion().end()); 162 rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), 163 newOp.getElseRegion().end()); 164 165 return newOp; 166 } 167 }; 168 } // namespace 169 170 namespace { 171 class ConvertWhileOpTypes 172 : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> { 173 public: 174 using Structural1ToNConversionPattern::Structural1ToNConversionPattern; 175 176 std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor, 177 ConversionPatternRewriter &rewriter, 178 TypeRange dstTypes) const { 179 auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, 180 flattenValues(adaptor.getOperands())); 181 182 for (auto i : {0u, 1u}) { 183 if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) 184 return std::nullopt; 185 auto &dstRegion = newOp.getRegion(i); 186 rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); 187 } 188 return newOp; 189 } 190 }; 191 } // namespace 192 193 namespace { 194 // When the result types of a ForOp/IfOp get changed, the operand types of the 195 // corresponding yield op need to be changed. In order to trigger the 196 // appropriate type conversions / materializations, we need a dummy pattern. 197 class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> { 198 public: 199 using OpConversionPattern::OpConversionPattern; 200 LogicalResult 201 matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor, 202 ConversionPatternRewriter &rewriter) const override { 203 rewriter.replaceOpWithNewOp<scf::YieldOp>( 204 op, flattenValues(adaptor.getOperands())); 205 return success(); 206 } 207 }; 208 } // namespace 209 210 namespace { 211 class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> { 212 public: 213 using OpConversionPattern<ConditionOp>::OpConversionPattern; 214 LogicalResult 215 matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor, 216 ConversionPatternRewriter &rewriter) const override { 217 rewriter.modifyOpInPlace( 218 op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); }); 219 return success(); 220 } 221 }; 222 } // namespace 223 224 void mlir::scf::populateSCFStructuralTypeConversions( 225 const TypeConverter &typeConverter, RewritePatternSet &patterns) { 226 patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes, 227 ConvertWhileOpTypes, ConvertConditionOpTypes>( 228 typeConverter, patterns.getContext()); 229 } 230 231 void mlir::scf::populateSCFStructuralTypeConversionTarget( 232 const TypeConverter &typeConverter, ConversionTarget &target) { 233 target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) { 234 return typeConverter.isLegal(op->getResultTypes()); 235 }); 236 target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) { 237 // We only have conversions for a subset of ops that use scf.yield 238 // terminators. 239 if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp())) 240 return true; 241 return typeConverter.isLegal(op.getOperandTypes()); 242 }); 243 target.addDynamicallyLegalOp<WhileOp, ConditionOp>( 244 [&](Operation *op) { return typeConverter.isLegal(op); }); 245 } 246 247 void mlir::scf::populateSCFStructuralTypeConversionsAndLegality( 248 const TypeConverter &typeConverter, RewritePatternSet &patterns, 249 ConversionTarget &target) { 250 populateSCFStructuralTypeConversions(typeConverter, patterns); 251 populateSCFStructuralTypeConversionTarget(typeConverter, target); 252 } 253