xref: /llvm-project/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- IndexToLLVM.cpp - Index to LLVM dialect conversion -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
10 
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/Index/IR/IndexAttrs.h"
14 #include "mlir/Dialect/Index/IR/IndexDialect.h"
15 #include "mlir/Dialect/Index/IR/IndexOps.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Pass/Pass.h"
18 
19 using namespace mlir;
20 using namespace index;
21 
22 namespace {
23 
24 //===----------------------------------------------------------------------===//
25 // ConvertIndexCeilDivS
26 //===----------------------------------------------------------------------===//
27 
28 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
29 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
30 struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> {
31   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
32 
33   LogicalResult
34   matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
35                   ConversionPatternRewriter &rewriter) const override {
36     Location loc = op.getLoc();
37     Value n = adaptor.getLhs();
38     Value m = adaptor.getRhs();
39     Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
40     Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
41     Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
42 
43     // Compute `x`.
44     Value mPos =
45         rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero);
46     Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne);
47 
48     // Compute the positive result.
49     Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x);
50     Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m);
51     Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne);
52 
53     // Compute the negative result.
54     Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n);
55     Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m);
56     Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM);
57 
58     // Pick the positive result if `n` and `m` have the same sign and `n` is
59     // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
60     Value nPos =
61         rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero);
62     Value sameSign =
63         rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos);
64     Value nNonZero =
65         rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
66     Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero);
67     rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes);
68     return success();
69   }
70 };
71 
72 //===----------------------------------------------------------------------===//
73 // ConvertIndexCeilDivU
74 //===----------------------------------------------------------------------===//
75 
76 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`.
77 struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> {
78   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
79 
80   LogicalResult
81   matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
82                   ConversionPatternRewriter &rewriter) const override {
83     Location loc = op.getLoc();
84     Value n = adaptor.getLhs();
85     Value m = adaptor.getRhs();
86     Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
87     Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
88 
89     // Compute the non-zero result.
90     Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one);
91     Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m);
92     Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one);
93 
94     // Pick the result.
95     Value cmp =
96         rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero);
97     rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne);
98     return success();
99   }
100 };
101 
102 //===----------------------------------------------------------------------===//
103 // ConvertIndexFloorDivS
104 //===----------------------------------------------------------------------===//
105 
106 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
107 /// `n*m < 0 ? -1 - (x-n)/m : n/m`.
108 struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> {
109   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
110 
111   LogicalResult
112   matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
113                   ConversionPatternRewriter &rewriter) const override {
114     Location loc = op.getLoc();
115     Value n = adaptor.getLhs();
116     Value m = adaptor.getRhs();
117     Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
118     Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
119     Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
120 
121     // Compute `x`.
122     Value mNeg =
123         rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero);
124     Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne);
125 
126     // Compute the negative result.
127     Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n);
128     Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m);
129     Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM);
130 
131     // Compute the positive result.
132     Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m);
133 
134     // Pick the negative result if `n` and `m` have different signs and `n` is
135     // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
136     Value nNeg =
137         rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero);
138     Value diffSign =
139         rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg);
140     Value nNonZero =
141         rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
142     Value cmp = rewriter.create<LLVM::AndOp>(loc, diffSign, nNonZero);
143     rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes);
144     return success();
145   }
146 };
147 
148 //===----------------------------------------------------------------------===//
149 // CovnertIndexCast
150 //===----------------------------------------------------------------------===//
151 
152 /// Convert a cast op. If the materialized index type is the same as the other
153 /// type, fold away the op. Otherwise, truncate or extend the op as appropriate.
154 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
155 /// zero extend when the result bitwidth is larger.
156 template <typename CastOp, typename ExtOp>
157 struct ConvertIndexCast : public mlir::ConvertOpToLLVMPattern<CastOp> {
158   using mlir::ConvertOpToLLVMPattern<CastOp>::ConvertOpToLLVMPattern;
159 
160   LogicalResult
161   matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
162                   ConversionPatternRewriter &rewriter) const override {
163     Type in = adaptor.getInput().getType();
164     Type out = this->getTypeConverter()->convertType(op.getType());
165     if (in == out)
166       rewriter.replaceOp(op, adaptor.getInput());
167     else if (in.getIntOrFloatBitWidth() > out.getIntOrFloatBitWidth())
168       rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, out, adaptor.getInput());
169     else
170       rewriter.replaceOpWithNewOp<ExtOp>(op, out, adaptor.getInput());
171     return success();
172   }
173 };
174 
175 using ConvertIndexCastS = ConvertIndexCast<CastSOp, LLVM::SExtOp>;
176 using ConvertIndexCastU = ConvertIndexCast<CastUOp, LLVM::ZExtOp>;
177 
178 //===----------------------------------------------------------------------===//
179 // ConvertIndexCmp
180 //===----------------------------------------------------------------------===//
181 
182 /// Assert that the LLVM comparison enum lines up with index's enum.
183 static constexpr bool checkPredicates(LLVM::ICmpPredicate lhs,
184                                       IndexCmpPredicate rhs) {
185   return static_cast<int>(lhs) == static_cast<int>(rhs);
186 }
187 
188 static_assert(
189     LLVM::getMaxEnumValForICmpPredicate() ==
190             getMaxEnumValForIndexCmpPredicate() &&
191         checkPredicates(LLVM::ICmpPredicate::eq, IndexCmpPredicate::EQ) &&
192         checkPredicates(LLVM::ICmpPredicate::ne, IndexCmpPredicate::NE) &&
193         checkPredicates(LLVM::ICmpPredicate::sge, IndexCmpPredicate::SGE) &&
194         checkPredicates(LLVM::ICmpPredicate::sgt, IndexCmpPredicate::SGT) &&
195         checkPredicates(LLVM::ICmpPredicate::sle, IndexCmpPredicate::SLE) &&
196         checkPredicates(LLVM::ICmpPredicate::slt, IndexCmpPredicate::SLT) &&
197         checkPredicates(LLVM::ICmpPredicate::uge, IndexCmpPredicate::UGE) &&
198         checkPredicates(LLVM::ICmpPredicate::ugt, IndexCmpPredicate::UGT) &&
199         checkPredicates(LLVM::ICmpPredicate::ule, IndexCmpPredicate::ULE) &&
200         checkPredicates(LLVM::ICmpPredicate::ult, IndexCmpPredicate::ULT),
201     "LLVM ICmpPredicate mismatches IndexCmpPredicate");
202 
203 struct ConvertIndexCmp : public mlir::ConvertOpToLLVMPattern<CmpOp> {
204   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
205 
206   LogicalResult
207   matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
208                   ConversionPatternRewriter &rewriter) const override {
209     // The LLVM enum has the same values as the index predicate enums.
210     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
211         op, *LLVM::symbolizeICmpPredicate(static_cast<uint32_t>(op.getPred())),
212         adaptor.getLhs(), adaptor.getRhs());
213     return success();
214   }
215 };
216 
217 //===----------------------------------------------------------------------===//
218 // ConvertIndexSizeOf
219 //===----------------------------------------------------------------------===//
220 
221 /// Lower `index.sizeof` to a constant with the value of the index bitwidth.
222 struct ConvertIndexSizeOf : public mlir::ConvertOpToLLVMPattern<SizeOfOp> {
223   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
224 
225   LogicalResult
226   matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
227                   ConversionPatternRewriter &rewriter) const override {
228     rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
229         op, getTypeConverter()->getIndexType(),
230         getTypeConverter()->getIndexTypeBitwidth());
231     return success();
232   }
233 };
234 
235 //===----------------------------------------------------------------------===//
236 // ConvertIndexConstant
237 //===----------------------------------------------------------------------===//
238 
239 /// Convert an index constant. Truncate the value as appropriate.
240 struct ConvertIndexConstant : public mlir::ConvertOpToLLVMPattern<ConstantOp> {
241   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
242 
243   LogicalResult
244   matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
245                   ConversionPatternRewriter &rewriter) const override {
246     Type type = getTypeConverter()->getIndexType();
247     APInt value = op.getValue().trunc(type.getIntOrFloatBitWidth());
248     rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
249         op, type, IntegerAttr::get(type, value));
250     return success();
251   }
252 };
253 
254 //===----------------------------------------------------------------------===//
255 // Trivial Conversions
256 //===----------------------------------------------------------------------===//
257 
258 using ConvertIndexAdd = mlir::OneToOneConvertToLLVMPattern<AddOp, LLVM::AddOp>;
259 using ConvertIndexSub = mlir::OneToOneConvertToLLVMPattern<SubOp, LLVM::SubOp>;
260 using ConvertIndexMul = mlir::OneToOneConvertToLLVMPattern<MulOp, LLVM::MulOp>;
261 using ConvertIndexDivS =
262     mlir::OneToOneConvertToLLVMPattern<DivSOp, LLVM::SDivOp>;
263 using ConvertIndexDivU =
264     mlir::OneToOneConvertToLLVMPattern<DivUOp, LLVM::UDivOp>;
265 using ConvertIndexRemS =
266     mlir::OneToOneConvertToLLVMPattern<RemSOp, LLVM::SRemOp>;
267 using ConvertIndexRemU =
268     mlir::OneToOneConvertToLLVMPattern<RemUOp, LLVM::URemOp>;
269 using ConvertIndexMaxS =
270     mlir::OneToOneConvertToLLVMPattern<MaxSOp, LLVM::SMaxOp>;
271 using ConvertIndexMaxU =
272     mlir::OneToOneConvertToLLVMPattern<MaxUOp, LLVM::UMaxOp>;
273 using ConvertIndexMinS =
274     mlir::OneToOneConvertToLLVMPattern<MinSOp, LLVM::SMinOp>;
275 using ConvertIndexMinU =
276     mlir::OneToOneConvertToLLVMPattern<MinUOp, LLVM::UMinOp>;
277 using ConvertIndexShl = mlir::OneToOneConvertToLLVMPattern<ShlOp, LLVM::ShlOp>;
278 using ConvertIndexShrS =
279     mlir::OneToOneConvertToLLVMPattern<ShrSOp, LLVM::AShrOp>;
280 using ConvertIndexShrU =
281     mlir::OneToOneConvertToLLVMPattern<ShrUOp, LLVM::LShrOp>;
282 using ConvertIndexAnd = mlir::OneToOneConvertToLLVMPattern<AndOp, LLVM::AndOp>;
283 using ConvertIndexOr = mlir::OneToOneConvertToLLVMPattern<OrOp, LLVM::OrOp>;
284 using ConvertIndexXor = mlir::OneToOneConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
285 using ConvertIndexBoolConstant =
286     mlir::OneToOneConvertToLLVMPattern<BoolConstantOp, LLVM::ConstantOp>;
287 
288 } // namespace
289 
290 //===----------------------------------------------------------------------===//
291 // Pattern Population
292 //===----------------------------------------------------------------------===//
293 
294 void index::populateIndexToLLVMConversionPatterns(
295     const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
296   patterns.insert<
297       // clang-format off
298       ConvertIndexAdd,
299       ConvertIndexSub,
300       ConvertIndexMul,
301       ConvertIndexDivS,
302       ConvertIndexDivU,
303       ConvertIndexRemS,
304       ConvertIndexRemU,
305       ConvertIndexMaxS,
306       ConvertIndexMaxU,
307       ConvertIndexMinS,
308       ConvertIndexMinU,
309       ConvertIndexShl,
310       ConvertIndexShrS,
311       ConvertIndexShrU,
312       ConvertIndexAnd,
313       ConvertIndexOr,
314       ConvertIndexXor,
315       ConvertIndexCeilDivS,
316       ConvertIndexCeilDivU,
317       ConvertIndexFloorDivS,
318       ConvertIndexCastS,
319       ConvertIndexCastU,
320       ConvertIndexCmp,
321       ConvertIndexSizeOf,
322       ConvertIndexConstant,
323       ConvertIndexBoolConstant
324       // clang-format on
325       >(typeConverter);
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // ODS-Generated Definitions
330 //===----------------------------------------------------------------------===//
331 
332 namespace mlir {
333 #define GEN_PASS_DEF_CONVERTINDEXTOLLVMPASS
334 #include "mlir/Conversion/Passes.h.inc"
335 } // namespace mlir
336 
337 //===----------------------------------------------------------------------===//
338 // Pass Definition
339 //===----------------------------------------------------------------------===//
340 
341 namespace {
342 struct ConvertIndexToLLVMPass
343     : public impl::ConvertIndexToLLVMPassBase<ConvertIndexToLLVMPass> {
344   using Base::Base;
345 
346   void runOnOperation() override;
347 };
348 } // namespace
349 
350 void ConvertIndexToLLVMPass::runOnOperation() {
351   // Configure dialect conversion.
352   ConversionTarget target(getContext());
353   target.addIllegalDialect<IndexDialect>();
354   target.addLegalDialect<LLVM::LLVMDialect>();
355 
356   // Set LLVM lowering options.
357   LowerToLLVMOptions options(&getContext());
358   if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
359     options.overrideIndexBitwidth(indexBitwidth);
360   LLVMTypeConverter typeConverter(&getContext(), options);
361 
362   // Populate patterns and run the conversion.
363   RewritePatternSet patterns(&getContext());
364   populateIndexToLLVMConversionPatterns(typeConverter, patterns);
365 
366   if (failed(
367           applyPartialConversion(getOperation(), target, std::move(patterns))))
368     return signalPassFailure();
369 }
370 
371 //===----------------------------------------------------------------------===//
372 // ConvertToLLVMPatternInterface implementation
373 //===----------------------------------------------------------------------===//
374 
375 namespace {
376 /// Implement the interface to convert Index to LLVM.
377 struct IndexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
378   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
379   void loadDependentDialects(MLIRContext *context) const final {
380     context->loadDialect<LLVM::LLVMDialect>();
381   }
382 
383   /// Hook for derived dialect interface to provide conversion patterns
384   /// and mark dialect legal for the conversion target.
385   void populateConvertToLLVMConversionPatterns(
386       ConversionTarget &target, LLVMTypeConverter &typeConverter,
387       RewritePatternSet &patterns) const final {
388     populateIndexToLLVMConversionPatterns(typeConverter, patterns);
389   }
390 };
391 } // namespace
392 
393 void mlir::index::registerConvertIndexToLLVMInterface(
394     DialectRegistry &registry) {
395   registry.addExtension(+[](MLIRContext *ctx, index::IndexDialect *dialect) {
396     dialect->addInterfaces<IndexToLLVMDialectInterface>();
397   });
398 }
399