1 //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/Arith.h" 10 #include "mlir/Dialect/Arith/Transforms/Passes.h" 11 #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 14 #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "llvm/Support/FormatVariadic.h" 18 #include "llvm/Support/MathExtras.h" 19 #include <cassert> 20 21 namespace mlir::memref { 22 #define GEN_PASS_DEF_MEMREFEMULATEWIDEINT 23 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 24 } // namespace mlir::memref 25 26 using namespace mlir; 27 28 namespace { 29 30 //===----------------------------------------------------------------------===// 31 // ConvertMemRefAlloc 32 //===----------------------------------------------------------------------===// 33 34 struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { 35 using OpConversionPattern::OpConversionPattern; 36 37 LogicalResult 38 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, 39 ConversionPatternRewriter &rewriter) const override { 40 Type newTy = getTypeConverter()->convertType(op.getType()); 41 if (!newTy) 42 return rewriter.notifyMatchFailure( 43 op->getLoc(), 44 llvm::formatv("failed to convert memref type: {0}", op.getType())); 45 46 rewriter.replaceOpWithNewOp<memref::AllocOp>( 47 op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), 48 adaptor.getAlignmentAttr()); 49 return success(); 50 } 51 }; 52 53 //===----------------------------------------------------------------------===// 54 // ConvertMemRefLoad 55 //===----------------------------------------------------------------------===// 56 57 struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { 58 using OpConversionPattern::OpConversionPattern; 59 60 LogicalResult 61 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, 62 ConversionPatternRewriter &rewriter) const override { 63 Type newResTy = getTypeConverter()->convertType(op.getType()); 64 if (!newResTy) 65 return rewriter.notifyMatchFailure( 66 op->getLoc(), llvm::formatv("failed to convert memref type: {0}", 67 op.getMemRefType())); 68 69 rewriter.replaceOpWithNewOp<memref::LoadOp>( 70 op, newResTy, adaptor.getMemref(), adaptor.getIndices(), 71 op.getNontemporal()); 72 return success(); 73 } 74 }; 75 76 //===----------------------------------------------------------------------===// 77 // ConvertMemRefStore 78 //===----------------------------------------------------------------------===// 79 80 struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> { 81 using OpConversionPattern::OpConversionPattern; 82 83 LogicalResult 84 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 85 ConversionPatternRewriter &rewriter) const override { 86 Type newTy = getTypeConverter()->convertType(op.getMemRefType()); 87 if (!newTy) 88 return rewriter.notifyMatchFailure( 89 op->getLoc(), llvm::formatv("failed to convert memref type: {0}", 90 op.getMemRefType())); 91 92 rewriter.replaceOpWithNewOp<memref::StoreOp>( 93 op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(), 94 op.getNontemporal()); 95 return success(); 96 } 97 }; 98 99 //===----------------------------------------------------------------------===// 100 // Pass Definition 101 //===----------------------------------------------------------------------===// 102 103 struct EmulateWideIntPass final 104 : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> { 105 using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase; 106 107 void runOnOperation() override { 108 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) { 109 signalPassFailure(); 110 return; 111 } 112 113 Operation *op = getOperation(); 114 MLIRContext *ctx = op->getContext(); 115 116 arith::WideIntEmulationConverter typeConverter(widestIntSupported); 117 memref::populateMemRefWideIntEmulationConversions(typeConverter); 118 ConversionTarget target(*ctx); 119 target.addDynamicallyLegalDialect< 120 arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>( 121 [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); 122 123 RewritePatternSet patterns(ctx); 124 // Add common pattenrs to support contants, functions, etc. 125 arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); 126 127 memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns); 128 129 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 130 signalPassFailure(); 131 } 132 }; 133 134 } // end anonymous namespace 135 136 //===----------------------------------------------------------------------===// 137 // Public Interface Definition 138 //===----------------------------------------------------------------------===// 139 140 void memref::populateMemRefWideIntEmulationPatterns( 141 const arith::WideIntEmulationConverter &typeConverter, 142 RewritePatternSet &patterns) { 143 // Populate `memref.*` conversion patterns. 144 patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>( 145 typeConverter, patterns.getContext()); 146 } 147 148 void memref::populateMemRefWideIntEmulationConversions( 149 arith::WideIntEmulationConverter &typeConverter) { 150 typeConverter.addConversion( 151 [&typeConverter](MemRefType ty) -> std::optional<Type> { 152 auto intTy = dyn_cast<IntegerType>(ty.getElementType()); 153 if (!intTy) 154 return ty; 155 156 if (intTy.getIntOrFloatBitWidth() <= 157 typeConverter.getMaxTargetIntBitWidth()) 158 return ty; 159 160 Type newElemTy = typeConverter.convertType(intTy); 161 if (!newElemTy) 162 return nullptr; 163 164 return ty.cloneWith(std::nullopt, newElemTy); 165 }); 166 } 167