1 //===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestOps.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 19 namespace { 20 /// Creates a sequence of `test.get_tuple_element` ops for all elements of a 21 /// given tuple value. If some tuple elements are, in turn, tuples, the elements 22 /// of those are extracted recursively such that the returned values have the 23 /// same types as `resultTypes.getFlattenedTypes()`. 24 static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder, 25 TypeRange resultTypes, 26 ValueRange inputs, Location loc) { 27 // Skip materialization if the single input value is not a tuple. 28 if (inputs.size() != 1) 29 return {}; 30 Value tuple = inputs.front(); 31 auto tupleType = dyn_cast<TupleType>(tuple.getType()); 32 if (!tupleType) 33 return {}; 34 // Skip materialization if the flattened types do not match the requested 35 // result types. 36 SmallVector<Type> flattenedTypes; 37 tupleType.getFlattenedTypes(flattenedTypes); 38 if (TypeRange(resultTypes) != TypeRange(flattenedTypes)) 39 return {}; 40 // Recursively decompose the tuple. 41 SmallVector<Value> result; 42 std::function<void(Value)> decompose = [&](Value tuple) { 43 auto tupleType = dyn_cast<TupleType>(tuple.getType()); 44 if (!tupleType) { 45 // This is not a tuple. 46 result.push_back(tuple); 47 return; 48 } 49 for (unsigned i = 0, e = tupleType.size(); i < e; ++i) { 50 Type elementType = tupleType.getType(i); 51 Value element = builder.create<test::GetTupleElementOp>( 52 loc, elementType, tuple, builder.getI32IntegerAttr(i)); 53 decompose(element); 54 } 55 }; 56 decompose(tuple); 57 return result; 58 } 59 60 /// Creates a `test.make_tuple` op out of the given inputs building a tuple of 61 /// type `resultType`. If that type is nested, each nested tuple is built 62 /// recursively with another `test.make_tuple` op. 63 static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType, 64 ValueRange inputs, Location loc) { 65 // Build one value for each element at this nesting level. 66 SmallVector<Value> elements; 67 elements.reserve(resultType.getTypes().size()); 68 ValueRange::iterator inputIt = inputs.begin(); 69 for (Type elementType : resultType.getTypes()) { 70 if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) { 71 // Determine how many input values are needed for the nested elements of 72 // the nested TupleType and advance inputIt by that number. 73 // TODO: We only need the *number* of nested types, not the types itself. 74 // Maybe it's worth adding a more efficient overload? 75 SmallVector<Type> nestedFlattenedTypes; 76 nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); 77 size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); 78 ValueRange nestedFlattenedelements(inputIt, 79 inputIt + numNestedFlattenedTypes); 80 inputIt += numNestedFlattenedTypes; 81 82 // Recurse on the values for the nested TupleType. 83 Value res = buildMakeTupleOp(builder, nestedTupleType, 84 nestedFlattenedelements, loc); 85 if (!res) 86 return Value(); 87 88 // The tuple constructed by the conversion is the element value. 89 elements.push_back(res); 90 } else { 91 // Base case: take one input as is. 92 elements.push_back(*inputIt++); 93 } 94 } 95 96 // Assemble the tuple from the elements. 97 return builder.create<test::MakeTupleOp>(loc, resultType, elements); 98 } 99 100 /// A pass for testing call graph type decomposition. 101 /// 102 /// This instantiates the patterns with a TypeConverter that splits tuple types 103 /// into their respective element types. 104 /// For example, `tuple<T1, T2, T3> --> T1, T2, T3`. 105 struct TestDecomposeCallGraphTypes 106 : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> { 107 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDecomposeCallGraphTypes) 108 109 void getDependentDialects(DialectRegistry ®istry) const override { 110 registry.insert<test::TestDialect>(); 111 } 112 StringRef getArgument() const final { 113 return "test-decompose-call-graph-types"; 114 } 115 StringRef getDescription() const final { 116 return "Decomposes types at call graph boundaries."; 117 } 118 void runOnOperation() override { 119 ModuleOp module = getOperation(); 120 auto *context = &getContext(); 121 TypeConverter typeConverter; 122 ConversionTarget target(*context); 123 RewritePatternSet patterns(context); 124 125 target.addLegalDialect<test::TestDialect>(); 126 127 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 128 return typeConverter.isLegal(op.getOperandTypes()); 129 }); 130 target.addDynamicallyLegalOp<func::CallOp>( 131 [&](func::CallOp op) { return typeConverter.isLegal(op); }); 132 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 133 return typeConverter.isSignatureLegal(op.getFunctionType()); 134 }); 135 136 typeConverter.addConversion([](Type type) { return type; }); 137 typeConverter.addConversion( 138 [](TupleType tupleType, SmallVectorImpl<Type> &types) { 139 tupleType.getFlattenedTypes(types); 140 return success(); 141 }); 142 typeConverter.addSourceMaterialization(buildMakeTupleOp); 143 typeConverter.addTargetMaterialization(buildDecomposeTuple); 144 145 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( 146 patterns, typeConverter); 147 populateReturnOpTypeConversionPattern(patterns, typeConverter); 148 populateCallOpTypeConversionPattern(patterns, typeConverter); 149 150 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 151 return signalPassFailure(); 152 } 153 }; 154 155 } // namespace 156 157 namespace mlir { 158 namespace test { 159 void registerTestDecomposeCallGraphTypes() { 160 PassRegistration<TestDecomposeCallGraphTypes>(); 161 } 162 } // namespace test 163 } // namespace mlir 164