1 //===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 11 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 12 #include "mlir/Pass/Pass.h" 13 #include "mlir/Transforms/DialectConversion.h" 14 15 using namespace mlir; 16 17 namespace { 18 19 /// Replace this op (which is expected to have 1 result) with the operands. 20 struct TestDirectReplacementOp : public ConversionPattern { 21 TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter) 22 : ConversionPattern(converter, "test.direct_replacement", 1, ctx) {} 23 LogicalResult 24 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 25 ConversionPatternRewriter &rewriter) const final { 26 if (op->getNumResults() != 1) 27 return failure(); 28 rewriter.replaceOpWithMultiple(op, {operands}); 29 return success(); 30 } 31 }; 32 33 struct TestLLVMLegalizePatternsPass 34 : public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> { 35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass) 36 37 StringRef getArgument() const final { return "test-llvm-legalize-patterns"; } 38 StringRef getDescription() const final { 39 return "Run LLVM dialect legalization patterns"; 40 } 41 42 void getDependentDialects(DialectRegistry ®istry) const override { 43 registry.insert<LLVM::LLVMDialect>(); 44 } 45 46 void runOnOperation() override { 47 MLIRContext *ctx = &getContext(); 48 LLVMTypeConverter converter(ctx); 49 mlir::RewritePatternSet patterns(ctx); 50 patterns.add<TestDirectReplacementOp>(ctx, converter); 51 52 // Define the conversion target used for the test. 53 ConversionTarget target(*ctx); 54 target.addLegalOp(OperationName("test.legal_op", ctx)); 55 56 // Handle a partial conversion. 57 DenseSet<Operation *> unlegalizedOps; 58 ConversionConfig config; 59 config.unlegalizedOps = &unlegalizedOps; 60 if (failed(applyPartialConversion(getOperation(), target, 61 std::move(patterns), config))) 62 getOperation()->emitError() << "applyPartialConversion failed"; 63 } 64 }; 65 } // namespace 66 67 //===----------------------------------------------------------------------===// 68 // PassRegistration 69 //===----------------------------------------------------------------------===// 70 71 namespace mlir { 72 namespace test { 73 void registerTestLLVMLegalizePatternsPass() { 74 PassRegistration<TestLLVMLegalizePatternsPass>(); 75 } 76 } // namespace test 77 } // namespace mlir 78