1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Arith/Utils/Utils.h" 15 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/Support/Casting.h" 26 #include <optional> 27 28 using namespace mlir; 29 using namespace mlir::vector; 30 31 // Helper to reduce vector type by one rank at front. 32 static VectorType reducedVectorTypeFront(VectorType tp) { 33 assert((tp.getRank() > 1) && "unlowerable vector type"); 34 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 35 tp.getScalableDims().drop_front()); 36 } 37 38 // Helper to reduce vector type by *all* but one rank at back. 39 static VectorType reducedVectorTypeBack(VectorType tp) { 40 assert((tp.getRank() > 1) && "unlowerable vector type"); 41 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 42 tp.getScalableDims().take_back()); 43 } 44 45 // Helper that picks the proper sequence for inserting. 46 static Value insertOne(ConversionPatternRewriter &rewriter, 47 LLVMTypeConverter &typeConverter, Location loc, 48 Value val1, Value val2, Type llvmType, int64_t rank, 49 int64_t pos) { 50 assert(rank > 0 && "0-D vector corner case should have been handled already"); 51 if (rank == 1) { 52 auto idxType = rewriter.getIndexType(); 53 auto constant = rewriter.create<LLVM::ConstantOp>( 54 loc, typeConverter.convertType(idxType), 55 rewriter.getIntegerAttr(idxType, pos)); 56 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 57 constant); 58 } 59 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 60 } 61 62 // Helper that picks the proper sequence for extracting. 63 static Value extractOne(ConversionPatternRewriter &rewriter, 64 LLVMTypeConverter &typeConverter, Location loc, 65 Value val, Type llvmType, int64_t rank, int64_t pos) { 66 if (rank <= 1) { 67 auto idxType = rewriter.getIndexType(); 68 auto constant = rewriter.create<LLVM::ConstantOp>( 69 loc, typeConverter.convertType(idxType), 70 rewriter.getIntegerAttr(idxType, pos)); 71 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 72 constant); 73 } 74 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 75 } 76 77 // Helper that returns data layout alignment of a memref. 78 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 79 MemRefType memrefType, unsigned &align) { 80 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 81 if (!elementTy) 82 return failure(); 83 84 // TODO: this should use the MLIR data layout when it becomes available and 85 // stop depending on translation. 86 llvm::LLVMContext llvmContext; 87 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 88 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 89 return success(); 90 } 91 92 // Check if the last stride is non-unit or the memory space is not zero. 93 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 94 LLVMTypeConverter &converter) { 95 if (!isLastMemrefDimUnitStride(memRefType)) 96 return failure(); 97 FailureOr<unsigned> addressSpace = 98 converter.getMemRefAddressSpace(memRefType); 99 if (failed(addressSpace) || *addressSpace != 0) 100 return failure(); 101 return success(); 102 } 103 104 // Add an index vector component to a base pointer. 105 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 106 LLVMTypeConverter &typeConverter, 107 MemRefType memRefType, Value llvmMemref, Value base, 108 Value index, uint64_t vLen) { 109 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 110 "unsupported memref type"); 111 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 112 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 113 return rewriter.create<LLVM::GEPOp>( 114 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 115 base, index); 116 } 117 118 // Casts a strided element pointer to a vector pointer. The vector pointer 119 // will be in the same address space as the incoming memref type. 120 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 121 Value ptr, MemRefType memRefType, Type vt, 122 LLVMTypeConverter &converter) { 123 if (converter.useOpaquePointers()) 124 return ptr; 125 126 unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); 127 auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); 128 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 129 } 130 131 namespace { 132 133 /// Trivial Vector to LLVM conversions 134 using VectorScaleOpConversion = 135 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 136 137 /// Conversion pattern for a vector.bitcast. 138 class VectorBitCastOpConversion 139 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 140 public: 141 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 142 143 LogicalResult 144 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 145 ConversionPatternRewriter &rewriter) const override { 146 // Only 0-D and 1-D vectors can be lowered to LLVM. 147 VectorType resultTy = bitCastOp.getResultVectorType(); 148 if (resultTy.getRank() > 1) 149 return failure(); 150 Type newResultTy = typeConverter->convertType(resultTy); 151 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 152 adaptor.getOperands()[0]); 153 return success(); 154 } 155 }; 156 157 /// Conversion pattern for a vector.matrix_multiply. 158 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 159 class VectorMatmulOpConversion 160 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 161 public: 162 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 163 164 LogicalResult 165 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 166 ConversionPatternRewriter &rewriter) const override { 167 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 168 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 169 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 170 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 171 return success(); 172 } 173 }; 174 175 /// Conversion pattern for a vector.flat_transpose. 176 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 177 class VectorFlatTransposeOpConversion 178 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 179 public: 180 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 181 182 LogicalResult 183 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 184 ConversionPatternRewriter &rewriter) const override { 185 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 186 transOp, typeConverter->convertType(transOp.getRes().getType()), 187 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 188 return success(); 189 } 190 }; 191 192 /// Overloaded utility that replaces a vector.load, vector.store, 193 /// vector.maskedload and vector.maskedstore with their respective LLVM 194 /// couterparts. 195 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 196 vector::LoadOpAdaptor adaptor, 197 VectorType vectorTy, Value ptr, unsigned align, 198 ConversionPatternRewriter &rewriter) { 199 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 200 } 201 202 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 203 vector::MaskedLoadOpAdaptor adaptor, 204 VectorType vectorTy, Value ptr, unsigned align, 205 ConversionPatternRewriter &rewriter) { 206 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 207 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 208 } 209 210 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 211 vector::StoreOpAdaptor adaptor, 212 VectorType vectorTy, Value ptr, unsigned align, 213 ConversionPatternRewriter &rewriter) { 214 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 215 ptr, align); 216 } 217 218 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 219 vector::MaskedStoreOpAdaptor adaptor, 220 VectorType vectorTy, Value ptr, unsigned align, 221 ConversionPatternRewriter &rewriter) { 222 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 223 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 224 } 225 226 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 227 /// vector.maskedstore. 228 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 229 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 230 public: 231 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 232 233 LogicalResult 234 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 235 typename LoadOrStoreOp::Adaptor adaptor, 236 ConversionPatternRewriter &rewriter) const override { 237 // Only 1-D vectors can be lowered to LLVM. 238 VectorType vectorTy = loadOrStoreOp.getVectorType(); 239 if (vectorTy.getRank() > 1) 240 return failure(); 241 242 auto loc = loadOrStoreOp->getLoc(); 243 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 244 245 // Resolve alignment. 246 unsigned align; 247 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 248 return failure(); 249 250 // Resolve address. 251 auto vtype = cast<VectorType>( 252 this->typeConverter->convertType(loadOrStoreOp.getVectorType())); 253 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 254 adaptor.getIndices(), rewriter); 255 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, 256 *this->getTypeConverter()); 257 258 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 259 return success(); 260 } 261 }; 262 263 /// Conversion pattern for a vector.gather. 264 class VectorGatherOpConversion 265 : public ConvertOpToLLVMPattern<vector::GatherOp> { 266 public: 267 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 268 269 LogicalResult 270 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 271 ConversionPatternRewriter &rewriter) const override { 272 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); 273 assert(memRefType && "The base should be bufferized"); 274 275 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 276 return failure(); 277 278 auto loc = gather->getLoc(); 279 280 // Resolve alignment. 281 unsigned align; 282 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 283 return failure(); 284 285 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 286 adaptor.getIndices(), rewriter); 287 Value base = adaptor.getBase(); 288 289 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 290 // Handle the simple case of 1-D vector. 291 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) { 292 auto vType = gather.getVectorType(); 293 // Resolve address. 294 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 295 memRefType, base, ptr, adaptor.getIndexVec(), 296 /*vLen=*/vType.getDimSize(0)); 297 // Replace with the gather intrinsic. 298 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 299 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 300 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 301 return success(); 302 } 303 304 LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 305 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 306 &typeConverter](Type llvm1DVectorTy, 307 ValueRange vectorOperands) { 308 // Resolve address. 309 Value ptrs = getIndexedPtrs( 310 rewriter, loc, typeConverter, memRefType, base, ptr, 311 /*index=*/vectorOperands[0], 312 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 313 // Create the gather intrinsic. 314 return rewriter.create<LLVM::masked_gather>( 315 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 316 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 317 }; 318 SmallVector<Value> vectorOperands = { 319 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 320 return LLVM::detail::handleMultidimensionalVectors( 321 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 322 } 323 }; 324 325 /// Conversion pattern for a vector.scatter. 326 class VectorScatterOpConversion 327 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 328 public: 329 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 330 331 LogicalResult 332 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 333 ConversionPatternRewriter &rewriter) const override { 334 auto loc = scatter->getLoc(); 335 MemRefType memRefType = scatter.getMemRefType(); 336 337 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 338 return failure(); 339 340 // Resolve alignment. 341 unsigned align; 342 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 343 return failure(); 344 345 // Resolve address. 346 VectorType vType = scatter.getVectorType(); 347 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 348 adaptor.getIndices(), rewriter); 349 Value ptrs = getIndexedPtrs( 350 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 351 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 352 353 // Replace with the scatter intrinsic. 354 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 355 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 356 rewriter.getI32IntegerAttr(align)); 357 return success(); 358 } 359 }; 360 361 /// Conversion pattern for a vector.expandload. 362 class VectorExpandLoadOpConversion 363 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 364 public: 365 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 366 367 LogicalResult 368 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 369 ConversionPatternRewriter &rewriter) const override { 370 auto loc = expand->getLoc(); 371 MemRefType memRefType = expand.getMemRefType(); 372 373 // Resolve address. 374 auto vtype = typeConverter->convertType(expand.getVectorType()); 375 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 376 adaptor.getIndices(), rewriter); 377 378 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 379 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 380 return success(); 381 } 382 }; 383 384 /// Conversion pattern for a vector.compressstore. 385 class VectorCompressStoreOpConversion 386 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 387 public: 388 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 389 390 LogicalResult 391 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const override { 393 auto loc = compress->getLoc(); 394 MemRefType memRefType = compress.getMemRefType(); 395 396 // Resolve address. 397 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 398 adaptor.getIndices(), rewriter); 399 400 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 401 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 402 return success(); 403 } 404 }; 405 406 /// Reduction neutral classes for overloading. 407 class ReductionNeutralZero {}; 408 class ReductionNeutralIntOne {}; 409 class ReductionNeutralFPOne {}; 410 class ReductionNeutralAllOnes {}; 411 class ReductionNeutralSIntMin {}; 412 class ReductionNeutralUIntMin {}; 413 class ReductionNeutralSIntMax {}; 414 class ReductionNeutralUIntMax {}; 415 class ReductionNeutralFPMin {}; 416 class ReductionNeutralFPMax {}; 417 418 /// Create the reduction neutral zero value. 419 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 420 ConversionPatternRewriter &rewriter, 421 Location loc, Type llvmType) { 422 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 423 rewriter.getZeroAttr(llvmType)); 424 } 425 426 /// Create the reduction neutral integer one value. 427 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 428 ConversionPatternRewriter &rewriter, 429 Location loc, Type llvmType) { 430 return rewriter.create<LLVM::ConstantOp>( 431 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 432 } 433 434 /// Create the reduction neutral fp one value. 435 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 436 ConversionPatternRewriter &rewriter, 437 Location loc, Type llvmType) { 438 return rewriter.create<LLVM::ConstantOp>( 439 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 440 } 441 442 /// Create the reduction neutral all-ones value. 443 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 444 ConversionPatternRewriter &rewriter, 445 Location loc, Type llvmType) { 446 return rewriter.create<LLVM::ConstantOp>( 447 loc, llvmType, 448 rewriter.getIntegerAttr( 449 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 450 } 451 452 /// Create the reduction neutral signed int minimum value. 453 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 454 ConversionPatternRewriter &rewriter, 455 Location loc, Type llvmType) { 456 return rewriter.create<LLVM::ConstantOp>( 457 loc, llvmType, 458 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 459 llvmType.getIntOrFloatBitWidth()))); 460 } 461 462 /// Create the reduction neutral unsigned int minimum value. 463 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 464 ConversionPatternRewriter &rewriter, 465 Location loc, Type llvmType) { 466 return rewriter.create<LLVM::ConstantOp>( 467 loc, llvmType, 468 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 469 llvmType.getIntOrFloatBitWidth()))); 470 } 471 472 /// Create the reduction neutral signed int maximum value. 473 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 474 ConversionPatternRewriter &rewriter, 475 Location loc, Type llvmType) { 476 return rewriter.create<LLVM::ConstantOp>( 477 loc, llvmType, 478 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 479 llvmType.getIntOrFloatBitWidth()))); 480 } 481 482 /// Create the reduction neutral unsigned int maximum value. 483 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 484 ConversionPatternRewriter &rewriter, 485 Location loc, Type llvmType) { 486 return rewriter.create<LLVM::ConstantOp>( 487 loc, llvmType, 488 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 489 llvmType.getIntOrFloatBitWidth()))); 490 } 491 492 /// Create the reduction neutral fp minimum value. 493 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 494 ConversionPatternRewriter &rewriter, 495 Location loc, Type llvmType) { 496 auto floatType = cast<FloatType>(llvmType); 497 return rewriter.create<LLVM::ConstantOp>( 498 loc, llvmType, 499 rewriter.getFloatAttr( 500 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 501 /*Negative=*/false))); 502 } 503 504 /// Create the reduction neutral fp maximum value. 505 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 506 ConversionPatternRewriter &rewriter, 507 Location loc, Type llvmType) { 508 auto floatType = cast<FloatType>(llvmType); 509 return rewriter.create<LLVM::ConstantOp>( 510 loc, llvmType, 511 rewriter.getFloatAttr( 512 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 513 /*Negative=*/true))); 514 } 515 516 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 517 /// returns a new accumulator value using `ReductionNeutral`. 518 template <class ReductionNeutral> 519 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 520 Location loc, Type llvmType, 521 Value accumulator) { 522 if (accumulator) 523 return accumulator; 524 525 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 526 llvmType); 527 } 528 529 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 530 /// This is used as effective vector length by some intrinsics supporting 531 /// dynamic vector lengths at runtime. 532 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 533 Location loc, Type llvmType) { 534 VectorType vType = cast<VectorType>(llvmType); 535 auto vShape = vType.getShape(); 536 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 537 538 return rewriter.create<LLVM::ConstantOp>( 539 loc, rewriter.getI32Type(), 540 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 541 } 542 543 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 544 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 545 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 546 /// non-null. 547 template <class LLVMRedIntrinOp, class ScalarOp> 548 static Value createIntegerReductionArithmeticOpLowering( 549 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 550 Value vectorOperand, Value accumulator) { 551 552 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 553 554 if (accumulator) 555 result = rewriter.create<ScalarOp>(loc, accumulator, result); 556 return result; 557 } 558 559 /// Helper method to lower a `vector.reduction` operation that performs 560 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 561 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 562 /// the accumulator value if non-null. 563 template <class LLVMRedIntrinOp> 564 static Value createIntegerReductionComparisonOpLowering( 565 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 566 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 567 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 568 if (accumulator) { 569 Value cmp = 570 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 571 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 572 } 573 return result; 574 } 575 576 namespace { 577 template <typename Source> 578 struct VectorToScalarMapper; 579 template <> 580 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> { 581 using Type = LLVM::MaximumOp; 582 }; 583 template <> 584 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> { 585 using Type = LLVM::MinimumOp; 586 }; 587 } // namespace 588 589 template <class LLVMRedIntrinOp> 590 static Value 591 createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter, 592 Location loc, Type llvmType, 593 Value vectorOperand, Value accumulator) { 594 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 595 596 if (accumulator) { 597 result = 598 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>( 599 loc, result, accumulator); 600 } 601 602 return result; 603 } 604 605 /// Overloaded methods to lower a reduction to an llvm instrinsic that requires 606 /// a start value. This start value format spans across fp reductions without 607 /// mask and all the masked reduction intrinsics. 608 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 609 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 610 Location loc, Type llvmType, 611 Value vectorOperand, 612 Value accumulator) { 613 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 614 llvmType, accumulator); 615 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 616 /*startValue=*/accumulator, 617 vectorOperand); 618 } 619 620 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 621 static Value 622 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 623 Type llvmType, Value vectorOperand, 624 Value accumulator, bool reassociateFPReds) { 625 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 626 llvmType, accumulator); 627 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 628 /*startValue=*/accumulator, 629 vectorOperand, reassociateFPReds); 630 } 631 632 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 633 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 634 Location loc, Type llvmType, 635 Value vectorOperand, 636 Value accumulator, Value mask) { 637 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 638 llvmType, accumulator); 639 Value vectorLength = 640 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 641 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 642 /*startValue=*/accumulator, 643 vectorOperand, mask, vectorLength); 644 } 645 646 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 647 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 648 Location loc, Type llvmType, 649 Value vectorOperand, 650 Value accumulator, Value mask, 651 bool reassociateFPReds) { 652 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 653 llvmType, accumulator); 654 Value vectorLength = 655 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 656 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 657 /*startValue=*/accumulator, 658 vectorOperand, mask, vectorLength, 659 reassociateFPReds); 660 } 661 662 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 663 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 664 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 665 Location loc, Type llvmType, 666 Value vectorOperand, 667 Value accumulator, Value mask) { 668 if (llvmType.isIntOrIndex()) 669 return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp, 670 IntReductionNeutral>( 671 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 672 673 // FP dispatch. 674 return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>( 675 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 676 } 677 678 /// Conversion pattern for all vector reductions. 679 class VectorReductionOpConversion 680 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 681 public: 682 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 683 bool reassociateFPRed) 684 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 685 reassociateFPReductions(reassociateFPRed) {} 686 687 LogicalResult 688 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 689 ConversionPatternRewriter &rewriter) const override { 690 auto kind = reductionOp.getKind(); 691 Type eltType = reductionOp.getDest().getType(); 692 Type llvmType = typeConverter->convertType(eltType); 693 Value operand = adaptor.getVector(); 694 Value acc = adaptor.getAcc(); 695 Location loc = reductionOp.getLoc(); 696 697 if (eltType.isIntOrIndex()) { 698 // Integer reductions: add/mul/min/max/and/or/xor. 699 Value result; 700 switch (kind) { 701 case vector::CombiningKind::ADD: 702 result = 703 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 704 LLVM::AddOp>( 705 rewriter, loc, llvmType, operand, acc); 706 break; 707 case vector::CombiningKind::MUL: 708 result = 709 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 710 LLVM::MulOp>( 711 rewriter, loc, llvmType, operand, acc); 712 break; 713 case vector::CombiningKind::MINUI: 714 result = createIntegerReductionComparisonOpLowering< 715 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 716 LLVM::ICmpPredicate::ule); 717 break; 718 case vector::CombiningKind::MINSI: 719 result = createIntegerReductionComparisonOpLowering< 720 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 721 LLVM::ICmpPredicate::sle); 722 break; 723 case vector::CombiningKind::MAXUI: 724 result = createIntegerReductionComparisonOpLowering< 725 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 726 LLVM::ICmpPredicate::uge); 727 break; 728 case vector::CombiningKind::MAXSI: 729 result = createIntegerReductionComparisonOpLowering< 730 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 731 LLVM::ICmpPredicate::sge); 732 break; 733 case vector::CombiningKind::AND: 734 result = 735 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 736 LLVM::AndOp>( 737 rewriter, loc, llvmType, operand, acc); 738 break; 739 case vector::CombiningKind::OR: 740 result = 741 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 742 LLVM::OrOp>( 743 rewriter, loc, llvmType, operand, acc); 744 break; 745 case vector::CombiningKind::XOR: 746 result = 747 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 748 LLVM::XOrOp>( 749 rewriter, loc, llvmType, operand, acc); 750 break; 751 default: 752 return failure(); 753 } 754 rewriter.replaceOp(reductionOp, result); 755 756 return success(); 757 } 758 759 if (!isa<FloatType>(eltType)) 760 return failure(); 761 762 // Floating-point reductions: add/mul/min/max 763 Value result; 764 if (kind == vector::CombiningKind::ADD) { 765 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 766 ReductionNeutralZero>( 767 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 768 } else if (kind == vector::CombiningKind::MUL) { 769 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 770 ReductionNeutralFPOne>( 771 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 772 } else if (kind == vector::CombiningKind::MINF) { 773 result = 774 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>( 775 rewriter, loc, llvmType, operand, acc); 776 } else if (kind == vector::CombiningKind::MAXF) { 777 result = 778 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>( 779 rewriter, loc, llvmType, operand, acc); 780 } else 781 return failure(); 782 783 rewriter.replaceOp(reductionOp, result); 784 return success(); 785 } 786 787 private: 788 const bool reassociateFPReductions; 789 }; 790 791 /// Base class to convert a `vector.mask` operation while matching traits 792 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 793 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 794 /// method performs a second match against the maskable operation `MaskedOp`. 795 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 796 /// implemented by the concrete conversion classes. This method can match 797 /// against specific traits of the `vector.mask` and the maskable operation. It 798 /// must replace the `vector.mask` operation. 799 template <class MaskedOp> 800 class VectorMaskOpConversionBase 801 : public ConvertOpToLLVMPattern<vector::MaskOp> { 802 public: 803 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 804 805 LogicalResult 806 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 807 ConversionPatternRewriter &rewriter) const final { 808 // Match against the maskable operation kind. 809 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 810 if (!maskedOp) 811 return failure(); 812 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 813 } 814 815 protected: 816 virtual LogicalResult 817 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 818 vector::MaskableOpInterface maskableOp, 819 ConversionPatternRewriter &rewriter) const = 0; 820 }; 821 822 class MaskedReductionOpConversion 823 : public VectorMaskOpConversionBase<vector::ReductionOp> { 824 825 public: 826 using VectorMaskOpConversionBase< 827 vector::ReductionOp>::VectorMaskOpConversionBase; 828 829 LogicalResult matchAndRewriteMaskableOp( 830 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 831 ConversionPatternRewriter &rewriter) const override { 832 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 833 auto kind = reductionOp.getKind(); 834 Type eltType = reductionOp.getDest().getType(); 835 Type llvmType = typeConverter->convertType(eltType); 836 Value operand = reductionOp.getVector(); 837 Value acc = reductionOp.getAcc(); 838 Location loc = reductionOp.getLoc(); 839 840 Value result; 841 switch (kind) { 842 case vector::CombiningKind::ADD: 843 result = lowerReductionWithStartValue< 844 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 845 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 846 maskOp.getMask()); 847 break; 848 case vector::CombiningKind::MUL: 849 result = lowerReductionWithStartValue< 850 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 851 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 852 maskOp.getMask()); 853 break; 854 case vector::CombiningKind::MINUI: 855 result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp, 856 ReductionNeutralUIntMax>( 857 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 858 break; 859 case vector::CombiningKind::MINSI: 860 result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp, 861 ReductionNeutralSIntMax>( 862 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 863 break; 864 case vector::CombiningKind::MAXUI: 865 result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp, 866 ReductionNeutralUIntMin>( 867 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 868 break; 869 case vector::CombiningKind::MAXSI: 870 result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp, 871 ReductionNeutralSIntMin>( 872 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 873 break; 874 case vector::CombiningKind::AND: 875 result = lowerReductionWithStartValue<LLVM::VPReduceAndOp, 876 ReductionNeutralAllOnes>( 877 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 878 break; 879 case vector::CombiningKind::OR: 880 result = lowerReductionWithStartValue<LLVM::VPReduceOrOp, 881 ReductionNeutralZero>( 882 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 883 break; 884 case vector::CombiningKind::XOR: 885 result = lowerReductionWithStartValue<LLVM::VPReduceXorOp, 886 ReductionNeutralZero>( 887 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 888 break; 889 case vector::CombiningKind::MINF: 890 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 891 // NaNs/-0.0/+0.0 in the same way. 892 result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp, 893 ReductionNeutralFPMax>( 894 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 895 break; 896 case vector::CombiningKind::MAXF: 897 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 898 // NaNs/-0.0/+0.0 in the same way. 899 result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp, 900 ReductionNeutralFPMin>( 901 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 902 break; 903 } 904 905 // Replace `vector.mask` operation altogether. 906 rewriter.replaceOp(maskOp, result); 907 return success(); 908 } 909 }; 910 911 class VectorShuffleOpConversion 912 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 913 public: 914 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 915 916 LogicalResult 917 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 918 ConversionPatternRewriter &rewriter) const override { 919 auto loc = shuffleOp->getLoc(); 920 auto v1Type = shuffleOp.getV1VectorType(); 921 auto v2Type = shuffleOp.getV2VectorType(); 922 auto vectorType = shuffleOp.getResultVectorType(); 923 Type llvmType = typeConverter->convertType(vectorType); 924 auto maskArrayAttr = shuffleOp.getMask(); 925 926 // Bail if result type cannot be lowered. 927 if (!llvmType) 928 return failure(); 929 930 // Get rank and dimension sizes. 931 int64_t rank = vectorType.getRank(); 932 #ifndef NDEBUG 933 bool wellFormed0DCase = 934 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 935 bool wellFormedNDCase = 936 v1Type.getRank() == rank && v2Type.getRank() == rank; 937 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 938 #endif 939 940 // For rank 0 and 1, where both operands have *exactly* the same vector 941 // type, there is direct shuffle support in LLVM. Use it! 942 if (rank <= 1 && v1Type == v2Type) { 943 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 944 loc, adaptor.getV1(), adaptor.getV2(), 945 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 946 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 947 return success(); 948 } 949 950 // For all other cases, insert the individual values individually. 951 int64_t v1Dim = v1Type.getDimSize(0); 952 Type eltType; 953 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType)) 954 eltType = arrayType.getElementType(); 955 else 956 eltType = cast<VectorType>(llvmType).getElementType(); 957 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 958 int64_t insPos = 0; 959 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 960 int64_t extPos = cast<IntegerAttr>(en.value()).getInt(); 961 Value value = adaptor.getV1(); 962 if (extPos >= v1Dim) { 963 extPos -= v1Dim; 964 value = adaptor.getV2(); 965 } 966 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 967 eltType, rank, extPos); 968 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 969 llvmType, rank, insPos++); 970 } 971 rewriter.replaceOp(shuffleOp, insert); 972 return success(); 973 } 974 }; 975 976 class VectorExtractElementOpConversion 977 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 978 public: 979 using ConvertOpToLLVMPattern< 980 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 981 982 LogicalResult 983 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 984 ConversionPatternRewriter &rewriter) const override { 985 auto vectorType = extractEltOp.getSourceVectorType(); 986 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 987 988 // Bail if result type cannot be lowered. 989 if (!llvmType) 990 return failure(); 991 992 if (vectorType.getRank() == 0) { 993 Location loc = extractEltOp.getLoc(); 994 auto idxType = rewriter.getIndexType(); 995 auto zero = rewriter.create<LLVM::ConstantOp>( 996 loc, typeConverter->convertType(idxType), 997 rewriter.getIntegerAttr(idxType, 0)); 998 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 999 extractEltOp, llvmType, adaptor.getVector(), zero); 1000 return success(); 1001 } 1002 1003 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1004 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1005 return success(); 1006 } 1007 }; 1008 1009 class VectorExtractOpConversion 1010 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1011 public: 1012 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1013 1014 LogicalResult 1015 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1016 ConversionPatternRewriter &rewriter) const override { 1017 auto loc = extractOp->getLoc(); 1018 auto resultType = extractOp.getResult().getType(); 1019 auto llvmResultType = typeConverter->convertType(resultType); 1020 ArrayRef<int64_t> positionArray = extractOp.getPosition(); 1021 1022 // Bail if result type cannot be lowered. 1023 if (!llvmResultType) 1024 return failure(); 1025 1026 // Extract entire vector. Should be handled by folder, but just to be safe. 1027 if (positionArray.empty()) { 1028 rewriter.replaceOp(extractOp, adaptor.getVector()); 1029 return success(); 1030 } 1031 1032 // One-shot extraction of vector from array (only requires extractvalue). 1033 if (isa<VectorType>(resultType)) { 1034 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1035 loc, adaptor.getVector(), positionArray); 1036 rewriter.replaceOp(extractOp, extracted); 1037 return success(); 1038 } 1039 1040 // Potential extraction of 1-D vector from array. 1041 Value extracted = adaptor.getVector(); 1042 if (positionArray.size() > 1) { 1043 extracted = rewriter.create<LLVM::ExtractValueOp>( 1044 loc, extracted, positionArray.drop_back()); 1045 } 1046 1047 // Remaining extraction of element from 1-D LLVM vector 1048 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1049 auto constant = 1050 rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back()); 1051 extracted = 1052 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 1053 rewriter.replaceOp(extractOp, extracted); 1054 1055 return success(); 1056 } 1057 }; 1058 1059 /// Conversion pattern that turns a vector.fma on a 1-D vector 1060 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1061 /// This does not match vectors of n >= 2 rank. 1062 /// 1063 /// Example: 1064 /// ``` 1065 /// vector.fma %a, %a, %a : vector<8xf32> 1066 /// ``` 1067 /// is converted to: 1068 /// ``` 1069 /// llvm.intr.fmuladd %va, %va, %va: 1070 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1071 /// -> !llvm."<8 x f32>"> 1072 /// ``` 1073 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1074 public: 1075 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1076 1077 LogicalResult 1078 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1079 ConversionPatternRewriter &rewriter) const override { 1080 VectorType vType = fmaOp.getVectorType(); 1081 if (vType.getRank() > 1) 1082 return failure(); 1083 1084 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1085 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1086 return success(); 1087 } 1088 }; 1089 1090 class VectorInsertElementOpConversion 1091 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1092 public: 1093 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1094 1095 LogicalResult 1096 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1097 ConversionPatternRewriter &rewriter) const override { 1098 auto vectorType = insertEltOp.getDestVectorType(); 1099 auto llvmType = typeConverter->convertType(vectorType); 1100 1101 // Bail if result type cannot be lowered. 1102 if (!llvmType) 1103 return failure(); 1104 1105 if (vectorType.getRank() == 0) { 1106 Location loc = insertEltOp.getLoc(); 1107 auto idxType = rewriter.getIndexType(); 1108 auto zero = rewriter.create<LLVM::ConstantOp>( 1109 loc, typeConverter->convertType(idxType), 1110 rewriter.getIntegerAttr(idxType, 0)); 1111 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1112 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1113 return success(); 1114 } 1115 1116 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1117 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1118 adaptor.getPosition()); 1119 return success(); 1120 } 1121 }; 1122 1123 class VectorInsertOpConversion 1124 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1125 public: 1126 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1127 1128 LogicalResult 1129 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1130 ConversionPatternRewriter &rewriter) const override { 1131 auto loc = insertOp->getLoc(); 1132 auto sourceType = insertOp.getSourceType(); 1133 auto destVectorType = insertOp.getDestVectorType(); 1134 auto llvmResultType = typeConverter->convertType(destVectorType); 1135 ArrayRef<int64_t> positionArray = insertOp.getPosition(); 1136 1137 // Bail if result type cannot be lowered. 1138 if (!llvmResultType) 1139 return failure(); 1140 1141 // Overwrite entire vector with value. Should be handled by folder, but 1142 // just to be safe. 1143 if (positionArray.empty()) { 1144 rewriter.replaceOp(insertOp, adaptor.getSource()); 1145 return success(); 1146 } 1147 1148 // One-shot insertion of a vector into an array (only requires insertvalue). 1149 if (isa<VectorType>(sourceType)) { 1150 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1151 loc, adaptor.getDest(), adaptor.getSource(), positionArray); 1152 rewriter.replaceOp(insertOp, inserted); 1153 return success(); 1154 } 1155 1156 // Potential extraction of 1-D vector from array. 1157 Value extracted = adaptor.getDest(); 1158 auto oneDVectorType = destVectorType; 1159 if (positionArray.size() > 1) { 1160 oneDVectorType = reducedVectorTypeBack(destVectorType); 1161 extracted = rewriter.create<LLVM::ExtractValueOp>( 1162 loc, extracted, positionArray.drop_back()); 1163 } 1164 1165 // Insertion of an element into a 1-D LLVM vector. 1166 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1167 auto constant = 1168 rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back()); 1169 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1170 loc, typeConverter->convertType(oneDVectorType), extracted, 1171 adaptor.getSource(), constant); 1172 1173 // Potential insertion of resulting 1-D vector into array. 1174 if (positionArray.size() > 1) { 1175 inserted = rewriter.create<LLVM::InsertValueOp>( 1176 loc, adaptor.getDest(), inserted, positionArray.drop_back()); 1177 } 1178 1179 rewriter.replaceOp(insertOp, inserted); 1180 return success(); 1181 } 1182 }; 1183 1184 /// Lower vector.scalable.insert ops to LLVM vector.insert 1185 struct VectorScalableInsertOpLowering 1186 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1187 using ConvertOpToLLVMPattern< 1188 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1189 1190 LogicalResult 1191 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1192 ConversionPatternRewriter &rewriter) const override { 1193 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1194 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1195 return success(); 1196 } 1197 }; 1198 1199 /// Lower vector.scalable.extract ops to LLVM vector.extract 1200 struct VectorScalableExtractOpLowering 1201 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1202 using ConvertOpToLLVMPattern< 1203 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1204 1205 LogicalResult 1206 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1207 ConversionPatternRewriter &rewriter) const override { 1208 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1209 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1210 adaptor.getSource(), adaptor.getPos()); 1211 return success(); 1212 } 1213 }; 1214 1215 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1216 /// 1217 /// Example: 1218 /// ``` 1219 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1220 /// ``` 1221 /// is rewritten into: 1222 /// ``` 1223 /// %r = splat %f0: vector<2x4xf32> 1224 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1225 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1226 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1227 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1228 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1229 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1230 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1231 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1232 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1233 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1234 /// // %r3 holds the final value. 1235 /// ``` 1236 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1237 public: 1238 using OpRewritePattern<FMAOp>::OpRewritePattern; 1239 1240 void initialize() { 1241 // This pattern recursively unpacks one dimension at a time. The recursion 1242 // bounded as the rank is strictly decreasing. 1243 setHasBoundedRewriteRecursion(); 1244 } 1245 1246 LogicalResult matchAndRewrite(FMAOp op, 1247 PatternRewriter &rewriter) const override { 1248 auto vType = op.getVectorType(); 1249 if (vType.getRank() < 2) 1250 return failure(); 1251 1252 auto loc = op.getLoc(); 1253 auto elemType = vType.getElementType(); 1254 Value zero = rewriter.create<arith::ConstantOp>( 1255 loc, elemType, rewriter.getZeroAttr(elemType)); 1256 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1257 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1258 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1259 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1260 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1261 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1262 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1263 } 1264 rewriter.replaceOp(op, desc); 1265 return success(); 1266 } 1267 }; 1268 1269 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1270 /// static layout. 1271 static std::optional<SmallVector<int64_t, 4>> 1272 computeContiguousStrides(MemRefType memRefType) { 1273 int64_t offset; 1274 SmallVector<int64_t, 4> strides; 1275 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1276 return std::nullopt; 1277 if (!strides.empty() && strides.back() != 1) 1278 return std::nullopt; 1279 // If no layout or identity layout, this is contiguous by definition. 1280 if (memRefType.getLayout().isIdentity()) 1281 return strides; 1282 1283 // Otherwise, we must determine contiguity form shapes. This can only ever 1284 // work in static cases because MemRefType is underspecified to represent 1285 // contiguous dynamic shapes in other ways than with just empty/identity 1286 // layout. 1287 auto sizes = memRefType.getShape(); 1288 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1289 if (ShapedType::isDynamic(sizes[index + 1]) || 1290 ShapedType::isDynamic(strides[index]) || 1291 ShapedType::isDynamic(strides[index + 1])) 1292 return std::nullopt; 1293 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1294 return std::nullopt; 1295 } 1296 return strides; 1297 } 1298 1299 class VectorTypeCastOpConversion 1300 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1301 public: 1302 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1303 1304 LogicalResult 1305 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1306 ConversionPatternRewriter &rewriter) const override { 1307 auto loc = castOp->getLoc(); 1308 MemRefType sourceMemRefType = 1309 cast<MemRefType>(castOp.getOperand().getType()); 1310 MemRefType targetMemRefType = castOp.getType(); 1311 1312 // Only static shape casts supported atm. 1313 if (!sourceMemRefType.hasStaticShape() || 1314 !targetMemRefType.hasStaticShape()) 1315 return failure(); 1316 1317 auto llvmSourceDescriptorTy = 1318 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); 1319 if (!llvmSourceDescriptorTy) 1320 return failure(); 1321 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1322 1323 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1324 typeConverter->convertType(targetMemRefType)); 1325 if (!llvmTargetDescriptorTy) 1326 return failure(); 1327 1328 // Only contiguous source buffers supported atm. 1329 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1330 if (!sourceStrides) 1331 return failure(); 1332 auto targetStrides = computeContiguousStrides(targetMemRefType); 1333 if (!targetStrides) 1334 return failure(); 1335 // Only support static strides for now, regardless of contiguity. 1336 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1337 return failure(); 1338 1339 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1340 1341 // Create descriptor. 1342 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1343 Type llvmTargetElementTy = desc.getElementPtrType(); 1344 // Set allocated ptr. 1345 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1346 if (!getTypeConverter()->useOpaquePointers()) 1347 allocated = 1348 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1349 desc.setAllocatedPtr(rewriter, loc, allocated); 1350 1351 // Set aligned ptr. 1352 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1353 if (!getTypeConverter()->useOpaquePointers()) 1354 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1355 1356 desc.setAlignedPtr(rewriter, loc, ptr); 1357 // Fill offset 0. 1358 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1359 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1360 desc.setOffset(rewriter, loc, zero); 1361 1362 // Fill size and stride descriptors in memref. 1363 for (const auto &indexedSize : 1364 llvm::enumerate(targetMemRefType.getShape())) { 1365 int64_t index = indexedSize.index(); 1366 auto sizeAttr = 1367 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1368 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1369 desc.setSize(rewriter, loc, index, size); 1370 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1371 (*targetStrides)[index]); 1372 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1373 desc.setStride(rewriter, loc, index, stride); 1374 } 1375 1376 rewriter.replaceOp(castOp, {desc}); 1377 return success(); 1378 } 1379 }; 1380 1381 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1382 /// Non-scalable versions of this operation are handled in Vector Transforms. 1383 class VectorCreateMaskOpRewritePattern 1384 : public OpRewritePattern<vector::CreateMaskOp> { 1385 public: 1386 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1387 bool enableIndexOpt) 1388 : OpRewritePattern<vector::CreateMaskOp>(context), 1389 force32BitVectorIndices(enableIndexOpt) {} 1390 1391 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1392 PatternRewriter &rewriter) const override { 1393 auto dstType = op.getType(); 1394 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) 1395 return failure(); 1396 IntegerType idxType = 1397 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1398 auto loc = op->getLoc(); 1399 Value indices = rewriter.create<LLVM::StepVectorOp>( 1400 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1401 /*isScalable=*/true)); 1402 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1403 op.getOperand(0)); 1404 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1405 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1406 indices, bounds); 1407 rewriter.replaceOp(op, comp); 1408 return success(); 1409 } 1410 1411 private: 1412 const bool force32BitVectorIndices; 1413 }; 1414 1415 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1416 public: 1417 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1418 1419 // Proof-of-concept lowering implementation that relies on a small 1420 // runtime support library, which only needs to provide a few 1421 // printing methods (single value for all data types, opening/closing 1422 // bracket, comma, newline). The lowering fully unrolls a vector 1423 // in terms of these elementary printing operations. The advantage 1424 // of this approach is that the library can remain unaware of all 1425 // low-level implementation details of vectors while still supporting 1426 // output of any shaped and dimensioned vector. Due to full unrolling, 1427 // this approach is less suited for very large vectors though. 1428 // 1429 // TODO: rely solely on libc in future? something else? 1430 // 1431 LogicalResult 1432 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1433 ConversionPatternRewriter &rewriter) const override { 1434 Type printType = printOp.getPrintType(); 1435 1436 if (typeConverter->convertType(printType) == nullptr) 1437 return failure(); 1438 1439 // Make sure element type has runtime support. 1440 PrintConversion conversion = PrintConversion::None; 1441 VectorType vectorType = dyn_cast<VectorType>(printType); 1442 Type eltType = vectorType ? vectorType.getElementType() : printType; 1443 auto parent = printOp->getParentOfType<ModuleOp>(); 1444 Operation *printer; 1445 if (eltType.isF32()) { 1446 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1447 } else if (eltType.isF64()) { 1448 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1449 } else if (eltType.isF16()) { 1450 conversion = PrintConversion::Bitcast16; // bits! 1451 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1452 } else if (eltType.isBF16()) { 1453 conversion = PrintConversion::Bitcast16; // bits! 1454 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1455 } else if (eltType.isIndex()) { 1456 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1457 } else if (auto intTy = dyn_cast<IntegerType>(eltType)) { 1458 // Integers need a zero or sign extension on the operand 1459 // (depending on the source type) as well as a signed or 1460 // unsigned print method. Up to 64-bit is supported. 1461 unsigned width = intTy.getWidth(); 1462 if (intTy.isUnsigned()) { 1463 if (width <= 64) { 1464 if (width < 64) 1465 conversion = PrintConversion::ZeroExt64; 1466 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1467 } else { 1468 return failure(); 1469 } 1470 } else { 1471 assert(intTy.isSignless() || intTy.isSigned()); 1472 if (width <= 64) { 1473 // Note that we *always* zero extend booleans (1-bit integers), 1474 // so that true/false is printed as 1/0 rather than -1/0. 1475 if (width == 1) 1476 conversion = PrintConversion::ZeroExt64; 1477 else if (width < 64) 1478 conversion = PrintConversion::SignExt64; 1479 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1480 } else { 1481 return failure(); 1482 } 1483 } 1484 } else { 1485 return failure(); 1486 } 1487 1488 // Unroll vector into elementary print calls. 1489 int64_t rank = vectorType ? vectorType.getRank() : 0; 1490 Type type = vectorType ? vectorType : eltType; 1491 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1492 conversion); 1493 emitCall(rewriter, printOp->getLoc(), 1494 LLVM::lookupOrCreatePrintNewlineFn(parent)); 1495 rewriter.eraseOp(printOp); 1496 return success(); 1497 } 1498 1499 private: 1500 enum class PrintConversion { 1501 // clang-format off 1502 None, 1503 ZeroExt64, 1504 SignExt64, 1505 Bitcast16 1506 // clang-format on 1507 }; 1508 1509 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1510 Value value, Type type, Operation *printer, int64_t rank, 1511 PrintConversion conversion) const { 1512 VectorType vectorType = dyn_cast<VectorType>(type); 1513 Location loc = op->getLoc(); 1514 if (!vectorType) { 1515 assert(rank == 0 && "The scalar case expects rank == 0"); 1516 switch (conversion) { 1517 case PrintConversion::ZeroExt64: 1518 value = rewriter.create<arith::ExtUIOp>( 1519 loc, IntegerType::get(rewriter.getContext(), 64), value); 1520 break; 1521 case PrintConversion::SignExt64: 1522 value = rewriter.create<arith::ExtSIOp>( 1523 loc, IntegerType::get(rewriter.getContext(), 64), value); 1524 break; 1525 case PrintConversion::Bitcast16: 1526 value = rewriter.create<LLVM::BitcastOp>( 1527 loc, IntegerType::get(rewriter.getContext(), 16), value); 1528 break; 1529 case PrintConversion::None: 1530 break; 1531 } 1532 emitCall(rewriter, loc, printer, value); 1533 return; 1534 } 1535 1536 auto parent = op->getParentOfType<ModuleOp>(); 1537 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent)); 1538 Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent); 1539 1540 if (rank <= 1) { 1541 auto reducedType = vectorType.getElementType(); 1542 auto llvmType = typeConverter->convertType(reducedType); 1543 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1544 for (int64_t d = 0; d < dim; ++d) { 1545 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1546 llvmType, /*rank=*/0, /*pos=*/d); 1547 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1548 conversion); 1549 if (d != dim - 1) 1550 emitCall(rewriter, loc, printComma); 1551 } 1552 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); 1553 return; 1554 } 1555 1556 int64_t dim = vectorType.getDimSize(0); 1557 for (int64_t d = 0; d < dim; ++d) { 1558 auto reducedType = reducedVectorTypeFront(vectorType); 1559 auto llvmType = typeConverter->convertType(reducedType); 1560 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1561 llvmType, rank, d); 1562 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1563 conversion); 1564 if (d != dim - 1) 1565 emitCall(rewriter, loc, printComma); 1566 } 1567 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); 1568 } 1569 1570 // Helper to emit a call. 1571 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1572 Operation *ref, ValueRange params = ValueRange()) { 1573 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1574 params); 1575 } 1576 }; 1577 1578 /// The Splat operation is lowered to an insertelement + a shufflevector 1579 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1580 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1581 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1582 1583 LogicalResult 1584 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1585 ConversionPatternRewriter &rewriter) const override { 1586 VectorType resultType = cast<VectorType>(splatOp.getType()); 1587 if (resultType.getRank() > 1) 1588 return failure(); 1589 1590 // First insert it into an undef vector so we can shuffle it. 1591 auto vectorType = typeConverter->convertType(splatOp.getType()); 1592 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1593 auto zero = rewriter.create<LLVM::ConstantOp>( 1594 splatOp.getLoc(), 1595 typeConverter->convertType(rewriter.getIntegerType(32)), 1596 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1597 1598 // For 0-d vector, we simply do `insertelement`. 1599 if (resultType.getRank() == 0) { 1600 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1601 splatOp, vectorType, undef, adaptor.getInput(), zero); 1602 return success(); 1603 } 1604 1605 // For 1-d vector, we additionally do a `vectorshuffle`. 1606 auto v = rewriter.create<LLVM::InsertElementOp>( 1607 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1608 1609 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); 1610 SmallVector<int32_t> zeroValues(width, 0); 1611 1612 // Shuffle the value across the desired number of elements. 1613 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1614 zeroValues); 1615 return success(); 1616 } 1617 }; 1618 1619 /// The Splat operation is lowered to an insertelement + a shufflevector 1620 /// operation. Splat to only 2+-d vector result types are lowered by the 1621 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1622 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1623 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1624 1625 LogicalResult 1626 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1627 ConversionPatternRewriter &rewriter) const override { 1628 VectorType resultType = splatOp.getType(); 1629 if (resultType.getRank() <= 1) 1630 return failure(); 1631 1632 // First insert it into an undef vector so we can shuffle it. 1633 auto loc = splatOp.getLoc(); 1634 auto vectorTypeInfo = 1635 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1636 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1637 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1638 if (!llvmNDVectorTy || !llvm1DVectorTy) 1639 return failure(); 1640 1641 // Construct returned value. 1642 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1643 1644 // Construct a 1-D vector with the splatted value that we insert in all the 1645 // places within the returned descriptor. 1646 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1647 auto zero = rewriter.create<LLVM::ConstantOp>( 1648 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1649 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1650 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1651 adaptor.getInput(), zero); 1652 1653 // Shuffle the value across the desired number of elements. 1654 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1655 SmallVector<int32_t> zeroValues(width, 0); 1656 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1657 1658 // Iterate of linear index, convert to coords space and insert splatted 1-D 1659 // vector in each position. 1660 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1661 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1662 }); 1663 rewriter.replaceOp(splatOp, desc); 1664 return success(); 1665 } 1666 }; 1667 1668 } // namespace 1669 1670 /// Populate the given list with patterns that convert from Vector to LLVM. 1671 void mlir::populateVectorToLLVMConversionPatterns( 1672 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1673 bool reassociateFPReductions, bool force32BitVectorIndices) { 1674 MLIRContext *ctx = converter.getDialect()->getContext(); 1675 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1676 populateVectorInsertExtractStridedSliceTransforms(patterns); 1677 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1678 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1679 patterns 1680 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1681 VectorExtractElementOpConversion, VectorExtractOpConversion, 1682 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1683 VectorInsertOpConversion, VectorPrintOpConversion, 1684 VectorTypeCastOpConversion, VectorScaleOpConversion, 1685 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1686 VectorLoadStoreConversion<vector::MaskedLoadOp, 1687 vector::MaskedLoadOpAdaptor>, 1688 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1689 VectorLoadStoreConversion<vector::MaskedStoreOp, 1690 vector::MaskedStoreOpAdaptor>, 1691 VectorGatherOpConversion, VectorScatterOpConversion, 1692 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1693 VectorSplatOpLowering, VectorSplatNdOpLowering, 1694 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1695 MaskedReductionOpConversion>(converter); 1696 // Transfer ops with rank > 1 are handled by VectorToSCF. 1697 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1698 } 1699 1700 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1701 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1702 patterns.add<VectorMatmulOpConversion>(converter); 1703 patterns.add<VectorFlatTransposeOpConversion>(converter); 1704 } 1705