1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Arith/Utils/Utils.h" 15 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/Support/Casting.h" 26 #include <optional> 27 28 using namespace mlir; 29 using namespace mlir::vector; 30 31 // Helper to reduce vector type by *all* but one rank at back. 32 static VectorType reducedVectorTypeBack(VectorType tp) { 33 assert((tp.getRank() > 1) && "unlowerable vector type"); 34 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 35 tp.getScalableDims().take_back()); 36 } 37 38 // Helper that picks the proper sequence for inserting. 39 static Value insertOne(ConversionPatternRewriter &rewriter, 40 LLVMTypeConverter &typeConverter, Location loc, 41 Value val1, Value val2, Type llvmType, int64_t rank, 42 int64_t pos) { 43 assert(rank > 0 && "0-D vector corner case should have been handled already"); 44 if (rank == 1) { 45 auto idxType = rewriter.getIndexType(); 46 auto constant = rewriter.create<LLVM::ConstantOp>( 47 loc, typeConverter.convertType(idxType), 48 rewriter.getIntegerAttr(idxType, pos)); 49 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 50 constant); 51 } 52 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 53 } 54 55 // Helper that picks the proper sequence for extracting. 56 static Value extractOne(ConversionPatternRewriter &rewriter, 57 LLVMTypeConverter &typeConverter, Location loc, 58 Value val, Type llvmType, int64_t rank, int64_t pos) { 59 if (rank <= 1) { 60 auto idxType = rewriter.getIndexType(); 61 auto constant = rewriter.create<LLVM::ConstantOp>( 62 loc, typeConverter.convertType(idxType), 63 rewriter.getIntegerAttr(idxType, pos)); 64 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 65 constant); 66 } 67 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 68 } 69 70 // Helper that returns data layout alignment of a memref. 71 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 72 MemRefType memrefType, unsigned &align) { 73 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 74 if (!elementTy) 75 return failure(); 76 77 // TODO: this should use the MLIR data layout when it becomes available and 78 // stop depending on translation. 79 llvm::LLVMContext llvmContext; 80 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 81 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 82 return success(); 83 } 84 85 // Check if the last stride is non-unit or the memory space is not zero. 86 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 87 LLVMTypeConverter &converter) { 88 if (!isLastMemrefDimUnitStride(memRefType)) 89 return failure(); 90 FailureOr<unsigned> addressSpace = 91 converter.getMemRefAddressSpace(memRefType); 92 if (failed(addressSpace) || *addressSpace != 0) 93 return failure(); 94 return success(); 95 } 96 97 // Add an index vector component to a base pointer. 98 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 99 LLVMTypeConverter &typeConverter, 100 MemRefType memRefType, Value llvmMemref, Value base, 101 Value index, uint64_t vLen) { 102 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 103 "unsupported memref type"); 104 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 105 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 106 return rewriter.create<LLVM::GEPOp>( 107 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 108 base, index); 109 } 110 111 // Casts a strided element pointer to a vector pointer. The vector pointer 112 // will be in the same address space as the incoming memref type. 113 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 114 Value ptr, MemRefType memRefType, Type vt, 115 LLVMTypeConverter &converter) { 116 if (converter.useOpaquePointers()) 117 return ptr; 118 119 unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); 120 auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); 121 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 122 } 123 124 namespace { 125 126 /// Trivial Vector to LLVM conversions 127 using VectorScaleOpConversion = 128 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 129 130 /// Conversion pattern for a vector.bitcast. 131 class VectorBitCastOpConversion 132 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 133 public: 134 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 135 136 LogicalResult 137 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 138 ConversionPatternRewriter &rewriter) const override { 139 // Only 0-D and 1-D vectors can be lowered to LLVM. 140 VectorType resultTy = bitCastOp.getResultVectorType(); 141 if (resultTy.getRank() > 1) 142 return failure(); 143 Type newResultTy = typeConverter->convertType(resultTy); 144 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 145 adaptor.getOperands()[0]); 146 return success(); 147 } 148 }; 149 150 /// Conversion pattern for a vector.matrix_multiply. 151 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 152 class VectorMatmulOpConversion 153 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 154 public: 155 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 156 157 LogicalResult 158 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 159 ConversionPatternRewriter &rewriter) const override { 160 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 161 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 162 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 163 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 164 return success(); 165 } 166 }; 167 168 /// Conversion pattern for a vector.flat_transpose. 169 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 170 class VectorFlatTransposeOpConversion 171 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 172 public: 173 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 174 175 LogicalResult 176 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 177 ConversionPatternRewriter &rewriter) const override { 178 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 179 transOp, typeConverter->convertType(transOp.getRes().getType()), 180 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 181 return success(); 182 } 183 }; 184 185 /// Overloaded utility that replaces a vector.load, vector.store, 186 /// vector.maskedload and vector.maskedstore with their respective LLVM 187 /// couterparts. 188 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 189 vector::LoadOpAdaptor adaptor, 190 VectorType vectorTy, Value ptr, unsigned align, 191 ConversionPatternRewriter &rewriter) { 192 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 193 } 194 195 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 196 vector::MaskedLoadOpAdaptor adaptor, 197 VectorType vectorTy, Value ptr, unsigned align, 198 ConversionPatternRewriter &rewriter) { 199 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 200 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 201 } 202 203 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 204 vector::StoreOpAdaptor adaptor, 205 VectorType vectorTy, Value ptr, unsigned align, 206 ConversionPatternRewriter &rewriter) { 207 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 208 ptr, align); 209 } 210 211 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 212 vector::MaskedStoreOpAdaptor adaptor, 213 VectorType vectorTy, Value ptr, unsigned align, 214 ConversionPatternRewriter &rewriter) { 215 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 216 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 217 } 218 219 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 220 /// vector.maskedstore. 221 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 222 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 223 public: 224 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 225 226 LogicalResult 227 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 228 typename LoadOrStoreOp::Adaptor adaptor, 229 ConversionPatternRewriter &rewriter) const override { 230 // Only 1-D vectors can be lowered to LLVM. 231 VectorType vectorTy = loadOrStoreOp.getVectorType(); 232 if (vectorTy.getRank() > 1) 233 return failure(); 234 235 auto loc = loadOrStoreOp->getLoc(); 236 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 237 238 // Resolve alignment. 239 unsigned align; 240 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 241 return failure(); 242 243 // Resolve address. 244 auto vtype = cast<VectorType>( 245 this->typeConverter->convertType(loadOrStoreOp.getVectorType())); 246 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 247 adaptor.getIndices(), rewriter); 248 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, 249 *this->getTypeConverter()); 250 251 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 252 return success(); 253 } 254 }; 255 256 /// Conversion pattern for a vector.gather. 257 class VectorGatherOpConversion 258 : public ConvertOpToLLVMPattern<vector::GatherOp> { 259 public: 260 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 261 262 LogicalResult 263 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 264 ConversionPatternRewriter &rewriter) const override { 265 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); 266 assert(memRefType && "The base should be bufferized"); 267 268 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 269 return failure(); 270 271 auto loc = gather->getLoc(); 272 273 // Resolve alignment. 274 unsigned align; 275 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 276 return failure(); 277 278 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 279 adaptor.getIndices(), rewriter); 280 Value base = adaptor.getBase(); 281 282 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 283 // Handle the simple case of 1-D vector. 284 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) { 285 auto vType = gather.getVectorType(); 286 // Resolve address. 287 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 288 memRefType, base, ptr, adaptor.getIndexVec(), 289 /*vLen=*/vType.getDimSize(0)); 290 // Replace with the gather intrinsic. 291 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 292 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 293 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 294 return success(); 295 } 296 297 LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 298 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 299 &typeConverter](Type llvm1DVectorTy, 300 ValueRange vectorOperands) { 301 // Resolve address. 302 Value ptrs = getIndexedPtrs( 303 rewriter, loc, typeConverter, memRefType, base, ptr, 304 /*index=*/vectorOperands[0], 305 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 306 // Create the gather intrinsic. 307 return rewriter.create<LLVM::masked_gather>( 308 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 309 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 310 }; 311 SmallVector<Value> vectorOperands = { 312 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 313 return LLVM::detail::handleMultidimensionalVectors( 314 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 315 } 316 }; 317 318 /// Conversion pattern for a vector.scatter. 319 class VectorScatterOpConversion 320 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 321 public: 322 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 323 324 LogicalResult 325 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 326 ConversionPatternRewriter &rewriter) const override { 327 auto loc = scatter->getLoc(); 328 MemRefType memRefType = scatter.getMemRefType(); 329 330 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 331 return failure(); 332 333 // Resolve alignment. 334 unsigned align; 335 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 336 return failure(); 337 338 // Resolve address. 339 VectorType vType = scatter.getVectorType(); 340 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 341 adaptor.getIndices(), rewriter); 342 Value ptrs = getIndexedPtrs( 343 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 344 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 345 346 // Replace with the scatter intrinsic. 347 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 348 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 349 rewriter.getI32IntegerAttr(align)); 350 return success(); 351 } 352 }; 353 354 /// Conversion pattern for a vector.expandload. 355 class VectorExpandLoadOpConversion 356 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 357 public: 358 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 359 360 LogicalResult 361 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 362 ConversionPatternRewriter &rewriter) const override { 363 auto loc = expand->getLoc(); 364 MemRefType memRefType = expand.getMemRefType(); 365 366 // Resolve address. 367 auto vtype = typeConverter->convertType(expand.getVectorType()); 368 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 369 adaptor.getIndices(), rewriter); 370 371 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 372 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 373 return success(); 374 } 375 }; 376 377 /// Conversion pattern for a vector.compressstore. 378 class VectorCompressStoreOpConversion 379 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 380 public: 381 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 382 383 LogicalResult 384 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 385 ConversionPatternRewriter &rewriter) const override { 386 auto loc = compress->getLoc(); 387 MemRefType memRefType = compress.getMemRefType(); 388 389 // Resolve address. 390 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 391 adaptor.getIndices(), rewriter); 392 393 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 394 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 395 return success(); 396 } 397 }; 398 399 /// Reduction neutral classes for overloading. 400 class ReductionNeutralZero {}; 401 class ReductionNeutralIntOne {}; 402 class ReductionNeutralFPOne {}; 403 class ReductionNeutralAllOnes {}; 404 class ReductionNeutralSIntMin {}; 405 class ReductionNeutralUIntMin {}; 406 class ReductionNeutralSIntMax {}; 407 class ReductionNeutralUIntMax {}; 408 class ReductionNeutralFPMin {}; 409 class ReductionNeutralFPMax {}; 410 411 /// Create the reduction neutral zero value. 412 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 413 ConversionPatternRewriter &rewriter, 414 Location loc, Type llvmType) { 415 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 416 rewriter.getZeroAttr(llvmType)); 417 } 418 419 /// Create the reduction neutral integer one value. 420 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 421 ConversionPatternRewriter &rewriter, 422 Location loc, Type llvmType) { 423 return rewriter.create<LLVM::ConstantOp>( 424 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 425 } 426 427 /// Create the reduction neutral fp one value. 428 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 429 ConversionPatternRewriter &rewriter, 430 Location loc, Type llvmType) { 431 return rewriter.create<LLVM::ConstantOp>( 432 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 433 } 434 435 /// Create the reduction neutral all-ones value. 436 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 437 ConversionPatternRewriter &rewriter, 438 Location loc, Type llvmType) { 439 return rewriter.create<LLVM::ConstantOp>( 440 loc, llvmType, 441 rewriter.getIntegerAttr( 442 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 443 } 444 445 /// Create the reduction neutral signed int minimum value. 446 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 447 ConversionPatternRewriter &rewriter, 448 Location loc, Type llvmType) { 449 return rewriter.create<LLVM::ConstantOp>( 450 loc, llvmType, 451 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 452 llvmType.getIntOrFloatBitWidth()))); 453 } 454 455 /// Create the reduction neutral unsigned int minimum value. 456 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 457 ConversionPatternRewriter &rewriter, 458 Location loc, Type llvmType) { 459 return rewriter.create<LLVM::ConstantOp>( 460 loc, llvmType, 461 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 462 llvmType.getIntOrFloatBitWidth()))); 463 } 464 465 /// Create the reduction neutral signed int maximum value. 466 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 467 ConversionPatternRewriter &rewriter, 468 Location loc, Type llvmType) { 469 return rewriter.create<LLVM::ConstantOp>( 470 loc, llvmType, 471 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 472 llvmType.getIntOrFloatBitWidth()))); 473 } 474 475 /// Create the reduction neutral unsigned int maximum value. 476 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 477 ConversionPatternRewriter &rewriter, 478 Location loc, Type llvmType) { 479 return rewriter.create<LLVM::ConstantOp>( 480 loc, llvmType, 481 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 482 llvmType.getIntOrFloatBitWidth()))); 483 } 484 485 /// Create the reduction neutral fp minimum value. 486 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 487 ConversionPatternRewriter &rewriter, 488 Location loc, Type llvmType) { 489 auto floatType = cast<FloatType>(llvmType); 490 return rewriter.create<LLVM::ConstantOp>( 491 loc, llvmType, 492 rewriter.getFloatAttr( 493 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 494 /*Negative=*/false))); 495 } 496 497 /// Create the reduction neutral fp maximum value. 498 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 499 ConversionPatternRewriter &rewriter, 500 Location loc, Type llvmType) { 501 auto floatType = cast<FloatType>(llvmType); 502 return rewriter.create<LLVM::ConstantOp>( 503 loc, llvmType, 504 rewriter.getFloatAttr( 505 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 506 /*Negative=*/true))); 507 } 508 509 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 510 /// returns a new accumulator value using `ReductionNeutral`. 511 template <class ReductionNeutral> 512 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 513 Location loc, Type llvmType, 514 Value accumulator) { 515 if (accumulator) 516 return accumulator; 517 518 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 519 llvmType); 520 } 521 522 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 523 /// This is used as effective vector length by some intrinsics supporting 524 /// dynamic vector lengths at runtime. 525 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 526 Location loc, Type llvmType) { 527 VectorType vType = cast<VectorType>(llvmType); 528 auto vShape = vType.getShape(); 529 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 530 531 return rewriter.create<LLVM::ConstantOp>( 532 loc, rewriter.getI32Type(), 533 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 534 } 535 536 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 537 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 538 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 539 /// non-null. 540 template <class LLVMRedIntrinOp, class ScalarOp> 541 static Value createIntegerReductionArithmeticOpLowering( 542 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 543 Value vectorOperand, Value accumulator) { 544 545 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 546 547 if (accumulator) 548 result = rewriter.create<ScalarOp>(loc, accumulator, result); 549 return result; 550 } 551 552 /// Helper method to lower a `vector.reduction` operation that performs 553 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 554 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 555 /// the accumulator value if non-null. 556 template <class LLVMRedIntrinOp> 557 static Value createIntegerReductionComparisonOpLowering( 558 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 559 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 560 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 561 if (accumulator) { 562 Value cmp = 563 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 564 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 565 } 566 return result; 567 } 568 569 namespace { 570 template <typename Source> 571 struct VectorToScalarMapper; 572 template <> 573 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> { 574 using Type = LLVM::MaximumOp; 575 }; 576 template <> 577 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> { 578 using Type = LLVM::MinimumOp; 579 }; 580 } // namespace 581 582 template <class LLVMRedIntrinOp> 583 static Value 584 createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter, 585 Location loc, Type llvmType, 586 Value vectorOperand, Value accumulator) { 587 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 588 589 if (accumulator) { 590 result = 591 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>( 592 loc, result, accumulator); 593 } 594 595 return result; 596 } 597 598 /// Overloaded methods to lower a reduction to an llvm instrinsic that requires 599 /// a start value. This start value format spans across fp reductions without 600 /// mask and all the masked reduction intrinsics. 601 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 602 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 603 Location loc, Type llvmType, 604 Value vectorOperand, 605 Value accumulator) { 606 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 607 llvmType, accumulator); 608 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 609 /*startValue=*/accumulator, 610 vectorOperand); 611 } 612 613 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 614 static Value 615 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 616 Type llvmType, Value vectorOperand, 617 Value accumulator, bool reassociateFPReds) { 618 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 619 llvmType, accumulator); 620 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 621 /*startValue=*/accumulator, 622 vectorOperand, reassociateFPReds); 623 } 624 625 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 626 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 627 Location loc, Type llvmType, 628 Value vectorOperand, 629 Value accumulator, Value mask) { 630 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 631 llvmType, accumulator); 632 Value vectorLength = 633 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 634 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 635 /*startValue=*/accumulator, 636 vectorOperand, mask, vectorLength); 637 } 638 639 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 640 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 641 Location loc, Type llvmType, 642 Value vectorOperand, 643 Value accumulator, Value mask, 644 bool reassociateFPReds) { 645 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 646 llvmType, accumulator); 647 Value vectorLength = 648 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 649 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 650 /*startValue=*/accumulator, 651 vectorOperand, mask, vectorLength, 652 reassociateFPReds); 653 } 654 655 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 656 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 657 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 658 Location loc, Type llvmType, 659 Value vectorOperand, 660 Value accumulator, Value mask) { 661 if (llvmType.isIntOrIndex()) 662 return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp, 663 IntReductionNeutral>( 664 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 665 666 // FP dispatch. 667 return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>( 668 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 669 } 670 671 /// Conversion pattern for all vector reductions. 672 class VectorReductionOpConversion 673 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 674 public: 675 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 676 bool reassociateFPRed) 677 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 678 reassociateFPReductions(reassociateFPRed) {} 679 680 LogicalResult 681 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 682 ConversionPatternRewriter &rewriter) const override { 683 auto kind = reductionOp.getKind(); 684 Type eltType = reductionOp.getDest().getType(); 685 Type llvmType = typeConverter->convertType(eltType); 686 Value operand = adaptor.getVector(); 687 Value acc = adaptor.getAcc(); 688 Location loc = reductionOp.getLoc(); 689 690 if (eltType.isIntOrIndex()) { 691 // Integer reductions: add/mul/min/max/and/or/xor. 692 Value result; 693 switch (kind) { 694 case vector::CombiningKind::ADD: 695 result = 696 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 697 LLVM::AddOp>( 698 rewriter, loc, llvmType, operand, acc); 699 break; 700 case vector::CombiningKind::MUL: 701 result = 702 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 703 LLVM::MulOp>( 704 rewriter, loc, llvmType, operand, acc); 705 break; 706 case vector::CombiningKind::MINUI: 707 result = createIntegerReductionComparisonOpLowering< 708 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 709 LLVM::ICmpPredicate::ule); 710 break; 711 case vector::CombiningKind::MINSI: 712 result = createIntegerReductionComparisonOpLowering< 713 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 714 LLVM::ICmpPredicate::sle); 715 break; 716 case vector::CombiningKind::MAXUI: 717 result = createIntegerReductionComparisonOpLowering< 718 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 719 LLVM::ICmpPredicate::uge); 720 break; 721 case vector::CombiningKind::MAXSI: 722 result = createIntegerReductionComparisonOpLowering< 723 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 724 LLVM::ICmpPredicate::sge); 725 break; 726 case vector::CombiningKind::AND: 727 result = 728 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 729 LLVM::AndOp>( 730 rewriter, loc, llvmType, operand, acc); 731 break; 732 case vector::CombiningKind::OR: 733 result = 734 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 735 LLVM::OrOp>( 736 rewriter, loc, llvmType, operand, acc); 737 break; 738 case vector::CombiningKind::XOR: 739 result = 740 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 741 LLVM::XOrOp>( 742 rewriter, loc, llvmType, operand, acc); 743 break; 744 default: 745 return failure(); 746 } 747 rewriter.replaceOp(reductionOp, result); 748 749 return success(); 750 } 751 752 if (!isa<FloatType>(eltType)) 753 return failure(); 754 755 // Floating-point reductions: add/mul/min/max 756 Value result; 757 if (kind == vector::CombiningKind::ADD) { 758 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 759 ReductionNeutralZero>( 760 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 761 } else if (kind == vector::CombiningKind::MUL) { 762 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 763 ReductionNeutralFPOne>( 764 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 765 } else if (kind == vector::CombiningKind::MINF) { 766 result = 767 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>( 768 rewriter, loc, llvmType, operand, acc); 769 } else if (kind == vector::CombiningKind::MAXF) { 770 result = 771 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>( 772 rewriter, loc, llvmType, operand, acc); 773 } else 774 return failure(); 775 776 rewriter.replaceOp(reductionOp, result); 777 return success(); 778 } 779 780 private: 781 const bool reassociateFPReductions; 782 }; 783 784 /// Base class to convert a `vector.mask` operation while matching traits 785 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 786 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 787 /// method performs a second match against the maskable operation `MaskedOp`. 788 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 789 /// implemented by the concrete conversion classes. This method can match 790 /// against specific traits of the `vector.mask` and the maskable operation. It 791 /// must replace the `vector.mask` operation. 792 template <class MaskedOp> 793 class VectorMaskOpConversionBase 794 : public ConvertOpToLLVMPattern<vector::MaskOp> { 795 public: 796 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 797 798 LogicalResult 799 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 800 ConversionPatternRewriter &rewriter) const final { 801 // Match against the maskable operation kind. 802 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 803 if (!maskedOp) 804 return failure(); 805 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 806 } 807 808 protected: 809 virtual LogicalResult 810 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 811 vector::MaskableOpInterface maskableOp, 812 ConversionPatternRewriter &rewriter) const = 0; 813 }; 814 815 class MaskedReductionOpConversion 816 : public VectorMaskOpConversionBase<vector::ReductionOp> { 817 818 public: 819 using VectorMaskOpConversionBase< 820 vector::ReductionOp>::VectorMaskOpConversionBase; 821 822 LogicalResult matchAndRewriteMaskableOp( 823 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 824 ConversionPatternRewriter &rewriter) const override { 825 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 826 auto kind = reductionOp.getKind(); 827 Type eltType = reductionOp.getDest().getType(); 828 Type llvmType = typeConverter->convertType(eltType); 829 Value operand = reductionOp.getVector(); 830 Value acc = reductionOp.getAcc(); 831 Location loc = reductionOp.getLoc(); 832 833 Value result; 834 switch (kind) { 835 case vector::CombiningKind::ADD: 836 result = lowerReductionWithStartValue< 837 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 838 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 839 maskOp.getMask()); 840 break; 841 case vector::CombiningKind::MUL: 842 result = lowerReductionWithStartValue< 843 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 844 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 845 maskOp.getMask()); 846 break; 847 case vector::CombiningKind::MINUI: 848 result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp, 849 ReductionNeutralUIntMax>( 850 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 851 break; 852 case vector::CombiningKind::MINSI: 853 result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp, 854 ReductionNeutralSIntMax>( 855 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 856 break; 857 case vector::CombiningKind::MAXUI: 858 result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp, 859 ReductionNeutralUIntMin>( 860 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 861 break; 862 case vector::CombiningKind::MAXSI: 863 result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp, 864 ReductionNeutralSIntMin>( 865 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 866 break; 867 case vector::CombiningKind::AND: 868 result = lowerReductionWithStartValue<LLVM::VPReduceAndOp, 869 ReductionNeutralAllOnes>( 870 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 871 break; 872 case vector::CombiningKind::OR: 873 result = lowerReductionWithStartValue<LLVM::VPReduceOrOp, 874 ReductionNeutralZero>( 875 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 876 break; 877 case vector::CombiningKind::XOR: 878 result = lowerReductionWithStartValue<LLVM::VPReduceXorOp, 879 ReductionNeutralZero>( 880 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 881 break; 882 case vector::CombiningKind::MINF: 883 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 884 // NaNs/-0.0/+0.0 in the same way. 885 result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp, 886 ReductionNeutralFPMax>( 887 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 888 break; 889 case vector::CombiningKind::MAXF: 890 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 891 // NaNs/-0.0/+0.0 in the same way. 892 result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp, 893 ReductionNeutralFPMin>( 894 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 895 break; 896 } 897 898 // Replace `vector.mask` operation altogether. 899 rewriter.replaceOp(maskOp, result); 900 return success(); 901 } 902 }; 903 904 class VectorShuffleOpConversion 905 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 906 public: 907 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 908 909 LogicalResult 910 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 911 ConversionPatternRewriter &rewriter) const override { 912 auto loc = shuffleOp->getLoc(); 913 auto v1Type = shuffleOp.getV1VectorType(); 914 auto v2Type = shuffleOp.getV2VectorType(); 915 auto vectorType = shuffleOp.getResultVectorType(); 916 Type llvmType = typeConverter->convertType(vectorType); 917 auto maskArrayAttr = shuffleOp.getMask(); 918 919 // Bail if result type cannot be lowered. 920 if (!llvmType) 921 return failure(); 922 923 // Get rank and dimension sizes. 924 int64_t rank = vectorType.getRank(); 925 #ifndef NDEBUG 926 bool wellFormed0DCase = 927 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 928 bool wellFormedNDCase = 929 v1Type.getRank() == rank && v2Type.getRank() == rank; 930 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 931 #endif 932 933 // For rank 0 and 1, where both operands have *exactly* the same vector 934 // type, there is direct shuffle support in LLVM. Use it! 935 if (rank <= 1 && v1Type == v2Type) { 936 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 937 loc, adaptor.getV1(), adaptor.getV2(), 938 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 939 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 940 return success(); 941 } 942 943 // For all other cases, insert the individual values individually. 944 int64_t v1Dim = v1Type.getDimSize(0); 945 Type eltType; 946 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType)) 947 eltType = arrayType.getElementType(); 948 else 949 eltType = cast<VectorType>(llvmType).getElementType(); 950 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 951 int64_t insPos = 0; 952 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 953 int64_t extPos = cast<IntegerAttr>(en.value()).getInt(); 954 Value value = adaptor.getV1(); 955 if (extPos >= v1Dim) { 956 extPos -= v1Dim; 957 value = adaptor.getV2(); 958 } 959 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 960 eltType, rank, extPos); 961 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 962 llvmType, rank, insPos++); 963 } 964 rewriter.replaceOp(shuffleOp, insert); 965 return success(); 966 } 967 }; 968 969 class VectorExtractElementOpConversion 970 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 971 public: 972 using ConvertOpToLLVMPattern< 973 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 974 975 LogicalResult 976 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 977 ConversionPatternRewriter &rewriter) const override { 978 auto vectorType = extractEltOp.getSourceVectorType(); 979 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 980 981 // Bail if result type cannot be lowered. 982 if (!llvmType) 983 return failure(); 984 985 if (vectorType.getRank() == 0) { 986 Location loc = extractEltOp.getLoc(); 987 auto idxType = rewriter.getIndexType(); 988 auto zero = rewriter.create<LLVM::ConstantOp>( 989 loc, typeConverter->convertType(idxType), 990 rewriter.getIntegerAttr(idxType, 0)); 991 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 992 extractEltOp, llvmType, adaptor.getVector(), zero); 993 return success(); 994 } 995 996 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 997 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 998 return success(); 999 } 1000 }; 1001 1002 class VectorExtractOpConversion 1003 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1004 public: 1005 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1006 1007 LogicalResult 1008 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1009 ConversionPatternRewriter &rewriter) const override { 1010 auto loc = extractOp->getLoc(); 1011 auto resultType = extractOp.getResult().getType(); 1012 auto llvmResultType = typeConverter->convertType(resultType); 1013 ArrayRef<int64_t> positionArray = extractOp.getPosition(); 1014 1015 // Bail if result type cannot be lowered. 1016 if (!llvmResultType) 1017 return failure(); 1018 1019 // Extract entire vector. Should be handled by folder, but just to be safe. 1020 if (positionArray.empty()) { 1021 rewriter.replaceOp(extractOp, adaptor.getVector()); 1022 return success(); 1023 } 1024 1025 // One-shot extraction of vector from array (only requires extractvalue). 1026 if (isa<VectorType>(resultType)) { 1027 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1028 loc, adaptor.getVector(), positionArray); 1029 rewriter.replaceOp(extractOp, extracted); 1030 return success(); 1031 } 1032 1033 // Potential extraction of 1-D vector from array. 1034 Value extracted = adaptor.getVector(); 1035 if (positionArray.size() > 1) { 1036 extracted = rewriter.create<LLVM::ExtractValueOp>( 1037 loc, extracted, positionArray.drop_back()); 1038 } 1039 1040 // Remaining extraction of element from 1-D LLVM vector 1041 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1042 auto constant = 1043 rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back()); 1044 extracted = 1045 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 1046 rewriter.replaceOp(extractOp, extracted); 1047 1048 return success(); 1049 } 1050 }; 1051 1052 /// Conversion pattern that turns a vector.fma on a 1-D vector 1053 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1054 /// This does not match vectors of n >= 2 rank. 1055 /// 1056 /// Example: 1057 /// ``` 1058 /// vector.fma %a, %a, %a : vector<8xf32> 1059 /// ``` 1060 /// is converted to: 1061 /// ``` 1062 /// llvm.intr.fmuladd %va, %va, %va: 1063 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1064 /// -> !llvm."<8 x f32>"> 1065 /// ``` 1066 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1067 public: 1068 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1069 1070 LogicalResult 1071 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1072 ConversionPatternRewriter &rewriter) const override { 1073 VectorType vType = fmaOp.getVectorType(); 1074 if (vType.getRank() > 1) 1075 return failure(); 1076 1077 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1078 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1079 return success(); 1080 } 1081 }; 1082 1083 class VectorInsertElementOpConversion 1084 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1085 public: 1086 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1087 1088 LogicalResult 1089 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1090 ConversionPatternRewriter &rewriter) const override { 1091 auto vectorType = insertEltOp.getDestVectorType(); 1092 auto llvmType = typeConverter->convertType(vectorType); 1093 1094 // Bail if result type cannot be lowered. 1095 if (!llvmType) 1096 return failure(); 1097 1098 if (vectorType.getRank() == 0) { 1099 Location loc = insertEltOp.getLoc(); 1100 auto idxType = rewriter.getIndexType(); 1101 auto zero = rewriter.create<LLVM::ConstantOp>( 1102 loc, typeConverter->convertType(idxType), 1103 rewriter.getIntegerAttr(idxType, 0)); 1104 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1105 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1106 return success(); 1107 } 1108 1109 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1110 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1111 adaptor.getPosition()); 1112 return success(); 1113 } 1114 }; 1115 1116 class VectorInsertOpConversion 1117 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1118 public: 1119 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1120 1121 LogicalResult 1122 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1123 ConversionPatternRewriter &rewriter) const override { 1124 auto loc = insertOp->getLoc(); 1125 auto sourceType = insertOp.getSourceType(); 1126 auto destVectorType = insertOp.getDestVectorType(); 1127 auto llvmResultType = typeConverter->convertType(destVectorType); 1128 ArrayRef<int64_t> positionArray = insertOp.getPosition(); 1129 1130 // Bail if result type cannot be lowered. 1131 if (!llvmResultType) 1132 return failure(); 1133 1134 // Overwrite entire vector with value. Should be handled by folder, but 1135 // just to be safe. 1136 if (positionArray.empty()) { 1137 rewriter.replaceOp(insertOp, adaptor.getSource()); 1138 return success(); 1139 } 1140 1141 // One-shot insertion of a vector into an array (only requires insertvalue). 1142 if (isa<VectorType>(sourceType)) { 1143 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1144 loc, adaptor.getDest(), adaptor.getSource(), positionArray); 1145 rewriter.replaceOp(insertOp, inserted); 1146 return success(); 1147 } 1148 1149 // Potential extraction of 1-D vector from array. 1150 Value extracted = adaptor.getDest(); 1151 auto oneDVectorType = destVectorType; 1152 if (positionArray.size() > 1) { 1153 oneDVectorType = reducedVectorTypeBack(destVectorType); 1154 extracted = rewriter.create<LLVM::ExtractValueOp>( 1155 loc, extracted, positionArray.drop_back()); 1156 } 1157 1158 // Insertion of an element into a 1-D LLVM vector. 1159 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1160 auto constant = 1161 rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back()); 1162 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1163 loc, typeConverter->convertType(oneDVectorType), extracted, 1164 adaptor.getSource(), constant); 1165 1166 // Potential insertion of resulting 1-D vector into array. 1167 if (positionArray.size() > 1) { 1168 inserted = rewriter.create<LLVM::InsertValueOp>( 1169 loc, adaptor.getDest(), inserted, positionArray.drop_back()); 1170 } 1171 1172 rewriter.replaceOp(insertOp, inserted); 1173 return success(); 1174 } 1175 }; 1176 1177 /// Lower vector.scalable.insert ops to LLVM vector.insert 1178 struct VectorScalableInsertOpLowering 1179 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1180 using ConvertOpToLLVMPattern< 1181 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1182 1183 LogicalResult 1184 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1185 ConversionPatternRewriter &rewriter) const override { 1186 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1187 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1188 return success(); 1189 } 1190 }; 1191 1192 /// Lower vector.scalable.extract ops to LLVM vector.extract 1193 struct VectorScalableExtractOpLowering 1194 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1195 using ConvertOpToLLVMPattern< 1196 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1197 1198 LogicalResult 1199 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1200 ConversionPatternRewriter &rewriter) const override { 1201 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1202 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1203 adaptor.getSource(), adaptor.getPos()); 1204 return success(); 1205 } 1206 }; 1207 1208 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1209 /// 1210 /// Example: 1211 /// ``` 1212 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1213 /// ``` 1214 /// is rewritten into: 1215 /// ``` 1216 /// %r = splat %f0: vector<2x4xf32> 1217 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1218 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1219 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1220 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1221 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1222 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1223 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1224 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1225 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1226 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1227 /// // %r3 holds the final value. 1228 /// ``` 1229 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1230 public: 1231 using OpRewritePattern<FMAOp>::OpRewritePattern; 1232 1233 void initialize() { 1234 // This pattern recursively unpacks one dimension at a time. The recursion 1235 // bounded as the rank is strictly decreasing. 1236 setHasBoundedRewriteRecursion(); 1237 } 1238 1239 LogicalResult matchAndRewrite(FMAOp op, 1240 PatternRewriter &rewriter) const override { 1241 auto vType = op.getVectorType(); 1242 if (vType.getRank() < 2) 1243 return failure(); 1244 1245 auto loc = op.getLoc(); 1246 auto elemType = vType.getElementType(); 1247 Value zero = rewriter.create<arith::ConstantOp>( 1248 loc, elemType, rewriter.getZeroAttr(elemType)); 1249 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1250 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1251 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1252 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1253 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1254 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1255 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1256 } 1257 rewriter.replaceOp(op, desc); 1258 return success(); 1259 } 1260 }; 1261 1262 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1263 /// static layout. 1264 static std::optional<SmallVector<int64_t, 4>> 1265 computeContiguousStrides(MemRefType memRefType) { 1266 int64_t offset; 1267 SmallVector<int64_t, 4> strides; 1268 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1269 return std::nullopt; 1270 if (!strides.empty() && strides.back() != 1) 1271 return std::nullopt; 1272 // If no layout or identity layout, this is contiguous by definition. 1273 if (memRefType.getLayout().isIdentity()) 1274 return strides; 1275 1276 // Otherwise, we must determine contiguity form shapes. This can only ever 1277 // work in static cases because MemRefType is underspecified to represent 1278 // contiguous dynamic shapes in other ways than with just empty/identity 1279 // layout. 1280 auto sizes = memRefType.getShape(); 1281 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1282 if (ShapedType::isDynamic(sizes[index + 1]) || 1283 ShapedType::isDynamic(strides[index]) || 1284 ShapedType::isDynamic(strides[index + 1])) 1285 return std::nullopt; 1286 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1287 return std::nullopt; 1288 } 1289 return strides; 1290 } 1291 1292 class VectorTypeCastOpConversion 1293 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1294 public: 1295 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1296 1297 LogicalResult 1298 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1299 ConversionPatternRewriter &rewriter) const override { 1300 auto loc = castOp->getLoc(); 1301 MemRefType sourceMemRefType = 1302 cast<MemRefType>(castOp.getOperand().getType()); 1303 MemRefType targetMemRefType = castOp.getType(); 1304 1305 // Only static shape casts supported atm. 1306 if (!sourceMemRefType.hasStaticShape() || 1307 !targetMemRefType.hasStaticShape()) 1308 return failure(); 1309 1310 auto llvmSourceDescriptorTy = 1311 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); 1312 if (!llvmSourceDescriptorTy) 1313 return failure(); 1314 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1315 1316 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1317 typeConverter->convertType(targetMemRefType)); 1318 if (!llvmTargetDescriptorTy) 1319 return failure(); 1320 1321 // Only contiguous source buffers supported atm. 1322 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1323 if (!sourceStrides) 1324 return failure(); 1325 auto targetStrides = computeContiguousStrides(targetMemRefType); 1326 if (!targetStrides) 1327 return failure(); 1328 // Only support static strides for now, regardless of contiguity. 1329 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1330 return failure(); 1331 1332 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1333 1334 // Create descriptor. 1335 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1336 Type llvmTargetElementTy = desc.getElementPtrType(); 1337 // Set allocated ptr. 1338 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1339 if (!getTypeConverter()->useOpaquePointers()) 1340 allocated = 1341 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1342 desc.setAllocatedPtr(rewriter, loc, allocated); 1343 1344 // Set aligned ptr. 1345 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1346 if (!getTypeConverter()->useOpaquePointers()) 1347 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1348 1349 desc.setAlignedPtr(rewriter, loc, ptr); 1350 // Fill offset 0. 1351 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1352 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1353 desc.setOffset(rewriter, loc, zero); 1354 1355 // Fill size and stride descriptors in memref. 1356 for (const auto &indexedSize : 1357 llvm::enumerate(targetMemRefType.getShape())) { 1358 int64_t index = indexedSize.index(); 1359 auto sizeAttr = 1360 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1361 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1362 desc.setSize(rewriter, loc, index, size); 1363 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1364 (*targetStrides)[index]); 1365 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1366 desc.setStride(rewriter, loc, index, stride); 1367 } 1368 1369 rewriter.replaceOp(castOp, {desc}); 1370 return success(); 1371 } 1372 }; 1373 1374 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1375 /// Non-scalable versions of this operation are handled in Vector Transforms. 1376 class VectorCreateMaskOpRewritePattern 1377 : public OpRewritePattern<vector::CreateMaskOp> { 1378 public: 1379 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1380 bool enableIndexOpt) 1381 : OpRewritePattern<vector::CreateMaskOp>(context), 1382 force32BitVectorIndices(enableIndexOpt) {} 1383 1384 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1385 PatternRewriter &rewriter) const override { 1386 auto dstType = op.getType(); 1387 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) 1388 return failure(); 1389 IntegerType idxType = 1390 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1391 auto loc = op->getLoc(); 1392 Value indices = rewriter.create<LLVM::StepVectorOp>( 1393 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1394 /*isScalable=*/true)); 1395 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1396 op.getOperand(0)); 1397 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1398 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1399 indices, bounds); 1400 rewriter.replaceOp(op, comp); 1401 return success(); 1402 } 1403 1404 private: 1405 const bool force32BitVectorIndices; 1406 }; 1407 1408 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1409 public: 1410 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1411 1412 // Lowering implementation that relies on a small runtime support library, 1413 // which only needs to provide a few printing methods (single value for all 1414 // data types, opening/closing bracket, comma, newline). The lowering splits 1415 // the vector into elementary printing operations. The advantage of this 1416 // approach is that the library can remain unaware of all low-level 1417 // implementation details of vectors while still supporting output of any 1418 // shaped and dimensioned vector. 1419 // 1420 // Note: This lowering only handles scalars, n-D vectors are broken into 1421 // printing scalars in loops in VectorToSCF. 1422 // 1423 // TODO: rely solely on libc in future? something else? 1424 // 1425 LogicalResult 1426 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1427 ConversionPatternRewriter &rewriter) const override { 1428 auto parent = printOp->getParentOfType<ModuleOp>(); 1429 auto loc = printOp->getLoc(); 1430 1431 if (auto value = adaptor.getSource()) { 1432 Type printType = printOp.getPrintType(); 1433 if (isa<VectorType>(printType)) { 1434 // Vectors should be broken into elementary print ops in VectorToSCF. 1435 return failure(); 1436 } 1437 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value))) 1438 return failure(); 1439 } 1440 1441 auto punct = printOp.getPunctuation(); 1442 if (punct != PrintPunctuation::NoPunctuation) { 1443 emitCall(rewriter, printOp->getLoc(), [&] { 1444 switch (punct) { 1445 case PrintPunctuation::Close: 1446 return LLVM::lookupOrCreatePrintCloseFn(parent); 1447 case PrintPunctuation::Open: 1448 return LLVM::lookupOrCreatePrintOpenFn(parent); 1449 case PrintPunctuation::Comma: 1450 return LLVM::lookupOrCreatePrintCommaFn(parent); 1451 case PrintPunctuation::NewLine: 1452 return LLVM::lookupOrCreatePrintNewlineFn(parent); 1453 default: 1454 llvm_unreachable("unexpected punctuation"); 1455 } 1456 }()); 1457 } 1458 1459 rewriter.eraseOp(printOp); 1460 return success(); 1461 } 1462 1463 private: 1464 enum class PrintConversion { 1465 // clang-format off 1466 None, 1467 ZeroExt64, 1468 SignExt64, 1469 Bitcast16 1470 // clang-format on 1471 }; 1472 1473 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter, 1474 ModuleOp parent, Location loc, Type printType, 1475 Value value) const { 1476 if (typeConverter->convertType(printType) == nullptr) 1477 return failure(); 1478 1479 // Make sure element type has runtime support. 1480 PrintConversion conversion = PrintConversion::None; 1481 Operation *printer; 1482 if (printType.isF32()) { 1483 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1484 } else if (printType.isF64()) { 1485 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1486 } else if (printType.isF16()) { 1487 conversion = PrintConversion::Bitcast16; // bits! 1488 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1489 } else if (printType.isBF16()) { 1490 conversion = PrintConversion::Bitcast16; // bits! 1491 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1492 } else if (printType.isIndex()) { 1493 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1494 } else if (auto intTy = dyn_cast<IntegerType>(printType)) { 1495 // Integers need a zero or sign extension on the operand 1496 // (depending on the source type) as well as a signed or 1497 // unsigned print method. Up to 64-bit is supported. 1498 unsigned width = intTy.getWidth(); 1499 if (intTy.isUnsigned()) { 1500 if (width <= 64) { 1501 if (width < 64) 1502 conversion = PrintConversion::ZeroExt64; 1503 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1504 } else { 1505 return failure(); 1506 } 1507 } else { 1508 assert(intTy.isSignless() || intTy.isSigned()); 1509 if (width <= 64) { 1510 // Note that we *always* zero extend booleans (1-bit integers), 1511 // so that true/false is printed as 1/0 rather than -1/0. 1512 if (width == 1) 1513 conversion = PrintConversion::ZeroExt64; 1514 else if (width < 64) 1515 conversion = PrintConversion::SignExt64; 1516 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1517 } else { 1518 return failure(); 1519 } 1520 } 1521 } else { 1522 return failure(); 1523 } 1524 1525 switch (conversion) { 1526 case PrintConversion::ZeroExt64: 1527 value = rewriter.create<arith::ExtUIOp>( 1528 loc, IntegerType::get(rewriter.getContext(), 64), value); 1529 break; 1530 case PrintConversion::SignExt64: 1531 value = rewriter.create<arith::ExtSIOp>( 1532 loc, IntegerType::get(rewriter.getContext(), 64), value); 1533 break; 1534 case PrintConversion::Bitcast16: 1535 value = rewriter.create<LLVM::BitcastOp>( 1536 loc, IntegerType::get(rewriter.getContext(), 16), value); 1537 break; 1538 case PrintConversion::None: 1539 break; 1540 } 1541 emitCall(rewriter, loc, printer, value); 1542 return success(); 1543 } 1544 1545 // Helper to emit a call. 1546 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1547 Operation *ref, ValueRange params = ValueRange()) { 1548 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1549 params); 1550 } 1551 }; 1552 1553 /// The Splat operation is lowered to an insertelement + a shufflevector 1554 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1555 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1556 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1557 1558 LogicalResult 1559 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1560 ConversionPatternRewriter &rewriter) const override { 1561 VectorType resultType = cast<VectorType>(splatOp.getType()); 1562 if (resultType.getRank() > 1) 1563 return failure(); 1564 1565 // First insert it into an undef vector so we can shuffle it. 1566 auto vectorType = typeConverter->convertType(splatOp.getType()); 1567 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1568 auto zero = rewriter.create<LLVM::ConstantOp>( 1569 splatOp.getLoc(), 1570 typeConverter->convertType(rewriter.getIntegerType(32)), 1571 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1572 1573 // For 0-d vector, we simply do `insertelement`. 1574 if (resultType.getRank() == 0) { 1575 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1576 splatOp, vectorType, undef, adaptor.getInput(), zero); 1577 return success(); 1578 } 1579 1580 // For 1-d vector, we additionally do a `vectorshuffle`. 1581 auto v = rewriter.create<LLVM::InsertElementOp>( 1582 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1583 1584 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); 1585 SmallVector<int32_t> zeroValues(width, 0); 1586 1587 // Shuffle the value across the desired number of elements. 1588 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1589 zeroValues); 1590 return success(); 1591 } 1592 }; 1593 1594 /// The Splat operation is lowered to an insertelement + a shufflevector 1595 /// operation. Splat to only 2+-d vector result types are lowered by the 1596 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1597 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1598 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1599 1600 LogicalResult 1601 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1602 ConversionPatternRewriter &rewriter) const override { 1603 VectorType resultType = splatOp.getType(); 1604 if (resultType.getRank() <= 1) 1605 return failure(); 1606 1607 // First insert it into an undef vector so we can shuffle it. 1608 auto loc = splatOp.getLoc(); 1609 auto vectorTypeInfo = 1610 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1611 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1612 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1613 if (!llvmNDVectorTy || !llvm1DVectorTy) 1614 return failure(); 1615 1616 // Construct returned value. 1617 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1618 1619 // Construct a 1-D vector with the splatted value that we insert in all the 1620 // places within the returned descriptor. 1621 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1622 auto zero = rewriter.create<LLVM::ConstantOp>( 1623 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1624 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1625 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1626 adaptor.getInput(), zero); 1627 1628 // Shuffle the value across the desired number of elements. 1629 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1630 SmallVector<int32_t> zeroValues(width, 0); 1631 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1632 1633 // Iterate of linear index, convert to coords space and insert splatted 1-D 1634 // vector in each position. 1635 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1636 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1637 }); 1638 rewriter.replaceOp(splatOp, desc); 1639 return success(); 1640 } 1641 }; 1642 1643 } // namespace 1644 1645 /// Populate the given list with patterns that convert from Vector to LLVM. 1646 void mlir::populateVectorToLLVMConversionPatterns( 1647 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1648 bool reassociateFPReductions, bool force32BitVectorIndices) { 1649 MLIRContext *ctx = converter.getDialect()->getContext(); 1650 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1651 populateVectorInsertExtractStridedSliceTransforms(patterns); 1652 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1653 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1654 patterns 1655 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1656 VectorExtractElementOpConversion, VectorExtractOpConversion, 1657 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1658 VectorInsertOpConversion, VectorPrintOpConversion, 1659 VectorTypeCastOpConversion, VectorScaleOpConversion, 1660 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1661 VectorLoadStoreConversion<vector::MaskedLoadOp, 1662 vector::MaskedLoadOpAdaptor>, 1663 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1664 VectorLoadStoreConversion<vector::MaskedStoreOp, 1665 vector::MaskedStoreOpAdaptor>, 1666 VectorGatherOpConversion, VectorScatterOpConversion, 1667 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1668 VectorSplatOpLowering, VectorSplatNdOpLowering, 1669 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1670 MaskedReductionOpConversion>(converter); 1671 // Transfer ops with rank > 1 are handled by VectorToSCF. 1672 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1673 } 1674 1675 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1676 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1677 patterns.add<VectorMatmulOpConversion>(converter); 1678 patterns.add<VectorFlatTransposeOpConversion>(converter); 1679 } 1680