xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp (revision f5aee1f18bdbc5694330a5e86eb46cf60e653d0c)
1 //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Arith/Transforms/Passes.h"
11 #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
14 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
15 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/MathExtras.h"
19 #include <cassert>
20 
21 namespace mlir::memref {
22 #define GEN_PASS_DEF_MEMREFEMULATEWIDEINT
23 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
24 } // namespace mlir::memref
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 //===----------------------------------------------------------------------===//
31 // ConvertMemRefAlloc
32 //===----------------------------------------------------------------------===//
33 
34 struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
35   using OpConversionPattern::OpConversionPattern;
36 
37   LogicalResult
38   matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
39                   ConversionPatternRewriter &rewriter) const override {
40     Type newTy = getTypeConverter()->convertType(op.getType());
41     if (!newTy)
42       return rewriter.notifyMatchFailure(
43           op->getLoc(),
44           llvm::formatv("failed to convert memref type: {0}", op.getType()));
45 
46     rewriter.replaceOpWithNewOp<memref::AllocOp>(
47         op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
48         adaptor.getAlignmentAttr());
49     return success();
50   }
51 };
52 
53 //===----------------------------------------------------------------------===//
54 // ConvertMemRefLoad
55 //===----------------------------------------------------------------------===//
56 
57 struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
58   using OpConversionPattern::OpConversionPattern;
59 
60   LogicalResult
61   matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
62                   ConversionPatternRewriter &rewriter) const override {
63     Type newResTy = getTypeConverter()->convertType(op.getType());
64     if (!newResTy)
65       return rewriter.notifyMatchFailure(
66           op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
67                                       op.getMemRefType()));
68 
69     rewriter.replaceOpWithNewOp<memref::LoadOp>(
70         op, newResTy, adaptor.getMemref(), adaptor.getIndices(),
71         op.getNontemporal());
72     return success();
73   }
74 };
75 
76 //===----------------------------------------------------------------------===//
77 // ConvertMemRefStore
78 //===----------------------------------------------------------------------===//
79 
80 struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
81   using OpConversionPattern::OpConversionPattern;
82 
83   LogicalResult
84   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
85                   ConversionPatternRewriter &rewriter) const override {
86     Type newTy = getTypeConverter()->convertType(op.getMemRefType());
87     if (!newTy)
88       return rewriter.notifyMatchFailure(
89           op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
90                                       op.getMemRefType()));
91 
92     rewriter.replaceOpWithNewOp<memref::StoreOp>(
93         op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(),
94         op.getNontemporal());
95     return success();
96   }
97 };
98 
99 //===----------------------------------------------------------------------===//
100 // Pass Definition
101 //===----------------------------------------------------------------------===//
102 
103 struct EmulateWideIntPass final
104     : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> {
105   using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase;
106 
107   void runOnOperation() override {
108     if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
109       signalPassFailure();
110       return;
111     }
112 
113     Operation *op = getOperation();
114     MLIRContext *ctx = op->getContext();
115 
116     arith::WideIntEmulationConverter typeConverter(widestIntSupported);
117     memref::populateMemRefWideIntEmulationConversions(typeConverter);
118     ConversionTarget target(*ctx);
119     target.addDynamicallyLegalDialect<
120         arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>(
121         [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
122 
123     RewritePatternSet patterns(ctx);
124     // Add common pattenrs to support contants, functions, etc.
125     arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
126 
127     memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
128 
129     if (failed(applyPartialConversion(op, target, std::move(patterns))))
130       signalPassFailure();
131   }
132 };
133 
134 } // end anonymous namespace
135 
136 //===----------------------------------------------------------------------===//
137 // Public Interface Definition
138 //===----------------------------------------------------------------------===//
139 
140 void memref::populateMemRefWideIntEmulationPatterns(
141     const arith::WideIntEmulationConverter &typeConverter,
142     RewritePatternSet &patterns) {
143   // Populate `memref.*` conversion patterns.
144   patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>(
145       typeConverter, patterns.getContext());
146 }
147 
148 void memref::populateMemRefWideIntEmulationConversions(
149     arith::WideIntEmulationConverter &typeConverter) {
150   typeConverter.addConversion(
151       [&typeConverter](MemRefType ty) -> std::optional<Type> {
152         auto intTy = dyn_cast<IntegerType>(ty.getElementType());
153         if (!intTy)
154           return ty;
155 
156         if (intTy.getIntOrFloatBitWidth() <=
157             typeConverter.getMaxTargetIntBitWidth())
158           return ty;
159 
160         Type newElemTy = typeConverter.convertType(intTy);
161         if (!newElemTy)
162           return nullptr;
163 
164         return ty.cloneWith(std::nullopt, newElemTy);
165       });
166 }
167