xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp (revision f5aee1f18bdbc5694330a5e86eb46cf60e653d0c)
1fae258e6SJakub Kuderski //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2fae258e6SJakub Kuderski //
3fae258e6SJakub Kuderski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fae258e6SJakub Kuderski // See https://llvm.org/LICENSE.txt for license information.
5fae258e6SJakub Kuderski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fae258e6SJakub Kuderski //
7fae258e6SJakub Kuderski //===----------------------------------------------------------------------===//
8fae258e6SJakub Kuderski 
9fae258e6SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
10fae258e6SJakub Kuderski #include "mlir/Dialect/Arith/Transforms/Passes.h"
11fae258e6SJakub Kuderski #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
12fae258e6SJakub Kuderski #include "mlir/Dialect/MemRef/IR/MemRef.h"
13fae258e6SJakub Kuderski #include "mlir/Dialect/MemRef/Transforms/Passes.h"
14faafd26cSQuentin Colombet #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
15fae258e6SJakub Kuderski #include "mlir/Dialect/Vector/IR/VectorOps.h"
16fae258e6SJakub Kuderski #include "mlir/Transforms/DialectConversion.h"
17fae258e6SJakub Kuderski #include "llvm/Support/FormatVariadic.h"
18fae258e6SJakub Kuderski #include "llvm/Support/MathExtras.h"
19fae258e6SJakub Kuderski #include <cassert>
20fae258e6SJakub Kuderski 
21fae258e6SJakub Kuderski namespace mlir::memref {
22fae258e6SJakub Kuderski #define GEN_PASS_DEF_MEMREFEMULATEWIDEINT
23fae258e6SJakub Kuderski #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
24fae258e6SJakub Kuderski } // namespace mlir::memref
25fae258e6SJakub Kuderski 
26fae258e6SJakub Kuderski using namespace mlir;
27fae258e6SJakub Kuderski 
28fae258e6SJakub Kuderski namespace {
29fae258e6SJakub Kuderski 
30fae258e6SJakub Kuderski //===----------------------------------------------------------------------===//
31fae258e6SJakub Kuderski // ConvertMemRefAlloc
32fae258e6SJakub Kuderski //===----------------------------------------------------------------------===//
33fae258e6SJakub Kuderski 
34fae258e6SJakub Kuderski struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
35fae258e6SJakub Kuderski   using OpConversionPattern::OpConversionPattern;
36fae258e6SJakub Kuderski 
37fae258e6SJakub Kuderski   LogicalResult
38fae258e6SJakub Kuderski   matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
39fae258e6SJakub Kuderski                   ConversionPatternRewriter &rewriter) const override {
40fae258e6SJakub Kuderski     Type newTy = getTypeConverter()->convertType(op.getType());
41fae258e6SJakub Kuderski     if (!newTy)
42fae258e6SJakub Kuderski       return rewriter.notifyMatchFailure(
43fae258e6SJakub Kuderski           op->getLoc(),
44fae258e6SJakub Kuderski           llvm::formatv("failed to convert memref type: {0}", op.getType()));
45fae258e6SJakub Kuderski 
46fae258e6SJakub Kuderski     rewriter.replaceOpWithNewOp<memref::AllocOp>(
47fae258e6SJakub Kuderski         op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
48fae258e6SJakub Kuderski         adaptor.getAlignmentAttr());
49fae258e6SJakub Kuderski     return success();
50fae258e6SJakub Kuderski   }
51fae258e6SJakub Kuderski };
52fae258e6SJakub Kuderski 
53fae258e6SJakub Kuderski //===----------------------------------------------------------------------===//
54fae258e6SJakub Kuderski // ConvertMemRefLoad
55fae258e6SJakub Kuderski //===----------------------------------------------------------------------===//
56fae258e6SJakub Kuderski 
57fae258e6SJakub Kuderski struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
58fae258e6SJakub Kuderski   using OpConversionPattern::OpConversionPattern;
59fae258e6SJakub Kuderski 
60fae258e6SJakub Kuderski   LogicalResult
61fae258e6SJakub Kuderski   matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
62fae258e6SJakub Kuderski                   ConversionPatternRewriter &rewriter) const override {
63fae258e6SJakub Kuderski     Type newResTy = getTypeConverter()->convertType(op.getType());
64fae258e6SJakub Kuderski     if (!newResTy)
65fae258e6SJakub Kuderski       return rewriter.notifyMatchFailure(
66fae258e6SJakub Kuderski           op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
67fae258e6SJakub Kuderski                                       op.getMemRefType()));
68fae258e6SJakub Kuderski 
69fae258e6SJakub Kuderski     rewriter.replaceOpWithNewOp<memref::LoadOp>(
701cb91b42SGuray Ozen         op, newResTy, adaptor.getMemref(), adaptor.getIndices(),
711cb91b42SGuray Ozen         op.getNontemporal());
72fae258e6SJakub Kuderski     return success();
73fae258e6SJakub Kuderski   }
74fae258e6SJakub Kuderski };
75fae258e6SJakub Kuderski 
76fae258e6SJakub Kuderski //===----------------------------------------------------------------------===//
77fae258e6SJakub Kuderski // ConvertMemRefStore
78fae258e6SJakub Kuderski //===----------------------------------------------------------------------===//
79fae258e6SJakub Kuderski 
80fae258e6SJakub Kuderski struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
81fae258e6SJakub Kuderski   using OpConversionPattern::OpConversionPattern;
82fae258e6SJakub Kuderski 
83fae258e6SJakub Kuderski   LogicalResult
84fae258e6SJakub Kuderski   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
85fae258e6SJakub Kuderski                   ConversionPatternRewriter &rewriter) const override {
86fae258e6SJakub Kuderski     Type newTy = getTypeConverter()->convertType(op.getMemRefType());
87fae258e6SJakub Kuderski     if (!newTy)
88fae258e6SJakub Kuderski       return rewriter.notifyMatchFailure(
89fae258e6SJakub Kuderski           op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
90fae258e6SJakub Kuderski                                       op.getMemRefType()));
91fae258e6SJakub Kuderski 
92fae258e6SJakub Kuderski     rewriter.replaceOpWithNewOp<memref::StoreOp>(
931cb91b42SGuray Ozen         op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(),
941cb91b42SGuray Ozen         op.getNontemporal());
95fae258e6SJakub Kuderski     return success();
96fae258e6SJakub Kuderski   }
97fae258e6SJakub Kuderski };
98fae258e6SJakub Kuderski 
99fae258e6SJakub Kuderski //===----------------------------------------------------------------------===//
100fae258e6SJakub Kuderski // Pass Definition
101fae258e6SJakub Kuderski //===----------------------------------------------------------------------===//
102fae258e6SJakub Kuderski 
103fae258e6SJakub Kuderski struct EmulateWideIntPass final
104fae258e6SJakub Kuderski     : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> {
105fae258e6SJakub Kuderski   using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase;
106fae258e6SJakub Kuderski 
107fae258e6SJakub Kuderski   void runOnOperation() override {
108fae258e6SJakub Kuderski     if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
109fae258e6SJakub Kuderski       signalPassFailure();
110fae258e6SJakub Kuderski       return;
111fae258e6SJakub Kuderski     }
112fae258e6SJakub Kuderski 
113fae258e6SJakub Kuderski     Operation *op = getOperation();
114fae258e6SJakub Kuderski     MLIRContext *ctx = op->getContext();
115fae258e6SJakub Kuderski 
116fae258e6SJakub Kuderski     arith::WideIntEmulationConverter typeConverter(widestIntSupported);
117fae258e6SJakub Kuderski     memref::populateMemRefWideIntEmulationConversions(typeConverter);
118fae258e6SJakub Kuderski     ConversionTarget target(*ctx);
119fae258e6SJakub Kuderski     target.addDynamicallyLegalDialect<
120fae258e6SJakub Kuderski         arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>(
121fae258e6SJakub Kuderski         [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
122fae258e6SJakub Kuderski 
123fae258e6SJakub Kuderski     RewritePatternSet patterns(ctx);
124fae258e6SJakub Kuderski     // Add common pattenrs to support contants, functions, etc.
125fae258e6SJakub Kuderski     arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
126fae258e6SJakub Kuderski 
127fae258e6SJakub Kuderski     memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
128fae258e6SJakub Kuderski 
129fae258e6SJakub Kuderski     if (failed(applyPartialConversion(op, target, std::move(patterns))))
130fae258e6SJakub Kuderski       signalPassFailure();
131fae258e6SJakub Kuderski   }
132fae258e6SJakub Kuderski };
133fae258e6SJakub Kuderski 
134fae258e6SJakub Kuderski } // end anonymous namespace
135fae258e6SJakub Kuderski 
136fae258e6SJakub Kuderski //===----------------------------------------------------------------------===//
137fae258e6SJakub Kuderski // Public Interface Definition
138fae258e6SJakub Kuderski //===----------------------------------------------------------------------===//
139fae258e6SJakub Kuderski 
140fae258e6SJakub Kuderski void memref::populateMemRefWideIntEmulationPatterns(
141206fad0eSMatthias Springer     const arith::WideIntEmulationConverter &typeConverter,
142fae258e6SJakub Kuderski     RewritePatternSet &patterns) {
143fae258e6SJakub Kuderski   // Populate `memref.*` conversion patterns.
144fae258e6SJakub Kuderski   patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>(
145fae258e6SJakub Kuderski       typeConverter, patterns.getContext());
146fae258e6SJakub Kuderski }
147fae258e6SJakub Kuderski 
148fae258e6SJakub Kuderski void memref::populateMemRefWideIntEmulationConversions(
149fae258e6SJakub Kuderski     arith::WideIntEmulationConverter &typeConverter) {
150fae258e6SJakub Kuderski   typeConverter.addConversion(
1510de16fafSRamkumar Ramachandra       [&typeConverter](MemRefType ty) -> std::optional<Type> {
1525550c821STres Popp         auto intTy = dyn_cast<IntegerType>(ty.getElementType());
153fae258e6SJakub Kuderski         if (!intTy)
154fae258e6SJakub Kuderski           return ty;
155fae258e6SJakub Kuderski 
156fae258e6SJakub Kuderski         if (intTy.getIntOrFloatBitWidth() <=
157fae258e6SJakub Kuderski             typeConverter.getMaxTargetIntBitWidth())
158fae258e6SJakub Kuderski           return ty;
159fae258e6SJakub Kuderski 
160fae258e6SJakub Kuderski         Type newElemTy = typeConverter.convertType(intTy);
161fae258e6SJakub Kuderski         if (!newElemTy)
162*f5aee1f1SLongsheng Mou           return nullptr;
163fae258e6SJakub Kuderski 
1641a36588eSKazu Hirata         return ty.cloneWith(std::nullopt, newElemTy);
165fae258e6SJakub Kuderski       });
166fae258e6SJakub Kuderski }
167