1 //===- TestOneToNTypeConversionPass.cpp - Test pass 1:N type conv. utils --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestOps.h" 11 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" 12 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Transforms/OneToNTypeConversion.h" 15 16 using namespace mlir; 17 18 namespace { 19 /// Test pass that exercises the (poor-man's) 1:N type conversion mechanisms 20 /// in `applyPartialOneToNConversion` by converting built-in tuples to the 21 /// elements they consist of as well as some dummy ops operating on these 22 /// tuples. 23 struct TestOneToNTypeConversionPass 24 : public PassWrapper<TestOneToNTypeConversionPass, 25 OperationPass<ModuleOp>> { 26 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass) 27 28 TestOneToNTypeConversionPass() = default; 29 TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass) 30 : PassWrapper(pass) {} 31 32 void getDependentDialects(DialectRegistry ®istry) const override { 33 registry.insert<test::TestDialect>(); 34 } 35 36 StringRef getArgument() const final { 37 return "test-one-to-n-type-conversion"; 38 } 39 40 StringRef getDescription() const final { 41 return "Test pass for 1:N type conversion"; 42 } 43 44 Option<bool> convertFuncOps{*this, "convert-func-ops", 45 llvm::cl::desc("Enable conversion on func ops"), 46 llvm::cl::init(false)}; 47 48 Option<bool> convertSCFOps{*this, "convert-scf-ops", 49 llvm::cl::desc("Enable conversion on scf ops"), 50 llvm::cl::init(false)}; 51 52 Option<bool> convertTupleOps{*this, "convert-tuple-ops", 53 llvm::cl::desc("Enable conversion on tuple ops"), 54 llvm::cl::init(false)}; 55 56 void runOnOperation() override; 57 }; 58 59 } // namespace 60 61 namespace mlir { 62 namespace test { 63 void registerTestOneToNTypeConversionPass() { 64 PassRegistration<TestOneToNTypeConversionPass>(); 65 } 66 } // namespace test 67 } // namespace mlir 68 69 namespace { 70 71 /// Test pattern on for the `make_tuple` op from the test dialect that converts 72 /// this kind of op into it's "decomposed" form, i.e., the elements of the tuple 73 /// that is being produced by `test.make_tuple`, which are really just the 74 /// operands of this op. 75 class ConvertMakeTupleOp 76 : public OneToNOpConversionPattern<::test::MakeTupleOp> { 77 public: 78 using OneToNOpConversionPattern< 79 ::test::MakeTupleOp>::OneToNOpConversionPattern; 80 81 LogicalResult 82 matchAndRewrite(::test::MakeTupleOp op, OpAdaptor adaptor, 83 OneToNPatternRewriter &rewriter) const override { 84 // Simply replace the current op with the converted operands. 85 rewriter.replaceOp(op, adaptor.getFlatOperands(), 86 adaptor.getResultMapping()); 87 return success(); 88 } 89 }; 90 91 /// Test pattern on for the `get_tuple_element` op from the test dialect that 92 /// converts this kind of op into it's "decomposed" form, i.e., instead of 93 /// "physically" extracting one element from the tuple, we forward the one 94 /// element of the decomposed form that is being extracted (or the several 95 /// elements in case that element is a nested tuple). 96 class ConvertGetTupleElementOp 97 : public OneToNOpConversionPattern<::test::GetTupleElementOp> { 98 public: 99 using OneToNOpConversionPattern< 100 ::test::GetTupleElementOp>::OneToNOpConversionPattern; 101 102 LogicalResult 103 matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor, 104 OneToNPatternRewriter &rewriter) const override { 105 // Construct mapping for tuple element types. 106 auto stateType = cast<TupleType>(op->getOperand(0).getType()); 107 TypeRange originalElementTypes = stateType.getTypes(); 108 OneToNTypeMapping elementMapping(originalElementTypes); 109 if (failed(typeConverter->convertSignatureArgs(originalElementTypes, 110 elementMapping))) 111 return failure(); 112 113 // Compute converted operands corresponding to original input tuple. 114 assert(adaptor.getOperands().size() == 1 && 115 "expected 'get_tuple_element' to have one operand"); 116 ValueRange convertedTuple = adaptor.getOperands()[0]; 117 118 // Got those converted operands that correspond to the index-th element ofq 119 // the original input tuple. 120 size_t index = op.getIndex(); 121 ValueRange extractedElement = 122 elementMapping.getConvertedValues(convertedTuple, index); 123 124 rewriter.replaceOp(op, extractedElement, adaptor.getResultMapping()); 125 126 return success(); 127 } 128 }; 129 130 } // namespace 131 132 static void 133 populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter, 134 RewritePatternSet &patterns) { 135 patterns.add< 136 // clang-format off 137 ConvertMakeTupleOp, 138 ConvertGetTupleElementOp 139 // clang-format on 140 >(typeConverter, patterns.getContext()); 141 } 142 143 /// Creates a sequence of `test.get_tuple_element` ops for all elements of a 144 /// given tuple value. If some tuple elements are, in turn, tuples, the elements 145 /// of those are extracted recursively such that the returned values have the 146 /// same types as `resultTypes.getFlattenedTypes()`. 147 /// 148 /// This function has been copied (with small adaptions) from 149 /// TestDecomposeCallGraphTypes.cpp. 150 static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder, 151 TypeRange resultTypes, 152 ValueRange inputs, 153 Location loc) { 154 if (inputs.size() != 1) 155 return {}; 156 Value input = inputs.front(); 157 158 TupleType inputType = dyn_cast<TupleType>(input.getType()); 159 if (!inputType) 160 return {}; 161 162 SmallVector<Value> values; 163 for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) { 164 Value element = builder.create<::test::GetTupleElementOp>( 165 loc, elementType, input, builder.getI32IntegerAttr(idx)); 166 if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) { 167 // Recurse if the current element is also a tuple. 168 SmallVector<Type> flatRecursiveTypes; 169 nestedTupleType.getFlattenedTypes(flatRecursiveTypes); 170 std::optional<SmallVector<Value>> resursiveValues = 171 buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc); 172 if (!resursiveValues.has_value()) 173 return {}; 174 values.append(resursiveValues.value()); 175 } else { 176 values.push_back(element); 177 } 178 } 179 return values; 180 } 181 182 /// Creates a `test.make_tuple` op out of the given inputs building a tuple of 183 /// type `resultType`. If that type is nested, each nested tuple is built 184 /// recursively with another `test.make_tuple` op. 185 /// 186 /// This function has been copied (with small adaptions) from 187 /// TestDecomposeCallGraphTypes.cpp. 188 static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType, 189 ValueRange inputs, Location loc) { 190 // Build one value for each element at this nesting level. 191 SmallVector<Value> elements; 192 elements.reserve(resultType.getTypes().size()); 193 ValueRange::iterator inputIt = inputs.begin(); 194 for (Type elementType : resultType.getTypes()) { 195 if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) { 196 // Determine how many input values are needed for the nested elements of 197 // the nested TupleType and advance inputIt by that number. 198 // TODO: We only need the *number* of nested types, not the types itself. 199 // Maybe it's worth adding a more efficient overload? 200 SmallVector<Type> nestedFlattenedTypes; 201 nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); 202 size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); 203 ValueRange nestedFlattenedelements(inputIt, 204 inputIt + numNestedFlattenedTypes); 205 inputIt += numNestedFlattenedTypes; 206 207 // Recurse on the values for the nested TupleType. 208 Value res = buildMakeTupleOp(builder, nestedTupleType, 209 nestedFlattenedelements, loc); 210 if (!res) 211 return Value(); 212 213 // The tuple constructed by the conversion is the element value. 214 elements.push_back(res); 215 } else { 216 // Base case: take one input as is. 217 elements.push_back(*inputIt++); 218 } 219 } 220 221 // Assemble the tuple from the elements. 222 return builder.create<::test::MakeTupleOp>(loc, resultType, elements); 223 } 224 225 void TestOneToNTypeConversionPass::runOnOperation() { 226 ModuleOp module = getOperation(); 227 auto *context = &getContext(); 228 229 // Assemble type converter. 230 TypeConverter typeConverter; 231 232 typeConverter.addConversion([](Type type) { return type; }); 233 typeConverter.addConversion( 234 [](TupleType tupleType, SmallVectorImpl<Type> &types) { 235 tupleType.getFlattenedTypes(types); 236 return success(); 237 }); 238 239 typeConverter.addArgumentMaterialization(buildMakeTupleOp); 240 typeConverter.addSourceMaterialization(buildMakeTupleOp); 241 typeConverter.addTargetMaterialization(buildGetTupleElementOps); 242 // Test the other target materialization variant that takes the original type 243 // as additional argument. This materialization function always fails. 244 typeConverter.addTargetMaterialization( 245 [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, 246 Location loc, Type originalType) -> SmallVector<Value> { return {}; }); 247 248 // Assemble patterns. 249 RewritePatternSet patterns(context); 250 if (convertTupleOps) 251 populateDecomposeTuplesTestPatterns(typeConverter, patterns); 252 if (convertFuncOps) 253 populateFuncTypeConversionPatterns(typeConverter, patterns); 254 if (convertSCFOps) 255 scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); 256 257 // Run conversion. 258 if (failed(applyPartialOneToNConversion(module, typeConverter, 259 std::move(patterns)))) 260 return signalPassFailure(); 261 } 262