xref: /llvm-project/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp (revision 3ace685105d3b50bca68328bf0c945af22d70f23)
1 //===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 /// Creates a sequence of `test.get_tuple_element` ops for all elements of a
21 /// given tuple value. If some tuple elements are, in turn, tuples, the elements
22 /// of those are extracted recursively such that the returned values have the
23 /// same types as `resultTypes.getFlattenedTypes()`.
24 static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder,
25                                               TypeRange resultTypes,
26                                               ValueRange inputs, Location loc) {
27   // Skip materialization if the single input value is not a tuple.
28   if (inputs.size() != 1)
29     return {};
30   Value tuple = inputs.front();
31   auto tupleType = dyn_cast<TupleType>(tuple.getType());
32   if (!tupleType)
33     return {};
34   // Skip materialization if the flattened types do not match the requested
35   // result types.
36   SmallVector<Type> flattenedTypes;
37   tupleType.getFlattenedTypes(flattenedTypes);
38   if (TypeRange(resultTypes) != TypeRange(flattenedTypes))
39     return {};
40   // Recursively decompose the tuple.
41   SmallVector<Value> result;
42   std::function<void(Value)> decompose = [&](Value tuple) {
43     auto tupleType = dyn_cast<TupleType>(tuple.getType());
44     if (!tupleType) {
45       // This is not a tuple.
46       result.push_back(tuple);
47       return;
48     }
49     for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
50       Type elementType = tupleType.getType(i);
51       Value element = builder.create<test::GetTupleElementOp>(
52           loc, elementType, tuple, builder.getI32IntegerAttr(i));
53       decompose(element);
54     }
55   };
56   decompose(tuple);
57   return result;
58 }
59 
60 /// Creates a `test.make_tuple` op out of the given inputs building a tuple of
61 /// type `resultType`. If that type is nested, each nested tuple is built
62 /// recursively with another `test.make_tuple` op.
63 static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
64                               ValueRange inputs, Location loc) {
65   // Build one value for each element at this nesting level.
66   SmallVector<Value> elements;
67   elements.reserve(resultType.getTypes().size());
68   ValueRange::iterator inputIt = inputs.begin();
69   for (Type elementType : resultType.getTypes()) {
70     if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
71       // Determine how many input values are needed for the nested elements of
72       // the nested TupleType and advance inputIt by that number.
73       // TODO: We only need the *number* of nested types, not the types itself.
74       //       Maybe it's worth adding a more efficient overload?
75       SmallVector<Type> nestedFlattenedTypes;
76       nestedTupleType.getFlattenedTypes(nestedFlattenedTypes);
77       size_t numNestedFlattenedTypes = nestedFlattenedTypes.size();
78       ValueRange nestedFlattenedelements(inputIt,
79                                          inputIt + numNestedFlattenedTypes);
80       inputIt += numNestedFlattenedTypes;
81 
82       // Recurse on the values for the nested TupleType.
83       Value res = buildMakeTupleOp(builder, nestedTupleType,
84                                    nestedFlattenedelements, loc);
85       if (!res)
86         return Value();
87 
88       // The tuple constructed by the conversion is the element value.
89       elements.push_back(res);
90     } else {
91       // Base case: take one input as is.
92       elements.push_back(*inputIt++);
93     }
94   }
95 
96   // Assemble the tuple from the elements.
97   return builder.create<test::MakeTupleOp>(loc, resultType, elements);
98 }
99 
100 /// A pass for testing call graph type decomposition.
101 ///
102 /// This instantiates the patterns with a TypeConverter that splits tuple types
103 /// into their respective element types.
104 /// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
105 struct TestDecomposeCallGraphTypes
106     : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
107   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDecomposeCallGraphTypes)
108 
109   void getDependentDialects(DialectRegistry &registry) const override {
110     registry.insert<test::TestDialect>();
111   }
112   StringRef getArgument() const final {
113     return "test-decompose-call-graph-types";
114   }
115   StringRef getDescription() const final {
116     return "Decomposes types at call graph boundaries.";
117   }
118   void runOnOperation() override {
119     ModuleOp module = getOperation();
120     auto *context = &getContext();
121     TypeConverter typeConverter;
122     ConversionTarget target(*context);
123     RewritePatternSet patterns(context);
124 
125     target.addLegalDialect<test::TestDialect>();
126 
127     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
128       return typeConverter.isLegal(op.getOperandTypes());
129     });
130     target.addDynamicallyLegalOp<func::CallOp>(
131         [&](func::CallOp op) { return typeConverter.isLegal(op); });
132     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
133       return typeConverter.isSignatureLegal(op.getFunctionType());
134     });
135 
136     typeConverter.addConversion([](Type type) { return type; });
137     typeConverter.addConversion(
138         [](TupleType tupleType, SmallVectorImpl<Type> &types) {
139           tupleType.getFlattenedTypes(types);
140           return success();
141         });
142     typeConverter.addSourceMaterialization(buildMakeTupleOp);
143     typeConverter.addTargetMaterialization(buildDecomposeTuple);
144 
145     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
146         patterns, typeConverter);
147     populateReturnOpTypeConversionPattern(patterns, typeConverter);
148     populateCallOpTypeConversionPattern(patterns, typeConverter);
149 
150     if (failed(applyPartialConversion(module, target, std::move(patterns))))
151       return signalPassFailure();
152   }
153 };
154 
155 } // namespace
156 
157 namespace mlir {
158 namespace test {
159 void registerTestDecomposeCallGraphTypes() {
160   PassRegistration<TestDecomposeCallGraphTypes>();
161 }
162 } // namespace test
163 } // namespace mlir
164