1fae258e6SJakub Kuderski //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===// 2fae258e6SJakub Kuderski // 3fae258e6SJakub Kuderski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4fae258e6SJakub Kuderski // See https://llvm.org/LICENSE.txt for license information. 5fae258e6SJakub Kuderski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6fae258e6SJakub Kuderski // 7fae258e6SJakub Kuderski //===----------------------------------------------------------------------===// 8fae258e6SJakub Kuderski 9fae258e6SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 10fae258e6SJakub Kuderski #include "mlir/Dialect/Arith/Transforms/Passes.h" 11fae258e6SJakub Kuderski #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h" 12fae258e6SJakub Kuderski #include "mlir/Dialect/MemRef/IR/MemRef.h" 13fae258e6SJakub Kuderski #include "mlir/Dialect/MemRef/Transforms/Passes.h" 14faafd26cSQuentin Colombet #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 15fae258e6SJakub Kuderski #include "mlir/Dialect/Vector/IR/VectorOps.h" 16fae258e6SJakub Kuderski #include "mlir/Transforms/DialectConversion.h" 17fae258e6SJakub Kuderski #include "llvm/Support/FormatVariadic.h" 18fae258e6SJakub Kuderski #include "llvm/Support/MathExtras.h" 19fae258e6SJakub Kuderski #include <cassert> 20fae258e6SJakub Kuderski 21fae258e6SJakub Kuderski namespace mlir::memref { 22fae258e6SJakub Kuderski #define GEN_PASS_DEF_MEMREFEMULATEWIDEINT 23fae258e6SJakub Kuderski #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 24fae258e6SJakub Kuderski } // namespace mlir::memref 25fae258e6SJakub Kuderski 26fae258e6SJakub Kuderski using namespace mlir; 27fae258e6SJakub Kuderski 28fae258e6SJakub Kuderski namespace { 29fae258e6SJakub Kuderski 30fae258e6SJakub Kuderski //===----------------------------------------------------------------------===// 31fae258e6SJakub Kuderski // ConvertMemRefAlloc 32fae258e6SJakub Kuderski //===----------------------------------------------------------------------===// 33fae258e6SJakub Kuderski 34fae258e6SJakub Kuderski struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { 35fae258e6SJakub Kuderski using OpConversionPattern::OpConversionPattern; 36fae258e6SJakub Kuderski 37fae258e6SJakub Kuderski LogicalResult 38fae258e6SJakub Kuderski matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, 39fae258e6SJakub Kuderski ConversionPatternRewriter &rewriter) const override { 40fae258e6SJakub Kuderski Type newTy = getTypeConverter()->convertType(op.getType()); 41fae258e6SJakub Kuderski if (!newTy) 42fae258e6SJakub Kuderski return rewriter.notifyMatchFailure( 43fae258e6SJakub Kuderski op->getLoc(), 44fae258e6SJakub Kuderski llvm::formatv("failed to convert memref type: {0}", op.getType())); 45fae258e6SJakub Kuderski 46fae258e6SJakub Kuderski rewriter.replaceOpWithNewOp<memref::AllocOp>( 47fae258e6SJakub Kuderski op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), 48fae258e6SJakub Kuderski adaptor.getAlignmentAttr()); 49fae258e6SJakub Kuderski return success(); 50fae258e6SJakub Kuderski } 51fae258e6SJakub Kuderski }; 52fae258e6SJakub Kuderski 53fae258e6SJakub Kuderski //===----------------------------------------------------------------------===// 54fae258e6SJakub Kuderski // ConvertMemRefLoad 55fae258e6SJakub Kuderski //===----------------------------------------------------------------------===// 56fae258e6SJakub Kuderski 57fae258e6SJakub Kuderski struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { 58fae258e6SJakub Kuderski using OpConversionPattern::OpConversionPattern; 59fae258e6SJakub Kuderski 60fae258e6SJakub Kuderski LogicalResult 61fae258e6SJakub Kuderski matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, 62fae258e6SJakub Kuderski ConversionPatternRewriter &rewriter) const override { 63fae258e6SJakub Kuderski Type newResTy = getTypeConverter()->convertType(op.getType()); 64fae258e6SJakub Kuderski if (!newResTy) 65fae258e6SJakub Kuderski return rewriter.notifyMatchFailure( 66fae258e6SJakub Kuderski op->getLoc(), llvm::formatv("failed to convert memref type: {0}", 67fae258e6SJakub Kuderski op.getMemRefType())); 68fae258e6SJakub Kuderski 69fae258e6SJakub Kuderski rewriter.replaceOpWithNewOp<memref::LoadOp>( 701cb91b42SGuray Ozen op, newResTy, adaptor.getMemref(), adaptor.getIndices(), 711cb91b42SGuray Ozen op.getNontemporal()); 72fae258e6SJakub Kuderski return success(); 73fae258e6SJakub Kuderski } 74fae258e6SJakub Kuderski }; 75fae258e6SJakub Kuderski 76fae258e6SJakub Kuderski //===----------------------------------------------------------------------===// 77fae258e6SJakub Kuderski // ConvertMemRefStore 78fae258e6SJakub Kuderski //===----------------------------------------------------------------------===// 79fae258e6SJakub Kuderski 80fae258e6SJakub Kuderski struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> { 81fae258e6SJakub Kuderski using OpConversionPattern::OpConversionPattern; 82fae258e6SJakub Kuderski 83fae258e6SJakub Kuderski LogicalResult 84fae258e6SJakub Kuderski matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 85fae258e6SJakub Kuderski ConversionPatternRewriter &rewriter) const override { 86fae258e6SJakub Kuderski Type newTy = getTypeConverter()->convertType(op.getMemRefType()); 87fae258e6SJakub Kuderski if (!newTy) 88fae258e6SJakub Kuderski return rewriter.notifyMatchFailure( 89fae258e6SJakub Kuderski op->getLoc(), llvm::formatv("failed to convert memref type: {0}", 90fae258e6SJakub Kuderski op.getMemRefType())); 91fae258e6SJakub Kuderski 92fae258e6SJakub Kuderski rewriter.replaceOpWithNewOp<memref::StoreOp>( 931cb91b42SGuray Ozen op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(), 941cb91b42SGuray Ozen op.getNontemporal()); 95fae258e6SJakub Kuderski return success(); 96fae258e6SJakub Kuderski } 97fae258e6SJakub Kuderski }; 98fae258e6SJakub Kuderski 99fae258e6SJakub Kuderski //===----------------------------------------------------------------------===// 100fae258e6SJakub Kuderski // Pass Definition 101fae258e6SJakub Kuderski //===----------------------------------------------------------------------===// 102fae258e6SJakub Kuderski 103fae258e6SJakub Kuderski struct EmulateWideIntPass final 104fae258e6SJakub Kuderski : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> { 105fae258e6SJakub Kuderski using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase; 106fae258e6SJakub Kuderski 107fae258e6SJakub Kuderski void runOnOperation() override { 108fae258e6SJakub Kuderski if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) { 109fae258e6SJakub Kuderski signalPassFailure(); 110fae258e6SJakub Kuderski return; 111fae258e6SJakub Kuderski } 112fae258e6SJakub Kuderski 113fae258e6SJakub Kuderski Operation *op = getOperation(); 114fae258e6SJakub Kuderski MLIRContext *ctx = op->getContext(); 115fae258e6SJakub Kuderski 116fae258e6SJakub Kuderski arith::WideIntEmulationConverter typeConverter(widestIntSupported); 117fae258e6SJakub Kuderski memref::populateMemRefWideIntEmulationConversions(typeConverter); 118fae258e6SJakub Kuderski ConversionTarget target(*ctx); 119fae258e6SJakub Kuderski target.addDynamicallyLegalDialect< 120fae258e6SJakub Kuderski arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>( 121fae258e6SJakub Kuderski [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); 122fae258e6SJakub Kuderski 123fae258e6SJakub Kuderski RewritePatternSet patterns(ctx); 124fae258e6SJakub Kuderski // Add common pattenrs to support contants, functions, etc. 125fae258e6SJakub Kuderski arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); 126fae258e6SJakub Kuderski 127fae258e6SJakub Kuderski memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns); 128fae258e6SJakub Kuderski 129fae258e6SJakub Kuderski if (failed(applyPartialConversion(op, target, std::move(patterns)))) 130fae258e6SJakub Kuderski signalPassFailure(); 131fae258e6SJakub Kuderski } 132fae258e6SJakub Kuderski }; 133fae258e6SJakub Kuderski 134fae258e6SJakub Kuderski } // end anonymous namespace 135fae258e6SJakub Kuderski 136fae258e6SJakub Kuderski //===----------------------------------------------------------------------===// 137fae258e6SJakub Kuderski // Public Interface Definition 138fae258e6SJakub Kuderski //===----------------------------------------------------------------------===// 139fae258e6SJakub Kuderski 140fae258e6SJakub Kuderski void memref::populateMemRefWideIntEmulationPatterns( 141206fad0eSMatthias Springer const arith::WideIntEmulationConverter &typeConverter, 142fae258e6SJakub Kuderski RewritePatternSet &patterns) { 143fae258e6SJakub Kuderski // Populate `memref.*` conversion patterns. 144fae258e6SJakub Kuderski patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>( 145fae258e6SJakub Kuderski typeConverter, patterns.getContext()); 146fae258e6SJakub Kuderski } 147fae258e6SJakub Kuderski 148fae258e6SJakub Kuderski void memref::populateMemRefWideIntEmulationConversions( 149fae258e6SJakub Kuderski arith::WideIntEmulationConverter &typeConverter) { 150fae258e6SJakub Kuderski typeConverter.addConversion( 1510de16fafSRamkumar Ramachandra [&typeConverter](MemRefType ty) -> std::optional<Type> { 1525550c821STres Popp auto intTy = dyn_cast<IntegerType>(ty.getElementType()); 153fae258e6SJakub Kuderski if (!intTy) 154fae258e6SJakub Kuderski return ty; 155fae258e6SJakub Kuderski 156fae258e6SJakub Kuderski if (intTy.getIntOrFloatBitWidth() <= 157fae258e6SJakub Kuderski typeConverter.getMaxTargetIntBitWidth()) 158fae258e6SJakub Kuderski return ty; 159fae258e6SJakub Kuderski 160fae258e6SJakub Kuderski Type newElemTy = typeConverter.convertType(intTy); 161fae258e6SJakub Kuderski if (!newElemTy) 162*f5aee1f1SLongsheng Mou return nullptr; 163fae258e6SJakub Kuderski 1641a36588eSKazu Hirata return ty.cloneWith(std::nullopt, newElemTy); 165fae258e6SJakub Kuderski }); 166fae258e6SJakub Kuderski } 167