xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===-- OneToNTypeConversion.cpp - SCF 1:N type conversion ------*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The patterns in this file are heavily inspired (and copied from)
10 // lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N
11 // type conversions.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
16 
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Transforms/OneToNTypeConversion.h"
19 
20 using namespace mlir;
21 using namespace mlir::scf;
22 
23 class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern<IfOp> {
24 public:
25   using OneToNOpConversionPattern<IfOp>::OneToNOpConversionPattern;
26 
27   LogicalResult
28   matchAndRewrite(IfOp op, OpAdaptor adaptor,
29                   OneToNPatternRewriter &rewriter) const override {
30     Location loc = op->getLoc();
31     const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
32 
33     // Nothing to do if there is no non-identity conversion.
34     if (!resultMapping.hasNonIdentityConversion())
35       return failure();
36 
37     // Create new IfOp.
38     TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
39     auto newOp = rewriter.create<IfOp>(loc, convertedResultTypes,
40                                        op.getCondition(), true);
41     newOp->setAttrs(op->getAttrs());
42 
43     // We do not need the empty blocks created by rewriter.
44     rewriter.eraseBlock(newOp.elseBlock());
45     rewriter.eraseBlock(newOp.thenBlock());
46 
47     // Inlines block from the original operation.
48     rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
49                                 newOp.getThenRegion().end());
50     rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
51                                 newOp.getElseRegion().end());
52 
53     rewriter.replaceOp(op, newOp->getResults(), resultMapping);
54     return success();
55   }
56 };
57 
58 class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern<WhileOp> {
59 public:
60   using OneToNOpConversionPattern<WhileOp>::OneToNOpConversionPattern;
61 
62   LogicalResult
63   matchAndRewrite(WhileOp op, OpAdaptor adaptor,
64                   OneToNPatternRewriter &rewriter) const override {
65     Location loc = op->getLoc();
66 
67     const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
68     const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
69 
70     // Nothing to do if the op doesn't have any non-identity conversions for its
71     // operands or results.
72     if (!operandMapping.hasNonIdentityConversion() &&
73         !resultMapping.hasNonIdentityConversion())
74       return failure();
75 
76     // Create new WhileOp.
77     TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
78 
79     auto newOp = rewriter.create<WhileOp>(loc, convertedResultTypes,
80                                           adaptor.getFlatOperands());
81     newOp->setAttrs(op->getAttrs());
82 
83     // Update block signatures.
84     std::array<OneToNTypeMapping, 2> blockMappings = {operandMapping,
85                                                       resultMapping};
86     for (unsigned int i : {0u, 1u}) {
87       Region *region = &op.getRegion(i);
88       Block *block = &region->front();
89 
90       rewriter.applySignatureConversion(block, blockMappings[i]);
91 
92       // Move updated region to new WhileOp.
93       Region &dstRegion = newOp.getRegion(i);
94       rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
95     }
96 
97     rewriter.replaceOp(op, newOp->getResults(), resultMapping);
98     return success();
99   }
100 };
101 
102 class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern<YieldOp> {
103 public:
104   using OneToNOpConversionPattern<YieldOp>::OneToNOpConversionPattern;
105 
106   LogicalResult
107   matchAndRewrite(YieldOp op, OpAdaptor adaptor,
108                   OneToNPatternRewriter &rewriter) const override {
109     // Nothing to do if there is no non-identity conversion.
110     if (!adaptor.getOperandMapping().hasNonIdentityConversion())
111       return failure();
112 
113     // Convert operands.
114     rewriter.modifyOpInPlace(
115         op, [&] { op->setOperands(adaptor.getFlatOperands()); });
116 
117     return success();
118   }
119 };
120 
121 class ConvertTypesInSCFConditionOp
122     : public OneToNOpConversionPattern<ConditionOp> {
123 public:
124   using OneToNOpConversionPattern<ConditionOp>::OneToNOpConversionPattern;
125 
126   LogicalResult
127   matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
128                   OneToNPatternRewriter &rewriter) const override {
129     // Nothing to do if there is no non-identity conversion.
130     if (!adaptor.getOperandMapping().hasNonIdentityConversion())
131       return failure();
132 
133     // Convert operands.
134     rewriter.modifyOpInPlace(
135         op, [&] { op->setOperands(adaptor.getFlatOperands()); });
136 
137     return success();
138   }
139 };
140 
141 class ConvertTypesInSCFForOp final : public OneToNOpConversionPattern<ForOp> {
142 public:
143   using OneToNOpConversionPattern<ForOp>::OneToNOpConversionPattern;
144 
145   LogicalResult
146   matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
147                   OneToNPatternRewriter &rewriter) const override {
148     const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
149     const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
150 
151     // Nothing to do if there is no non-identity conversion.
152     if (!operandMapping.hasNonIdentityConversion() &&
153         !resultMapping.hasNonIdentityConversion())
154       return failure();
155 
156     // If the lower-bound, upper-bound, or step were expanded, abort the
157     // conversion. This conversion does not know what to do in such cases.
158     ValueRange lbs = adaptor.getLowerBound();
159     ValueRange ubs = adaptor.getUpperBound();
160     ValueRange steps = adaptor.getStep();
161     if (lbs.size() != 1 || ubs.size() != 1 || steps.size() != 1)
162       return rewriter.notifyMatchFailure(
163           forOp, "index operands converted to multiple values");
164 
165     Location loc = forOp.getLoc();
166 
167     Region *region = &forOp.getRegion();
168     Block *block = &region->front();
169 
170     // Construct the new for-op with an empty body.
171     ValueRange newInits = adaptor.getFlatOperands().drop_front(3);
172     auto newOp =
173         rewriter.create<ForOp>(loc, lbs[0], ubs[0], steps[0], newInits);
174     newOp->setAttrs(forOp->getAttrs());
175 
176     // We do not need the empty blocks created by rewriter.
177     rewriter.eraseBlock(newOp.getBody());
178 
179     // Convert the signature of the body region.
180     OneToNTypeMapping bodyTypeMapping(block->getArgumentTypes());
181     if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
182                                                    bodyTypeMapping)))
183       return failure();
184 
185     // Perform signature conversion on the body block.
186     rewriter.applySignatureConversion(block, bodyTypeMapping);
187 
188     // Splice the old body region into the new for-op.
189     Region &dstRegion = newOp.getBodyRegion();
190     rewriter.inlineRegionBefore(forOp.getRegion(), dstRegion, dstRegion.end());
191 
192     rewriter.replaceOp(forOp, newOp.getResults(), resultMapping);
193 
194     return success();
195   }
196 };
197 
198 namespace mlir {
199 namespace scf {
200 
201 void populateSCFStructuralOneToNTypeConversions(
202     const TypeConverter &typeConverter, RewritePatternSet &patterns) {
203   patterns.add<
204       // clang-format off
205       ConvertTypesInSCFConditionOp,
206       ConvertTypesInSCFForOp,
207       ConvertTypesInSCFIfOp,
208       ConvertTypesInSCFWhileOp,
209       ConvertTypesInSCFYieldOp
210       // clang-format on
211       >(typeConverter, patterns.getContext());
212 }
213 
214 } // namespace scf
215 } // namespace mlir
216