1 //===- TestEmulateNarrowType.cpp - Test Narrow Type Emulation ------*- c++ 2 //-*-===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" 13 #include "mlir/Dialect/Arith/Transforms/Passes.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 using namespace mlir; 23 24 namespace { 25 26 struct TestEmulateNarrowTypePass 27 : public PassWrapper<TestEmulateNarrowTypePass, 28 OperationPass<func::FuncOp>> { 29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowTypePass) 30 31 TestEmulateNarrowTypePass() = default; 32 TestEmulateNarrowTypePass(const TestEmulateNarrowTypePass &pass) 33 : PassWrapper(pass) {} 34 35 void getDependentDialects(DialectRegistry ®istry) const override { 36 registry 37 .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect, 38 vector::VectorDialect, affine::AffineDialect>(); 39 } 40 StringRef getArgument() const final { return "test-emulate-narrow-int"; } 41 StringRef getDescription() const final { 42 return "Function pass to test Narrow Integer Emulation"; 43 } 44 45 void runOnOperation() override { 46 if (!llvm::isPowerOf2_32(loadStoreEmulateBitwidth) || 47 loadStoreEmulateBitwidth < 8) { 48 signalPassFailure(); 49 return; 50 } 51 52 Operation *op = getOperation(); 53 MLIRContext *ctx = op->getContext(); 54 55 arith::NarrowTypeEmulationConverter typeConverter(loadStoreEmulateBitwidth); 56 57 // Convert scalar type. 58 typeConverter.addConversion([this](IntegerType ty) -> std::optional<Type> { 59 unsigned width = ty.getWidth(); 60 if (width >= arithComputeBitwidth) 61 return ty; 62 63 return IntegerType::get(ty.getContext(), arithComputeBitwidth); 64 }); 65 66 // Convert vector type. 67 typeConverter.addConversion([this](VectorType ty) -> std::optional<Type> { 68 auto intTy = dyn_cast<IntegerType>(ty.getElementType()); 69 if (!intTy) 70 return ty; 71 72 unsigned width = intTy.getWidth(); 73 if (width >= arithComputeBitwidth) 74 return ty; 75 76 return VectorType::get( 77 to_vector(ty.getShape()), 78 IntegerType::get(ty.getContext(), arithComputeBitwidth)); 79 }); 80 81 // With the type converter enabled, we are effectively unable to write 82 // negative tests. This is a workaround specifically for negative tests. 83 if (!disableMemrefTypeConversion) 84 memref::populateMemRefNarrowTypeEmulationConversions(typeConverter); 85 86 ConversionTarget target(*ctx); 87 target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) { 88 return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType()); 89 }); 90 auto opLegalCallback = [&typeConverter](Operation *op) { 91 return typeConverter.isLegal(op); 92 }; 93 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback); 94 target.addDynamicallyLegalDialect< 95 arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect, 96 affine::AffineDialect>(opLegalCallback); 97 98 RewritePatternSet patterns(ctx); 99 100 arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns); 101 memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); 102 vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns); 103 104 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 105 signalPassFailure(); 106 } 107 108 Option<unsigned> loadStoreEmulateBitwidth{ 109 *this, "memref-load-bitwidth", 110 llvm::cl::desc("memref load/store emulation bit width"), 111 llvm::cl::init(8)}; 112 113 Option<unsigned> arithComputeBitwidth{ 114 *this, "arith-compute-bitwidth", 115 llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)}; 116 117 Option<bool> disableMemrefTypeConversion{ 118 *this, "skip-memref-type-conversion", 119 llvm::cl::desc("disable memref type conversion (to test failures)"), 120 llvm::cl::init(false)}; 121 }; 122 } // namespace 123 124 namespace mlir::test { 125 void registerTestEmulateNarrowTypePass() { 126 PassRegistration<TestEmulateNarrowTypePass>(); 127 } 128 } // namespace mlir::test 129