1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Arith/Utils/Utils.h" 15 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/Support/Casting.h" 26 #include <optional> 27 28 using namespace mlir; 29 using namespace mlir::vector; 30 31 // Helper to reduce vector type by one rank at front. 32 static VectorType reducedVectorTypeFront(VectorType tp) { 33 assert((tp.getRank() > 1) && "unlowerable vector type"); 34 unsigned numScalableDims = tp.getNumScalableDims(); 35 if (tp.getShape().size() == numScalableDims) 36 --numScalableDims; 37 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 38 numScalableDims); 39 } 40 41 // Helper to reduce vector type by *all* but one rank at back. 42 static VectorType reducedVectorTypeBack(VectorType tp) { 43 assert((tp.getRank() > 1) && "unlowerable vector type"); 44 unsigned numScalableDims = tp.getNumScalableDims(); 45 if (numScalableDims > 0) 46 --numScalableDims; 47 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 48 numScalableDims); 49 } 50 51 // Helper that picks the proper sequence for inserting. 52 static Value insertOne(ConversionPatternRewriter &rewriter, 53 LLVMTypeConverter &typeConverter, Location loc, 54 Value val1, Value val2, Type llvmType, int64_t rank, 55 int64_t pos) { 56 assert(rank > 0 && "0-D vector corner case should have been handled already"); 57 if (rank == 1) { 58 auto idxType = rewriter.getIndexType(); 59 auto constant = rewriter.create<LLVM::ConstantOp>( 60 loc, typeConverter.convertType(idxType), 61 rewriter.getIntegerAttr(idxType, pos)); 62 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 63 constant); 64 } 65 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 66 } 67 68 // Helper that picks the proper sequence for extracting. 69 static Value extractOne(ConversionPatternRewriter &rewriter, 70 LLVMTypeConverter &typeConverter, Location loc, 71 Value val, Type llvmType, int64_t rank, int64_t pos) { 72 if (rank <= 1) { 73 auto idxType = rewriter.getIndexType(); 74 auto constant = rewriter.create<LLVM::ConstantOp>( 75 loc, typeConverter.convertType(idxType), 76 rewriter.getIntegerAttr(idxType, pos)); 77 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 78 constant); 79 } 80 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 81 } 82 83 // Helper that returns data layout alignment of a memref. 84 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 85 MemRefType memrefType, unsigned &align) { 86 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 87 if (!elementTy) 88 return failure(); 89 90 // TODO: this should use the MLIR data layout when it becomes available and 91 // stop depending on translation. 92 llvm::LLVMContext llvmContext; 93 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 94 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 95 return success(); 96 } 97 98 // Check if the last stride is non-unit or the memory space is not zero. 99 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 100 LLVMTypeConverter &converter) { 101 int64_t offset; 102 SmallVector<int64_t, 4> strides; 103 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 104 FailureOr<unsigned> addressSpace = 105 converter.getMemRefAddressSpace(memRefType); 106 if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) || 107 *addressSpace != 0) 108 return failure(); 109 return success(); 110 } 111 112 // Add an index vector component to a base pointer. 113 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 114 LLVMTypeConverter &typeConverter, 115 MemRefType memRefType, Value llvmMemref, Value base, 116 Value index, uint64_t vLen) { 117 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 118 "unsupported memref type"); 119 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 120 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 121 return rewriter.create<LLVM::GEPOp>( 122 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 123 base, index); 124 } 125 126 // Casts a strided element pointer to a vector pointer. The vector pointer 127 // will be in the same address space as the incoming memref type. 128 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 129 Value ptr, MemRefType memRefType, Type vt, 130 LLVMTypeConverter &converter) { 131 if (converter.useOpaquePointers()) 132 return ptr; 133 134 unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); 135 auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); 136 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 137 } 138 139 namespace { 140 141 /// Trivial Vector to LLVM conversions 142 using VectorScaleOpConversion = 143 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 144 145 /// Conversion pattern for a vector.bitcast. 146 class VectorBitCastOpConversion 147 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 148 public: 149 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 150 151 LogicalResult 152 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 153 ConversionPatternRewriter &rewriter) const override { 154 // Only 0-D and 1-D vectors can be lowered to LLVM. 155 VectorType resultTy = bitCastOp.getResultVectorType(); 156 if (resultTy.getRank() > 1) 157 return failure(); 158 Type newResultTy = typeConverter->convertType(resultTy); 159 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 160 adaptor.getOperands()[0]); 161 return success(); 162 } 163 }; 164 165 /// Conversion pattern for a vector.matrix_multiply. 166 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 167 class VectorMatmulOpConversion 168 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 169 public: 170 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 171 172 LogicalResult 173 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 174 ConversionPatternRewriter &rewriter) const override { 175 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 176 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 177 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 178 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 179 return success(); 180 } 181 }; 182 183 /// Conversion pattern for a vector.flat_transpose. 184 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 185 class VectorFlatTransposeOpConversion 186 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 187 public: 188 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 189 190 LogicalResult 191 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 192 ConversionPatternRewriter &rewriter) const override { 193 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 194 transOp, typeConverter->convertType(transOp.getRes().getType()), 195 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 196 return success(); 197 } 198 }; 199 200 /// Overloaded utility that replaces a vector.load, vector.store, 201 /// vector.maskedload and vector.maskedstore with their respective LLVM 202 /// couterparts. 203 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 204 vector::LoadOpAdaptor adaptor, 205 VectorType vectorTy, Value ptr, unsigned align, 206 ConversionPatternRewriter &rewriter) { 207 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 208 } 209 210 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 211 vector::MaskedLoadOpAdaptor adaptor, 212 VectorType vectorTy, Value ptr, unsigned align, 213 ConversionPatternRewriter &rewriter) { 214 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 215 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 216 } 217 218 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 219 vector::StoreOpAdaptor adaptor, 220 VectorType vectorTy, Value ptr, unsigned align, 221 ConversionPatternRewriter &rewriter) { 222 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 223 ptr, align); 224 } 225 226 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 227 vector::MaskedStoreOpAdaptor adaptor, 228 VectorType vectorTy, Value ptr, unsigned align, 229 ConversionPatternRewriter &rewriter) { 230 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 231 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 232 } 233 234 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 235 /// vector.maskedstore. 236 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 237 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 238 public: 239 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 240 241 LogicalResult 242 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 243 typename LoadOrStoreOp::Adaptor adaptor, 244 ConversionPatternRewriter &rewriter) const override { 245 // Only 1-D vectors can be lowered to LLVM. 246 VectorType vectorTy = loadOrStoreOp.getVectorType(); 247 if (vectorTy.getRank() > 1) 248 return failure(); 249 250 auto loc = loadOrStoreOp->getLoc(); 251 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 252 253 // Resolve alignment. 254 unsigned align; 255 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 256 return failure(); 257 258 // Resolve address. 259 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 260 .template cast<VectorType>(); 261 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 262 adaptor.getIndices(), rewriter); 263 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, 264 *this->getTypeConverter()); 265 266 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 267 return success(); 268 } 269 }; 270 271 /// Conversion pattern for a vector.gather. 272 class VectorGatherOpConversion 273 : public ConvertOpToLLVMPattern<vector::GatherOp> { 274 public: 275 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 276 277 LogicalResult 278 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 279 ConversionPatternRewriter &rewriter) const override { 280 MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>(); 281 assert(memRefType && "The base should be bufferized"); 282 283 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 284 return failure(); 285 286 auto loc = gather->getLoc(); 287 288 // Resolve alignment. 289 unsigned align; 290 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 291 return failure(); 292 293 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 294 adaptor.getIndices(), rewriter); 295 Value base = adaptor.getBase(); 296 297 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 298 // Handle the simple case of 1-D vector. 299 if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) { 300 auto vType = gather.getVectorType(); 301 // Resolve address. 302 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 303 memRefType, base, ptr, adaptor.getIndexVec(), 304 /*vLen=*/vType.getDimSize(0)); 305 // Replace with the gather intrinsic. 306 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 307 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 308 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 309 return success(); 310 } 311 312 LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 313 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 314 &typeConverter](Type llvm1DVectorTy, 315 ValueRange vectorOperands) { 316 // Resolve address. 317 Value ptrs = getIndexedPtrs( 318 rewriter, loc, typeConverter, memRefType, base, ptr, 319 /*index=*/vectorOperands[0], 320 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 321 // Create the gather intrinsic. 322 return rewriter.create<LLVM::masked_gather>( 323 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 324 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 325 }; 326 SmallVector<Value> vectorOperands = { 327 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 328 return LLVM::detail::handleMultidimensionalVectors( 329 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 330 } 331 }; 332 333 /// Conversion pattern for a vector.scatter. 334 class VectorScatterOpConversion 335 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 336 public: 337 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 338 339 LogicalResult 340 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 341 ConversionPatternRewriter &rewriter) const override { 342 auto loc = scatter->getLoc(); 343 MemRefType memRefType = scatter.getMemRefType(); 344 345 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 346 return failure(); 347 348 // Resolve alignment. 349 unsigned align; 350 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 351 return failure(); 352 353 // Resolve address. 354 VectorType vType = scatter.getVectorType(); 355 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 356 adaptor.getIndices(), rewriter); 357 Value ptrs = getIndexedPtrs( 358 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 359 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 360 361 // Replace with the scatter intrinsic. 362 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 363 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 364 rewriter.getI32IntegerAttr(align)); 365 return success(); 366 } 367 }; 368 369 /// Conversion pattern for a vector.expandload. 370 class VectorExpandLoadOpConversion 371 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 372 public: 373 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 374 375 LogicalResult 376 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 377 ConversionPatternRewriter &rewriter) const override { 378 auto loc = expand->getLoc(); 379 MemRefType memRefType = expand.getMemRefType(); 380 381 // Resolve address. 382 auto vtype = typeConverter->convertType(expand.getVectorType()); 383 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 384 adaptor.getIndices(), rewriter); 385 386 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 387 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 388 return success(); 389 } 390 }; 391 392 /// Conversion pattern for a vector.compressstore. 393 class VectorCompressStoreOpConversion 394 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 395 public: 396 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 397 398 LogicalResult 399 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 400 ConversionPatternRewriter &rewriter) const override { 401 auto loc = compress->getLoc(); 402 MemRefType memRefType = compress.getMemRefType(); 403 404 // Resolve address. 405 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 406 adaptor.getIndices(), rewriter); 407 408 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 409 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 410 return success(); 411 } 412 }; 413 414 /// Reduction neutral classes for overloading. 415 class ReductionNeutralZero {}; 416 class ReductionNeutralIntOne {}; 417 class ReductionNeutralFPOne {}; 418 class ReductionNeutralAllOnes {}; 419 class ReductionNeutralSIntMin {}; 420 class ReductionNeutralUIntMin {}; 421 class ReductionNeutralSIntMax {}; 422 class ReductionNeutralUIntMax {}; 423 class ReductionNeutralFPMin {}; 424 class ReductionNeutralFPMax {}; 425 426 /// Create the reduction neutral zero value. 427 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 428 ConversionPatternRewriter &rewriter, 429 Location loc, Type llvmType) { 430 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 431 rewriter.getZeroAttr(llvmType)); 432 } 433 434 /// Create the reduction neutral integer one value. 435 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 436 ConversionPatternRewriter &rewriter, 437 Location loc, Type llvmType) { 438 return rewriter.create<LLVM::ConstantOp>( 439 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 440 } 441 442 /// Create the reduction neutral fp one value. 443 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 444 ConversionPatternRewriter &rewriter, 445 Location loc, Type llvmType) { 446 return rewriter.create<LLVM::ConstantOp>( 447 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 448 } 449 450 /// Create the reduction neutral all-ones value. 451 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 452 ConversionPatternRewriter &rewriter, 453 Location loc, Type llvmType) { 454 return rewriter.create<LLVM::ConstantOp>( 455 loc, llvmType, 456 rewriter.getIntegerAttr( 457 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 458 } 459 460 /// Create the reduction neutral signed int minimum value. 461 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 462 ConversionPatternRewriter &rewriter, 463 Location loc, Type llvmType) { 464 return rewriter.create<LLVM::ConstantOp>( 465 loc, llvmType, 466 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 467 llvmType.getIntOrFloatBitWidth()))); 468 } 469 470 /// Create the reduction neutral unsigned int minimum value. 471 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 472 ConversionPatternRewriter &rewriter, 473 Location loc, Type llvmType) { 474 return rewriter.create<LLVM::ConstantOp>( 475 loc, llvmType, 476 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 477 llvmType.getIntOrFloatBitWidth()))); 478 } 479 480 /// Create the reduction neutral signed int maximum value. 481 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 482 ConversionPatternRewriter &rewriter, 483 Location loc, Type llvmType) { 484 return rewriter.create<LLVM::ConstantOp>( 485 loc, llvmType, 486 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 487 llvmType.getIntOrFloatBitWidth()))); 488 } 489 490 /// Create the reduction neutral unsigned int maximum value. 491 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 492 ConversionPatternRewriter &rewriter, 493 Location loc, Type llvmType) { 494 return rewriter.create<LLVM::ConstantOp>( 495 loc, llvmType, 496 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 497 llvmType.getIntOrFloatBitWidth()))); 498 } 499 500 /// Create the reduction neutral fp minimum value. 501 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 502 ConversionPatternRewriter &rewriter, 503 Location loc, Type llvmType) { 504 auto floatType = llvmType.cast<FloatType>(); 505 return rewriter.create<LLVM::ConstantOp>( 506 loc, llvmType, 507 rewriter.getFloatAttr( 508 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 509 /*Negative=*/false))); 510 } 511 512 /// Create the reduction neutral fp maximum value. 513 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 514 ConversionPatternRewriter &rewriter, 515 Location loc, Type llvmType) { 516 auto floatType = llvmType.cast<FloatType>(); 517 return rewriter.create<LLVM::ConstantOp>( 518 loc, llvmType, 519 rewriter.getFloatAttr( 520 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 521 /*Negative=*/true))); 522 } 523 524 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 525 /// returns a new accumulator value using `ReductionNeutral`. 526 template <class ReductionNeutral> 527 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 528 Location loc, Type llvmType, 529 Value accumulator) { 530 if (accumulator) 531 return accumulator; 532 533 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 534 llvmType); 535 } 536 537 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 538 /// This is used as effective vector length by some intrinsics supporting 539 /// dynamic vector lengths at runtime. 540 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 541 Location loc, Type llvmType) { 542 VectorType vType = cast<VectorType>(llvmType); 543 auto vShape = vType.getShape(); 544 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 545 546 return rewriter.create<LLVM::ConstantOp>( 547 loc, rewriter.getI32Type(), 548 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 549 } 550 551 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 552 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 553 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 554 /// non-null. 555 template <class LLVMRedIntrinOp, class ScalarOp> 556 static Value createIntegerReductionArithmeticOpLowering( 557 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 558 Value vectorOperand, Value accumulator) { 559 560 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 561 562 if (accumulator) 563 result = rewriter.create<ScalarOp>(loc, accumulator, result); 564 return result; 565 } 566 567 /// Helper method to lower a `vector.reduction` operation that performs 568 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 569 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 570 /// the accumulator value if non-null. 571 template <class LLVMRedIntrinOp> 572 static Value createIntegerReductionComparisonOpLowering( 573 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 574 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 575 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 576 if (accumulator) { 577 Value cmp = 578 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 579 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 580 } 581 return result; 582 } 583 584 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum 585 /// with vector types. 586 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, 587 Value rhs, bool isMin) { 588 auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); 589 Type i1Type = builder.getI1Type(); 590 if (auto vecType = lhs.getType().dyn_cast<VectorType>()) 591 i1Type = VectorType::get(vecType.getShape(), i1Type); 592 Value cmp = builder.create<LLVM::FCmpOp>( 593 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, 594 lhs, rhs); 595 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); 596 Value isNan = builder.create<LLVM::FCmpOp>( 597 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); 598 Value nan = builder.create<LLVM::ConstantOp>( 599 loc, lhs.getType(), 600 builder.getFloatAttr(floatType, 601 APFloat::getQNaN(floatType.getFloatSemantics()))); 602 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); 603 } 604 605 template <class LLVMRedIntrinOp> 606 static Value createFPReductionComparisonOpLowering( 607 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 608 Value vectorOperand, Value accumulator, bool isMin) { 609 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 610 611 if (accumulator) 612 result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin); 613 614 return result; 615 } 616 617 /// Overloaded methods to lower a reduction to an llvm instrinsic that requires 618 /// a start value. This start value format spans across fp reductions without 619 /// mask and all the masked reduction intrinsics. 620 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 621 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 622 Location loc, Type llvmType, 623 Value vectorOperand, 624 Value accumulator) { 625 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 626 llvmType, accumulator); 627 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 628 /*startValue=*/accumulator, 629 vectorOperand); 630 } 631 632 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 633 static Value 634 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 635 Type llvmType, Value vectorOperand, 636 Value accumulator, bool reassociateFPReds) { 637 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 638 llvmType, accumulator); 639 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 640 /*startValue=*/accumulator, 641 vectorOperand, reassociateFPReds); 642 } 643 644 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 645 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 646 Location loc, Type llvmType, 647 Value vectorOperand, 648 Value accumulator, Value mask) { 649 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 650 llvmType, accumulator); 651 Value vectorLength = 652 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 653 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 654 /*startValue=*/accumulator, 655 vectorOperand, mask, vectorLength); 656 } 657 658 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 659 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 660 Location loc, Type llvmType, 661 Value vectorOperand, 662 Value accumulator, Value mask, 663 bool reassociateFPReds) { 664 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 665 llvmType, accumulator); 666 Value vectorLength = 667 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 668 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 669 /*startValue=*/accumulator, 670 vectorOperand, mask, vectorLength, 671 reassociateFPReds); 672 } 673 674 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 675 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 676 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 677 Location loc, Type llvmType, 678 Value vectorOperand, 679 Value accumulator, Value mask) { 680 if (llvmType.isIntOrIndex()) 681 return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp, 682 IntReductionNeutral>( 683 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 684 685 // FP dispatch. 686 return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>( 687 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 688 } 689 690 /// Conversion pattern for all vector reductions. 691 class VectorReductionOpConversion 692 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 693 public: 694 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 695 bool reassociateFPRed) 696 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 697 reassociateFPReductions(reassociateFPRed) {} 698 699 LogicalResult 700 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 701 ConversionPatternRewriter &rewriter) const override { 702 auto kind = reductionOp.getKind(); 703 Type eltType = reductionOp.getDest().getType(); 704 Type llvmType = typeConverter->convertType(eltType); 705 Value operand = adaptor.getVector(); 706 Value acc = adaptor.getAcc(); 707 Location loc = reductionOp.getLoc(); 708 709 if (eltType.isIntOrIndex()) { 710 // Integer reductions: add/mul/min/max/and/or/xor. 711 Value result; 712 switch (kind) { 713 case vector::CombiningKind::ADD: 714 result = 715 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 716 LLVM::AddOp>( 717 rewriter, loc, llvmType, operand, acc); 718 break; 719 case vector::CombiningKind::MUL: 720 result = 721 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 722 LLVM::MulOp>( 723 rewriter, loc, llvmType, operand, acc); 724 break; 725 case vector::CombiningKind::MINUI: 726 result = createIntegerReductionComparisonOpLowering< 727 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 728 LLVM::ICmpPredicate::ule); 729 break; 730 case vector::CombiningKind::MINSI: 731 result = createIntegerReductionComparisonOpLowering< 732 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 733 LLVM::ICmpPredicate::sle); 734 break; 735 case vector::CombiningKind::MAXUI: 736 result = createIntegerReductionComparisonOpLowering< 737 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 738 LLVM::ICmpPredicate::uge); 739 break; 740 case vector::CombiningKind::MAXSI: 741 result = createIntegerReductionComparisonOpLowering< 742 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 743 LLVM::ICmpPredicate::sge); 744 break; 745 case vector::CombiningKind::AND: 746 result = 747 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 748 LLVM::AndOp>( 749 rewriter, loc, llvmType, operand, acc); 750 break; 751 case vector::CombiningKind::OR: 752 result = 753 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 754 LLVM::OrOp>( 755 rewriter, loc, llvmType, operand, acc); 756 break; 757 case vector::CombiningKind::XOR: 758 result = 759 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 760 LLVM::XOrOp>( 761 rewriter, loc, llvmType, operand, acc); 762 break; 763 default: 764 return failure(); 765 } 766 rewriter.replaceOp(reductionOp, result); 767 768 return success(); 769 } 770 771 if (!eltType.isa<FloatType>()) 772 return failure(); 773 774 // Floating-point reductions: add/mul/min/max 775 Value result; 776 if (kind == vector::CombiningKind::ADD) { 777 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 778 ReductionNeutralZero>( 779 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 780 } else if (kind == vector::CombiningKind::MUL) { 781 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 782 ReductionNeutralFPOne>( 783 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 784 } else if (kind == vector::CombiningKind::MINF) { 785 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 786 // NaNs/-0.0/+0.0 in the same way. 787 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 788 rewriter, loc, llvmType, operand, acc, 789 /*isMin=*/true); 790 } else if (kind == vector::CombiningKind::MAXF) { 791 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 792 // NaNs/-0.0/+0.0 in the same way. 793 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 794 rewriter, loc, llvmType, operand, acc, 795 /*isMin=*/false); 796 } else 797 return failure(); 798 799 rewriter.replaceOp(reductionOp, result); 800 return success(); 801 } 802 803 private: 804 const bool reassociateFPReductions; 805 }; 806 807 /// Base class to convert a `vector.mask` operation while matching traits 808 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 809 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 810 /// method performs a second match against the maskable operation `MaskedOp`. 811 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 812 /// implemented by the concrete conversion classes. This method can match 813 /// against specific traits of the `vector.mask` and the maskable operation. It 814 /// must replace the `vector.mask` operation. 815 template <class MaskedOp> 816 class VectorMaskOpConversionBase 817 : public ConvertOpToLLVMPattern<vector::MaskOp> { 818 public: 819 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 820 821 LogicalResult 822 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 823 ConversionPatternRewriter &rewriter) const final { 824 // Match against the maskable operation kind. 825 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 826 if (!maskedOp) 827 return failure(); 828 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 829 } 830 831 protected: 832 virtual LogicalResult 833 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 834 vector::MaskableOpInterface maskableOp, 835 ConversionPatternRewriter &rewriter) const = 0; 836 }; 837 838 class MaskedReductionOpConversion 839 : public VectorMaskOpConversionBase<vector::ReductionOp> { 840 841 public: 842 using VectorMaskOpConversionBase< 843 vector::ReductionOp>::VectorMaskOpConversionBase; 844 845 LogicalResult matchAndRewriteMaskableOp( 846 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 847 ConversionPatternRewriter &rewriter) const override { 848 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 849 auto kind = reductionOp.getKind(); 850 Type eltType = reductionOp.getDest().getType(); 851 Type llvmType = typeConverter->convertType(eltType); 852 Value operand = reductionOp.getVector(); 853 Value acc = reductionOp.getAcc(); 854 Location loc = reductionOp.getLoc(); 855 856 Value result; 857 switch (kind) { 858 case vector::CombiningKind::ADD: 859 result = lowerReductionWithStartValue< 860 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 861 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 862 maskOp.getMask()); 863 break; 864 case vector::CombiningKind::MUL: 865 result = lowerReductionWithStartValue< 866 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 867 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 868 maskOp.getMask()); 869 break; 870 case vector::CombiningKind::MINUI: 871 result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp, 872 ReductionNeutralUIntMax>( 873 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 874 break; 875 case vector::CombiningKind::MINSI: 876 result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp, 877 ReductionNeutralSIntMax>( 878 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 879 break; 880 case vector::CombiningKind::MAXUI: 881 result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp, 882 ReductionNeutralUIntMin>( 883 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 884 break; 885 case vector::CombiningKind::MAXSI: 886 result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp, 887 ReductionNeutralSIntMin>( 888 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 889 break; 890 case vector::CombiningKind::AND: 891 result = lowerReductionWithStartValue<LLVM::VPReduceAndOp, 892 ReductionNeutralAllOnes>( 893 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 894 break; 895 case vector::CombiningKind::OR: 896 result = lowerReductionWithStartValue<LLVM::VPReduceOrOp, 897 ReductionNeutralZero>( 898 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 899 break; 900 case vector::CombiningKind::XOR: 901 result = lowerReductionWithStartValue<LLVM::VPReduceXorOp, 902 ReductionNeutralZero>( 903 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 904 break; 905 case vector::CombiningKind::MINF: 906 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 907 // NaNs/-0.0/+0.0 in the same way. 908 result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp, 909 ReductionNeutralFPMax>( 910 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 911 break; 912 case vector::CombiningKind::MAXF: 913 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 914 // NaNs/-0.0/+0.0 in the same way. 915 result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp, 916 ReductionNeutralFPMin>( 917 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 918 break; 919 } 920 921 // Replace `vector.mask` operation altogether. 922 rewriter.replaceOp(maskOp, result); 923 return success(); 924 } 925 }; 926 927 class VectorShuffleOpConversion 928 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 929 public: 930 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 931 932 LogicalResult 933 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 934 ConversionPatternRewriter &rewriter) const override { 935 auto loc = shuffleOp->getLoc(); 936 auto v1Type = shuffleOp.getV1VectorType(); 937 auto v2Type = shuffleOp.getV2VectorType(); 938 auto vectorType = shuffleOp.getResultVectorType(); 939 Type llvmType = typeConverter->convertType(vectorType); 940 auto maskArrayAttr = shuffleOp.getMask(); 941 942 // Bail if result type cannot be lowered. 943 if (!llvmType) 944 return failure(); 945 946 // Get rank and dimension sizes. 947 int64_t rank = vectorType.getRank(); 948 #ifndef NDEBUG 949 bool wellFormed0DCase = 950 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 951 bool wellFormedNDCase = 952 v1Type.getRank() == rank && v2Type.getRank() == rank; 953 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 954 #endif 955 956 // For rank 0 and 1, where both operands have *exactly* the same vector 957 // type, there is direct shuffle support in LLVM. Use it! 958 if (rank <= 1 && v1Type == v2Type) { 959 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 960 loc, adaptor.getV1(), adaptor.getV2(), 961 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 962 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 963 return success(); 964 } 965 966 // For all other cases, insert the individual values individually. 967 int64_t v1Dim = v1Type.getDimSize(0); 968 Type eltType; 969 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 970 eltType = arrayType.getElementType(); 971 else 972 eltType = llvmType.cast<VectorType>().getElementType(); 973 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 974 int64_t insPos = 0; 975 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 976 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 977 Value value = adaptor.getV1(); 978 if (extPos >= v1Dim) { 979 extPos -= v1Dim; 980 value = adaptor.getV2(); 981 } 982 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 983 eltType, rank, extPos); 984 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 985 llvmType, rank, insPos++); 986 } 987 rewriter.replaceOp(shuffleOp, insert); 988 return success(); 989 } 990 }; 991 992 class VectorExtractElementOpConversion 993 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 994 public: 995 using ConvertOpToLLVMPattern< 996 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 997 998 LogicalResult 999 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 1000 ConversionPatternRewriter &rewriter) const override { 1001 auto vectorType = extractEltOp.getSourceVectorType(); 1002 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 1003 1004 // Bail if result type cannot be lowered. 1005 if (!llvmType) 1006 return failure(); 1007 1008 if (vectorType.getRank() == 0) { 1009 Location loc = extractEltOp.getLoc(); 1010 auto idxType = rewriter.getIndexType(); 1011 auto zero = rewriter.create<LLVM::ConstantOp>( 1012 loc, typeConverter->convertType(idxType), 1013 rewriter.getIntegerAttr(idxType, 0)); 1014 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1015 extractEltOp, llvmType, adaptor.getVector(), zero); 1016 return success(); 1017 } 1018 1019 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1020 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1021 return success(); 1022 } 1023 }; 1024 1025 class VectorExtractOpConversion 1026 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1027 public: 1028 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1029 1030 LogicalResult 1031 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1032 ConversionPatternRewriter &rewriter) const override { 1033 auto loc = extractOp->getLoc(); 1034 auto resultType = extractOp.getResult().getType(); 1035 auto llvmResultType = typeConverter->convertType(resultType); 1036 auto positionArrayAttr = extractOp.getPosition(); 1037 1038 // Bail if result type cannot be lowered. 1039 if (!llvmResultType) 1040 return failure(); 1041 1042 // Extract entire vector. Should be handled by folder, but just to be safe. 1043 if (positionArrayAttr.empty()) { 1044 rewriter.replaceOp(extractOp, adaptor.getVector()); 1045 return success(); 1046 } 1047 1048 // One-shot extraction of vector from array (only requires extractvalue). 1049 if (resultType.isa<VectorType>()) { 1050 SmallVector<int64_t> indices; 1051 for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>()) 1052 indices.push_back(idx.getInt()); 1053 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1054 loc, adaptor.getVector(), indices); 1055 rewriter.replaceOp(extractOp, extracted); 1056 return success(); 1057 } 1058 1059 // Potential extraction of 1-D vector from array. 1060 Value extracted = adaptor.getVector(); 1061 auto positionAttrs = positionArrayAttr.getValue(); 1062 if (positionAttrs.size() > 1) { 1063 SmallVector<int64_t> nMinusOnePosition; 1064 for (auto idx : positionAttrs.drop_back()) 1065 nMinusOnePosition.push_back(idx.cast<IntegerAttr>().getInt()); 1066 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 1067 nMinusOnePosition); 1068 } 1069 1070 // Remaining extraction of element from 1-D LLVM vector 1071 auto position = positionAttrs.back().cast<IntegerAttr>(); 1072 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1073 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 1074 extracted = 1075 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 1076 rewriter.replaceOp(extractOp, extracted); 1077 1078 return success(); 1079 } 1080 }; 1081 1082 /// Conversion pattern that turns a vector.fma on a 1-D vector 1083 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1084 /// This does not match vectors of n >= 2 rank. 1085 /// 1086 /// Example: 1087 /// ``` 1088 /// vector.fma %a, %a, %a : vector<8xf32> 1089 /// ``` 1090 /// is converted to: 1091 /// ``` 1092 /// llvm.intr.fmuladd %va, %va, %va: 1093 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1094 /// -> !llvm."<8 x f32>"> 1095 /// ``` 1096 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1097 public: 1098 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1099 1100 LogicalResult 1101 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1102 ConversionPatternRewriter &rewriter) const override { 1103 VectorType vType = fmaOp.getVectorType(); 1104 if (vType.getRank() > 1) 1105 return failure(); 1106 1107 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1108 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1109 return success(); 1110 } 1111 }; 1112 1113 class VectorInsertElementOpConversion 1114 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1115 public: 1116 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1117 1118 LogicalResult 1119 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1120 ConversionPatternRewriter &rewriter) const override { 1121 auto vectorType = insertEltOp.getDestVectorType(); 1122 auto llvmType = typeConverter->convertType(vectorType); 1123 1124 // Bail if result type cannot be lowered. 1125 if (!llvmType) 1126 return failure(); 1127 1128 if (vectorType.getRank() == 0) { 1129 Location loc = insertEltOp.getLoc(); 1130 auto idxType = rewriter.getIndexType(); 1131 auto zero = rewriter.create<LLVM::ConstantOp>( 1132 loc, typeConverter->convertType(idxType), 1133 rewriter.getIntegerAttr(idxType, 0)); 1134 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1135 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1136 return success(); 1137 } 1138 1139 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1140 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1141 adaptor.getPosition()); 1142 return success(); 1143 } 1144 }; 1145 1146 class VectorInsertOpConversion 1147 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1148 public: 1149 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1150 1151 LogicalResult 1152 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1153 ConversionPatternRewriter &rewriter) const override { 1154 auto loc = insertOp->getLoc(); 1155 auto sourceType = insertOp.getSourceType(); 1156 auto destVectorType = insertOp.getDestVectorType(); 1157 auto llvmResultType = typeConverter->convertType(destVectorType); 1158 auto positionArrayAttr = insertOp.getPosition(); 1159 1160 // Bail if result type cannot be lowered. 1161 if (!llvmResultType) 1162 return failure(); 1163 1164 // Overwrite entire vector with value. Should be handled by folder, but 1165 // just to be safe. 1166 if (positionArrayAttr.empty()) { 1167 rewriter.replaceOp(insertOp, adaptor.getSource()); 1168 return success(); 1169 } 1170 1171 // One-shot insertion of a vector into an array (only requires insertvalue). 1172 if (sourceType.isa<VectorType>()) { 1173 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1174 loc, adaptor.getDest(), adaptor.getSource(), 1175 LLVM::convertArrayToIndices(positionArrayAttr)); 1176 rewriter.replaceOp(insertOp, inserted); 1177 return success(); 1178 } 1179 1180 // Potential extraction of 1-D vector from array. 1181 Value extracted = adaptor.getDest(); 1182 auto positionAttrs = positionArrayAttr.getValue(); 1183 auto position = positionAttrs.back().cast<IntegerAttr>(); 1184 auto oneDVectorType = destVectorType; 1185 if (positionAttrs.size() > 1) { 1186 oneDVectorType = reducedVectorTypeBack(destVectorType); 1187 extracted = rewriter.create<LLVM::ExtractValueOp>( 1188 loc, extracted, 1189 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 1190 } 1191 1192 // Insertion of an element into a 1-D LLVM vector. 1193 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1194 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 1195 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1196 loc, typeConverter->convertType(oneDVectorType), extracted, 1197 adaptor.getSource(), constant); 1198 1199 // Potential insertion of resulting 1-D vector into array. 1200 if (positionAttrs.size() > 1) { 1201 inserted = rewriter.create<LLVM::InsertValueOp>( 1202 loc, adaptor.getDest(), inserted, 1203 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 1204 } 1205 1206 rewriter.replaceOp(insertOp, inserted); 1207 return success(); 1208 } 1209 }; 1210 1211 /// Lower vector.scalable.insert ops to LLVM vector.insert 1212 struct VectorScalableInsertOpLowering 1213 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1214 using ConvertOpToLLVMPattern< 1215 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1216 1217 LogicalResult 1218 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1219 ConversionPatternRewriter &rewriter) const override { 1220 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1221 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1222 return success(); 1223 } 1224 }; 1225 1226 /// Lower vector.scalable.extract ops to LLVM vector.extract 1227 struct VectorScalableExtractOpLowering 1228 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1229 using ConvertOpToLLVMPattern< 1230 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1231 1232 LogicalResult 1233 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1234 ConversionPatternRewriter &rewriter) const override { 1235 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1236 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1237 adaptor.getSource(), adaptor.getPos()); 1238 return success(); 1239 } 1240 }; 1241 1242 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1243 /// 1244 /// Example: 1245 /// ``` 1246 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1247 /// ``` 1248 /// is rewritten into: 1249 /// ``` 1250 /// %r = splat %f0: vector<2x4xf32> 1251 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1252 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1253 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1254 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1255 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1256 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1257 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1258 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1259 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1260 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1261 /// // %r3 holds the final value. 1262 /// ``` 1263 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1264 public: 1265 using OpRewritePattern<FMAOp>::OpRewritePattern; 1266 1267 void initialize() { 1268 // This pattern recursively unpacks one dimension at a time. The recursion 1269 // bounded as the rank is strictly decreasing. 1270 setHasBoundedRewriteRecursion(); 1271 } 1272 1273 LogicalResult matchAndRewrite(FMAOp op, 1274 PatternRewriter &rewriter) const override { 1275 auto vType = op.getVectorType(); 1276 if (vType.getRank() < 2) 1277 return failure(); 1278 1279 auto loc = op.getLoc(); 1280 auto elemType = vType.getElementType(); 1281 Value zero = rewriter.create<arith::ConstantOp>( 1282 loc, elemType, rewriter.getZeroAttr(elemType)); 1283 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1284 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1285 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1286 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1287 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1288 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1289 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1290 } 1291 rewriter.replaceOp(op, desc); 1292 return success(); 1293 } 1294 }; 1295 1296 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1297 /// static layout. 1298 static std::optional<SmallVector<int64_t, 4>> 1299 computeContiguousStrides(MemRefType memRefType) { 1300 int64_t offset; 1301 SmallVector<int64_t, 4> strides; 1302 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1303 return std::nullopt; 1304 if (!strides.empty() && strides.back() != 1) 1305 return std::nullopt; 1306 // If no layout or identity layout, this is contiguous by definition. 1307 if (memRefType.getLayout().isIdentity()) 1308 return strides; 1309 1310 // Otherwise, we must determine contiguity form shapes. This can only ever 1311 // work in static cases because MemRefType is underspecified to represent 1312 // contiguous dynamic shapes in other ways than with just empty/identity 1313 // layout. 1314 auto sizes = memRefType.getShape(); 1315 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1316 if (ShapedType::isDynamic(sizes[index + 1]) || 1317 ShapedType::isDynamic(strides[index]) || 1318 ShapedType::isDynamic(strides[index + 1])) 1319 return std::nullopt; 1320 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1321 return std::nullopt; 1322 } 1323 return strides; 1324 } 1325 1326 class VectorTypeCastOpConversion 1327 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1328 public: 1329 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1330 1331 LogicalResult 1332 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1333 ConversionPatternRewriter &rewriter) const override { 1334 auto loc = castOp->getLoc(); 1335 MemRefType sourceMemRefType = 1336 castOp.getOperand().getType().cast<MemRefType>(); 1337 MemRefType targetMemRefType = castOp.getType(); 1338 1339 // Only static shape casts supported atm. 1340 if (!sourceMemRefType.hasStaticShape() || 1341 !targetMemRefType.hasStaticShape()) 1342 return failure(); 1343 1344 auto llvmSourceDescriptorTy = 1345 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 1346 if (!llvmSourceDescriptorTy) 1347 return failure(); 1348 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1349 1350 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1351 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1352 if (!llvmTargetDescriptorTy) 1353 return failure(); 1354 1355 // Only contiguous source buffers supported atm. 1356 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1357 if (!sourceStrides) 1358 return failure(); 1359 auto targetStrides = computeContiguousStrides(targetMemRefType); 1360 if (!targetStrides) 1361 return failure(); 1362 // Only support static strides for now, regardless of contiguity. 1363 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1364 return failure(); 1365 1366 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1367 1368 // Create descriptor. 1369 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1370 Type llvmTargetElementTy = desc.getElementPtrType(); 1371 // Set allocated ptr. 1372 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1373 if (!getTypeConverter()->useOpaquePointers()) 1374 allocated = 1375 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1376 desc.setAllocatedPtr(rewriter, loc, allocated); 1377 1378 // Set aligned ptr. 1379 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1380 if (!getTypeConverter()->useOpaquePointers()) 1381 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1382 1383 desc.setAlignedPtr(rewriter, loc, ptr); 1384 // Fill offset 0. 1385 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1386 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1387 desc.setOffset(rewriter, loc, zero); 1388 1389 // Fill size and stride descriptors in memref. 1390 for (const auto &indexedSize : 1391 llvm::enumerate(targetMemRefType.getShape())) { 1392 int64_t index = indexedSize.index(); 1393 auto sizeAttr = 1394 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1395 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1396 desc.setSize(rewriter, loc, index, size); 1397 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1398 (*targetStrides)[index]); 1399 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1400 desc.setStride(rewriter, loc, index, stride); 1401 } 1402 1403 rewriter.replaceOp(castOp, {desc}); 1404 return success(); 1405 } 1406 }; 1407 1408 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1409 /// Non-scalable versions of this operation are handled in Vector Transforms. 1410 class VectorCreateMaskOpRewritePattern 1411 : public OpRewritePattern<vector::CreateMaskOp> { 1412 public: 1413 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1414 bool enableIndexOpt) 1415 : OpRewritePattern<vector::CreateMaskOp>(context), 1416 force32BitVectorIndices(enableIndexOpt) {} 1417 1418 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1419 PatternRewriter &rewriter) const override { 1420 auto dstType = op.getType(); 1421 if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable()) 1422 return failure(); 1423 IntegerType idxType = 1424 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1425 auto loc = op->getLoc(); 1426 Value indices = rewriter.create<LLVM::StepVectorOp>( 1427 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1428 /*isScalable=*/true)); 1429 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1430 op.getOperand(0)); 1431 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1432 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1433 indices, bounds); 1434 rewriter.replaceOp(op, comp); 1435 return success(); 1436 } 1437 1438 private: 1439 const bool force32BitVectorIndices; 1440 }; 1441 1442 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1443 public: 1444 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1445 1446 // Proof-of-concept lowering implementation that relies on a small 1447 // runtime support library, which only needs to provide a few 1448 // printing methods (single value for all data types, opening/closing 1449 // bracket, comma, newline). The lowering fully unrolls a vector 1450 // in terms of these elementary printing operations. The advantage 1451 // of this approach is that the library can remain unaware of all 1452 // low-level implementation details of vectors while still supporting 1453 // output of any shaped and dimensioned vector. Due to full unrolling, 1454 // this approach is less suited for very large vectors though. 1455 // 1456 // TODO: rely solely on libc in future? something else? 1457 // 1458 LogicalResult 1459 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1460 ConversionPatternRewriter &rewriter) const override { 1461 Type printType = printOp.getPrintType(); 1462 1463 if (typeConverter->convertType(printType) == nullptr) 1464 return failure(); 1465 1466 // Make sure element type has runtime support. 1467 PrintConversion conversion = PrintConversion::None; 1468 VectorType vectorType = printType.dyn_cast<VectorType>(); 1469 Type eltType = vectorType ? vectorType.getElementType() : printType; 1470 auto parent = printOp->getParentOfType<ModuleOp>(); 1471 Operation *printer; 1472 if (eltType.isF32()) { 1473 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1474 } else if (eltType.isF64()) { 1475 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1476 } else if (eltType.isF16()) { 1477 conversion = PrintConversion::Bitcast16; // bits! 1478 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1479 } else if (eltType.isBF16()) { 1480 conversion = PrintConversion::Bitcast16; // bits! 1481 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1482 } else if (eltType.isIndex()) { 1483 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1484 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1485 // Integers need a zero or sign extension on the operand 1486 // (depending on the source type) as well as a signed or 1487 // unsigned print method. Up to 64-bit is supported. 1488 unsigned width = intTy.getWidth(); 1489 if (intTy.isUnsigned()) { 1490 if (width <= 64) { 1491 if (width < 64) 1492 conversion = PrintConversion::ZeroExt64; 1493 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1494 } else { 1495 return failure(); 1496 } 1497 } else { 1498 assert(intTy.isSignless() || intTy.isSigned()); 1499 if (width <= 64) { 1500 // Note that we *always* zero extend booleans (1-bit integers), 1501 // so that true/false is printed as 1/0 rather than -1/0. 1502 if (width == 1) 1503 conversion = PrintConversion::ZeroExt64; 1504 else if (width < 64) 1505 conversion = PrintConversion::SignExt64; 1506 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1507 } else { 1508 return failure(); 1509 } 1510 } 1511 } else { 1512 return failure(); 1513 } 1514 1515 // Unroll vector into elementary print calls. 1516 int64_t rank = vectorType ? vectorType.getRank() : 0; 1517 Type type = vectorType ? vectorType : eltType; 1518 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1519 conversion); 1520 emitCall(rewriter, printOp->getLoc(), 1521 LLVM::lookupOrCreatePrintNewlineFn(parent)); 1522 rewriter.eraseOp(printOp); 1523 return success(); 1524 } 1525 1526 private: 1527 enum class PrintConversion { 1528 // clang-format off 1529 None, 1530 ZeroExt64, 1531 SignExt64, 1532 Bitcast16 1533 // clang-format on 1534 }; 1535 1536 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1537 Value value, Type type, Operation *printer, int64_t rank, 1538 PrintConversion conversion) const { 1539 VectorType vectorType = type.dyn_cast<VectorType>(); 1540 Location loc = op->getLoc(); 1541 if (!vectorType) { 1542 assert(rank == 0 && "The scalar case expects rank == 0"); 1543 switch (conversion) { 1544 case PrintConversion::ZeroExt64: 1545 value = rewriter.create<arith::ExtUIOp>( 1546 loc, IntegerType::get(rewriter.getContext(), 64), value); 1547 break; 1548 case PrintConversion::SignExt64: 1549 value = rewriter.create<arith::ExtSIOp>( 1550 loc, IntegerType::get(rewriter.getContext(), 64), value); 1551 break; 1552 case PrintConversion::Bitcast16: 1553 value = rewriter.create<LLVM::BitcastOp>( 1554 loc, IntegerType::get(rewriter.getContext(), 16), value); 1555 break; 1556 case PrintConversion::None: 1557 break; 1558 } 1559 emitCall(rewriter, loc, printer, value); 1560 return; 1561 } 1562 1563 auto parent = op->getParentOfType<ModuleOp>(); 1564 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent)); 1565 Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent); 1566 1567 if (rank <= 1) { 1568 auto reducedType = vectorType.getElementType(); 1569 auto llvmType = typeConverter->convertType(reducedType); 1570 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1571 for (int64_t d = 0; d < dim; ++d) { 1572 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1573 llvmType, /*rank=*/0, /*pos=*/d); 1574 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1575 conversion); 1576 if (d != dim - 1) 1577 emitCall(rewriter, loc, printComma); 1578 } 1579 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); 1580 return; 1581 } 1582 1583 int64_t dim = vectorType.getDimSize(0); 1584 for (int64_t d = 0; d < dim; ++d) { 1585 auto reducedType = reducedVectorTypeFront(vectorType); 1586 auto llvmType = typeConverter->convertType(reducedType); 1587 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1588 llvmType, rank, d); 1589 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1590 conversion); 1591 if (d != dim - 1) 1592 emitCall(rewriter, loc, printComma); 1593 } 1594 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); 1595 } 1596 1597 // Helper to emit a call. 1598 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1599 Operation *ref, ValueRange params = ValueRange()) { 1600 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1601 params); 1602 } 1603 }; 1604 1605 /// The Splat operation is lowered to an insertelement + a shufflevector 1606 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1607 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1608 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1609 1610 LogicalResult 1611 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1612 ConversionPatternRewriter &rewriter) const override { 1613 VectorType resultType = splatOp.getType().cast<VectorType>(); 1614 if (resultType.getRank() > 1) 1615 return failure(); 1616 1617 // First insert it into an undef vector so we can shuffle it. 1618 auto vectorType = typeConverter->convertType(splatOp.getType()); 1619 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1620 auto zero = rewriter.create<LLVM::ConstantOp>( 1621 splatOp.getLoc(), 1622 typeConverter->convertType(rewriter.getIntegerType(32)), 1623 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1624 1625 // For 0-d vector, we simply do `insertelement`. 1626 if (resultType.getRank() == 0) { 1627 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1628 splatOp, vectorType, undef, adaptor.getInput(), zero); 1629 return success(); 1630 } 1631 1632 // For 1-d vector, we additionally do a `vectorshuffle`. 1633 auto v = rewriter.create<LLVM::InsertElementOp>( 1634 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1635 1636 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); 1637 SmallVector<int32_t> zeroValues(width, 0); 1638 1639 // Shuffle the value across the desired number of elements. 1640 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1641 zeroValues); 1642 return success(); 1643 } 1644 }; 1645 1646 /// The Splat operation is lowered to an insertelement + a shufflevector 1647 /// operation. Splat to only 2+-d vector result types are lowered by the 1648 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1649 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1650 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1651 1652 LogicalResult 1653 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1654 ConversionPatternRewriter &rewriter) const override { 1655 VectorType resultType = splatOp.getType(); 1656 if (resultType.getRank() <= 1) 1657 return failure(); 1658 1659 // First insert it into an undef vector so we can shuffle it. 1660 auto loc = splatOp.getLoc(); 1661 auto vectorTypeInfo = 1662 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1663 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1664 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1665 if (!llvmNDVectorTy || !llvm1DVectorTy) 1666 return failure(); 1667 1668 // Construct returned value. 1669 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1670 1671 // Construct a 1-D vector with the splatted value that we insert in all the 1672 // places within the returned descriptor. 1673 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1674 auto zero = rewriter.create<LLVM::ConstantOp>( 1675 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1676 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1677 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1678 adaptor.getInput(), zero); 1679 1680 // Shuffle the value across the desired number of elements. 1681 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1682 SmallVector<int32_t> zeroValues(width, 0); 1683 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1684 1685 // Iterate of linear index, convert to coords space and insert splatted 1-D 1686 // vector in each position. 1687 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1688 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1689 }); 1690 rewriter.replaceOp(splatOp, desc); 1691 return success(); 1692 } 1693 }; 1694 1695 } // namespace 1696 1697 /// Populate the given list with patterns that convert from Vector to LLVM. 1698 void mlir::populateVectorToLLVMConversionPatterns( 1699 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1700 bool reassociateFPReductions, bool force32BitVectorIndices) { 1701 MLIRContext *ctx = converter.getDialect()->getContext(); 1702 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1703 populateVectorInsertExtractStridedSliceTransforms(patterns); 1704 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1705 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1706 patterns 1707 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1708 VectorExtractElementOpConversion, VectorExtractOpConversion, 1709 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1710 VectorInsertOpConversion, VectorPrintOpConversion, 1711 VectorTypeCastOpConversion, VectorScaleOpConversion, 1712 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1713 VectorLoadStoreConversion<vector::MaskedLoadOp, 1714 vector::MaskedLoadOpAdaptor>, 1715 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1716 VectorLoadStoreConversion<vector::MaskedStoreOp, 1717 vector::MaskedStoreOpAdaptor>, 1718 VectorGatherOpConversion, VectorScatterOpConversion, 1719 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1720 VectorSplatOpLowering, VectorSplatNdOpLowering, 1721 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1722 MaskedReductionOpConversion>(converter); 1723 // Transfer ops with rank > 1 are handled by VectorToSCF. 1724 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1725 } 1726 1727 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1728 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1729 patterns.add<VectorMatmulOpConversion>(converter); 1730 patterns.add<VectorFlatTransposeOpConversion>(converter); 1731 } 1732