xref: /llvm-project/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp (revision e458434ebe87f890db0d4a03bbc3de30f3d052b9)
1 //===- TestEmulateNarrowType.cpp - Test Narrow Type Emulation  ------*- c++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
13 #include "mlir/Dialect/Arith/Transforms/Passes.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 struct TestEmulateNarrowTypePass
27     : public PassWrapper<TestEmulateNarrowTypePass,
28                          OperationPass<func::FuncOp>> {
29   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowTypePass)
30 
31   TestEmulateNarrowTypePass() = default;
32   TestEmulateNarrowTypePass(const TestEmulateNarrowTypePass &pass)
33       : PassWrapper(pass) {}
34 
35   void getDependentDialects(DialectRegistry &registry) const override {
36     registry
37         .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
38                 vector::VectorDialect, affine::AffineDialect>();
39   }
40   StringRef getArgument() const final { return "test-emulate-narrow-int"; }
41   StringRef getDescription() const final {
42     return "Function pass to test Narrow Integer Emulation";
43   }
44 
45   void runOnOperation() override {
46     if (!llvm::isPowerOf2_32(loadStoreEmulateBitwidth) ||
47         loadStoreEmulateBitwidth < 8) {
48       signalPassFailure();
49       return;
50     }
51 
52     Operation *op = getOperation();
53     MLIRContext *ctx = op->getContext();
54 
55     arith::NarrowTypeEmulationConverter typeConverter(loadStoreEmulateBitwidth);
56 
57     // Convert scalar type.
58     typeConverter.addConversion([this](IntegerType ty) -> std::optional<Type> {
59       unsigned width = ty.getWidth();
60       if (width >= arithComputeBitwidth)
61         return ty;
62 
63       return IntegerType::get(ty.getContext(), arithComputeBitwidth);
64     });
65 
66     // Convert vector type.
67     typeConverter.addConversion([this](VectorType ty) -> std::optional<Type> {
68       auto intTy = dyn_cast<IntegerType>(ty.getElementType());
69       if (!intTy)
70         return ty;
71 
72       unsigned width = intTy.getWidth();
73       if (width >= arithComputeBitwidth)
74         return ty;
75 
76       return VectorType::get(
77           to_vector(ty.getShape()),
78           IntegerType::get(ty.getContext(), arithComputeBitwidth));
79     });
80 
81     // With the type converter enabled, we are effectively unable to write
82     // negative tests. This is a workaround specifically for negative tests.
83     if (!disableMemrefTypeConversion)
84       memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
85 
86     ConversionTarget target(*ctx);
87     target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
88       return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
89     });
90     auto opLegalCallback = [&typeConverter](Operation *op) {
91       return typeConverter.isLegal(op);
92     };
93     target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
94     target.addDynamicallyLegalDialect<
95         arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
96         affine::AffineDialect>(opLegalCallback);
97 
98     RewritePatternSet patterns(ctx);
99 
100     arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
101     memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
102     vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
103 
104     if (failed(applyPartialConversion(op, target, std::move(patterns))))
105       signalPassFailure();
106   }
107 
108   Option<unsigned> loadStoreEmulateBitwidth{
109       *this, "memref-load-bitwidth",
110       llvm::cl::desc("memref load/store emulation bit width"),
111       llvm::cl::init(8)};
112 
113   Option<unsigned> arithComputeBitwidth{
114       *this, "arith-compute-bitwidth",
115       llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)};
116 
117   Option<bool> disableMemrefTypeConversion{
118       *this, "skip-memref-type-conversion",
119       llvm::cl::desc("disable memref type conversion (to test failures)"),
120       llvm::cl::init(false)};
121 };
122 } // namespace
123 
124 namespace mlir::test {
125 void registerTestEmulateNarrowTypePass() {
126   PassRegistration<TestEmulateNarrowTypePass>();
127 }
128 } // namespace mlir::test
129