1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Support/MathExtras.h" 21 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 24 using namespace mlir; 25 using namespace mlir::vector; 26 27 // Helper to reduce vector type by one rank at front. 28 static VectorType reducedVectorTypeFront(VectorType tp) { 29 assert((tp.getRank() > 1) && "unlowerable vector type"); 30 unsigned numScalableDims = tp.getNumScalableDims(); 31 if (tp.getShape().size() == numScalableDims) 32 --numScalableDims; 33 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 34 numScalableDims); 35 } 36 37 // Helper to reduce vector type by *all* but one rank at back. 38 static VectorType reducedVectorTypeBack(VectorType tp) { 39 assert((tp.getRank() > 1) && "unlowerable vector type"); 40 unsigned numScalableDims = tp.getNumScalableDims(); 41 if (numScalableDims > 0) 42 --numScalableDims; 43 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 44 numScalableDims); 45 } 46 47 // Helper that picks the proper sequence for inserting. 48 static Value insertOne(ConversionPatternRewriter &rewriter, 49 LLVMTypeConverter &typeConverter, Location loc, 50 Value val1, Value val2, Type llvmType, int64_t rank, 51 int64_t pos) { 52 assert(rank > 0 && "0-D vector corner case should have been handled already"); 53 if (rank == 1) { 54 auto idxType = rewriter.getIndexType(); 55 auto constant = rewriter.create<LLVM::ConstantOp>( 56 loc, typeConverter.convertType(idxType), 57 rewriter.getIntegerAttr(idxType, pos)); 58 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 59 constant); 60 } 61 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 62 } 63 64 // Helper that picks the proper sequence for extracting. 65 static Value extractOne(ConversionPatternRewriter &rewriter, 66 LLVMTypeConverter &typeConverter, Location loc, 67 Value val, Type llvmType, int64_t rank, int64_t pos) { 68 if (rank <= 1) { 69 auto idxType = rewriter.getIndexType(); 70 auto constant = rewriter.create<LLVM::ConstantOp>( 71 loc, typeConverter.convertType(idxType), 72 rewriter.getIntegerAttr(idxType, pos)); 73 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 74 constant); 75 } 76 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 77 } 78 79 // Helper that returns data layout alignment of a memref. 80 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 81 MemRefType memrefType, unsigned &align) { 82 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 83 if (!elementTy) 84 return failure(); 85 86 // TODO: this should use the MLIR data layout when it becomes available and 87 // stop depending on translation. 88 llvm::LLVMContext llvmContext; 89 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 90 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 91 return success(); 92 } 93 94 // Add an index vector component to a base pointer. This almost always succeeds 95 // unless the last stride is non-unit or the memory space is not zero. 96 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 97 Location loc, Value memref, Value base, 98 Value index, MemRefType memRefType, 99 VectorType vType, Value &ptrs) { 100 int64_t offset; 101 SmallVector<int64_t, 4> strides; 102 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 103 if (failed(successStrides) || strides.back() != 1 || 104 memRefType.getMemorySpaceAsInt() != 0) 105 return failure(); 106 auto pType = MemRefDescriptor(memref).getElementPtrType(); 107 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 108 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index); 109 return success(); 110 } 111 112 // Casts a strided element pointer to a vector pointer. The vector pointer 113 // will be in the same address space as the incoming memref type. 114 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 115 Value ptr, MemRefType memRefType, Type vt) { 116 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); 117 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 118 } 119 120 namespace { 121 122 /// Trivial Vector to LLVM conversions 123 using VectorScaleOpConversion = 124 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 125 126 /// Conversion pattern for a vector.bitcast. 127 class VectorBitCastOpConversion 128 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 129 public: 130 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 131 132 LogicalResult 133 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 134 ConversionPatternRewriter &rewriter) const override { 135 // Only 0-D and 1-D vectors can be lowered to LLVM. 136 VectorType resultTy = bitCastOp.getResultVectorType(); 137 if (resultTy.getRank() > 1) 138 return failure(); 139 Type newResultTy = typeConverter->convertType(resultTy); 140 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 141 adaptor.getOperands()[0]); 142 return success(); 143 } 144 }; 145 146 /// Conversion pattern for a vector.matrix_multiply. 147 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 148 class VectorMatmulOpConversion 149 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 150 public: 151 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 152 153 LogicalResult 154 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 155 ConversionPatternRewriter &rewriter) const override { 156 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 157 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 158 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 159 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 160 return success(); 161 } 162 }; 163 164 /// Conversion pattern for a vector.flat_transpose. 165 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 166 class VectorFlatTransposeOpConversion 167 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 168 public: 169 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 170 171 LogicalResult 172 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 173 ConversionPatternRewriter &rewriter) const override { 174 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 175 transOp, typeConverter->convertType(transOp.getRes().getType()), 176 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 177 return success(); 178 } 179 }; 180 181 /// Overloaded utility that replaces a vector.load, vector.store, 182 /// vector.maskedload and vector.maskedstore with their respective LLVM 183 /// couterparts. 184 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 185 vector::LoadOpAdaptor adaptor, 186 VectorType vectorTy, Value ptr, unsigned align, 187 ConversionPatternRewriter &rewriter) { 188 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 189 } 190 191 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 192 vector::MaskedLoadOpAdaptor adaptor, 193 VectorType vectorTy, Value ptr, unsigned align, 194 ConversionPatternRewriter &rewriter) { 195 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 196 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 197 } 198 199 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 200 vector::StoreOpAdaptor adaptor, 201 VectorType vectorTy, Value ptr, unsigned align, 202 ConversionPatternRewriter &rewriter) { 203 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 204 ptr, align); 205 } 206 207 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 208 vector::MaskedStoreOpAdaptor adaptor, 209 VectorType vectorTy, Value ptr, unsigned align, 210 ConversionPatternRewriter &rewriter) { 211 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 212 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 213 } 214 215 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 216 /// vector.maskedstore. 217 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 218 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 219 public: 220 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 221 222 LogicalResult 223 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 224 typename LoadOrStoreOp::Adaptor adaptor, 225 ConversionPatternRewriter &rewriter) const override { 226 // Only 1-D vectors can be lowered to LLVM. 227 VectorType vectorTy = loadOrStoreOp.getVectorType(); 228 if (vectorTy.getRank() > 1) 229 return failure(); 230 231 auto loc = loadOrStoreOp->getLoc(); 232 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 233 234 // Resolve alignment. 235 unsigned align; 236 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 237 return failure(); 238 239 // Resolve address. 240 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 241 .template cast<VectorType>(); 242 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 243 adaptor.getIndices(), rewriter); 244 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 245 246 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 247 return success(); 248 } 249 }; 250 251 /// Conversion pattern for a vector.gather. 252 class VectorGatherOpConversion 253 : public ConvertOpToLLVMPattern<vector::GatherOp> { 254 public: 255 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 256 257 LogicalResult 258 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 259 ConversionPatternRewriter &rewriter) const override { 260 auto loc = gather->getLoc(); 261 MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>(); 262 assert(memRefType && "The base should be bufferized"); 263 264 // Resolve alignment. 265 unsigned align; 266 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 267 return failure(); 268 269 // Resolve address. 270 Value ptrs; 271 VectorType vType = gather.getVectorType(); 272 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 273 adaptor.getIndices(), rewriter); 274 if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, 275 adaptor.getIndexVec(), memRefType, vType, ptrs))) 276 return failure(); 277 278 // Replace with the gather intrinsic. 279 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 280 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 281 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 282 return success(); 283 } 284 }; 285 286 /// Conversion pattern for a vector.scatter. 287 class VectorScatterOpConversion 288 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 289 public: 290 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 291 292 LogicalResult 293 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 294 ConversionPatternRewriter &rewriter) const override { 295 auto loc = scatter->getLoc(); 296 MemRefType memRefType = scatter.getMemRefType(); 297 298 // Resolve alignment. 299 unsigned align; 300 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 301 return failure(); 302 303 // Resolve address. 304 Value ptrs; 305 VectorType vType = scatter.getVectorType(); 306 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 307 adaptor.getIndices(), rewriter); 308 if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, 309 adaptor.getIndexVec(), memRefType, vType, ptrs))) 310 return failure(); 311 312 // Replace with the scatter intrinsic. 313 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 314 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 315 rewriter.getI32IntegerAttr(align)); 316 return success(); 317 } 318 }; 319 320 /// Conversion pattern for a vector.expandload. 321 class VectorExpandLoadOpConversion 322 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 323 public: 324 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 325 326 LogicalResult 327 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 328 ConversionPatternRewriter &rewriter) const override { 329 auto loc = expand->getLoc(); 330 MemRefType memRefType = expand.getMemRefType(); 331 332 // Resolve address. 333 auto vtype = typeConverter->convertType(expand.getVectorType()); 334 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 335 adaptor.getIndices(), rewriter); 336 337 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 338 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 339 return success(); 340 } 341 }; 342 343 /// Conversion pattern for a vector.compressstore. 344 class VectorCompressStoreOpConversion 345 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 346 public: 347 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 348 349 LogicalResult 350 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 351 ConversionPatternRewriter &rewriter) const override { 352 auto loc = compress->getLoc(); 353 MemRefType memRefType = compress.getMemRefType(); 354 355 // Resolve address. 356 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 357 adaptor.getIndices(), rewriter); 358 359 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 360 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 361 return success(); 362 } 363 }; 364 365 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 366 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 367 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 368 /// non-null. 369 template <class VectorOp, class ScalarOp> 370 static Value createIntegerReductionArithmeticOpLowering( 371 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 372 Value vectorOperand, Value accumulator) { 373 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); 374 if (accumulator) 375 result = rewriter.create<ScalarOp>(loc, accumulator, result); 376 return result; 377 } 378 379 /// Helper method to lower a `vector.reduction` operation that performs 380 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 381 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 382 /// the accumulator value if non-null. 383 template <class VectorOp> 384 static Value createIntegerReductionComparisonOpLowering( 385 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 386 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 387 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); 388 if (accumulator) { 389 Value cmp = 390 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 391 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 392 } 393 return result; 394 } 395 396 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum 397 /// with vector types. 398 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, 399 Value rhs, bool isMin) { 400 auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); 401 Type i1Type = builder.getI1Type(); 402 if (auto vecType = lhs.getType().dyn_cast<VectorType>()) 403 i1Type = VectorType::get(vecType.getShape(), i1Type); 404 Value cmp = builder.create<LLVM::FCmpOp>( 405 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, 406 lhs, rhs); 407 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); 408 Value isNan = builder.create<LLVM::FCmpOp>( 409 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); 410 Value nan = builder.create<LLVM::ConstantOp>( 411 loc, lhs.getType(), 412 builder.getFloatAttr(floatType, 413 APFloat::getQNaN(floatType.getFloatSemantics()))); 414 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); 415 } 416 417 /// Conversion pattern for all vector reductions. 418 class VectorReductionOpConversion 419 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 420 public: 421 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 422 bool reassociateFPRed) 423 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 424 reassociateFPReductions(reassociateFPRed) {} 425 426 LogicalResult 427 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 428 ConversionPatternRewriter &rewriter) const override { 429 auto kind = reductionOp.getKind(); 430 Type eltType = reductionOp.getDest().getType(); 431 Type llvmType = typeConverter->convertType(eltType); 432 Value operand = adaptor.getVector(); 433 Value acc = adaptor.getAcc(); 434 Location loc = reductionOp.getLoc(); 435 if (eltType.isIntOrIndex()) { 436 // Integer reductions: add/mul/min/max/and/or/xor. 437 Value result; 438 switch (kind) { 439 case vector::CombiningKind::ADD: 440 result = 441 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 442 LLVM::AddOp>( 443 rewriter, loc, llvmType, operand, acc); 444 break; 445 case vector::CombiningKind::MUL: 446 result = 447 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 448 LLVM::MulOp>( 449 rewriter, loc, llvmType, operand, acc); 450 break; 451 case vector::CombiningKind::MINUI: 452 result = createIntegerReductionComparisonOpLowering< 453 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 454 LLVM::ICmpPredicate::ule); 455 break; 456 case vector::CombiningKind::MINSI: 457 result = createIntegerReductionComparisonOpLowering< 458 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 459 LLVM::ICmpPredicate::sle); 460 break; 461 case vector::CombiningKind::MAXUI: 462 result = createIntegerReductionComparisonOpLowering< 463 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 464 LLVM::ICmpPredicate::uge); 465 break; 466 case vector::CombiningKind::MAXSI: 467 result = createIntegerReductionComparisonOpLowering< 468 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 469 LLVM::ICmpPredicate::sge); 470 break; 471 case vector::CombiningKind::AND: 472 result = 473 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 474 LLVM::AndOp>( 475 rewriter, loc, llvmType, operand, acc); 476 break; 477 case vector::CombiningKind::OR: 478 result = 479 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 480 LLVM::OrOp>( 481 rewriter, loc, llvmType, operand, acc); 482 break; 483 case vector::CombiningKind::XOR: 484 result = 485 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 486 LLVM::XOrOp>( 487 rewriter, loc, llvmType, operand, acc); 488 break; 489 default: 490 return failure(); 491 } 492 rewriter.replaceOp(reductionOp, result); 493 494 return success(); 495 } 496 497 if (!eltType.isa<FloatType>()) 498 return failure(); 499 500 // Floating-point reductions: add/mul/min/max 501 if (kind == vector::CombiningKind::ADD) { 502 // Optional accumulator (or zero). 503 Value acc = adaptor.getOperands().size() > 1 504 ? adaptor.getOperands()[1] 505 : rewriter.create<LLVM::ConstantOp>( 506 reductionOp->getLoc(), llvmType, 507 rewriter.getZeroAttr(eltType)); 508 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 509 reductionOp, llvmType, acc, operand, 510 rewriter.getBoolAttr(reassociateFPReductions)); 511 } else if (kind == vector::CombiningKind::MUL) { 512 // Optional accumulator (or one). 513 Value acc = adaptor.getOperands().size() > 1 514 ? adaptor.getOperands()[1] 515 : rewriter.create<LLVM::ConstantOp>( 516 reductionOp->getLoc(), llvmType, 517 rewriter.getFloatAttr(eltType, 1.0)); 518 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 519 reductionOp, llvmType, acc, operand, 520 rewriter.getBoolAttr(reassociateFPReductions)); 521 } else if (kind == vector::CombiningKind::MINF) { 522 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 523 // NaNs/-0.0/+0.0 in the same way. 524 Value result = 525 rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand); 526 if (acc) 527 result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true); 528 rewriter.replaceOp(reductionOp, result); 529 } else if (kind == vector::CombiningKind::MAXF) { 530 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 531 // NaNs/-0.0/+0.0 in the same way. 532 Value result = 533 rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand); 534 if (acc) 535 result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false); 536 rewriter.replaceOp(reductionOp, result); 537 } else 538 return failure(); 539 540 return success(); 541 } 542 543 private: 544 const bool reassociateFPReductions; 545 }; 546 547 class VectorShuffleOpConversion 548 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 549 public: 550 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 551 552 LogicalResult 553 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 554 ConversionPatternRewriter &rewriter) const override { 555 auto loc = shuffleOp->getLoc(); 556 auto v1Type = shuffleOp.getV1VectorType(); 557 auto v2Type = shuffleOp.getV2VectorType(); 558 auto vectorType = shuffleOp.getVectorType(); 559 Type llvmType = typeConverter->convertType(vectorType); 560 auto maskArrayAttr = shuffleOp.getMask(); 561 562 // Bail if result type cannot be lowered. 563 if (!llvmType) 564 return failure(); 565 566 // Get rank and dimension sizes. 567 int64_t rank = vectorType.getRank(); 568 assert(v1Type.getRank() == rank); 569 assert(v2Type.getRank() == rank); 570 int64_t v1Dim = v1Type.getDimSize(0); 571 572 // For rank 1, where both operands have *exactly* the same vector type, 573 // there is direct shuffle support in LLVM. Use it! 574 if (rank == 1 && v1Type == v2Type) { 575 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 576 loc, adaptor.getV1(), adaptor.getV2(), 577 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 578 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 579 return success(); 580 } 581 582 // For all other cases, insert the individual values individually. 583 Type eltType; 584 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 585 eltType = arrayType.getElementType(); 586 else 587 eltType = llvmType.cast<VectorType>().getElementType(); 588 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 589 int64_t insPos = 0; 590 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 591 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 592 Value value = adaptor.getV1(); 593 if (extPos >= v1Dim) { 594 extPos -= v1Dim; 595 value = adaptor.getV2(); 596 } 597 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 598 eltType, rank, extPos); 599 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 600 llvmType, rank, insPos++); 601 } 602 rewriter.replaceOp(shuffleOp, insert); 603 return success(); 604 } 605 }; 606 607 class VectorExtractElementOpConversion 608 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 609 public: 610 using ConvertOpToLLVMPattern< 611 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 612 613 LogicalResult 614 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 615 ConversionPatternRewriter &rewriter) const override { 616 auto vectorType = extractEltOp.getVectorType(); 617 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 618 619 // Bail if result type cannot be lowered. 620 if (!llvmType) 621 return failure(); 622 623 if (vectorType.getRank() == 0) { 624 Location loc = extractEltOp.getLoc(); 625 auto idxType = rewriter.getIndexType(); 626 auto zero = rewriter.create<LLVM::ConstantOp>( 627 loc, typeConverter->convertType(idxType), 628 rewriter.getIntegerAttr(idxType, 0)); 629 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 630 extractEltOp, llvmType, adaptor.getVector(), zero); 631 return success(); 632 } 633 634 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 635 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 636 return success(); 637 } 638 }; 639 640 class VectorExtractOpConversion 641 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 642 public: 643 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 644 645 LogicalResult 646 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 647 ConversionPatternRewriter &rewriter) const override { 648 auto loc = extractOp->getLoc(); 649 auto resultType = extractOp.getResult().getType(); 650 auto llvmResultType = typeConverter->convertType(resultType); 651 auto positionArrayAttr = extractOp.getPosition(); 652 653 // Bail if result type cannot be lowered. 654 if (!llvmResultType) 655 return failure(); 656 657 // Extract entire vector. Should be handled by folder, but just to be safe. 658 if (positionArrayAttr.empty()) { 659 rewriter.replaceOp(extractOp, adaptor.getVector()); 660 return success(); 661 } 662 663 // One-shot extraction of vector from array (only requires extractvalue). 664 if (resultType.isa<VectorType>()) { 665 SmallVector<int64_t> indices; 666 for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>()) 667 indices.push_back(idx.getInt()); 668 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 669 loc, adaptor.getVector(), indices); 670 rewriter.replaceOp(extractOp, extracted); 671 return success(); 672 } 673 674 // Potential extraction of 1-D vector from array. 675 Value extracted = adaptor.getVector(); 676 auto positionAttrs = positionArrayAttr.getValue(); 677 if (positionAttrs.size() > 1) { 678 SmallVector<int64_t> nMinusOnePosition; 679 for (auto idx : positionAttrs.drop_back()) 680 nMinusOnePosition.push_back(idx.cast<IntegerAttr>().getInt()); 681 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 682 nMinusOnePosition); 683 } 684 685 // Remaining extraction of element from 1-D LLVM vector 686 auto position = positionAttrs.back().cast<IntegerAttr>(); 687 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 688 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 689 extracted = 690 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 691 rewriter.replaceOp(extractOp, extracted); 692 693 return success(); 694 } 695 }; 696 697 /// Conversion pattern that turns a vector.fma on a 1-D vector 698 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 699 /// This does not match vectors of n >= 2 rank. 700 /// 701 /// Example: 702 /// ``` 703 /// vector.fma %a, %a, %a : vector<8xf32> 704 /// ``` 705 /// is converted to: 706 /// ``` 707 /// llvm.intr.fmuladd %va, %va, %va: 708 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 709 /// -> !llvm."<8 x f32>"> 710 /// ``` 711 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 712 public: 713 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 714 715 LogicalResult 716 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 717 ConversionPatternRewriter &rewriter) const override { 718 VectorType vType = fmaOp.getVectorType(); 719 if (vType.getRank() != 1) 720 return failure(); 721 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 722 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 723 return success(); 724 } 725 }; 726 727 class VectorInsertElementOpConversion 728 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 729 public: 730 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 731 732 LogicalResult 733 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 734 ConversionPatternRewriter &rewriter) const override { 735 auto vectorType = insertEltOp.getDestVectorType(); 736 auto llvmType = typeConverter->convertType(vectorType); 737 738 // Bail if result type cannot be lowered. 739 if (!llvmType) 740 return failure(); 741 742 if (vectorType.getRank() == 0) { 743 Location loc = insertEltOp.getLoc(); 744 auto idxType = rewriter.getIndexType(); 745 auto zero = rewriter.create<LLVM::ConstantOp>( 746 loc, typeConverter->convertType(idxType), 747 rewriter.getIntegerAttr(idxType, 0)); 748 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 749 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 750 return success(); 751 } 752 753 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 754 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 755 adaptor.getPosition()); 756 return success(); 757 } 758 }; 759 760 class VectorInsertOpConversion 761 : public ConvertOpToLLVMPattern<vector::InsertOp> { 762 public: 763 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 764 765 LogicalResult 766 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 767 ConversionPatternRewriter &rewriter) const override { 768 auto loc = insertOp->getLoc(); 769 auto sourceType = insertOp.getSourceType(); 770 auto destVectorType = insertOp.getDestVectorType(); 771 auto llvmResultType = typeConverter->convertType(destVectorType); 772 auto positionArrayAttr = insertOp.getPosition(); 773 774 // Bail if result type cannot be lowered. 775 if (!llvmResultType) 776 return failure(); 777 778 // Overwrite entire vector with value. Should be handled by folder, but 779 // just to be safe. 780 if (positionArrayAttr.empty()) { 781 rewriter.replaceOp(insertOp, adaptor.getSource()); 782 return success(); 783 } 784 785 // One-shot insertion of a vector into an array (only requires insertvalue). 786 if (sourceType.isa<VectorType>()) { 787 Value inserted = rewriter.create<LLVM::InsertValueOp>( 788 loc, adaptor.getDest(), adaptor.getSource(), 789 LLVM::convertArrayToIndices(positionArrayAttr)); 790 rewriter.replaceOp(insertOp, inserted); 791 return success(); 792 } 793 794 // Potential extraction of 1-D vector from array. 795 Value extracted = adaptor.getDest(); 796 auto positionAttrs = positionArrayAttr.getValue(); 797 auto position = positionAttrs.back().cast<IntegerAttr>(); 798 auto oneDVectorType = destVectorType; 799 if (positionAttrs.size() > 1) { 800 oneDVectorType = reducedVectorTypeBack(destVectorType); 801 extracted = rewriter.create<LLVM::ExtractValueOp>( 802 loc, extracted, 803 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 804 } 805 806 // Insertion of an element into a 1-D LLVM vector. 807 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 808 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 809 Value inserted = rewriter.create<LLVM::InsertElementOp>( 810 loc, typeConverter->convertType(oneDVectorType), extracted, 811 adaptor.getSource(), constant); 812 813 // Potential insertion of resulting 1-D vector into array. 814 if (positionAttrs.size() > 1) { 815 inserted = rewriter.create<LLVM::InsertValueOp>( 816 loc, adaptor.getDest(), inserted, 817 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 818 } 819 820 rewriter.replaceOp(insertOp, inserted); 821 return success(); 822 } 823 }; 824 825 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 826 /// 827 /// Example: 828 /// ``` 829 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 830 /// ``` 831 /// is rewritten into: 832 /// ``` 833 /// %r = splat %f0: vector<2x4xf32> 834 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 835 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 836 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 837 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 838 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 839 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 840 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 841 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 842 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 843 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 844 /// // %r3 holds the final value. 845 /// ``` 846 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 847 public: 848 using OpRewritePattern<FMAOp>::OpRewritePattern; 849 850 void initialize() { 851 // This pattern recursively unpacks one dimension at a time. The recursion 852 // bounded as the rank is strictly decreasing. 853 setHasBoundedRewriteRecursion(); 854 } 855 856 LogicalResult matchAndRewrite(FMAOp op, 857 PatternRewriter &rewriter) const override { 858 auto vType = op.getVectorType(); 859 if (vType.getRank() < 2) 860 return failure(); 861 862 auto loc = op.getLoc(); 863 auto elemType = vType.getElementType(); 864 Value zero = rewriter.create<arith::ConstantOp>( 865 loc, elemType, rewriter.getZeroAttr(elemType)); 866 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 867 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 868 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 869 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 870 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 871 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 872 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 873 } 874 rewriter.replaceOp(op, desc); 875 return success(); 876 } 877 }; 878 879 /// Returns the strides if the memory underlying `memRefType` has a contiguous 880 /// static layout. 881 static llvm::Optional<SmallVector<int64_t, 4>> 882 computeContiguousStrides(MemRefType memRefType) { 883 int64_t offset; 884 SmallVector<int64_t, 4> strides; 885 if (failed(getStridesAndOffset(memRefType, strides, offset))) 886 return None; 887 if (!strides.empty() && strides.back() != 1) 888 return None; 889 // If no layout or identity layout, this is contiguous by definition. 890 if (memRefType.getLayout().isIdentity()) 891 return strides; 892 893 // Otherwise, we must determine contiguity form shapes. This can only ever 894 // work in static cases because MemRefType is underspecified to represent 895 // contiguous dynamic shapes in other ways than with just empty/identity 896 // layout. 897 auto sizes = memRefType.getShape(); 898 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 899 if (ShapedType::isDynamic(sizes[index + 1]) || 900 ShapedType::isDynamicStrideOrOffset(strides[index]) || 901 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 902 return None; 903 if (strides[index] != strides[index + 1] * sizes[index + 1]) 904 return None; 905 } 906 return strides; 907 } 908 909 class VectorTypeCastOpConversion 910 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 911 public: 912 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 913 914 LogicalResult 915 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 916 ConversionPatternRewriter &rewriter) const override { 917 auto loc = castOp->getLoc(); 918 MemRefType sourceMemRefType = 919 castOp.getOperand().getType().cast<MemRefType>(); 920 MemRefType targetMemRefType = castOp.getType(); 921 922 // Only static shape casts supported atm. 923 if (!sourceMemRefType.hasStaticShape() || 924 !targetMemRefType.hasStaticShape()) 925 return failure(); 926 927 auto llvmSourceDescriptorTy = 928 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 929 if (!llvmSourceDescriptorTy) 930 return failure(); 931 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 932 933 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 934 .dyn_cast_or_null<LLVM::LLVMStructType>(); 935 if (!llvmTargetDescriptorTy) 936 return failure(); 937 938 // Only contiguous source buffers supported atm. 939 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 940 if (!sourceStrides) 941 return failure(); 942 auto targetStrides = computeContiguousStrides(targetMemRefType); 943 if (!targetStrides) 944 return failure(); 945 // Only support static strides for now, regardless of contiguity. 946 if (llvm::any_of(*targetStrides, ShapedType::isDynamicStrideOrOffset)) 947 return failure(); 948 949 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 950 951 // Create descriptor. 952 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 953 Type llvmTargetElementTy = desc.getElementPtrType(); 954 // Set allocated ptr. 955 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 956 allocated = 957 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 958 desc.setAllocatedPtr(rewriter, loc, allocated); 959 // Set aligned ptr. 960 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 961 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 962 desc.setAlignedPtr(rewriter, loc, ptr); 963 // Fill offset 0. 964 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 965 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 966 desc.setOffset(rewriter, loc, zero); 967 968 // Fill size and stride descriptors in memref. 969 for (const auto &indexedSize : 970 llvm::enumerate(targetMemRefType.getShape())) { 971 int64_t index = indexedSize.index(); 972 auto sizeAttr = 973 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 974 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 975 desc.setSize(rewriter, loc, index, size); 976 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 977 (*targetStrides)[index]); 978 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 979 desc.setStride(rewriter, loc, index, stride); 980 } 981 982 rewriter.replaceOp(castOp, {desc}); 983 return success(); 984 } 985 }; 986 987 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 988 /// Non-scalable versions of this operation are handled in Vector Transforms. 989 class VectorCreateMaskOpRewritePattern 990 : public OpRewritePattern<vector::CreateMaskOp> { 991 public: 992 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 993 bool enableIndexOpt) 994 : OpRewritePattern<vector::CreateMaskOp>(context), 995 force32BitVectorIndices(enableIndexOpt) {} 996 997 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 998 PatternRewriter &rewriter) const override { 999 auto dstType = op.getType(); 1000 if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable()) 1001 return failure(); 1002 IntegerType idxType = 1003 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1004 auto loc = op->getLoc(); 1005 Value indices = rewriter.create<LLVM::StepVectorOp>( 1006 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1007 /*isScalable=*/true)); 1008 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1009 op.getOperand(0)); 1010 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1011 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1012 indices, bounds); 1013 rewriter.replaceOp(op, comp); 1014 return success(); 1015 } 1016 1017 private: 1018 const bool force32BitVectorIndices; 1019 }; 1020 1021 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1022 public: 1023 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1024 1025 // Proof-of-concept lowering implementation that relies on a small 1026 // runtime support library, which only needs to provide a few 1027 // printing methods (single value for all data types, opening/closing 1028 // bracket, comma, newline). The lowering fully unrolls a vector 1029 // in terms of these elementary printing operations. The advantage 1030 // of this approach is that the library can remain unaware of all 1031 // low-level implementation details of vectors while still supporting 1032 // output of any shaped and dimensioned vector. Due to full unrolling, 1033 // this approach is less suited for very large vectors though. 1034 // 1035 // TODO: rely solely on libc in future? something else? 1036 // 1037 LogicalResult 1038 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1039 ConversionPatternRewriter &rewriter) const override { 1040 Type printType = printOp.getPrintType(); 1041 1042 if (typeConverter->convertType(printType) == nullptr) 1043 return failure(); 1044 1045 // Make sure element type has runtime support. 1046 PrintConversion conversion = PrintConversion::None; 1047 VectorType vectorType = printType.dyn_cast<VectorType>(); 1048 Type eltType = vectorType ? vectorType.getElementType() : printType; 1049 Operation *printer; 1050 if (eltType.isF32()) { 1051 printer = 1052 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 1053 } else if (eltType.isF64()) { 1054 printer = 1055 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 1056 } else if (eltType.isIndex()) { 1057 printer = 1058 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 1059 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1060 // Integers need a zero or sign extension on the operand 1061 // (depending on the source type) as well as a signed or 1062 // unsigned print method. Up to 64-bit is supported. 1063 unsigned width = intTy.getWidth(); 1064 if (intTy.isUnsigned()) { 1065 if (width <= 64) { 1066 if (width < 64) 1067 conversion = PrintConversion::ZeroExt64; 1068 printer = LLVM::lookupOrCreatePrintU64Fn( 1069 printOp->getParentOfType<ModuleOp>()); 1070 } else { 1071 return failure(); 1072 } 1073 } else { 1074 assert(intTy.isSignless() || intTy.isSigned()); 1075 if (width <= 64) { 1076 // Note that we *always* zero extend booleans (1-bit integers), 1077 // so that true/false is printed as 1/0 rather than -1/0. 1078 if (width == 1) 1079 conversion = PrintConversion::ZeroExt64; 1080 else if (width < 64) 1081 conversion = PrintConversion::SignExt64; 1082 printer = LLVM::lookupOrCreatePrintI64Fn( 1083 printOp->getParentOfType<ModuleOp>()); 1084 } else { 1085 return failure(); 1086 } 1087 } 1088 } else { 1089 return failure(); 1090 } 1091 1092 // Unroll vector into elementary print calls. 1093 int64_t rank = vectorType ? vectorType.getRank() : 0; 1094 Type type = vectorType ? vectorType : eltType; 1095 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1096 conversion); 1097 emitCall(rewriter, printOp->getLoc(), 1098 LLVM::lookupOrCreatePrintNewlineFn( 1099 printOp->getParentOfType<ModuleOp>())); 1100 rewriter.eraseOp(printOp); 1101 return success(); 1102 } 1103 1104 private: 1105 enum class PrintConversion { 1106 // clang-format off 1107 None, 1108 ZeroExt64, 1109 SignExt64 1110 // clang-format on 1111 }; 1112 1113 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1114 Value value, Type type, Operation *printer, int64_t rank, 1115 PrintConversion conversion) const { 1116 VectorType vectorType = type.dyn_cast<VectorType>(); 1117 Location loc = op->getLoc(); 1118 if (!vectorType) { 1119 assert(rank == 0 && "The scalar case expects rank == 0"); 1120 switch (conversion) { 1121 case PrintConversion::ZeroExt64: 1122 value = rewriter.create<arith::ExtUIOp>( 1123 loc, IntegerType::get(rewriter.getContext(), 64), value); 1124 break; 1125 case PrintConversion::SignExt64: 1126 value = rewriter.create<arith::ExtSIOp>( 1127 loc, IntegerType::get(rewriter.getContext(), 64), value); 1128 break; 1129 case PrintConversion::None: 1130 break; 1131 } 1132 emitCall(rewriter, loc, printer, value); 1133 return; 1134 } 1135 1136 emitCall(rewriter, loc, 1137 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1138 Operation *printComma = 1139 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1140 1141 if (rank <= 1) { 1142 auto reducedType = vectorType.getElementType(); 1143 auto llvmType = typeConverter->convertType(reducedType); 1144 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1145 for (int64_t d = 0; d < dim; ++d) { 1146 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1147 llvmType, /*rank=*/0, /*pos=*/d); 1148 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1149 conversion); 1150 if (d != dim - 1) 1151 emitCall(rewriter, loc, printComma); 1152 } 1153 emitCall( 1154 rewriter, loc, 1155 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1156 return; 1157 } 1158 1159 int64_t dim = vectorType.getDimSize(0); 1160 for (int64_t d = 0; d < dim; ++d) { 1161 auto reducedType = reducedVectorTypeFront(vectorType); 1162 auto llvmType = typeConverter->convertType(reducedType); 1163 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1164 llvmType, rank, d); 1165 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1166 conversion); 1167 if (d != dim - 1) 1168 emitCall(rewriter, loc, printComma); 1169 } 1170 emitCall(rewriter, loc, 1171 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1172 } 1173 1174 // Helper to emit a call. 1175 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1176 Operation *ref, ValueRange params = ValueRange()) { 1177 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1178 params); 1179 } 1180 }; 1181 1182 /// The Splat operation is lowered to an insertelement + a shufflevector 1183 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1184 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1185 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1186 1187 LogicalResult 1188 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1189 ConversionPatternRewriter &rewriter) const override { 1190 VectorType resultType = splatOp.getType().cast<VectorType>(); 1191 if (resultType.getRank() > 1) 1192 return failure(); 1193 1194 // First insert it into an undef vector so we can shuffle it. 1195 auto vectorType = typeConverter->convertType(splatOp.getType()); 1196 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1197 auto zero = rewriter.create<LLVM::ConstantOp>( 1198 splatOp.getLoc(), 1199 typeConverter->convertType(rewriter.getIntegerType(32)), 1200 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1201 1202 // For 0-d vector, we simply do `insertelement`. 1203 if (resultType.getRank() == 0) { 1204 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1205 splatOp, vectorType, undef, adaptor.getInput(), zero); 1206 return success(); 1207 } 1208 1209 // For 1-d vector, we additionally do a `vectorshuffle`. 1210 auto v = rewriter.create<LLVM::InsertElementOp>( 1211 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1212 1213 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); 1214 SmallVector<int32_t> zeroValues(width, 0); 1215 1216 // Shuffle the value across the desired number of elements. 1217 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1218 zeroValues); 1219 return success(); 1220 } 1221 }; 1222 1223 /// The Splat operation is lowered to an insertelement + a shufflevector 1224 /// operation. Splat to only 2+-d vector result types are lowered by the 1225 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1226 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1227 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1228 1229 LogicalResult 1230 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1231 ConversionPatternRewriter &rewriter) const override { 1232 VectorType resultType = splatOp.getType(); 1233 if (resultType.getRank() <= 1) 1234 return failure(); 1235 1236 // First insert it into an undef vector so we can shuffle it. 1237 auto loc = splatOp.getLoc(); 1238 auto vectorTypeInfo = 1239 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1240 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1241 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1242 if (!llvmNDVectorTy || !llvm1DVectorTy) 1243 return failure(); 1244 1245 // Construct returned value. 1246 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1247 1248 // Construct a 1-D vector with the splatted value that we insert in all the 1249 // places within the returned descriptor. 1250 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1251 auto zero = rewriter.create<LLVM::ConstantOp>( 1252 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1253 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1254 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1255 adaptor.getInput(), zero); 1256 1257 // Shuffle the value across the desired number of elements. 1258 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1259 SmallVector<int32_t> zeroValues(width, 0); 1260 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1261 1262 // Iterate of linear index, convert to coords space and insert splatted 1-D 1263 // vector in each position. 1264 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1265 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1266 }); 1267 rewriter.replaceOp(splatOp, desc); 1268 return success(); 1269 } 1270 }; 1271 1272 } // namespace 1273 1274 /// Populate the given list with patterns that convert from Vector to LLVM. 1275 void mlir::populateVectorToLLVMConversionPatterns( 1276 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1277 bool reassociateFPReductions, bool force32BitVectorIndices) { 1278 MLIRContext *ctx = converter.getDialect()->getContext(); 1279 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1280 populateVectorInsertExtractStridedSliceTransforms(patterns); 1281 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1282 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1283 patterns 1284 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1285 VectorExtractElementOpConversion, VectorExtractOpConversion, 1286 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1287 VectorInsertOpConversion, VectorPrintOpConversion, 1288 VectorTypeCastOpConversion, VectorScaleOpConversion, 1289 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1290 VectorLoadStoreConversion<vector::MaskedLoadOp, 1291 vector::MaskedLoadOpAdaptor>, 1292 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1293 VectorLoadStoreConversion<vector::MaskedStoreOp, 1294 vector::MaskedStoreOpAdaptor>, 1295 VectorGatherOpConversion, VectorScatterOpConversion, 1296 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1297 VectorSplatOpLowering, VectorSplatNdOpLowering>(converter); 1298 // Transfer ops with rank > 1 are handled by VectorToSCF. 1299 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1300 } 1301 1302 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1303 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1304 patterns.add<VectorMatmulOpConversion>(converter); 1305 patterns.add<VectorFlatTransposeOpConversion>(converter); 1306 } 1307