1 //===- TestWideIntEmulation.cpp - Test Wide Int Emulation ------*- c++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for integration testing of wide integer 10 // emulation patterns. Applies conversion patterns only to functions whose 11 // names start with a specified prefix. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Transforms/Passes.h" 17 #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 24 using namespace mlir; 25 26 namespace { 27 struct TestEmulateWideIntPass 28 : public PassWrapper<TestEmulateWideIntPass, OperationPass<func::FuncOp>> { 29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateWideIntPass) 30 31 TestEmulateWideIntPass() = default; 32 TestEmulateWideIntPass(const TestEmulateWideIntPass &pass) 33 : PassWrapper(pass) {} 34 35 void getDependentDialects(DialectRegistry ®istry) const override { 36 registry.insert<arith::ArithDialect, func::FuncDialect, LLVM::LLVMDialect, 37 vector::VectorDialect>(); 38 } 39 StringRef getArgument() const final { return "test-arith-emulate-wide-int"; } 40 StringRef getDescription() const final { 41 return "Function pass to test Wide Integer Emulation"; 42 } 43 44 void runOnOperation() override { 45 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) { 46 signalPassFailure(); 47 return; 48 } 49 50 func::FuncOp op = getOperation(); 51 if (!op.getSymName().starts_with(testFunctionPrefix)) 52 return; 53 54 MLIRContext *ctx = op.getContext(); 55 arith::WideIntEmulationConverter typeConverter(widestIntSupported); 56 57 // Use `llvm.bitcast` as the bridge so that we can use preserve the 58 // function argument and return types of the processed function. 59 // TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector 60 // casts (and vice versa) and using it insted of `llvm.bitcast`. 61 auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs, 62 Location loc) -> Value { 63 auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs); 64 return cast->getResult(0); 65 }; 66 typeConverter.addSourceMaterialization(addBitcast); 67 typeConverter.addTargetMaterialization(addBitcast); 68 69 ConversionTarget target(*ctx); 70 target 71 .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>( 72 [&typeConverter](Operation *op) { 73 return typeConverter.isLegal(op); 74 }); 75 76 RewritePatternSet patterns(ctx); 77 arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); 78 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 79 signalPassFailure(); 80 } 81 82 Option<std::string> testFunctionPrefix{ 83 *this, "function-prefix", 84 llvm::cl::desc("Prefix of functions to run the emulation pass on"), 85 llvm::cl::init("emulate_")}; 86 Option<unsigned> widestIntSupported{ 87 *this, "widest-int-supported", 88 llvm::cl::desc("Maximum integer bit width supported by the target"), 89 llvm::cl::init(32)}; 90 }; 91 } // namespace 92 93 namespace mlir::test { 94 void registerTestArithEmulateWideIntPass() { 95 PassRegistration<TestEmulateWideIntPass>(); 96 } 97 } // namespace mlir::test 98