1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" 12 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Utils/Utils.h" 17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 22 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 24 #include "mlir/IR/BuiltinAttributes.h" 25 #include "mlir/IR/BuiltinTypeInterfaces.h" 26 #include "mlir/IR/BuiltinTypes.h" 27 #include "mlir/IR/TypeUtilities.h" 28 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 #include "llvm/ADT/APFloat.h" 31 #include "llvm/Support/Casting.h" 32 #include <optional> 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 // Helper to reduce vector type by *all* but one rank at back. 38 static VectorType reducedVectorTypeBack(VectorType tp) { 39 assert((tp.getRank() > 1) && "unlowerable vector type"); 40 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 41 tp.getScalableDims().take_back()); 42 } 43 44 // Helper that picks the proper sequence for inserting. 45 static Value insertOne(ConversionPatternRewriter &rewriter, 46 const LLVMTypeConverter &typeConverter, Location loc, 47 Value val1, Value val2, Type llvmType, int64_t rank, 48 int64_t pos) { 49 assert(rank > 0 && "0-D vector corner case should have been handled already"); 50 if (rank == 1) { 51 auto idxType = rewriter.getIndexType(); 52 auto constant = rewriter.create<LLVM::ConstantOp>( 53 loc, typeConverter.convertType(idxType), 54 rewriter.getIntegerAttr(idxType, pos)); 55 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 56 constant); 57 } 58 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 59 } 60 61 // Helper that picks the proper sequence for extracting. 62 static Value extractOne(ConversionPatternRewriter &rewriter, 63 const LLVMTypeConverter &typeConverter, Location loc, 64 Value val, Type llvmType, int64_t rank, int64_t pos) { 65 if (rank <= 1) { 66 auto idxType = rewriter.getIndexType(); 67 auto constant = rewriter.create<LLVM::ConstantOp>( 68 loc, typeConverter.convertType(idxType), 69 rewriter.getIntegerAttr(idxType, pos)); 70 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 71 constant); 72 } 73 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 74 } 75 76 // Helper that returns data layout alignment of a memref. 77 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, 78 MemRefType memrefType, unsigned &align) { 79 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 80 if (!elementTy) 81 return failure(); 82 83 // TODO: this should use the MLIR data layout when it becomes available and 84 // stop depending on translation. 85 llvm::LLVMContext llvmContext; 86 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 87 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 88 return success(); 89 } 90 91 // Check if the last stride is non-unit and has a valid memory space. 92 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 93 const LLVMTypeConverter &converter) { 94 if (!memRefType.isLastDimUnitStride()) 95 return failure(); 96 if (failed(converter.getMemRefAddressSpace(memRefType))) 97 return failure(); 98 return success(); 99 } 100 101 // Add an index vector component to a base pointer. 102 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 103 const LLVMTypeConverter &typeConverter, 104 MemRefType memRefType, Value llvmMemref, Value base, 105 Value index, VectorType vectorType) { 106 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 107 "unsupported memref type"); 108 assert(vectorType.getRank() == 1 && "expected a 1-d vector type"); 109 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 110 auto ptrsType = 111 LLVM::getVectorType(pType, vectorType.getDimSize(0), 112 /*isScalable=*/vectorType.getScalableDims()[0]); 113 return rewriter.create<LLVM::GEPOp>( 114 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 115 base, index); 116 } 117 118 /// Convert `foldResult` into a Value. Integer attribute is converted to 119 /// an LLVM constant op. 120 static Value getAsLLVMValue(OpBuilder &builder, Location loc, 121 OpFoldResult foldResult) { 122 if (auto attr = foldResult.dyn_cast<Attribute>()) { 123 auto intAttr = cast<IntegerAttr>(attr); 124 return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult(); 125 } 126 127 return cast<Value>(foldResult); 128 } 129 130 namespace { 131 132 /// Trivial Vector to LLVM conversions 133 using VectorScaleOpConversion = 134 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 135 136 /// Conversion pattern for a vector.bitcast. 137 class VectorBitCastOpConversion 138 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 139 public: 140 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 141 142 LogicalResult 143 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 144 ConversionPatternRewriter &rewriter) const override { 145 // Only 0-D and 1-D vectors can be lowered to LLVM. 146 VectorType resultTy = bitCastOp.getResultVectorType(); 147 if (resultTy.getRank() > 1) 148 return failure(); 149 Type newResultTy = typeConverter->convertType(resultTy); 150 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 151 adaptor.getOperands()[0]); 152 return success(); 153 } 154 }; 155 156 /// Conversion pattern for a vector.matrix_multiply. 157 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 158 class VectorMatmulOpConversion 159 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 160 public: 161 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 162 163 LogicalResult 164 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 165 ConversionPatternRewriter &rewriter) const override { 166 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 167 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 168 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 169 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 170 return success(); 171 } 172 }; 173 174 /// Conversion pattern for a vector.flat_transpose. 175 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 176 class VectorFlatTransposeOpConversion 177 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 178 public: 179 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 180 181 LogicalResult 182 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 183 ConversionPatternRewriter &rewriter) const override { 184 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 185 transOp, typeConverter->convertType(transOp.getRes().getType()), 186 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 187 return success(); 188 } 189 }; 190 191 /// Overloaded utility that replaces a vector.load, vector.store, 192 /// vector.maskedload and vector.maskedstore with their respective LLVM 193 /// couterparts. 194 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 195 vector::LoadOpAdaptor adaptor, 196 VectorType vectorTy, Value ptr, unsigned align, 197 ConversionPatternRewriter &rewriter) { 198 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align, 199 /*volatile_=*/false, 200 loadOp.getNontemporal()); 201 } 202 203 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 204 vector::MaskedLoadOpAdaptor adaptor, 205 VectorType vectorTy, Value ptr, unsigned align, 206 ConversionPatternRewriter &rewriter) { 207 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 208 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 209 } 210 211 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 212 vector::StoreOpAdaptor adaptor, 213 VectorType vectorTy, Value ptr, unsigned align, 214 ConversionPatternRewriter &rewriter) { 215 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 216 ptr, align, /*volatile_=*/false, 217 storeOp.getNontemporal()); 218 } 219 220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 221 vector::MaskedStoreOpAdaptor adaptor, 222 VectorType vectorTy, Value ptr, unsigned align, 223 ConversionPatternRewriter &rewriter) { 224 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 225 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 226 } 227 228 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 229 /// vector.maskedstore. 230 template <class LoadOrStoreOp> 231 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 232 public: 233 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 234 235 LogicalResult 236 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 237 typename LoadOrStoreOp::Adaptor adaptor, 238 ConversionPatternRewriter &rewriter) const override { 239 // Only 1-D vectors can be lowered to LLVM. 240 VectorType vectorTy = loadOrStoreOp.getVectorType(); 241 if (vectorTy.getRank() > 1) 242 return failure(); 243 244 auto loc = loadOrStoreOp->getLoc(); 245 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 246 247 // Resolve alignment. 248 unsigned align; 249 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 250 return failure(); 251 252 // Resolve address. 253 auto vtype = cast<VectorType>( 254 this->typeConverter->convertType(loadOrStoreOp.getVectorType())); 255 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 256 adaptor.getIndices(), rewriter); 257 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align, 258 rewriter); 259 return success(); 260 } 261 }; 262 263 /// Conversion pattern for a vector.gather. 264 class VectorGatherOpConversion 265 : public ConvertOpToLLVMPattern<vector::GatherOp> { 266 public: 267 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 268 269 LogicalResult 270 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 271 ConversionPatternRewriter &rewriter) const override { 272 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); 273 assert(memRefType && "The base should be bufferized"); 274 275 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 276 return failure(); 277 278 auto loc = gather->getLoc(); 279 280 // Resolve alignment. 281 unsigned align; 282 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 283 return failure(); 284 285 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 286 adaptor.getIndices(), rewriter); 287 Value base = adaptor.getBase(); 288 289 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 290 // Handle the simple case of 1-D vector. 291 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) { 292 auto vType = gather.getVectorType(); 293 // Resolve address. 294 Value ptrs = 295 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, 296 base, ptr, adaptor.getIndexVec(), vType); 297 // Replace with the gather intrinsic. 298 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 299 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 300 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 301 return success(); 302 } 303 304 const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 305 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 306 &typeConverter](Type llvm1DVectorTy, 307 ValueRange vectorOperands) { 308 // Resolve address. 309 Value ptrs = getIndexedPtrs( 310 rewriter, loc, typeConverter, memRefType, base, ptr, 311 /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy)); 312 // Create the gather intrinsic. 313 return rewriter.create<LLVM::masked_gather>( 314 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 315 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 316 }; 317 SmallVector<Value> vectorOperands = { 318 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 319 return LLVM::detail::handleMultidimensionalVectors( 320 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 321 } 322 }; 323 324 /// Conversion pattern for a vector.scatter. 325 class VectorScatterOpConversion 326 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 327 public: 328 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 329 330 LogicalResult 331 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 332 ConversionPatternRewriter &rewriter) const override { 333 auto loc = scatter->getLoc(); 334 MemRefType memRefType = scatter.getMemRefType(); 335 336 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 337 return failure(); 338 339 // Resolve alignment. 340 unsigned align; 341 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 342 return failure(); 343 344 // Resolve address. 345 VectorType vType = scatter.getVectorType(); 346 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 347 adaptor.getIndices(), rewriter); 348 Value ptrs = 349 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, 350 adaptor.getBase(), ptr, adaptor.getIndexVec(), vType); 351 352 // Replace with the scatter intrinsic. 353 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 354 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 355 rewriter.getI32IntegerAttr(align)); 356 return success(); 357 } 358 }; 359 360 /// Conversion pattern for a vector.expandload. 361 class VectorExpandLoadOpConversion 362 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 363 public: 364 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 365 366 LogicalResult 367 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 368 ConversionPatternRewriter &rewriter) const override { 369 auto loc = expand->getLoc(); 370 MemRefType memRefType = expand.getMemRefType(); 371 372 // Resolve address. 373 auto vtype = typeConverter->convertType(expand.getVectorType()); 374 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 375 adaptor.getIndices(), rewriter); 376 377 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 378 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 379 return success(); 380 } 381 }; 382 383 /// Conversion pattern for a vector.compressstore. 384 class VectorCompressStoreOpConversion 385 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 386 public: 387 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 388 389 LogicalResult 390 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 391 ConversionPatternRewriter &rewriter) const override { 392 auto loc = compress->getLoc(); 393 MemRefType memRefType = compress.getMemRefType(); 394 395 // Resolve address. 396 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 397 adaptor.getIndices(), rewriter); 398 399 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 400 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 401 return success(); 402 } 403 }; 404 405 /// Reduction neutral classes for overloading. 406 class ReductionNeutralZero {}; 407 class ReductionNeutralIntOne {}; 408 class ReductionNeutralFPOne {}; 409 class ReductionNeutralAllOnes {}; 410 class ReductionNeutralSIntMin {}; 411 class ReductionNeutralUIntMin {}; 412 class ReductionNeutralSIntMax {}; 413 class ReductionNeutralUIntMax {}; 414 class ReductionNeutralFPMin {}; 415 class ReductionNeutralFPMax {}; 416 417 /// Create the reduction neutral zero value. 418 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 419 ConversionPatternRewriter &rewriter, 420 Location loc, Type llvmType) { 421 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 422 rewriter.getZeroAttr(llvmType)); 423 } 424 425 /// Create the reduction neutral integer one value. 426 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 427 ConversionPatternRewriter &rewriter, 428 Location loc, Type llvmType) { 429 return rewriter.create<LLVM::ConstantOp>( 430 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 431 } 432 433 /// Create the reduction neutral fp one value. 434 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 435 ConversionPatternRewriter &rewriter, 436 Location loc, Type llvmType) { 437 return rewriter.create<LLVM::ConstantOp>( 438 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 439 } 440 441 /// Create the reduction neutral all-ones value. 442 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 443 ConversionPatternRewriter &rewriter, 444 Location loc, Type llvmType) { 445 return rewriter.create<LLVM::ConstantOp>( 446 loc, llvmType, 447 rewriter.getIntegerAttr( 448 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 449 } 450 451 /// Create the reduction neutral signed int minimum value. 452 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 453 ConversionPatternRewriter &rewriter, 454 Location loc, Type llvmType) { 455 return rewriter.create<LLVM::ConstantOp>( 456 loc, llvmType, 457 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 458 llvmType.getIntOrFloatBitWidth()))); 459 } 460 461 /// Create the reduction neutral unsigned int minimum value. 462 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 463 ConversionPatternRewriter &rewriter, 464 Location loc, Type llvmType) { 465 return rewriter.create<LLVM::ConstantOp>( 466 loc, llvmType, 467 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 468 llvmType.getIntOrFloatBitWidth()))); 469 } 470 471 /// Create the reduction neutral signed int maximum value. 472 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 473 ConversionPatternRewriter &rewriter, 474 Location loc, Type llvmType) { 475 return rewriter.create<LLVM::ConstantOp>( 476 loc, llvmType, 477 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 478 llvmType.getIntOrFloatBitWidth()))); 479 } 480 481 /// Create the reduction neutral unsigned int maximum value. 482 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 483 ConversionPatternRewriter &rewriter, 484 Location loc, Type llvmType) { 485 return rewriter.create<LLVM::ConstantOp>( 486 loc, llvmType, 487 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 488 llvmType.getIntOrFloatBitWidth()))); 489 } 490 491 /// Create the reduction neutral fp minimum value. 492 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 493 ConversionPatternRewriter &rewriter, 494 Location loc, Type llvmType) { 495 auto floatType = cast<FloatType>(llvmType); 496 return rewriter.create<LLVM::ConstantOp>( 497 loc, llvmType, 498 rewriter.getFloatAttr( 499 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 500 /*Negative=*/false))); 501 } 502 503 /// Create the reduction neutral fp maximum value. 504 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 505 ConversionPatternRewriter &rewriter, 506 Location loc, Type llvmType) { 507 auto floatType = cast<FloatType>(llvmType); 508 return rewriter.create<LLVM::ConstantOp>( 509 loc, llvmType, 510 rewriter.getFloatAttr( 511 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 512 /*Negative=*/true))); 513 } 514 515 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 516 /// returns a new accumulator value using `ReductionNeutral`. 517 template <class ReductionNeutral> 518 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 519 Location loc, Type llvmType, 520 Value accumulator) { 521 if (accumulator) 522 return accumulator; 523 524 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 525 llvmType); 526 } 527 528 /// Creates a value with the 1-D vector shape provided in `llvmType`. 529 /// This is used as effective vector length by some intrinsics supporting 530 /// dynamic vector lengths at runtime. 531 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 532 Location loc, Type llvmType) { 533 VectorType vType = cast<VectorType>(llvmType); 534 auto vShape = vType.getShape(); 535 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 536 537 Value baseVecLength = rewriter.create<LLVM::ConstantOp>( 538 loc, rewriter.getI32Type(), 539 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 540 541 if (!vType.getScalableDims()[0]) 542 return baseVecLength; 543 544 // For a scalable vector type, create and return `vScale * baseVecLength`. 545 Value vScale = rewriter.create<vector::VectorScaleOp>(loc); 546 vScale = 547 rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale); 548 Value scalableVecLength = 549 rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale); 550 return scalableVecLength; 551 } 552 553 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 554 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 555 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 556 /// non-null. 557 template <class LLVMRedIntrinOp, class ScalarOp> 558 static Value createIntegerReductionArithmeticOpLowering( 559 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 560 Value vectorOperand, Value accumulator) { 561 562 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 563 564 if (accumulator) 565 result = rewriter.create<ScalarOp>(loc, accumulator, result); 566 return result; 567 } 568 569 /// Helper method to lower a `vector.reduction` operation that performs 570 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 571 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 572 /// the accumulator value if non-null. 573 template <class LLVMRedIntrinOp> 574 static Value createIntegerReductionComparisonOpLowering( 575 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 576 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 577 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 578 if (accumulator) { 579 Value cmp = 580 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 581 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 582 } 583 return result; 584 } 585 586 namespace { 587 template <typename Source> 588 struct VectorToScalarMapper; 589 template <> 590 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> { 591 using Type = LLVM::MaximumOp; 592 }; 593 template <> 594 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> { 595 using Type = LLVM::MinimumOp; 596 }; 597 template <> 598 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> { 599 using Type = LLVM::MaxNumOp; 600 }; 601 template <> 602 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> { 603 using Type = LLVM::MinNumOp; 604 }; 605 } // namespace 606 607 template <class LLVMRedIntrinOp> 608 static Value createFPReductionComparisonOpLowering( 609 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 610 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { 611 Value result = 612 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf); 613 614 if (accumulator) { 615 result = 616 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>( 617 loc, result, accumulator); 618 } 619 620 return result; 621 } 622 623 /// Reduction neutral classes for overloading 624 class MaskNeutralFMaximum {}; 625 class MaskNeutralFMinimum {}; 626 627 /// Get the mask neutral floating point maximum value 628 static llvm::APFloat 629 getMaskNeutralValue(MaskNeutralFMaximum, 630 const llvm::fltSemantics &floatSemantics) { 631 return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true); 632 } 633 /// Get the mask neutral floating point minimum value 634 static llvm::APFloat 635 getMaskNeutralValue(MaskNeutralFMinimum, 636 const llvm::fltSemantics &floatSemantics) { 637 return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false); 638 } 639 640 /// Create the mask neutral floating point MLIR vector constant 641 template <typename MaskNeutral> 642 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, 643 Location loc, Type llvmType, 644 Type vectorType) { 645 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics(); 646 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); 647 auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value); 648 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue); 649 } 650 651 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked 652 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for 653 /// `fmaximum`/`fminimum`. 654 /// More information: https://github.com/llvm/llvm-project/issues/64940 655 template <class LLVMRedIntrinOp, class MaskNeutral> 656 static Value 657 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, 658 Location loc, Type llvmType, 659 Value vectorOperand, Value accumulator, 660 Value mask, LLVM::FastmathFlagsAttr fmf) { 661 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>( 662 rewriter, loc, llvmType, vectorOperand.getType()); 663 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>( 664 loc, mask, vectorOperand, vectorMaskNeutral); 665 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>( 666 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); 667 } 668 669 template <class LLVMRedIntrinOp, class ReductionNeutral> 670 static Value 671 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 672 Type llvmType, Value vectorOperand, 673 Value accumulator, LLVM::FastmathFlagsAttr fmf) { 674 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 675 llvmType, accumulator); 676 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType, 677 /*startValue=*/accumulator, 678 vectorOperand, fmf); 679 } 680 681 /// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic 682 /// that requires a start value. This start value format spans across fp 683 /// reductions without mask and all the masked reduction intrinsics. 684 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 685 static Value 686 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, 687 Location loc, Type llvmType, 688 Value vectorOperand, Value accumulator) { 689 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 690 llvmType, accumulator); 691 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 692 /*startValue=*/accumulator, 693 vectorOperand); 694 } 695 696 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 697 static Value lowerPredicatedReductionWithStartValue( 698 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 699 Value vectorOperand, Value accumulator, Value mask) { 700 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 701 llvmType, accumulator); 702 Value vectorLength = 703 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 704 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 705 /*startValue=*/accumulator, 706 vectorOperand, mask, vectorLength); 707 } 708 709 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 710 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 711 static Value lowerPredicatedReductionWithStartValue( 712 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 713 Value vectorOperand, Value accumulator, Value mask) { 714 if (llvmType.isIntOrIndex()) 715 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp, 716 IntReductionNeutral>( 717 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 718 719 // FP dispatch. 720 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp, 721 FPReductionNeutral>( 722 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 723 } 724 725 /// Conversion pattern for all vector reductions. 726 class VectorReductionOpConversion 727 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 728 public: 729 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv, 730 bool reassociateFPRed) 731 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 732 reassociateFPReductions(reassociateFPRed) {} 733 734 LogicalResult 735 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 736 ConversionPatternRewriter &rewriter) const override { 737 auto kind = reductionOp.getKind(); 738 Type eltType = reductionOp.getDest().getType(); 739 Type llvmType = typeConverter->convertType(eltType); 740 Value operand = adaptor.getVector(); 741 Value acc = adaptor.getAcc(); 742 Location loc = reductionOp.getLoc(); 743 744 if (eltType.isIntOrIndex()) { 745 // Integer reductions: add/mul/min/max/and/or/xor. 746 Value result; 747 switch (kind) { 748 case vector::CombiningKind::ADD: 749 result = 750 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 751 LLVM::AddOp>( 752 rewriter, loc, llvmType, operand, acc); 753 break; 754 case vector::CombiningKind::MUL: 755 result = 756 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 757 LLVM::MulOp>( 758 rewriter, loc, llvmType, operand, acc); 759 break; 760 case vector::CombiningKind::MINUI: 761 result = createIntegerReductionComparisonOpLowering< 762 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 763 LLVM::ICmpPredicate::ule); 764 break; 765 case vector::CombiningKind::MINSI: 766 result = createIntegerReductionComparisonOpLowering< 767 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 768 LLVM::ICmpPredicate::sle); 769 break; 770 case vector::CombiningKind::MAXUI: 771 result = createIntegerReductionComparisonOpLowering< 772 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 773 LLVM::ICmpPredicate::uge); 774 break; 775 case vector::CombiningKind::MAXSI: 776 result = createIntegerReductionComparisonOpLowering< 777 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 778 LLVM::ICmpPredicate::sge); 779 break; 780 case vector::CombiningKind::AND: 781 result = 782 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 783 LLVM::AndOp>( 784 rewriter, loc, llvmType, operand, acc); 785 break; 786 case vector::CombiningKind::OR: 787 result = 788 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 789 LLVM::OrOp>( 790 rewriter, loc, llvmType, operand, acc); 791 break; 792 case vector::CombiningKind::XOR: 793 result = 794 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 795 LLVM::XOrOp>( 796 rewriter, loc, llvmType, operand, acc); 797 break; 798 default: 799 return failure(); 800 } 801 rewriter.replaceOp(reductionOp, result); 802 803 return success(); 804 } 805 806 if (!isa<FloatType>(eltType)) 807 return failure(); 808 809 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 810 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 811 reductionOp.getContext(), 812 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 813 fmf = LLVM::FastmathFlagsAttr::get( 814 reductionOp.getContext(), 815 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc 816 : LLVM::FastmathFlags::none)); 817 818 // Floating-point reductions: add/mul/min/max 819 Value result; 820 if (kind == vector::CombiningKind::ADD) { 821 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 822 ReductionNeutralZero>( 823 rewriter, loc, llvmType, operand, acc, fmf); 824 } else if (kind == vector::CombiningKind::MUL) { 825 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 826 ReductionNeutralFPOne>( 827 rewriter, loc, llvmType, operand, acc, fmf); 828 } else if (kind == vector::CombiningKind::MINIMUMF) { 829 result = 830 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>( 831 rewriter, loc, llvmType, operand, acc, fmf); 832 } else if (kind == vector::CombiningKind::MAXIMUMF) { 833 result = 834 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>( 835 rewriter, loc, llvmType, operand, acc, fmf); 836 } else if (kind == vector::CombiningKind::MINNUMF) { 837 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 838 rewriter, loc, llvmType, operand, acc, fmf); 839 } else if (kind == vector::CombiningKind::MAXNUMF) { 840 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 841 rewriter, loc, llvmType, operand, acc, fmf); 842 } else 843 return failure(); 844 845 rewriter.replaceOp(reductionOp, result); 846 return success(); 847 } 848 849 private: 850 const bool reassociateFPReductions; 851 }; 852 853 /// Base class to convert a `vector.mask` operation while matching traits 854 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 855 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 856 /// method performs a second match against the maskable operation `MaskedOp`. 857 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 858 /// implemented by the concrete conversion classes. This method can match 859 /// against specific traits of the `vector.mask` and the maskable operation. It 860 /// must replace the `vector.mask` operation. 861 template <class MaskedOp> 862 class VectorMaskOpConversionBase 863 : public ConvertOpToLLVMPattern<vector::MaskOp> { 864 public: 865 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 866 867 LogicalResult 868 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 869 ConversionPatternRewriter &rewriter) const final { 870 // Match against the maskable operation kind. 871 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 872 if (!maskedOp) 873 return failure(); 874 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 875 } 876 877 protected: 878 virtual LogicalResult 879 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 880 vector::MaskableOpInterface maskableOp, 881 ConversionPatternRewriter &rewriter) const = 0; 882 }; 883 884 class MaskedReductionOpConversion 885 : public VectorMaskOpConversionBase<vector::ReductionOp> { 886 887 public: 888 using VectorMaskOpConversionBase< 889 vector::ReductionOp>::VectorMaskOpConversionBase; 890 891 LogicalResult matchAndRewriteMaskableOp( 892 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 893 ConversionPatternRewriter &rewriter) const override { 894 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 895 auto kind = reductionOp.getKind(); 896 Type eltType = reductionOp.getDest().getType(); 897 Type llvmType = typeConverter->convertType(eltType); 898 Value operand = reductionOp.getVector(); 899 Value acc = reductionOp.getAcc(); 900 Location loc = reductionOp.getLoc(); 901 902 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 903 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 904 reductionOp.getContext(), 905 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 906 907 Value result; 908 switch (kind) { 909 case vector::CombiningKind::ADD: 910 result = lowerPredicatedReductionWithStartValue< 911 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 912 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 913 maskOp.getMask()); 914 break; 915 case vector::CombiningKind::MUL: 916 result = lowerPredicatedReductionWithStartValue< 917 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 918 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 919 maskOp.getMask()); 920 break; 921 case vector::CombiningKind::MINUI: 922 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp, 923 ReductionNeutralUIntMax>( 924 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 925 break; 926 case vector::CombiningKind::MINSI: 927 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp, 928 ReductionNeutralSIntMax>( 929 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 930 break; 931 case vector::CombiningKind::MAXUI: 932 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp, 933 ReductionNeutralUIntMin>( 934 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 935 break; 936 case vector::CombiningKind::MAXSI: 937 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp, 938 ReductionNeutralSIntMin>( 939 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 940 break; 941 case vector::CombiningKind::AND: 942 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp, 943 ReductionNeutralAllOnes>( 944 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 945 break; 946 case vector::CombiningKind::OR: 947 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp, 948 ReductionNeutralZero>( 949 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 950 break; 951 case vector::CombiningKind::XOR: 952 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp, 953 ReductionNeutralZero>( 954 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 955 break; 956 case vector::CombiningKind::MINNUMF: 957 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp, 958 ReductionNeutralFPMax>( 959 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 960 break; 961 case vector::CombiningKind::MAXNUMF: 962 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp, 963 ReductionNeutralFPMin>( 964 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 965 break; 966 case CombiningKind::MAXIMUMF: 967 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum, 968 MaskNeutralFMaximum>( 969 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 970 break; 971 case CombiningKind::MINIMUMF: 972 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum, 973 MaskNeutralFMinimum>( 974 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 975 break; 976 } 977 978 // Replace `vector.mask` operation altogether. 979 rewriter.replaceOp(maskOp, result); 980 return success(); 981 } 982 }; 983 984 class VectorShuffleOpConversion 985 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 986 public: 987 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 988 989 LogicalResult 990 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 991 ConversionPatternRewriter &rewriter) const override { 992 auto loc = shuffleOp->getLoc(); 993 auto v1Type = shuffleOp.getV1VectorType(); 994 auto v2Type = shuffleOp.getV2VectorType(); 995 auto vectorType = shuffleOp.getResultVectorType(); 996 Type llvmType = typeConverter->convertType(vectorType); 997 ArrayRef<int64_t> mask = shuffleOp.getMask(); 998 999 // Bail if result type cannot be lowered. 1000 if (!llvmType) 1001 return failure(); 1002 1003 // Get rank and dimension sizes. 1004 int64_t rank = vectorType.getRank(); 1005 #ifndef NDEBUG 1006 bool wellFormed0DCase = 1007 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 1008 bool wellFormedNDCase = 1009 v1Type.getRank() == rank && v2Type.getRank() == rank; 1010 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 1011 #endif 1012 1013 // For rank 0 and 1, where both operands have *exactly* the same vector 1014 // type, there is direct shuffle support in LLVM. Use it! 1015 if (rank <= 1 && v1Type == v2Type) { 1016 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 1017 loc, adaptor.getV1(), adaptor.getV2(), 1018 llvm::to_vector_of<int32_t>(mask)); 1019 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 1020 return success(); 1021 } 1022 1023 // For all other cases, insert the individual values individually. 1024 int64_t v1Dim = v1Type.getDimSize(0); 1025 Type eltType; 1026 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType)) 1027 eltType = arrayType.getElementType(); 1028 else 1029 eltType = cast<VectorType>(llvmType).getElementType(); 1030 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 1031 int64_t insPos = 0; 1032 for (int64_t extPos : mask) { 1033 Value value = adaptor.getV1(); 1034 if (extPos >= v1Dim) { 1035 extPos -= v1Dim; 1036 value = adaptor.getV2(); 1037 } 1038 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 1039 eltType, rank, extPos); 1040 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 1041 llvmType, rank, insPos++); 1042 } 1043 rewriter.replaceOp(shuffleOp, insert); 1044 return success(); 1045 } 1046 }; 1047 1048 class VectorExtractElementOpConversion 1049 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 1050 public: 1051 using ConvertOpToLLVMPattern< 1052 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 1053 1054 LogicalResult 1055 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 1056 ConversionPatternRewriter &rewriter) const override { 1057 auto vectorType = extractEltOp.getSourceVectorType(); 1058 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 1059 1060 // Bail if result type cannot be lowered. 1061 if (!llvmType) 1062 return failure(); 1063 1064 if (vectorType.getRank() == 0) { 1065 Location loc = extractEltOp.getLoc(); 1066 auto idxType = rewriter.getIndexType(); 1067 auto zero = rewriter.create<LLVM::ConstantOp>( 1068 loc, typeConverter->convertType(idxType), 1069 rewriter.getIntegerAttr(idxType, 0)); 1070 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1071 extractEltOp, llvmType, adaptor.getVector(), zero); 1072 return success(); 1073 } 1074 1075 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1076 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1077 return success(); 1078 } 1079 }; 1080 1081 class VectorExtractOpConversion 1082 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1083 public: 1084 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1085 1086 LogicalResult 1087 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1088 ConversionPatternRewriter &rewriter) const override { 1089 auto loc = extractOp->getLoc(); 1090 auto resultType = extractOp.getResult().getType(); 1091 auto llvmResultType = typeConverter->convertType(resultType); 1092 // Bail if result type cannot be lowered. 1093 if (!llvmResultType) 1094 return failure(); 1095 1096 SmallVector<OpFoldResult> positionVec = getMixedValues( 1097 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter); 1098 1099 // The Vector -> LLVM lowering models N-D vectors as nested aggregates of 1100 // 1-d vectors. This nesting is modeled using arrays. We do this conversion 1101 // from a N-d vector extract to a nested aggregate vector extract in two 1102 // steps: 1103 // - Extract a member from the nested aggregate. The result can be 1104 // a lower rank nested aggregate or a vector (1-D). This is done using 1105 // `llvm.extractvalue`. 1106 // - Extract a scalar out of the vector if needed. This is done using 1107 // `llvm.extractelement`. 1108 1109 // Determine if we need to extract a member out of the aggregate. We 1110 // always need to extract a member if the input rank >= 2. 1111 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2; 1112 // Determine if we need to extract a scalar as the result. We extract 1113 // a scalar if the extract is full rank, i.e., the number of indices is 1114 // equal to source vector rank. 1115 bool extractsScalar = static_cast<int64_t>(positionVec.size()) == 1116 extractOp.getSourceVectorType().getRank(); 1117 1118 // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we 1119 // need to add a position for this change. 1120 if (extractOp.getSourceVectorType().getRank() == 0) { 1121 Type idxType = typeConverter->convertType(rewriter.getIndexType()); 1122 positionVec.push_back(rewriter.getZeroAttr(idxType)); 1123 } 1124 1125 Value extracted = adaptor.getVector(); 1126 if (extractsAggregate) { 1127 ArrayRef<OpFoldResult> position(positionVec); 1128 if (extractsScalar) { 1129 // If we are extracting a scalar from the extracted member, we drop 1130 // the last index, which will be used to extract the scalar out of the 1131 // vector. 1132 position = position.drop_back(); 1133 } 1134 // llvm.extractvalue does not support dynamic dimensions. 1135 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) { 1136 return failure(); 1137 } 1138 extracted = rewriter.create<LLVM::ExtractValueOp>( 1139 loc, extracted, getAsIntegers(position)); 1140 } 1141 1142 if (extractsScalar) { 1143 extracted = rewriter.create<LLVM::ExtractElementOp>( 1144 loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back())); 1145 } 1146 1147 rewriter.replaceOp(extractOp, extracted); 1148 return success(); 1149 } 1150 }; 1151 1152 /// Conversion pattern that turns a vector.fma on a 1-D vector 1153 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1154 /// This does not match vectors of n >= 2 rank. 1155 /// 1156 /// Example: 1157 /// ``` 1158 /// vector.fma %a, %a, %a : vector<8xf32> 1159 /// ``` 1160 /// is converted to: 1161 /// ``` 1162 /// llvm.intr.fmuladd %va, %va, %va: 1163 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1164 /// -> !llvm."<8 x f32>"> 1165 /// ``` 1166 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1167 public: 1168 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1169 1170 LogicalResult 1171 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1172 ConversionPatternRewriter &rewriter) const override { 1173 VectorType vType = fmaOp.getVectorType(); 1174 if (vType.getRank() > 1) 1175 return failure(); 1176 1177 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1178 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1179 return success(); 1180 } 1181 }; 1182 1183 class VectorInsertElementOpConversion 1184 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1185 public: 1186 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1187 1188 LogicalResult 1189 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1190 ConversionPatternRewriter &rewriter) const override { 1191 auto vectorType = insertEltOp.getDestVectorType(); 1192 auto llvmType = typeConverter->convertType(vectorType); 1193 1194 // Bail if result type cannot be lowered. 1195 if (!llvmType) 1196 return failure(); 1197 1198 if (vectorType.getRank() == 0) { 1199 Location loc = insertEltOp.getLoc(); 1200 auto idxType = rewriter.getIndexType(); 1201 auto zero = rewriter.create<LLVM::ConstantOp>( 1202 loc, typeConverter->convertType(idxType), 1203 rewriter.getIntegerAttr(idxType, 0)); 1204 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1205 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1206 return success(); 1207 } 1208 1209 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1210 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1211 adaptor.getPosition()); 1212 return success(); 1213 } 1214 }; 1215 1216 class VectorInsertOpConversion 1217 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1218 public: 1219 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1220 1221 LogicalResult 1222 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1223 ConversionPatternRewriter &rewriter) const override { 1224 auto loc = insertOp->getLoc(); 1225 auto sourceType = insertOp.getSourceType(); 1226 auto destVectorType = insertOp.getDestVectorType(); 1227 auto llvmResultType = typeConverter->convertType(destVectorType); 1228 // Bail if result type cannot be lowered. 1229 if (!llvmResultType) 1230 return failure(); 1231 1232 SmallVector<OpFoldResult> positionVec = getMixedValues( 1233 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter); 1234 1235 // Overwrite entire vector with value. Should be handled by folder, but 1236 // just to be safe. 1237 ArrayRef<OpFoldResult> position(positionVec); 1238 if (position.empty()) { 1239 rewriter.replaceOp(insertOp, adaptor.getSource()); 1240 return success(); 1241 } 1242 1243 // One-shot insertion of a vector into an array (only requires insertvalue). 1244 if (isa<VectorType>(sourceType)) { 1245 if (insertOp.hasDynamicPosition()) 1246 return failure(); 1247 1248 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1249 loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position)); 1250 rewriter.replaceOp(insertOp, inserted); 1251 return success(); 1252 } 1253 1254 // Potential extraction of 1-D vector from array. 1255 Value extracted = adaptor.getDest(); 1256 auto oneDVectorType = destVectorType; 1257 if (position.size() > 1) { 1258 if (insertOp.hasDynamicPosition()) 1259 return failure(); 1260 1261 oneDVectorType = reducedVectorTypeBack(destVectorType); 1262 extracted = rewriter.create<LLVM::ExtractValueOp>( 1263 loc, extracted, getAsIntegers(position.drop_back())); 1264 } 1265 1266 // Insertion of an element into a 1-D LLVM vector. 1267 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1268 loc, typeConverter->convertType(oneDVectorType), extracted, 1269 adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back())); 1270 1271 // Potential insertion of resulting 1-D vector into array. 1272 if (position.size() > 1) { 1273 if (insertOp.hasDynamicPosition()) 1274 return failure(); 1275 1276 inserted = rewriter.create<LLVM::InsertValueOp>( 1277 loc, adaptor.getDest(), inserted, 1278 getAsIntegers(position.drop_back())); 1279 } 1280 1281 rewriter.replaceOp(insertOp, inserted); 1282 return success(); 1283 } 1284 }; 1285 1286 /// Lower vector.scalable.insert ops to LLVM vector.insert 1287 struct VectorScalableInsertOpLowering 1288 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1289 using ConvertOpToLLVMPattern< 1290 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1291 1292 LogicalResult 1293 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1294 ConversionPatternRewriter &rewriter) const override { 1295 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1296 insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos()); 1297 return success(); 1298 } 1299 }; 1300 1301 /// Lower vector.scalable.extract ops to LLVM vector.extract 1302 struct VectorScalableExtractOpLowering 1303 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1304 using ConvertOpToLLVMPattern< 1305 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1306 1307 LogicalResult 1308 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1309 ConversionPatternRewriter &rewriter) const override { 1310 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1311 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1312 adaptor.getSource(), adaptor.getPos()); 1313 return success(); 1314 } 1315 }; 1316 1317 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1318 /// 1319 /// Example: 1320 /// ``` 1321 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1322 /// ``` 1323 /// is rewritten into: 1324 /// ``` 1325 /// %r = splat %f0: vector<2x4xf32> 1326 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1327 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1328 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1329 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1330 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1331 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1332 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1333 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1334 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1335 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1336 /// // %r3 holds the final value. 1337 /// ``` 1338 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1339 public: 1340 using OpRewritePattern<FMAOp>::OpRewritePattern; 1341 1342 void initialize() { 1343 // This pattern recursively unpacks one dimension at a time. The recursion 1344 // bounded as the rank is strictly decreasing. 1345 setHasBoundedRewriteRecursion(); 1346 } 1347 1348 LogicalResult matchAndRewrite(FMAOp op, 1349 PatternRewriter &rewriter) const override { 1350 auto vType = op.getVectorType(); 1351 if (vType.getRank() < 2) 1352 return failure(); 1353 1354 auto loc = op.getLoc(); 1355 auto elemType = vType.getElementType(); 1356 Value zero = rewriter.create<arith::ConstantOp>( 1357 loc, elemType, rewriter.getZeroAttr(elemType)); 1358 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1359 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1360 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1361 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1362 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1363 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1364 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1365 } 1366 rewriter.replaceOp(op, desc); 1367 return success(); 1368 } 1369 }; 1370 1371 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1372 /// static layout. 1373 static std::optional<SmallVector<int64_t, 4>> 1374 computeContiguousStrides(MemRefType memRefType) { 1375 int64_t offset; 1376 SmallVector<int64_t, 4> strides; 1377 if (failed(memRefType.getStridesAndOffset(strides, offset))) 1378 return std::nullopt; 1379 if (!strides.empty() && strides.back() != 1) 1380 return std::nullopt; 1381 // If no layout or identity layout, this is contiguous by definition. 1382 if (memRefType.getLayout().isIdentity()) 1383 return strides; 1384 1385 // Otherwise, we must determine contiguity form shapes. This can only ever 1386 // work in static cases because MemRefType is underspecified to represent 1387 // contiguous dynamic shapes in other ways than with just empty/identity 1388 // layout. 1389 auto sizes = memRefType.getShape(); 1390 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1391 if (ShapedType::isDynamic(sizes[index + 1]) || 1392 ShapedType::isDynamic(strides[index]) || 1393 ShapedType::isDynamic(strides[index + 1])) 1394 return std::nullopt; 1395 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1396 return std::nullopt; 1397 } 1398 return strides; 1399 } 1400 1401 class VectorTypeCastOpConversion 1402 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1403 public: 1404 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1405 1406 LogicalResult 1407 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1408 ConversionPatternRewriter &rewriter) const override { 1409 auto loc = castOp->getLoc(); 1410 MemRefType sourceMemRefType = 1411 cast<MemRefType>(castOp.getOperand().getType()); 1412 MemRefType targetMemRefType = castOp.getType(); 1413 1414 // Only static shape casts supported atm. 1415 if (!sourceMemRefType.hasStaticShape() || 1416 !targetMemRefType.hasStaticShape()) 1417 return failure(); 1418 1419 auto llvmSourceDescriptorTy = 1420 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); 1421 if (!llvmSourceDescriptorTy) 1422 return failure(); 1423 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1424 1425 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1426 typeConverter->convertType(targetMemRefType)); 1427 if (!llvmTargetDescriptorTy) 1428 return failure(); 1429 1430 // Only contiguous source buffers supported atm. 1431 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1432 if (!sourceStrides) 1433 return failure(); 1434 auto targetStrides = computeContiguousStrides(targetMemRefType); 1435 if (!targetStrides) 1436 return failure(); 1437 // Only support static strides for now, regardless of contiguity. 1438 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1439 return failure(); 1440 1441 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1442 1443 // Create descriptor. 1444 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1445 // Set allocated ptr. 1446 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1447 desc.setAllocatedPtr(rewriter, loc, allocated); 1448 1449 // Set aligned ptr. 1450 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1451 desc.setAlignedPtr(rewriter, loc, ptr); 1452 // Fill offset 0. 1453 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1454 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1455 desc.setOffset(rewriter, loc, zero); 1456 1457 // Fill size and stride descriptors in memref. 1458 for (const auto &indexedSize : 1459 llvm::enumerate(targetMemRefType.getShape())) { 1460 int64_t index = indexedSize.index(); 1461 auto sizeAttr = 1462 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1463 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1464 desc.setSize(rewriter, loc, index, size); 1465 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1466 (*targetStrides)[index]); 1467 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1468 desc.setStride(rewriter, loc, index, stride); 1469 } 1470 1471 rewriter.replaceOp(castOp, {desc}); 1472 return success(); 1473 } 1474 }; 1475 1476 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1477 /// Non-scalable versions of this operation are handled in Vector Transforms. 1478 class VectorCreateMaskOpConversion 1479 : public OpConversionPattern<vector::CreateMaskOp> { 1480 public: 1481 explicit VectorCreateMaskOpConversion(MLIRContext *context, 1482 bool enableIndexOpt) 1483 : OpConversionPattern<vector::CreateMaskOp>(context), 1484 force32BitVectorIndices(enableIndexOpt) {} 1485 1486 LogicalResult 1487 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor, 1488 ConversionPatternRewriter &rewriter) const override { 1489 auto dstType = op.getType(); 1490 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) 1491 return failure(); 1492 IntegerType idxType = 1493 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1494 auto loc = op->getLoc(); 1495 Value indices = rewriter.create<LLVM::StepVectorOp>( 1496 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1497 /*isScalable=*/true)); 1498 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1499 adaptor.getOperands()[0]); 1500 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1501 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1502 indices, bounds); 1503 rewriter.replaceOp(op, comp); 1504 return success(); 1505 } 1506 1507 private: 1508 const bool force32BitVectorIndices; 1509 }; 1510 1511 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1512 public: 1513 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1514 1515 // Lowering implementation that relies on a small runtime support library, 1516 // which only needs to provide a few printing methods (single value for all 1517 // data types, opening/closing bracket, comma, newline). The lowering splits 1518 // the vector into elementary printing operations. The advantage of this 1519 // approach is that the library can remain unaware of all low-level 1520 // implementation details of vectors while still supporting output of any 1521 // shaped and dimensioned vector. 1522 // 1523 // Note: This lowering only handles scalars, n-D vectors are broken into 1524 // printing scalars in loops in VectorToSCF. 1525 // 1526 // TODO: rely solely on libc in future? something else? 1527 // 1528 LogicalResult 1529 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1530 ConversionPatternRewriter &rewriter) const override { 1531 auto parent = printOp->getParentOfType<ModuleOp>(); 1532 if (!parent) 1533 return failure(); 1534 1535 auto loc = printOp->getLoc(); 1536 1537 if (auto value = adaptor.getSource()) { 1538 Type printType = printOp.getPrintType(); 1539 if (isa<VectorType>(printType)) { 1540 // Vectors should be broken into elementary print ops in VectorToSCF. 1541 return failure(); 1542 } 1543 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value))) 1544 return failure(); 1545 } 1546 1547 auto punct = printOp.getPunctuation(); 1548 if (auto stringLiteral = printOp.getStringLiteral()) { 1549 auto createResult = 1550 LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", 1551 *stringLiteral, *getTypeConverter(), 1552 /*addNewline=*/false); 1553 if (createResult.failed()) 1554 return failure(); 1555 1556 } else if (punct != PrintPunctuation::NoPunctuation) { 1557 FailureOr<LLVM::LLVMFuncOp> op = [&]() { 1558 switch (punct) { 1559 case PrintPunctuation::Close: 1560 return LLVM::lookupOrCreatePrintCloseFn(parent); 1561 case PrintPunctuation::Open: 1562 return LLVM::lookupOrCreatePrintOpenFn(parent); 1563 case PrintPunctuation::Comma: 1564 return LLVM::lookupOrCreatePrintCommaFn(parent); 1565 case PrintPunctuation::NewLine: 1566 return LLVM::lookupOrCreatePrintNewlineFn(parent); 1567 default: 1568 llvm_unreachable("unexpected punctuation"); 1569 } 1570 }(); 1571 if (failed(op)) 1572 return failure(); 1573 emitCall(rewriter, printOp->getLoc(), op.value()); 1574 } 1575 1576 rewriter.eraseOp(printOp); 1577 return success(); 1578 } 1579 1580 private: 1581 enum class PrintConversion { 1582 // clang-format off 1583 None, 1584 ZeroExt64, 1585 SignExt64, 1586 Bitcast16 1587 // clang-format on 1588 }; 1589 1590 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter, 1591 ModuleOp parent, Location loc, Type printType, 1592 Value value) const { 1593 if (typeConverter->convertType(printType) == nullptr) 1594 return failure(); 1595 1596 // Make sure element type has runtime support. 1597 PrintConversion conversion = PrintConversion::None; 1598 FailureOr<Operation *> printer; 1599 if (printType.isF32()) { 1600 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1601 } else if (printType.isF64()) { 1602 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1603 } else if (printType.isF16()) { 1604 conversion = PrintConversion::Bitcast16; // bits! 1605 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1606 } else if (printType.isBF16()) { 1607 conversion = PrintConversion::Bitcast16; // bits! 1608 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1609 } else if (printType.isIndex()) { 1610 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1611 } else if (auto intTy = dyn_cast<IntegerType>(printType)) { 1612 // Integers need a zero or sign extension on the operand 1613 // (depending on the source type) as well as a signed or 1614 // unsigned print method. Up to 64-bit is supported. 1615 unsigned width = intTy.getWidth(); 1616 if (intTy.isUnsigned()) { 1617 if (width <= 64) { 1618 if (width < 64) 1619 conversion = PrintConversion::ZeroExt64; 1620 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1621 } else { 1622 return failure(); 1623 } 1624 } else { 1625 assert(intTy.isSignless() || intTy.isSigned()); 1626 if (width <= 64) { 1627 // Note that we *always* zero extend booleans (1-bit integers), 1628 // so that true/false is printed as 1/0 rather than -1/0. 1629 if (width == 1) 1630 conversion = PrintConversion::ZeroExt64; 1631 else if (width < 64) 1632 conversion = PrintConversion::SignExt64; 1633 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1634 } else { 1635 return failure(); 1636 } 1637 } 1638 } else { 1639 return failure(); 1640 } 1641 if (failed(printer)) 1642 return failure(); 1643 1644 switch (conversion) { 1645 case PrintConversion::ZeroExt64: 1646 value = rewriter.create<arith::ExtUIOp>( 1647 loc, IntegerType::get(rewriter.getContext(), 64), value); 1648 break; 1649 case PrintConversion::SignExt64: 1650 value = rewriter.create<arith::ExtSIOp>( 1651 loc, IntegerType::get(rewriter.getContext(), 64), value); 1652 break; 1653 case PrintConversion::Bitcast16: 1654 value = rewriter.create<LLVM::BitcastOp>( 1655 loc, IntegerType::get(rewriter.getContext(), 16), value); 1656 break; 1657 case PrintConversion::None: 1658 break; 1659 } 1660 emitCall(rewriter, loc, printer.value(), value); 1661 return success(); 1662 } 1663 1664 // Helper to emit a call. 1665 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1666 Operation *ref, ValueRange params = ValueRange()) { 1667 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1668 params); 1669 } 1670 }; 1671 1672 /// The Splat operation is lowered to an insertelement + a shufflevector 1673 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1674 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1675 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1676 1677 LogicalResult 1678 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1679 ConversionPatternRewriter &rewriter) const override { 1680 VectorType resultType = cast<VectorType>(splatOp.getType()); 1681 if (resultType.getRank() > 1) 1682 return failure(); 1683 1684 // First insert it into an undef vector so we can shuffle it. 1685 auto vectorType = typeConverter->convertType(splatOp.getType()); 1686 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1687 auto zero = rewriter.create<LLVM::ConstantOp>( 1688 splatOp.getLoc(), 1689 typeConverter->convertType(rewriter.getIntegerType(32)), 1690 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1691 1692 // For 0-d vector, we simply do `insertelement`. 1693 if (resultType.getRank() == 0) { 1694 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1695 splatOp, vectorType, undef, adaptor.getInput(), zero); 1696 return success(); 1697 } 1698 1699 // For 1-d vector, we additionally do a `vectorshuffle`. 1700 auto v = rewriter.create<LLVM::InsertElementOp>( 1701 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1702 1703 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); 1704 SmallVector<int32_t> zeroValues(width, 0); 1705 1706 // Shuffle the value across the desired number of elements. 1707 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1708 zeroValues); 1709 return success(); 1710 } 1711 }; 1712 1713 /// The Splat operation is lowered to an insertelement + a shufflevector 1714 /// operation. Splat to only 2+-d vector result types are lowered by the 1715 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1716 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1717 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1718 1719 LogicalResult 1720 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1721 ConversionPatternRewriter &rewriter) const override { 1722 VectorType resultType = splatOp.getType(); 1723 if (resultType.getRank() <= 1) 1724 return failure(); 1725 1726 // First insert it into an undef vector so we can shuffle it. 1727 auto loc = splatOp.getLoc(); 1728 auto vectorTypeInfo = 1729 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1730 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1731 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1732 if (!llvmNDVectorTy || !llvm1DVectorTy) 1733 return failure(); 1734 1735 // Construct returned value. 1736 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1737 1738 // Construct a 1-D vector with the splatted value that we insert in all the 1739 // places within the returned descriptor. 1740 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1741 auto zero = rewriter.create<LLVM::ConstantOp>( 1742 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1743 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1744 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1745 adaptor.getInput(), zero); 1746 1747 // Shuffle the value across the desired number of elements. 1748 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1749 SmallVector<int32_t> zeroValues(width, 0); 1750 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1751 1752 // Iterate of linear index, convert to coords space and insert splatted 1-D 1753 // vector in each position. 1754 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1755 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1756 }); 1757 rewriter.replaceOp(splatOp, desc); 1758 return success(); 1759 } 1760 }; 1761 1762 /// Conversion pattern for a `vector.interleave`. 1763 /// This supports fixed-sized vectors and scalable vectors. 1764 struct VectorInterleaveOpLowering 1765 : public ConvertOpToLLVMPattern<vector::InterleaveOp> { 1766 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 1767 1768 LogicalResult 1769 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor, 1770 ConversionPatternRewriter &rewriter) const override { 1771 VectorType resultType = interleaveOp.getResultVectorType(); 1772 // n-D interleaves should have been lowered already. 1773 if (resultType.getRank() != 1) 1774 return rewriter.notifyMatchFailure(interleaveOp, 1775 "InterleaveOp not rank 1"); 1776 // If the result is rank 1, then this directly maps to LLVM. 1777 if (resultType.isScalable()) { 1778 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>( 1779 interleaveOp, typeConverter->convertType(resultType), 1780 adaptor.getLhs(), adaptor.getRhs()); 1781 return success(); 1782 } 1783 // Lower fixed-size interleaves to a shufflevector. While the 1784 // vector.interleave2 intrinsic supports fixed and scalable vectors, the 1785 // langref still recommends fixed-vectors use shufflevector, see: 1786 // https://llvm.org/docs/LangRef.html#id876. 1787 int64_t resultVectorSize = resultType.getNumElements(); 1788 SmallVector<int32_t> interleaveShuffleMask; 1789 interleaveShuffleMask.reserve(resultVectorSize); 1790 for (int i = 0, end = resultVectorSize / 2; i < end; ++i) { 1791 interleaveShuffleMask.push_back(i); 1792 interleaveShuffleMask.push_back((resultVectorSize / 2) + i); 1793 } 1794 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>( 1795 interleaveOp, adaptor.getLhs(), adaptor.getRhs(), 1796 interleaveShuffleMask); 1797 return success(); 1798 } 1799 }; 1800 1801 /// Conversion pattern for a `vector.deinterleave`. 1802 /// This supports fixed-sized vectors and scalable vectors. 1803 struct VectorDeinterleaveOpLowering 1804 : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> { 1805 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 1806 1807 LogicalResult 1808 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor, 1809 ConversionPatternRewriter &rewriter) const override { 1810 VectorType resultType = deinterleaveOp.getResultVectorType(); 1811 VectorType sourceType = deinterleaveOp.getSourceVectorType(); 1812 auto loc = deinterleaveOp.getLoc(); 1813 1814 // Note: n-D deinterleave operations should be lowered to the 1-D before 1815 // converting to LLVM. 1816 if (resultType.getRank() != 1) 1817 return rewriter.notifyMatchFailure(deinterleaveOp, 1818 "DeinterleaveOp not rank 1"); 1819 1820 if (resultType.isScalable()) { 1821 auto llvmTypeConverter = this->getTypeConverter(); 1822 auto deinterleaveResults = deinterleaveOp.getResultTypes(); 1823 auto packedOpResults = 1824 llvmTypeConverter->packOperationResults(deinterleaveResults); 1825 auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>( 1826 loc, packedOpResults, adaptor.getSource()); 1827 1828 auto evenResult = rewriter.create<LLVM::ExtractValueOp>( 1829 loc, intrinsic->getResult(0), 0); 1830 auto oddResult = rewriter.create<LLVM::ExtractValueOp>( 1831 loc, intrinsic->getResult(0), 1); 1832 1833 rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult}); 1834 return success(); 1835 } 1836 // Lower fixed-size deinterleave to two shufflevectors. While the 1837 // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the 1838 // langref still recommends fixed-vectors use shufflevector, see: 1839 // https://llvm.org/docs/LangRef.html#id889. 1840 int64_t resultVectorSize = resultType.getNumElements(); 1841 SmallVector<int32_t> evenShuffleMask; 1842 SmallVector<int32_t> oddShuffleMask; 1843 1844 evenShuffleMask.reserve(resultVectorSize); 1845 oddShuffleMask.reserve(resultVectorSize); 1846 1847 for (int i = 0; i < sourceType.getNumElements(); ++i) { 1848 if (i % 2 == 0) 1849 evenShuffleMask.push_back(i); 1850 else 1851 oddShuffleMask.push_back(i); 1852 } 1853 1854 auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType); 1855 auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>( 1856 loc, adaptor.getSource(), poison, evenShuffleMask); 1857 auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>( 1858 loc, adaptor.getSource(), poison, oddShuffleMask); 1859 1860 rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle}); 1861 return success(); 1862 } 1863 }; 1864 1865 /// Conversion pattern for a `vector.from_elements`. 1866 struct VectorFromElementsLowering 1867 : public ConvertOpToLLVMPattern<vector::FromElementsOp> { 1868 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 1869 1870 LogicalResult 1871 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor, 1872 ConversionPatternRewriter &rewriter) const override { 1873 Location loc = fromElementsOp.getLoc(); 1874 VectorType vectorType = fromElementsOp.getType(); 1875 // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>. 1876 // Such ops should be handled in the same way as vector.insert. 1877 if (vectorType.getRank() > 1) 1878 return rewriter.notifyMatchFailure(fromElementsOp, 1879 "rank > 1 vectors are not supported"); 1880 Type llvmType = typeConverter->convertType(vectorType); 1881 Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 1882 for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) 1883 result = rewriter.create<vector::InsertOp>(loc, val, result, idx); 1884 rewriter.replaceOp(fromElementsOp, result); 1885 return success(); 1886 } 1887 }; 1888 1889 /// Conversion pattern for vector.step. 1890 struct VectorScalableStepOpLowering 1891 : public ConvertOpToLLVMPattern<vector::StepOp> { 1892 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 1893 1894 LogicalResult 1895 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor, 1896 ConversionPatternRewriter &rewriter) const override { 1897 auto resultType = cast<VectorType>(stepOp.getType()); 1898 if (!resultType.isScalable()) { 1899 return failure(); 1900 } 1901 Type llvmType = typeConverter->convertType(stepOp.getType()); 1902 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType); 1903 return success(); 1904 } 1905 }; 1906 1907 } // namespace 1908 1909 void mlir::vector::populateVectorRankReducingFMAPattern( 1910 RewritePatternSet &patterns) { 1911 patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext()); 1912 } 1913 1914 /// Populate the given list with patterns that convert from Vector to LLVM. 1915 void mlir::populateVectorToLLVMConversionPatterns( 1916 const LLVMTypeConverter &converter, RewritePatternSet &patterns, 1917 bool reassociateFPReductions, bool force32BitVectorIndices) { 1918 // This function populates only ConversionPatterns, not RewritePatterns. 1919 MLIRContext *ctx = converter.getDialect()->getContext(); 1920 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1921 patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices); 1922 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1923 VectorExtractElementOpConversion, VectorExtractOpConversion, 1924 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1925 VectorInsertOpConversion, VectorPrintOpConversion, 1926 VectorTypeCastOpConversion, VectorScaleOpConversion, 1927 VectorLoadStoreConversion<vector::LoadOp>, 1928 VectorLoadStoreConversion<vector::MaskedLoadOp>, 1929 VectorLoadStoreConversion<vector::StoreOp>, 1930 VectorLoadStoreConversion<vector::MaskedStoreOp>, 1931 VectorGatherOpConversion, VectorScatterOpConversion, 1932 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1933 VectorSplatOpLowering, VectorSplatNdOpLowering, 1934 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1935 MaskedReductionOpConversion, VectorInterleaveOpLowering, 1936 VectorDeinterleaveOpLowering, VectorFromElementsLowering, 1937 VectorScalableStepOpLowering>(converter); 1938 } 1939 1940 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1941 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1942 patterns.add<VectorMatmulOpConversion>(converter); 1943 patterns.add<VectorFlatTransposeOpConversion>(converter); 1944 } 1945