xref: /llvm-project/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (revision 8c4bc1e75de27adfbaead34b895b0efbaf17bd02)
1 //===- TestOneToNTypeConversionPass.cpp - Test pass 1:N type conv. utils --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
12 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/OneToNTypeConversion.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 /// Test pass that exercises the (poor-man's) 1:N type conversion mechanisms
20 /// in `applyPartialOneToNConversion` by converting built-in tuples to the
21 /// elements they consist of as well as some dummy ops operating on these
22 /// tuples.
23 struct TestOneToNTypeConversionPass
24     : public PassWrapper<TestOneToNTypeConversionPass,
25                          OperationPass<ModuleOp>> {
26   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass)
27 
28   TestOneToNTypeConversionPass() = default;
29   TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass)
30       : PassWrapper(pass) {}
31 
32   void getDependentDialects(DialectRegistry &registry) const override {
33     registry.insert<test::TestDialect>();
34   }
35 
36   StringRef getArgument() const final {
37     return "test-one-to-n-type-conversion";
38   }
39 
40   StringRef getDescription() const final {
41     return "Test pass for 1:N type conversion";
42   }
43 
44   Option<bool> convertFuncOps{*this, "convert-func-ops",
45                               llvm::cl::desc("Enable conversion on func ops"),
46                               llvm::cl::init(false)};
47 
48   Option<bool> convertSCFOps{*this, "convert-scf-ops",
49                              llvm::cl::desc("Enable conversion on scf ops"),
50                              llvm::cl::init(false)};
51 
52   Option<bool> convertTupleOps{*this, "convert-tuple-ops",
53                                llvm::cl::desc("Enable conversion on tuple ops"),
54                                llvm::cl::init(false)};
55 
56   void runOnOperation() override;
57 };
58 
59 } // namespace
60 
61 namespace mlir {
62 namespace test {
63 void registerTestOneToNTypeConversionPass() {
64   PassRegistration<TestOneToNTypeConversionPass>();
65 }
66 } // namespace test
67 } // namespace mlir
68 
69 namespace {
70 
71 /// Test pattern on for the `make_tuple` op from the test dialect that converts
72 /// this kind of op into it's "decomposed" form, i.e., the elements of the tuple
73 /// that is being produced by `test.make_tuple`, which are really just the
74 /// operands of this op.
75 class ConvertMakeTupleOp
76     : public OneToNOpConversionPattern<::test::MakeTupleOp> {
77 public:
78   using OneToNOpConversionPattern<
79       ::test::MakeTupleOp>::OneToNOpConversionPattern;
80 
81   LogicalResult
82   matchAndRewrite(::test::MakeTupleOp op, OpAdaptor adaptor,
83                   OneToNPatternRewriter &rewriter) const override {
84     // Simply replace the current op with the converted operands.
85     rewriter.replaceOp(op, adaptor.getFlatOperands(),
86                        adaptor.getResultMapping());
87     return success();
88   }
89 };
90 
91 /// Test pattern on for the `get_tuple_element` op from the test dialect that
92 /// converts this kind of op into it's "decomposed" form, i.e., instead of
93 /// "physically" extracting one element from the tuple, we forward the one
94 /// element of the decomposed form that is being extracted (or the several
95 /// elements in case that element is a nested tuple).
96 class ConvertGetTupleElementOp
97     : public OneToNOpConversionPattern<::test::GetTupleElementOp> {
98 public:
99   using OneToNOpConversionPattern<
100       ::test::GetTupleElementOp>::OneToNOpConversionPattern;
101 
102   LogicalResult
103   matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor,
104                   OneToNPatternRewriter &rewriter) const override {
105     // Construct mapping for tuple element types.
106     auto stateType = cast<TupleType>(op->getOperand(0).getType());
107     TypeRange originalElementTypes = stateType.getTypes();
108     OneToNTypeMapping elementMapping(originalElementTypes);
109     if (failed(typeConverter->convertSignatureArgs(originalElementTypes,
110                                                    elementMapping)))
111       return failure();
112 
113     // Compute converted operands corresponding to original input tuple.
114     assert(adaptor.getOperands().size() == 1 &&
115            "expected 'get_tuple_element' to have one operand");
116     ValueRange convertedTuple = adaptor.getOperands()[0];
117 
118     // Got those converted operands that correspond to the index-th element ofq
119     // the original input tuple.
120     size_t index = op.getIndex();
121     ValueRange extractedElement =
122         elementMapping.getConvertedValues(convertedTuple, index);
123 
124     rewriter.replaceOp(op, extractedElement, adaptor.getResultMapping());
125 
126     return success();
127   }
128 };
129 
130 } // namespace
131 
132 static void
133 populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
134                                     RewritePatternSet &patterns) {
135   patterns.add<
136       // clang-format off
137       ConvertMakeTupleOp,
138       ConvertGetTupleElementOp
139       // clang-format on
140       >(typeConverter, patterns.getContext());
141 }
142 
143 /// Creates a sequence of `test.get_tuple_element` ops for all elements of a
144 /// given tuple value. If some tuple elements are, in turn, tuples, the elements
145 /// of those are extracted recursively such that the returned values have the
146 /// same types as `resultTypes.getFlattenedTypes()`.
147 ///
148 /// This function has been copied (with small adaptions) from
149 /// TestDecomposeCallGraphTypes.cpp.
150 static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
151                                                   TypeRange resultTypes,
152                                                   ValueRange inputs,
153                                                   Location loc) {
154   if (inputs.size() != 1)
155     return {};
156   Value input = inputs.front();
157 
158   TupleType inputType = dyn_cast<TupleType>(input.getType());
159   if (!inputType)
160     return {};
161 
162   SmallVector<Value> values;
163   for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) {
164     Value element = builder.create<::test::GetTupleElementOp>(
165         loc, elementType, input, builder.getI32IntegerAttr(idx));
166     if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
167       // Recurse if the current element is also a tuple.
168       SmallVector<Type> flatRecursiveTypes;
169       nestedTupleType.getFlattenedTypes(flatRecursiveTypes);
170       std::optional<SmallVector<Value>> resursiveValues =
171           buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc);
172       if (!resursiveValues.has_value())
173         return {};
174       values.append(resursiveValues.value());
175     } else {
176       values.push_back(element);
177     }
178   }
179   return values;
180 }
181 
182 /// Creates a `test.make_tuple` op out of the given inputs building a tuple of
183 /// type `resultType`. If that type is nested, each nested tuple is built
184 /// recursively with another `test.make_tuple` op.
185 ///
186 /// This function has been copied (with small adaptions) from
187 /// TestDecomposeCallGraphTypes.cpp.
188 static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
189                               ValueRange inputs, Location loc) {
190   // Build one value for each element at this nesting level.
191   SmallVector<Value> elements;
192   elements.reserve(resultType.getTypes().size());
193   ValueRange::iterator inputIt = inputs.begin();
194   for (Type elementType : resultType.getTypes()) {
195     if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
196       // Determine how many input values are needed for the nested elements of
197       // the nested TupleType and advance inputIt by that number.
198       // TODO: We only need the *number* of nested types, not the types itself.
199       //       Maybe it's worth adding a more efficient overload?
200       SmallVector<Type> nestedFlattenedTypes;
201       nestedTupleType.getFlattenedTypes(nestedFlattenedTypes);
202       size_t numNestedFlattenedTypes = nestedFlattenedTypes.size();
203       ValueRange nestedFlattenedelements(inputIt,
204                                          inputIt + numNestedFlattenedTypes);
205       inputIt += numNestedFlattenedTypes;
206 
207       // Recurse on the values for the nested TupleType.
208       Value res = buildMakeTupleOp(builder, nestedTupleType,
209                                    nestedFlattenedelements, loc);
210       if (!res)
211         return Value();
212 
213       // The tuple constructed by the conversion is the element value.
214       elements.push_back(res);
215     } else {
216       // Base case: take one input as is.
217       elements.push_back(*inputIt++);
218     }
219   }
220 
221   // Assemble the tuple from the elements.
222   return builder.create<::test::MakeTupleOp>(loc, resultType, elements);
223 }
224 
225 void TestOneToNTypeConversionPass::runOnOperation() {
226   ModuleOp module = getOperation();
227   auto *context = &getContext();
228 
229   // Assemble type converter.
230   TypeConverter typeConverter;
231 
232   typeConverter.addConversion([](Type type) { return type; });
233   typeConverter.addConversion(
234       [](TupleType tupleType, SmallVectorImpl<Type> &types) {
235         tupleType.getFlattenedTypes(types);
236         return success();
237       });
238 
239   typeConverter.addArgumentMaterialization(buildMakeTupleOp);
240   typeConverter.addSourceMaterialization(buildMakeTupleOp);
241   typeConverter.addTargetMaterialization(buildGetTupleElementOps);
242   // Test the other target materialization variant that takes the original type
243   // as additional argument. This materialization function always fails.
244   typeConverter.addTargetMaterialization(
245       [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
246          Location loc, Type originalType) -> SmallVector<Value> { return {}; });
247 
248   // Assemble patterns.
249   RewritePatternSet patterns(context);
250   if (convertTupleOps)
251     populateDecomposeTuplesTestPatterns(typeConverter, patterns);
252   if (convertFuncOps)
253     populateFuncTypeConversionPatterns(typeConverter, patterns);
254   if (convertSCFOps)
255     scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns);
256 
257   // Run conversion.
258   if (failed(applyPartialOneToNConversion(module, typeConverter,
259                                           std::move(patterns))))
260     return signalPassFailure();
261 }
262