xref: /llvm-project/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp (revision f18c3e4e7335df282c468b6dff3d29be1822a96d)
1 //===- TestWideIntEmulation.cpp - Test Wide Int Emulation  ------*- c++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for integration testing of wide integer
10 // emulation patterns. Applies conversion patterns only to functions whose
11 // names start with a specified prefix.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Transforms/Passes.h"
17 #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 struct TestEmulateWideIntPass
28     : public PassWrapper<TestEmulateWideIntPass, OperationPass<func::FuncOp>> {
29   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateWideIntPass)
30 
31   TestEmulateWideIntPass() = default;
32   TestEmulateWideIntPass(const TestEmulateWideIntPass &pass)
33       : PassWrapper(pass) {}
34 
35   void getDependentDialects(DialectRegistry &registry) const override {
36     registry.insert<arith::ArithDialect, func::FuncDialect, LLVM::LLVMDialect,
37                     vector::VectorDialect>();
38   }
39   StringRef getArgument() const final { return "test-arith-emulate-wide-int"; }
40   StringRef getDescription() const final {
41     return "Function pass to test Wide Integer Emulation";
42   }
43 
44   void runOnOperation() override {
45     if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
46       signalPassFailure();
47       return;
48     }
49 
50     func::FuncOp op = getOperation();
51     if (!op.getSymName().starts_with(testFunctionPrefix))
52       return;
53 
54     MLIRContext *ctx = op.getContext();
55     arith::WideIntEmulationConverter typeConverter(widestIntSupported);
56 
57     // Use `llvm.bitcast` as the bridge so that we can use preserve the
58     // function argument and return types of the processed function.
59     // TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector
60     // casts (and vice versa) and using it insted of `llvm.bitcast`.
61     auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
62                          Location loc) -> Value {
63       auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
64       return cast->getResult(0);
65     };
66     typeConverter.addSourceMaterialization(addBitcast);
67     typeConverter.addTargetMaterialization(addBitcast);
68 
69     ConversionTarget target(*ctx);
70     target
71         .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
72             [&typeConverter](Operation *op) {
73               return typeConverter.isLegal(op);
74             });
75 
76     RewritePatternSet patterns(ctx);
77     arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
78     if (failed(applyPartialConversion(op, target, std::move(patterns))))
79       signalPassFailure();
80   }
81 
82   Option<std::string> testFunctionPrefix{
83       *this, "function-prefix",
84       llvm::cl::desc("Prefix of functions to run the emulation pass on"),
85       llvm::cl::init("emulate_")};
86   Option<unsigned> widestIntSupported{
87       *this, "widest-int-supported",
88       llvm::cl::desc("Maximum integer bit width supported by the target"),
89       llvm::cl::init(32)};
90 };
91 } // namespace
92 
93 namespace mlir::test {
94 void registerTestArithEmulateWideIntPass() {
95   PassRegistration<TestEmulateWideIntPass>();
96 }
97 } // namespace mlir::test
98