1 //===- IndexToLLVM.cpp - Index to LLVM dialect conversion -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" 10 11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/Index/IR/IndexAttrs.h" 14 #include "mlir/Dialect/Index/IR/IndexDialect.h" 15 #include "mlir/Dialect/Index/IR/IndexOps.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Pass/Pass.h" 18 19 using namespace mlir; 20 using namespace index; 21 22 namespace { 23 24 //===----------------------------------------------------------------------===// 25 // ConvertIndexCeilDivS 26 //===----------------------------------------------------------------------===// 27 28 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then 29 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. 30 struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> { 31 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 32 33 LogicalResult 34 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor, 35 ConversionPatternRewriter &rewriter) const override { 36 Location loc = op.getLoc(); 37 Value n = adaptor.getLhs(); 38 Value m = adaptor.getRhs(); 39 Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0); 40 Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1); 41 Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1); 42 43 // Compute `x`. 44 Value mPos = 45 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero); 46 Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne); 47 48 // Compute the positive result. 49 Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x); 50 Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m); 51 Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne); 52 53 // Compute the negative result. 54 Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n); 55 Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m); 56 Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM); 57 58 // Pick the positive result if `n` and `m` have the same sign and `n` is 59 // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. 60 Value nPos = 61 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero); 62 Value sameSign = 63 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos); 64 Value nNonZero = 65 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero); 66 Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero); 67 rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes); 68 return success(); 69 } 70 }; 71 72 //===----------------------------------------------------------------------===// 73 // ConvertIndexCeilDivU 74 //===----------------------------------------------------------------------===// 75 76 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. 77 struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> { 78 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 79 80 LogicalResult 81 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor, 82 ConversionPatternRewriter &rewriter) const override { 83 Location loc = op.getLoc(); 84 Value n = adaptor.getLhs(); 85 Value m = adaptor.getRhs(); 86 Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0); 87 Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1); 88 89 // Compute the non-zero result. 90 Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one); 91 Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m); 92 Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one); 93 94 // Pick the result. 95 Value cmp = 96 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero); 97 rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne); 98 return success(); 99 } 100 }; 101 102 //===----------------------------------------------------------------------===// 103 // ConvertIndexFloorDivS 104 //===----------------------------------------------------------------------===// 105 106 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then 107 /// `n*m < 0 ? -1 - (x-n)/m : n/m`. 108 struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> { 109 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 110 111 LogicalResult 112 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor, 113 ConversionPatternRewriter &rewriter) const override { 114 Location loc = op.getLoc(); 115 Value n = adaptor.getLhs(); 116 Value m = adaptor.getRhs(); 117 Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0); 118 Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1); 119 Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1); 120 121 // Compute `x`. 122 Value mNeg = 123 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero); 124 Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne); 125 126 // Compute the negative result. 127 Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n); 128 Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m); 129 Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM); 130 131 // Compute the positive result. 132 Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m); 133 134 // Pick the negative result if `n` and `m` have different signs and `n` is 135 // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. 136 Value nNeg = 137 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero); 138 Value diffSign = 139 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg); 140 Value nNonZero = 141 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero); 142 Value cmp = rewriter.create<LLVM::AndOp>(loc, diffSign, nNonZero); 143 rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes); 144 return success(); 145 } 146 }; 147 148 //===----------------------------------------------------------------------===// 149 // CovnertIndexCast 150 //===----------------------------------------------------------------------===// 151 152 /// Convert a cast op. If the materialized index type is the same as the other 153 /// type, fold away the op. Otherwise, truncate or extend the op as appropriate. 154 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts 155 /// zero extend when the result bitwidth is larger. 156 template <typename CastOp, typename ExtOp> 157 struct ConvertIndexCast : public mlir::ConvertOpToLLVMPattern<CastOp> { 158 using mlir::ConvertOpToLLVMPattern<CastOp>::ConvertOpToLLVMPattern; 159 160 LogicalResult 161 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, 162 ConversionPatternRewriter &rewriter) const override { 163 Type in = adaptor.getInput().getType(); 164 Type out = this->getTypeConverter()->convertType(op.getType()); 165 if (in == out) 166 rewriter.replaceOp(op, adaptor.getInput()); 167 else if (in.getIntOrFloatBitWidth() > out.getIntOrFloatBitWidth()) 168 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, out, adaptor.getInput()); 169 else 170 rewriter.replaceOpWithNewOp<ExtOp>(op, out, adaptor.getInput()); 171 return success(); 172 } 173 }; 174 175 using ConvertIndexCastS = ConvertIndexCast<CastSOp, LLVM::SExtOp>; 176 using ConvertIndexCastU = ConvertIndexCast<CastUOp, LLVM::ZExtOp>; 177 178 //===----------------------------------------------------------------------===// 179 // ConvertIndexCmp 180 //===----------------------------------------------------------------------===// 181 182 /// Assert that the LLVM comparison enum lines up with index's enum. 183 static constexpr bool checkPredicates(LLVM::ICmpPredicate lhs, 184 IndexCmpPredicate rhs) { 185 return static_cast<int>(lhs) == static_cast<int>(rhs); 186 } 187 188 static_assert( 189 LLVM::getMaxEnumValForICmpPredicate() == 190 getMaxEnumValForIndexCmpPredicate() && 191 checkPredicates(LLVM::ICmpPredicate::eq, IndexCmpPredicate::EQ) && 192 checkPredicates(LLVM::ICmpPredicate::ne, IndexCmpPredicate::NE) && 193 checkPredicates(LLVM::ICmpPredicate::sge, IndexCmpPredicate::SGE) && 194 checkPredicates(LLVM::ICmpPredicate::sgt, IndexCmpPredicate::SGT) && 195 checkPredicates(LLVM::ICmpPredicate::sle, IndexCmpPredicate::SLE) && 196 checkPredicates(LLVM::ICmpPredicate::slt, IndexCmpPredicate::SLT) && 197 checkPredicates(LLVM::ICmpPredicate::uge, IndexCmpPredicate::UGE) && 198 checkPredicates(LLVM::ICmpPredicate::ugt, IndexCmpPredicate::UGT) && 199 checkPredicates(LLVM::ICmpPredicate::ule, IndexCmpPredicate::ULE) && 200 checkPredicates(LLVM::ICmpPredicate::ult, IndexCmpPredicate::ULT), 201 "LLVM ICmpPredicate mismatches IndexCmpPredicate"); 202 203 struct ConvertIndexCmp : public mlir::ConvertOpToLLVMPattern<CmpOp> { 204 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 205 206 LogicalResult 207 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor, 208 ConversionPatternRewriter &rewriter) const override { 209 // The LLVM enum has the same values as the index predicate enums. 210 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( 211 op, *LLVM::symbolizeICmpPredicate(static_cast<uint32_t>(op.getPred())), 212 adaptor.getLhs(), adaptor.getRhs()); 213 return success(); 214 } 215 }; 216 217 //===----------------------------------------------------------------------===// 218 // ConvertIndexSizeOf 219 //===----------------------------------------------------------------------===// 220 221 /// Lower `index.sizeof` to a constant with the value of the index bitwidth. 222 struct ConvertIndexSizeOf : public mlir::ConvertOpToLLVMPattern<SizeOfOp> { 223 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 224 225 LogicalResult 226 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor, 227 ConversionPatternRewriter &rewriter) const override { 228 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 229 op, getTypeConverter()->getIndexType(), 230 getTypeConverter()->getIndexTypeBitwidth()); 231 return success(); 232 } 233 }; 234 235 //===----------------------------------------------------------------------===// 236 // ConvertIndexConstant 237 //===----------------------------------------------------------------------===// 238 239 /// Convert an index constant. Truncate the value as appropriate. 240 struct ConvertIndexConstant : public mlir::ConvertOpToLLVMPattern<ConstantOp> { 241 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 242 243 LogicalResult 244 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor, 245 ConversionPatternRewriter &rewriter) const override { 246 Type type = getTypeConverter()->getIndexType(); 247 APInt value = op.getValue().trunc(type.getIntOrFloatBitWidth()); 248 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 249 op, type, IntegerAttr::get(type, value)); 250 return success(); 251 } 252 }; 253 254 //===----------------------------------------------------------------------===// 255 // Trivial Conversions 256 //===----------------------------------------------------------------------===// 257 258 using ConvertIndexAdd = mlir::OneToOneConvertToLLVMPattern<AddOp, LLVM::AddOp>; 259 using ConvertIndexSub = mlir::OneToOneConvertToLLVMPattern<SubOp, LLVM::SubOp>; 260 using ConvertIndexMul = mlir::OneToOneConvertToLLVMPattern<MulOp, LLVM::MulOp>; 261 using ConvertIndexDivS = 262 mlir::OneToOneConvertToLLVMPattern<DivSOp, LLVM::SDivOp>; 263 using ConvertIndexDivU = 264 mlir::OneToOneConvertToLLVMPattern<DivUOp, LLVM::UDivOp>; 265 using ConvertIndexRemS = 266 mlir::OneToOneConvertToLLVMPattern<RemSOp, LLVM::SRemOp>; 267 using ConvertIndexRemU = 268 mlir::OneToOneConvertToLLVMPattern<RemUOp, LLVM::URemOp>; 269 using ConvertIndexMaxS = 270 mlir::OneToOneConvertToLLVMPattern<MaxSOp, LLVM::SMaxOp>; 271 using ConvertIndexMaxU = 272 mlir::OneToOneConvertToLLVMPattern<MaxUOp, LLVM::UMaxOp>; 273 using ConvertIndexMinS = 274 mlir::OneToOneConvertToLLVMPattern<MinSOp, LLVM::SMinOp>; 275 using ConvertIndexMinU = 276 mlir::OneToOneConvertToLLVMPattern<MinUOp, LLVM::UMinOp>; 277 using ConvertIndexShl = mlir::OneToOneConvertToLLVMPattern<ShlOp, LLVM::ShlOp>; 278 using ConvertIndexShrS = 279 mlir::OneToOneConvertToLLVMPattern<ShrSOp, LLVM::AShrOp>; 280 using ConvertIndexShrU = 281 mlir::OneToOneConvertToLLVMPattern<ShrUOp, LLVM::LShrOp>; 282 using ConvertIndexAnd = mlir::OneToOneConvertToLLVMPattern<AndOp, LLVM::AndOp>; 283 using ConvertIndexOr = mlir::OneToOneConvertToLLVMPattern<OrOp, LLVM::OrOp>; 284 using ConvertIndexXor = mlir::OneToOneConvertToLLVMPattern<XOrOp, LLVM::XOrOp>; 285 using ConvertIndexBoolConstant = 286 mlir::OneToOneConvertToLLVMPattern<BoolConstantOp, LLVM::ConstantOp>; 287 288 } // namespace 289 290 //===----------------------------------------------------------------------===// 291 // Pattern Population 292 //===----------------------------------------------------------------------===// 293 294 void index::populateIndexToLLVMConversionPatterns( 295 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 296 patterns.insert< 297 // clang-format off 298 ConvertIndexAdd, 299 ConvertIndexSub, 300 ConvertIndexMul, 301 ConvertIndexDivS, 302 ConvertIndexDivU, 303 ConvertIndexRemS, 304 ConvertIndexRemU, 305 ConvertIndexMaxS, 306 ConvertIndexMaxU, 307 ConvertIndexMinS, 308 ConvertIndexMinU, 309 ConvertIndexShl, 310 ConvertIndexShrS, 311 ConvertIndexShrU, 312 ConvertIndexAnd, 313 ConvertIndexOr, 314 ConvertIndexXor, 315 ConvertIndexCeilDivS, 316 ConvertIndexCeilDivU, 317 ConvertIndexFloorDivS, 318 ConvertIndexCastS, 319 ConvertIndexCastU, 320 ConvertIndexCmp, 321 ConvertIndexSizeOf, 322 ConvertIndexConstant, 323 ConvertIndexBoolConstant 324 // clang-format on 325 >(typeConverter); 326 } 327 328 //===----------------------------------------------------------------------===// 329 // ODS-Generated Definitions 330 //===----------------------------------------------------------------------===// 331 332 namespace mlir { 333 #define GEN_PASS_DEF_CONVERTINDEXTOLLVMPASS 334 #include "mlir/Conversion/Passes.h.inc" 335 } // namespace mlir 336 337 //===----------------------------------------------------------------------===// 338 // Pass Definition 339 //===----------------------------------------------------------------------===// 340 341 namespace { 342 struct ConvertIndexToLLVMPass 343 : public impl::ConvertIndexToLLVMPassBase<ConvertIndexToLLVMPass> { 344 using Base::Base; 345 346 void runOnOperation() override; 347 }; 348 } // namespace 349 350 void ConvertIndexToLLVMPass::runOnOperation() { 351 // Configure dialect conversion. 352 ConversionTarget target(getContext()); 353 target.addIllegalDialect<IndexDialect>(); 354 target.addLegalDialect<LLVM::LLVMDialect>(); 355 356 // Set LLVM lowering options. 357 LowerToLLVMOptions options(&getContext()); 358 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 359 options.overrideIndexBitwidth(indexBitwidth); 360 LLVMTypeConverter typeConverter(&getContext(), options); 361 362 // Populate patterns and run the conversion. 363 RewritePatternSet patterns(&getContext()); 364 populateIndexToLLVMConversionPatterns(typeConverter, patterns); 365 366 if (failed( 367 applyPartialConversion(getOperation(), target, std::move(patterns)))) 368 return signalPassFailure(); 369 } 370 371 //===----------------------------------------------------------------------===// 372 // ConvertToLLVMPatternInterface implementation 373 //===----------------------------------------------------------------------===// 374 375 namespace { 376 /// Implement the interface to convert Index to LLVM. 377 struct IndexToLLVMDialectInterface : public ConvertToLLVMPatternInterface { 378 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 379 void loadDependentDialects(MLIRContext *context) const final { 380 context->loadDialect<LLVM::LLVMDialect>(); 381 } 382 383 /// Hook for derived dialect interface to provide conversion patterns 384 /// and mark dialect legal for the conversion target. 385 void populateConvertToLLVMConversionPatterns( 386 ConversionTarget &target, LLVMTypeConverter &typeConverter, 387 RewritePatternSet &patterns) const final { 388 populateIndexToLLVMConversionPatterns(typeConverter, patterns); 389 } 390 }; 391 } // namespace 392 393 void mlir::index::registerConvertIndexToLLVMInterface( 394 DialectRegistry ®istry) { 395 registry.addExtension(+[](MLIRContext *ctx, index::IndexDialect *dialect) { 396 dialect->addInterfaces<IndexToLLVMDialectInterface>(); 397 }); 398 } 399