1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" 12 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Utils/Utils.h" 17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 22 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 24 #include "mlir/IR/BuiltinAttributes.h" 25 #include "mlir/IR/BuiltinTypeInterfaces.h" 26 #include "mlir/IR/BuiltinTypes.h" 27 #include "mlir/IR/TypeUtilities.h" 28 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 #include "llvm/ADT/APFloat.h" 31 #include "llvm/Support/Casting.h" 32 #include <optional> 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 // Helper to reduce vector type by *all* but one rank at back. 38 static VectorType reducedVectorTypeBack(VectorType tp) { 39 assert((tp.getRank() > 1) && "unlowerable vector type"); 40 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 41 tp.getScalableDims().take_back()); 42 } 43 44 // Helper that picks the proper sequence for inserting. 45 static Value insertOne(ConversionPatternRewriter &rewriter, 46 const LLVMTypeConverter &typeConverter, Location loc, 47 Value val1, Value val2, Type llvmType, int64_t rank, 48 int64_t pos) { 49 assert(rank > 0 && "0-D vector corner case should have been handled already"); 50 if (rank == 1) { 51 auto idxType = rewriter.getIndexType(); 52 auto constant = rewriter.create<LLVM::ConstantOp>( 53 loc, typeConverter.convertType(idxType), 54 rewriter.getIntegerAttr(idxType, pos)); 55 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 56 constant); 57 } 58 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 59 } 60 61 // Helper that picks the proper sequence for extracting. 62 static Value extractOne(ConversionPatternRewriter &rewriter, 63 const LLVMTypeConverter &typeConverter, Location loc, 64 Value val, Type llvmType, int64_t rank, int64_t pos) { 65 if (rank <= 1) { 66 auto idxType = rewriter.getIndexType(); 67 auto constant = rewriter.create<LLVM::ConstantOp>( 68 loc, typeConverter.convertType(idxType), 69 rewriter.getIntegerAttr(idxType, pos)); 70 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 71 constant); 72 } 73 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 74 } 75 76 // Helper that returns data layout alignment of a memref. 77 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, 78 MemRefType memrefType, unsigned &align) { 79 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 80 if (!elementTy) 81 return failure(); 82 83 // TODO: this should use the MLIR data layout when it becomes available and 84 // stop depending on translation. 85 llvm::LLVMContext llvmContext; 86 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 87 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 88 return success(); 89 } 90 91 // Check if the last stride is non-unit and has a valid memory space. 92 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 93 const LLVMTypeConverter &converter) { 94 if (!isLastMemrefDimUnitStride(memRefType)) 95 return failure(); 96 if (failed(converter.getMemRefAddressSpace(memRefType))) 97 return failure(); 98 return success(); 99 } 100 101 // Add an index vector component to a base pointer. 102 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 103 const LLVMTypeConverter &typeConverter, 104 MemRefType memRefType, Value llvmMemref, Value base, 105 Value index, uint64_t vLen) { 106 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 107 "unsupported memref type"); 108 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 109 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 110 return rewriter.create<LLVM::GEPOp>( 111 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 112 base, index); 113 } 114 115 /// Convert `foldResult` into a Value. Integer attribute is converted to 116 /// an LLVM constant op. 117 static Value getAsLLVMValue(OpBuilder &builder, Location loc, 118 OpFoldResult foldResult) { 119 if (auto attr = foldResult.dyn_cast<Attribute>()) { 120 auto intAttr = cast<IntegerAttr>(attr); 121 return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult(); 122 } 123 124 return foldResult.get<Value>(); 125 } 126 127 namespace { 128 129 /// Trivial Vector to LLVM conversions 130 using VectorScaleOpConversion = 131 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 132 133 /// Conversion pattern for a vector.bitcast. 134 class VectorBitCastOpConversion 135 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 136 public: 137 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 138 139 LogicalResult 140 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 141 ConversionPatternRewriter &rewriter) const override { 142 // Only 0-D and 1-D vectors can be lowered to LLVM. 143 VectorType resultTy = bitCastOp.getResultVectorType(); 144 if (resultTy.getRank() > 1) 145 return failure(); 146 Type newResultTy = typeConverter->convertType(resultTy); 147 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 148 adaptor.getOperands()[0]); 149 return success(); 150 } 151 }; 152 153 /// Conversion pattern for a vector.matrix_multiply. 154 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 155 class VectorMatmulOpConversion 156 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 157 public: 158 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 159 160 LogicalResult 161 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 162 ConversionPatternRewriter &rewriter) const override { 163 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 164 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 165 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 166 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 167 return success(); 168 } 169 }; 170 171 /// Conversion pattern for a vector.flat_transpose. 172 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 173 class VectorFlatTransposeOpConversion 174 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 175 public: 176 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 177 178 LogicalResult 179 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 180 ConversionPatternRewriter &rewriter) const override { 181 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 182 transOp, typeConverter->convertType(transOp.getRes().getType()), 183 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 184 return success(); 185 } 186 }; 187 188 /// Overloaded utility that replaces a vector.load, vector.store, 189 /// vector.maskedload and vector.maskedstore with their respective LLVM 190 /// couterparts. 191 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 192 vector::LoadOpAdaptor adaptor, 193 VectorType vectorTy, Value ptr, unsigned align, 194 ConversionPatternRewriter &rewriter) { 195 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align, 196 /*volatile_=*/false, 197 loadOp.getNontemporal()); 198 } 199 200 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 201 vector::MaskedLoadOpAdaptor adaptor, 202 VectorType vectorTy, Value ptr, unsigned align, 203 ConversionPatternRewriter &rewriter) { 204 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 205 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 206 } 207 208 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 209 vector::StoreOpAdaptor adaptor, 210 VectorType vectorTy, Value ptr, unsigned align, 211 ConversionPatternRewriter &rewriter) { 212 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 213 ptr, align, /*volatile_=*/false, 214 storeOp.getNontemporal()); 215 } 216 217 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 218 vector::MaskedStoreOpAdaptor adaptor, 219 VectorType vectorTy, Value ptr, unsigned align, 220 ConversionPatternRewriter &rewriter) { 221 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 222 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 223 } 224 225 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 226 /// vector.maskedstore. 227 template <class LoadOrStoreOp> 228 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 229 public: 230 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 231 232 LogicalResult 233 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 234 typename LoadOrStoreOp::Adaptor adaptor, 235 ConversionPatternRewriter &rewriter) const override { 236 // Only 1-D vectors can be lowered to LLVM. 237 VectorType vectorTy = loadOrStoreOp.getVectorType(); 238 if (vectorTy.getRank() > 1) 239 return failure(); 240 241 auto loc = loadOrStoreOp->getLoc(); 242 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 243 244 // Resolve alignment. 245 unsigned align; 246 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 247 return failure(); 248 249 // Resolve address. 250 auto vtype = cast<VectorType>( 251 this->typeConverter->convertType(loadOrStoreOp.getVectorType())); 252 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 253 adaptor.getIndices(), rewriter); 254 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align, 255 rewriter); 256 return success(); 257 } 258 }; 259 260 /// Conversion pattern for a vector.gather. 261 class VectorGatherOpConversion 262 : public ConvertOpToLLVMPattern<vector::GatherOp> { 263 public: 264 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 265 266 LogicalResult 267 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 268 ConversionPatternRewriter &rewriter) const override { 269 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); 270 assert(memRefType && "The base should be bufferized"); 271 272 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 273 return failure(); 274 275 auto loc = gather->getLoc(); 276 277 // Resolve alignment. 278 unsigned align; 279 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 280 return failure(); 281 282 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 283 adaptor.getIndices(), rewriter); 284 Value base = adaptor.getBase(); 285 286 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 287 // Handle the simple case of 1-D vector. 288 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) { 289 auto vType = gather.getVectorType(); 290 // Resolve address. 291 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 292 memRefType, base, ptr, adaptor.getIndexVec(), 293 /*vLen=*/vType.getDimSize(0)); 294 // Replace with the gather intrinsic. 295 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 296 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 297 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 298 return success(); 299 } 300 301 const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 302 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 303 &typeConverter](Type llvm1DVectorTy, 304 ValueRange vectorOperands) { 305 // Resolve address. 306 Value ptrs = getIndexedPtrs( 307 rewriter, loc, typeConverter, memRefType, base, ptr, 308 /*index=*/vectorOperands[0], 309 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 310 // Create the gather intrinsic. 311 return rewriter.create<LLVM::masked_gather>( 312 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 313 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 314 }; 315 SmallVector<Value> vectorOperands = { 316 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 317 return LLVM::detail::handleMultidimensionalVectors( 318 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 319 } 320 }; 321 322 /// Conversion pattern for a vector.scatter. 323 class VectorScatterOpConversion 324 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 325 public: 326 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 327 328 LogicalResult 329 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 330 ConversionPatternRewriter &rewriter) const override { 331 auto loc = scatter->getLoc(); 332 MemRefType memRefType = scatter.getMemRefType(); 333 334 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 335 return failure(); 336 337 // Resolve alignment. 338 unsigned align; 339 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 340 return failure(); 341 342 // Resolve address. 343 VectorType vType = scatter.getVectorType(); 344 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 345 adaptor.getIndices(), rewriter); 346 Value ptrs = getIndexedPtrs( 347 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 348 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 349 350 // Replace with the scatter intrinsic. 351 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 352 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 353 rewriter.getI32IntegerAttr(align)); 354 return success(); 355 } 356 }; 357 358 /// Conversion pattern for a vector.expandload. 359 class VectorExpandLoadOpConversion 360 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 361 public: 362 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 363 364 LogicalResult 365 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 366 ConversionPatternRewriter &rewriter) const override { 367 auto loc = expand->getLoc(); 368 MemRefType memRefType = expand.getMemRefType(); 369 370 // Resolve address. 371 auto vtype = typeConverter->convertType(expand.getVectorType()); 372 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 373 adaptor.getIndices(), rewriter); 374 375 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 376 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 377 return success(); 378 } 379 }; 380 381 /// Conversion pattern for a vector.compressstore. 382 class VectorCompressStoreOpConversion 383 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 384 public: 385 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 386 387 LogicalResult 388 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 389 ConversionPatternRewriter &rewriter) const override { 390 auto loc = compress->getLoc(); 391 MemRefType memRefType = compress.getMemRefType(); 392 393 // Resolve address. 394 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 395 adaptor.getIndices(), rewriter); 396 397 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 398 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 399 return success(); 400 } 401 }; 402 403 /// Reduction neutral classes for overloading. 404 class ReductionNeutralZero {}; 405 class ReductionNeutralIntOne {}; 406 class ReductionNeutralFPOne {}; 407 class ReductionNeutralAllOnes {}; 408 class ReductionNeutralSIntMin {}; 409 class ReductionNeutralUIntMin {}; 410 class ReductionNeutralSIntMax {}; 411 class ReductionNeutralUIntMax {}; 412 class ReductionNeutralFPMin {}; 413 class ReductionNeutralFPMax {}; 414 415 /// Create the reduction neutral zero value. 416 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 417 ConversionPatternRewriter &rewriter, 418 Location loc, Type llvmType) { 419 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 420 rewriter.getZeroAttr(llvmType)); 421 } 422 423 /// Create the reduction neutral integer one value. 424 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 425 ConversionPatternRewriter &rewriter, 426 Location loc, Type llvmType) { 427 return rewriter.create<LLVM::ConstantOp>( 428 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 429 } 430 431 /// Create the reduction neutral fp one value. 432 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 433 ConversionPatternRewriter &rewriter, 434 Location loc, Type llvmType) { 435 return rewriter.create<LLVM::ConstantOp>( 436 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 437 } 438 439 /// Create the reduction neutral all-ones value. 440 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 441 ConversionPatternRewriter &rewriter, 442 Location loc, Type llvmType) { 443 return rewriter.create<LLVM::ConstantOp>( 444 loc, llvmType, 445 rewriter.getIntegerAttr( 446 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 447 } 448 449 /// Create the reduction neutral signed int minimum value. 450 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 451 ConversionPatternRewriter &rewriter, 452 Location loc, Type llvmType) { 453 return rewriter.create<LLVM::ConstantOp>( 454 loc, llvmType, 455 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 456 llvmType.getIntOrFloatBitWidth()))); 457 } 458 459 /// Create the reduction neutral unsigned int minimum value. 460 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 461 ConversionPatternRewriter &rewriter, 462 Location loc, Type llvmType) { 463 return rewriter.create<LLVM::ConstantOp>( 464 loc, llvmType, 465 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 466 llvmType.getIntOrFloatBitWidth()))); 467 } 468 469 /// Create the reduction neutral signed int maximum value. 470 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 471 ConversionPatternRewriter &rewriter, 472 Location loc, Type llvmType) { 473 return rewriter.create<LLVM::ConstantOp>( 474 loc, llvmType, 475 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 476 llvmType.getIntOrFloatBitWidth()))); 477 } 478 479 /// Create the reduction neutral unsigned int maximum value. 480 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 481 ConversionPatternRewriter &rewriter, 482 Location loc, Type llvmType) { 483 return rewriter.create<LLVM::ConstantOp>( 484 loc, llvmType, 485 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 486 llvmType.getIntOrFloatBitWidth()))); 487 } 488 489 /// Create the reduction neutral fp minimum value. 490 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 491 ConversionPatternRewriter &rewriter, 492 Location loc, Type llvmType) { 493 auto floatType = cast<FloatType>(llvmType); 494 return rewriter.create<LLVM::ConstantOp>( 495 loc, llvmType, 496 rewriter.getFloatAttr( 497 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 498 /*Negative=*/false))); 499 } 500 501 /// Create the reduction neutral fp maximum value. 502 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 503 ConversionPatternRewriter &rewriter, 504 Location loc, Type llvmType) { 505 auto floatType = cast<FloatType>(llvmType); 506 return rewriter.create<LLVM::ConstantOp>( 507 loc, llvmType, 508 rewriter.getFloatAttr( 509 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 510 /*Negative=*/true))); 511 } 512 513 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 514 /// returns a new accumulator value using `ReductionNeutral`. 515 template <class ReductionNeutral> 516 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 517 Location loc, Type llvmType, 518 Value accumulator) { 519 if (accumulator) 520 return accumulator; 521 522 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 523 llvmType); 524 } 525 526 /// Creates a value with the 1-D vector shape provided in `llvmType`. 527 /// This is used as effective vector length by some intrinsics supporting 528 /// dynamic vector lengths at runtime. 529 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 530 Location loc, Type llvmType) { 531 VectorType vType = cast<VectorType>(llvmType); 532 auto vShape = vType.getShape(); 533 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 534 535 Value baseVecLength = rewriter.create<LLVM::ConstantOp>( 536 loc, rewriter.getI32Type(), 537 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 538 539 if (!vType.getScalableDims()[0]) 540 return baseVecLength; 541 542 // For a scalable vector type, create and return `vScale * baseVecLength`. 543 Value vScale = rewriter.create<vector::VectorScaleOp>(loc); 544 vScale = 545 rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale); 546 Value scalableVecLength = 547 rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale); 548 return scalableVecLength; 549 } 550 551 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 552 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 553 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 554 /// non-null. 555 template <class LLVMRedIntrinOp, class ScalarOp> 556 static Value createIntegerReductionArithmeticOpLowering( 557 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 558 Value vectorOperand, Value accumulator) { 559 560 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 561 562 if (accumulator) 563 result = rewriter.create<ScalarOp>(loc, accumulator, result); 564 return result; 565 } 566 567 /// Helper method to lower a `vector.reduction` operation that performs 568 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 569 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 570 /// the accumulator value if non-null. 571 template <class LLVMRedIntrinOp> 572 static Value createIntegerReductionComparisonOpLowering( 573 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 574 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 575 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 576 if (accumulator) { 577 Value cmp = 578 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 579 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 580 } 581 return result; 582 } 583 584 namespace { 585 template <typename Source> 586 struct VectorToScalarMapper; 587 template <> 588 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> { 589 using Type = LLVM::MaximumOp; 590 }; 591 template <> 592 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> { 593 using Type = LLVM::MinimumOp; 594 }; 595 template <> 596 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> { 597 using Type = LLVM::MaxNumOp; 598 }; 599 template <> 600 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> { 601 using Type = LLVM::MinNumOp; 602 }; 603 } // namespace 604 605 template <class LLVMRedIntrinOp> 606 static Value createFPReductionComparisonOpLowering( 607 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 608 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { 609 Value result = 610 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf); 611 612 if (accumulator) { 613 result = 614 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>( 615 loc, result, accumulator); 616 } 617 618 return result; 619 } 620 621 /// Reduction neutral classes for overloading 622 class MaskNeutralFMaximum {}; 623 class MaskNeutralFMinimum {}; 624 625 /// Get the mask neutral floating point maximum value 626 static llvm::APFloat 627 getMaskNeutralValue(MaskNeutralFMaximum, 628 const llvm::fltSemantics &floatSemantics) { 629 return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true); 630 } 631 /// Get the mask neutral floating point minimum value 632 static llvm::APFloat 633 getMaskNeutralValue(MaskNeutralFMinimum, 634 const llvm::fltSemantics &floatSemantics) { 635 return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false); 636 } 637 638 /// Create the mask neutral floating point MLIR vector constant 639 template <typename MaskNeutral> 640 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, 641 Location loc, Type llvmType, 642 Type vectorType) { 643 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics(); 644 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); 645 auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value); 646 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue); 647 } 648 649 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked 650 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for 651 /// `fmaximum`/`fminimum`. 652 /// More information: https://github.com/llvm/llvm-project/issues/64940 653 template <class LLVMRedIntrinOp, class MaskNeutral> 654 static Value 655 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, 656 Location loc, Type llvmType, 657 Value vectorOperand, Value accumulator, 658 Value mask, LLVM::FastmathFlagsAttr fmf) { 659 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>( 660 rewriter, loc, llvmType, vectorOperand.getType()); 661 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>( 662 loc, mask, vectorOperand, vectorMaskNeutral); 663 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>( 664 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); 665 } 666 667 template <class LLVMRedIntrinOp, class ReductionNeutral> 668 static Value 669 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 670 Type llvmType, Value vectorOperand, 671 Value accumulator, LLVM::FastmathFlagsAttr fmf) { 672 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 673 llvmType, accumulator); 674 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType, 675 /*startValue=*/accumulator, 676 vectorOperand, fmf); 677 } 678 679 /// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic 680 /// that requires a start value. This start value format spans across fp 681 /// reductions without mask and all the masked reduction intrinsics. 682 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 683 static Value 684 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, 685 Location loc, Type llvmType, 686 Value vectorOperand, Value accumulator) { 687 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 688 llvmType, accumulator); 689 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 690 /*startValue=*/accumulator, 691 vectorOperand); 692 } 693 694 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 695 static Value lowerPredicatedReductionWithStartValue( 696 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 697 Value vectorOperand, Value accumulator, Value mask) { 698 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 699 llvmType, accumulator); 700 Value vectorLength = 701 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 702 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 703 /*startValue=*/accumulator, 704 vectorOperand, mask, vectorLength); 705 } 706 707 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 708 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 709 static Value lowerPredicatedReductionWithStartValue( 710 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 711 Value vectorOperand, Value accumulator, Value mask) { 712 if (llvmType.isIntOrIndex()) 713 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp, 714 IntReductionNeutral>( 715 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 716 717 // FP dispatch. 718 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp, 719 FPReductionNeutral>( 720 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 721 } 722 723 /// Conversion pattern for all vector reductions. 724 class VectorReductionOpConversion 725 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 726 public: 727 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv, 728 bool reassociateFPRed) 729 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 730 reassociateFPReductions(reassociateFPRed) {} 731 732 LogicalResult 733 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 734 ConversionPatternRewriter &rewriter) const override { 735 auto kind = reductionOp.getKind(); 736 Type eltType = reductionOp.getDest().getType(); 737 Type llvmType = typeConverter->convertType(eltType); 738 Value operand = adaptor.getVector(); 739 Value acc = adaptor.getAcc(); 740 Location loc = reductionOp.getLoc(); 741 742 if (eltType.isIntOrIndex()) { 743 // Integer reductions: add/mul/min/max/and/or/xor. 744 Value result; 745 switch (kind) { 746 case vector::CombiningKind::ADD: 747 result = 748 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 749 LLVM::AddOp>( 750 rewriter, loc, llvmType, operand, acc); 751 break; 752 case vector::CombiningKind::MUL: 753 result = 754 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 755 LLVM::MulOp>( 756 rewriter, loc, llvmType, operand, acc); 757 break; 758 case vector::CombiningKind::MINUI: 759 result = createIntegerReductionComparisonOpLowering< 760 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 761 LLVM::ICmpPredicate::ule); 762 break; 763 case vector::CombiningKind::MINSI: 764 result = createIntegerReductionComparisonOpLowering< 765 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 766 LLVM::ICmpPredicate::sle); 767 break; 768 case vector::CombiningKind::MAXUI: 769 result = createIntegerReductionComparisonOpLowering< 770 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 771 LLVM::ICmpPredicate::uge); 772 break; 773 case vector::CombiningKind::MAXSI: 774 result = createIntegerReductionComparisonOpLowering< 775 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 776 LLVM::ICmpPredicate::sge); 777 break; 778 case vector::CombiningKind::AND: 779 result = 780 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 781 LLVM::AndOp>( 782 rewriter, loc, llvmType, operand, acc); 783 break; 784 case vector::CombiningKind::OR: 785 result = 786 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 787 LLVM::OrOp>( 788 rewriter, loc, llvmType, operand, acc); 789 break; 790 case vector::CombiningKind::XOR: 791 result = 792 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 793 LLVM::XOrOp>( 794 rewriter, loc, llvmType, operand, acc); 795 break; 796 default: 797 return failure(); 798 } 799 rewriter.replaceOp(reductionOp, result); 800 801 return success(); 802 } 803 804 if (!isa<FloatType>(eltType)) 805 return failure(); 806 807 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 808 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 809 reductionOp.getContext(), 810 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 811 fmf = LLVM::FastmathFlagsAttr::get( 812 reductionOp.getContext(), 813 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc 814 : LLVM::FastmathFlags::none)); 815 816 // Floating-point reductions: add/mul/min/max 817 Value result; 818 if (kind == vector::CombiningKind::ADD) { 819 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 820 ReductionNeutralZero>( 821 rewriter, loc, llvmType, operand, acc, fmf); 822 } else if (kind == vector::CombiningKind::MUL) { 823 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 824 ReductionNeutralFPOne>( 825 rewriter, loc, llvmType, operand, acc, fmf); 826 } else if (kind == vector::CombiningKind::MINIMUMF) { 827 result = 828 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>( 829 rewriter, loc, llvmType, operand, acc, fmf); 830 } else if (kind == vector::CombiningKind::MAXIMUMF) { 831 result = 832 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>( 833 rewriter, loc, llvmType, operand, acc, fmf); 834 } else if (kind == vector::CombiningKind::MINNUMF) { 835 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 836 rewriter, loc, llvmType, operand, acc, fmf); 837 } else if (kind == vector::CombiningKind::MAXNUMF) { 838 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 839 rewriter, loc, llvmType, operand, acc, fmf); 840 } else 841 return failure(); 842 843 rewriter.replaceOp(reductionOp, result); 844 return success(); 845 } 846 847 private: 848 const bool reassociateFPReductions; 849 }; 850 851 /// Base class to convert a `vector.mask` operation while matching traits 852 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 853 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 854 /// method performs a second match against the maskable operation `MaskedOp`. 855 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 856 /// implemented by the concrete conversion classes. This method can match 857 /// against specific traits of the `vector.mask` and the maskable operation. It 858 /// must replace the `vector.mask` operation. 859 template <class MaskedOp> 860 class VectorMaskOpConversionBase 861 : public ConvertOpToLLVMPattern<vector::MaskOp> { 862 public: 863 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 864 865 LogicalResult 866 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 867 ConversionPatternRewriter &rewriter) const final { 868 // Match against the maskable operation kind. 869 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 870 if (!maskedOp) 871 return failure(); 872 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 873 } 874 875 protected: 876 virtual LogicalResult 877 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 878 vector::MaskableOpInterface maskableOp, 879 ConversionPatternRewriter &rewriter) const = 0; 880 }; 881 882 class MaskedReductionOpConversion 883 : public VectorMaskOpConversionBase<vector::ReductionOp> { 884 885 public: 886 using VectorMaskOpConversionBase< 887 vector::ReductionOp>::VectorMaskOpConversionBase; 888 889 LogicalResult matchAndRewriteMaskableOp( 890 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 891 ConversionPatternRewriter &rewriter) const override { 892 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 893 auto kind = reductionOp.getKind(); 894 Type eltType = reductionOp.getDest().getType(); 895 Type llvmType = typeConverter->convertType(eltType); 896 Value operand = reductionOp.getVector(); 897 Value acc = reductionOp.getAcc(); 898 Location loc = reductionOp.getLoc(); 899 900 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 901 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 902 reductionOp.getContext(), 903 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 904 905 Value result; 906 switch (kind) { 907 case vector::CombiningKind::ADD: 908 result = lowerPredicatedReductionWithStartValue< 909 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 910 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 911 maskOp.getMask()); 912 break; 913 case vector::CombiningKind::MUL: 914 result = lowerPredicatedReductionWithStartValue< 915 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 916 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 917 maskOp.getMask()); 918 break; 919 case vector::CombiningKind::MINUI: 920 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp, 921 ReductionNeutralUIntMax>( 922 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 923 break; 924 case vector::CombiningKind::MINSI: 925 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp, 926 ReductionNeutralSIntMax>( 927 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 928 break; 929 case vector::CombiningKind::MAXUI: 930 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp, 931 ReductionNeutralUIntMin>( 932 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 933 break; 934 case vector::CombiningKind::MAXSI: 935 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp, 936 ReductionNeutralSIntMin>( 937 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 938 break; 939 case vector::CombiningKind::AND: 940 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp, 941 ReductionNeutralAllOnes>( 942 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 943 break; 944 case vector::CombiningKind::OR: 945 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp, 946 ReductionNeutralZero>( 947 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 948 break; 949 case vector::CombiningKind::XOR: 950 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp, 951 ReductionNeutralZero>( 952 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 953 break; 954 case vector::CombiningKind::MINNUMF: 955 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp, 956 ReductionNeutralFPMax>( 957 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 958 break; 959 case vector::CombiningKind::MAXNUMF: 960 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp, 961 ReductionNeutralFPMin>( 962 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 963 break; 964 case CombiningKind::MAXIMUMF: 965 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum, 966 MaskNeutralFMaximum>( 967 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 968 break; 969 case CombiningKind::MINIMUMF: 970 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum, 971 MaskNeutralFMinimum>( 972 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 973 break; 974 } 975 976 // Replace `vector.mask` operation altogether. 977 rewriter.replaceOp(maskOp, result); 978 return success(); 979 } 980 }; 981 982 class VectorShuffleOpConversion 983 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 984 public: 985 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 986 987 LogicalResult 988 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 989 ConversionPatternRewriter &rewriter) const override { 990 auto loc = shuffleOp->getLoc(); 991 auto v1Type = shuffleOp.getV1VectorType(); 992 auto v2Type = shuffleOp.getV2VectorType(); 993 auto vectorType = shuffleOp.getResultVectorType(); 994 Type llvmType = typeConverter->convertType(vectorType); 995 auto maskArrayAttr = shuffleOp.getMask(); 996 997 // Bail if result type cannot be lowered. 998 if (!llvmType) 999 return failure(); 1000 1001 // Get rank and dimension sizes. 1002 int64_t rank = vectorType.getRank(); 1003 #ifndef NDEBUG 1004 bool wellFormed0DCase = 1005 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 1006 bool wellFormedNDCase = 1007 v1Type.getRank() == rank && v2Type.getRank() == rank; 1008 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 1009 #endif 1010 1011 // For rank 0 and 1, where both operands have *exactly* the same vector 1012 // type, there is direct shuffle support in LLVM. Use it! 1013 if (rank <= 1 && v1Type == v2Type) { 1014 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 1015 loc, adaptor.getV1(), adaptor.getV2(), 1016 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 1017 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 1018 return success(); 1019 } 1020 1021 // For all other cases, insert the individual values individually. 1022 int64_t v1Dim = v1Type.getDimSize(0); 1023 Type eltType; 1024 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType)) 1025 eltType = arrayType.getElementType(); 1026 else 1027 eltType = cast<VectorType>(llvmType).getElementType(); 1028 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 1029 int64_t insPos = 0; 1030 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 1031 int64_t extPos = cast<IntegerAttr>(en.value()).getInt(); 1032 Value value = adaptor.getV1(); 1033 if (extPos >= v1Dim) { 1034 extPos -= v1Dim; 1035 value = adaptor.getV2(); 1036 } 1037 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 1038 eltType, rank, extPos); 1039 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 1040 llvmType, rank, insPos++); 1041 } 1042 rewriter.replaceOp(shuffleOp, insert); 1043 return success(); 1044 } 1045 }; 1046 1047 class VectorExtractElementOpConversion 1048 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 1049 public: 1050 using ConvertOpToLLVMPattern< 1051 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 1052 1053 LogicalResult 1054 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 1055 ConversionPatternRewriter &rewriter) const override { 1056 auto vectorType = extractEltOp.getSourceVectorType(); 1057 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 1058 1059 // Bail if result type cannot be lowered. 1060 if (!llvmType) 1061 return failure(); 1062 1063 if (vectorType.getRank() == 0) { 1064 Location loc = extractEltOp.getLoc(); 1065 auto idxType = rewriter.getIndexType(); 1066 auto zero = rewriter.create<LLVM::ConstantOp>( 1067 loc, typeConverter->convertType(idxType), 1068 rewriter.getIntegerAttr(idxType, 0)); 1069 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1070 extractEltOp, llvmType, adaptor.getVector(), zero); 1071 return success(); 1072 } 1073 1074 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1075 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1076 return success(); 1077 } 1078 }; 1079 1080 class VectorExtractOpConversion 1081 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1082 public: 1083 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1084 1085 LogicalResult 1086 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1087 ConversionPatternRewriter &rewriter) const override { 1088 auto loc = extractOp->getLoc(); 1089 auto resultType = extractOp.getResult().getType(); 1090 auto llvmResultType = typeConverter->convertType(resultType); 1091 // Bail if result type cannot be lowered. 1092 if (!llvmResultType) 1093 return failure(); 1094 1095 SmallVector<OpFoldResult> positionVec = getMixedValues( 1096 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter); 1097 1098 // Extract entire vector. Should be handled by folder, but just to be safe. 1099 ArrayRef<OpFoldResult> position(positionVec); 1100 if (position.empty()) { 1101 rewriter.replaceOp(extractOp, adaptor.getVector()); 1102 return success(); 1103 } 1104 1105 // One-shot extraction of vector from array (only requires extractvalue). 1106 if (isa<VectorType>(resultType)) { 1107 if (extractOp.hasDynamicPosition()) 1108 return failure(); 1109 1110 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1111 loc, adaptor.getVector(), getAsIntegers(position)); 1112 rewriter.replaceOp(extractOp, extracted); 1113 return success(); 1114 } 1115 1116 // Potential extraction of 1-D vector from array. 1117 Value extracted = adaptor.getVector(); 1118 if (position.size() > 1) { 1119 if (extractOp.hasDynamicPosition()) 1120 return failure(); 1121 1122 SmallVector<int64_t> nMinusOnePosition = 1123 getAsIntegers(position.drop_back()); 1124 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 1125 nMinusOnePosition); 1126 } 1127 1128 Value lastPosition = getAsLLVMValue(rewriter, loc, position.back()); 1129 // Remaining extraction of element from 1-D LLVM vector. 1130 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted, 1131 lastPosition); 1132 return success(); 1133 } 1134 }; 1135 1136 /// Conversion pattern that turns a vector.fma on a 1-D vector 1137 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1138 /// This does not match vectors of n >= 2 rank. 1139 /// 1140 /// Example: 1141 /// ``` 1142 /// vector.fma %a, %a, %a : vector<8xf32> 1143 /// ``` 1144 /// is converted to: 1145 /// ``` 1146 /// llvm.intr.fmuladd %va, %va, %va: 1147 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1148 /// -> !llvm."<8 x f32>"> 1149 /// ``` 1150 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1151 public: 1152 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1153 1154 LogicalResult 1155 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1156 ConversionPatternRewriter &rewriter) const override { 1157 VectorType vType = fmaOp.getVectorType(); 1158 if (vType.getRank() > 1) 1159 return failure(); 1160 1161 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1162 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1163 return success(); 1164 } 1165 }; 1166 1167 class VectorInsertElementOpConversion 1168 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1169 public: 1170 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1171 1172 LogicalResult 1173 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1174 ConversionPatternRewriter &rewriter) const override { 1175 auto vectorType = insertEltOp.getDestVectorType(); 1176 auto llvmType = typeConverter->convertType(vectorType); 1177 1178 // Bail if result type cannot be lowered. 1179 if (!llvmType) 1180 return failure(); 1181 1182 if (vectorType.getRank() == 0) { 1183 Location loc = insertEltOp.getLoc(); 1184 auto idxType = rewriter.getIndexType(); 1185 auto zero = rewriter.create<LLVM::ConstantOp>( 1186 loc, typeConverter->convertType(idxType), 1187 rewriter.getIntegerAttr(idxType, 0)); 1188 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1189 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1190 return success(); 1191 } 1192 1193 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1194 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1195 adaptor.getPosition()); 1196 return success(); 1197 } 1198 }; 1199 1200 class VectorInsertOpConversion 1201 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1202 public: 1203 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1204 1205 LogicalResult 1206 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1207 ConversionPatternRewriter &rewriter) const override { 1208 auto loc = insertOp->getLoc(); 1209 auto sourceType = insertOp.getSourceType(); 1210 auto destVectorType = insertOp.getDestVectorType(); 1211 auto llvmResultType = typeConverter->convertType(destVectorType); 1212 // Bail if result type cannot be lowered. 1213 if (!llvmResultType) 1214 return failure(); 1215 1216 SmallVector<OpFoldResult> positionVec = getMixedValues( 1217 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter); 1218 1219 // Overwrite entire vector with value. Should be handled by folder, but 1220 // just to be safe. 1221 ArrayRef<OpFoldResult> position(positionVec); 1222 if (position.empty()) { 1223 rewriter.replaceOp(insertOp, adaptor.getSource()); 1224 return success(); 1225 } 1226 1227 // One-shot insertion of a vector into an array (only requires insertvalue). 1228 if (isa<VectorType>(sourceType)) { 1229 if (insertOp.hasDynamicPosition()) 1230 return failure(); 1231 1232 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1233 loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position)); 1234 rewriter.replaceOp(insertOp, inserted); 1235 return success(); 1236 } 1237 1238 // Potential extraction of 1-D vector from array. 1239 Value extracted = adaptor.getDest(); 1240 auto oneDVectorType = destVectorType; 1241 if (position.size() > 1) { 1242 if (insertOp.hasDynamicPosition()) 1243 return failure(); 1244 1245 oneDVectorType = reducedVectorTypeBack(destVectorType); 1246 extracted = rewriter.create<LLVM::ExtractValueOp>( 1247 loc, extracted, getAsIntegers(position.drop_back())); 1248 } 1249 1250 // Insertion of an element into a 1-D LLVM vector. 1251 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1252 loc, typeConverter->convertType(oneDVectorType), extracted, 1253 adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back())); 1254 1255 // Potential insertion of resulting 1-D vector into array. 1256 if (position.size() > 1) { 1257 if (insertOp.hasDynamicPosition()) 1258 return failure(); 1259 1260 inserted = rewriter.create<LLVM::InsertValueOp>( 1261 loc, adaptor.getDest(), inserted, 1262 getAsIntegers(position.drop_back())); 1263 } 1264 1265 rewriter.replaceOp(insertOp, inserted); 1266 return success(); 1267 } 1268 }; 1269 1270 /// Lower vector.scalable.insert ops to LLVM vector.insert 1271 struct VectorScalableInsertOpLowering 1272 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1273 using ConvertOpToLLVMPattern< 1274 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1275 1276 LogicalResult 1277 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1278 ConversionPatternRewriter &rewriter) const override { 1279 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1280 insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos()); 1281 return success(); 1282 } 1283 }; 1284 1285 /// Lower vector.scalable.extract ops to LLVM vector.extract 1286 struct VectorScalableExtractOpLowering 1287 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1288 using ConvertOpToLLVMPattern< 1289 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1290 1291 LogicalResult 1292 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1293 ConversionPatternRewriter &rewriter) const override { 1294 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1295 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1296 adaptor.getSource(), adaptor.getPos()); 1297 return success(); 1298 } 1299 }; 1300 1301 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1302 /// 1303 /// Example: 1304 /// ``` 1305 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1306 /// ``` 1307 /// is rewritten into: 1308 /// ``` 1309 /// %r = splat %f0: vector<2x4xf32> 1310 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1311 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1312 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1313 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1314 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1315 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1316 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1317 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1318 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1319 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1320 /// // %r3 holds the final value. 1321 /// ``` 1322 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1323 public: 1324 using OpRewritePattern<FMAOp>::OpRewritePattern; 1325 1326 void initialize() { 1327 // This pattern recursively unpacks one dimension at a time. The recursion 1328 // bounded as the rank is strictly decreasing. 1329 setHasBoundedRewriteRecursion(); 1330 } 1331 1332 LogicalResult matchAndRewrite(FMAOp op, 1333 PatternRewriter &rewriter) const override { 1334 auto vType = op.getVectorType(); 1335 if (vType.getRank() < 2) 1336 return failure(); 1337 1338 auto loc = op.getLoc(); 1339 auto elemType = vType.getElementType(); 1340 Value zero = rewriter.create<arith::ConstantOp>( 1341 loc, elemType, rewriter.getZeroAttr(elemType)); 1342 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1343 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1344 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1345 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1346 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1347 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1348 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1349 } 1350 rewriter.replaceOp(op, desc); 1351 return success(); 1352 } 1353 }; 1354 1355 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1356 /// static layout. 1357 static std::optional<SmallVector<int64_t, 4>> 1358 computeContiguousStrides(MemRefType memRefType) { 1359 int64_t offset; 1360 SmallVector<int64_t, 4> strides; 1361 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1362 return std::nullopt; 1363 if (!strides.empty() && strides.back() != 1) 1364 return std::nullopt; 1365 // If no layout or identity layout, this is contiguous by definition. 1366 if (memRefType.getLayout().isIdentity()) 1367 return strides; 1368 1369 // Otherwise, we must determine contiguity form shapes. This can only ever 1370 // work in static cases because MemRefType is underspecified to represent 1371 // contiguous dynamic shapes in other ways than with just empty/identity 1372 // layout. 1373 auto sizes = memRefType.getShape(); 1374 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1375 if (ShapedType::isDynamic(sizes[index + 1]) || 1376 ShapedType::isDynamic(strides[index]) || 1377 ShapedType::isDynamic(strides[index + 1])) 1378 return std::nullopt; 1379 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1380 return std::nullopt; 1381 } 1382 return strides; 1383 } 1384 1385 class VectorTypeCastOpConversion 1386 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1387 public: 1388 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1389 1390 LogicalResult 1391 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1392 ConversionPatternRewriter &rewriter) const override { 1393 auto loc = castOp->getLoc(); 1394 MemRefType sourceMemRefType = 1395 cast<MemRefType>(castOp.getOperand().getType()); 1396 MemRefType targetMemRefType = castOp.getType(); 1397 1398 // Only static shape casts supported atm. 1399 if (!sourceMemRefType.hasStaticShape() || 1400 !targetMemRefType.hasStaticShape()) 1401 return failure(); 1402 1403 auto llvmSourceDescriptorTy = 1404 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); 1405 if (!llvmSourceDescriptorTy) 1406 return failure(); 1407 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1408 1409 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1410 typeConverter->convertType(targetMemRefType)); 1411 if (!llvmTargetDescriptorTy) 1412 return failure(); 1413 1414 // Only contiguous source buffers supported atm. 1415 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1416 if (!sourceStrides) 1417 return failure(); 1418 auto targetStrides = computeContiguousStrides(targetMemRefType); 1419 if (!targetStrides) 1420 return failure(); 1421 // Only support static strides for now, regardless of contiguity. 1422 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1423 return failure(); 1424 1425 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1426 1427 // Create descriptor. 1428 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1429 // Set allocated ptr. 1430 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1431 desc.setAllocatedPtr(rewriter, loc, allocated); 1432 1433 // Set aligned ptr. 1434 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1435 desc.setAlignedPtr(rewriter, loc, ptr); 1436 // Fill offset 0. 1437 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1438 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1439 desc.setOffset(rewriter, loc, zero); 1440 1441 // Fill size and stride descriptors in memref. 1442 for (const auto &indexedSize : 1443 llvm::enumerate(targetMemRefType.getShape())) { 1444 int64_t index = indexedSize.index(); 1445 auto sizeAttr = 1446 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1447 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1448 desc.setSize(rewriter, loc, index, size); 1449 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1450 (*targetStrides)[index]); 1451 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1452 desc.setStride(rewriter, loc, index, stride); 1453 } 1454 1455 rewriter.replaceOp(castOp, {desc}); 1456 return success(); 1457 } 1458 }; 1459 1460 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1461 /// Non-scalable versions of this operation are handled in Vector Transforms. 1462 class VectorCreateMaskOpRewritePattern 1463 : public OpRewritePattern<vector::CreateMaskOp> { 1464 public: 1465 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1466 bool enableIndexOpt) 1467 : OpRewritePattern<vector::CreateMaskOp>(context), 1468 force32BitVectorIndices(enableIndexOpt) {} 1469 1470 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1471 PatternRewriter &rewriter) const override { 1472 auto dstType = op.getType(); 1473 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) 1474 return failure(); 1475 IntegerType idxType = 1476 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1477 auto loc = op->getLoc(); 1478 Value indices = rewriter.create<LLVM::StepVectorOp>( 1479 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1480 /*isScalable=*/true)); 1481 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1482 op.getOperand(0)); 1483 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1484 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1485 indices, bounds); 1486 rewriter.replaceOp(op, comp); 1487 return success(); 1488 } 1489 1490 private: 1491 const bool force32BitVectorIndices; 1492 }; 1493 1494 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1495 public: 1496 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1497 1498 // Lowering implementation that relies on a small runtime support library, 1499 // which only needs to provide a few printing methods (single value for all 1500 // data types, opening/closing bracket, comma, newline). The lowering splits 1501 // the vector into elementary printing operations. The advantage of this 1502 // approach is that the library can remain unaware of all low-level 1503 // implementation details of vectors while still supporting output of any 1504 // shaped and dimensioned vector. 1505 // 1506 // Note: This lowering only handles scalars, n-D vectors are broken into 1507 // printing scalars in loops in VectorToSCF. 1508 // 1509 // TODO: rely solely on libc in future? something else? 1510 // 1511 LogicalResult 1512 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1513 ConversionPatternRewriter &rewriter) const override { 1514 auto parent = printOp->getParentOfType<ModuleOp>(); 1515 if (!parent) 1516 return failure(); 1517 1518 auto loc = printOp->getLoc(); 1519 1520 if (auto value = adaptor.getSource()) { 1521 Type printType = printOp.getPrintType(); 1522 if (isa<VectorType>(printType)) { 1523 // Vectors should be broken into elementary print ops in VectorToSCF. 1524 return failure(); 1525 } 1526 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value))) 1527 return failure(); 1528 } 1529 1530 auto punct = printOp.getPunctuation(); 1531 if (auto stringLiteral = printOp.getStringLiteral()) { 1532 LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", 1533 *stringLiteral, *getTypeConverter(), 1534 /*addNewline=*/false); 1535 } else if (punct != PrintPunctuation::NoPunctuation) { 1536 emitCall(rewriter, printOp->getLoc(), [&] { 1537 switch (punct) { 1538 case PrintPunctuation::Close: 1539 return LLVM::lookupOrCreatePrintCloseFn(parent); 1540 case PrintPunctuation::Open: 1541 return LLVM::lookupOrCreatePrintOpenFn(parent); 1542 case PrintPunctuation::Comma: 1543 return LLVM::lookupOrCreatePrintCommaFn(parent); 1544 case PrintPunctuation::NewLine: 1545 return LLVM::lookupOrCreatePrintNewlineFn(parent); 1546 default: 1547 llvm_unreachable("unexpected punctuation"); 1548 } 1549 }()); 1550 } 1551 1552 rewriter.eraseOp(printOp); 1553 return success(); 1554 } 1555 1556 private: 1557 enum class PrintConversion { 1558 // clang-format off 1559 None, 1560 ZeroExt64, 1561 SignExt64, 1562 Bitcast16 1563 // clang-format on 1564 }; 1565 1566 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter, 1567 ModuleOp parent, Location loc, Type printType, 1568 Value value) const { 1569 if (typeConverter->convertType(printType) == nullptr) 1570 return failure(); 1571 1572 // Make sure element type has runtime support. 1573 PrintConversion conversion = PrintConversion::None; 1574 Operation *printer; 1575 if (printType.isF32()) { 1576 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1577 } else if (printType.isF64()) { 1578 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1579 } else if (printType.isF16()) { 1580 conversion = PrintConversion::Bitcast16; // bits! 1581 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1582 } else if (printType.isBF16()) { 1583 conversion = PrintConversion::Bitcast16; // bits! 1584 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1585 } else if (printType.isIndex()) { 1586 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1587 } else if (auto intTy = dyn_cast<IntegerType>(printType)) { 1588 // Integers need a zero or sign extension on the operand 1589 // (depending on the source type) as well as a signed or 1590 // unsigned print method. Up to 64-bit is supported. 1591 unsigned width = intTy.getWidth(); 1592 if (intTy.isUnsigned()) { 1593 if (width <= 64) { 1594 if (width < 64) 1595 conversion = PrintConversion::ZeroExt64; 1596 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1597 } else { 1598 return failure(); 1599 } 1600 } else { 1601 assert(intTy.isSignless() || intTy.isSigned()); 1602 if (width <= 64) { 1603 // Note that we *always* zero extend booleans (1-bit integers), 1604 // so that true/false is printed as 1/0 rather than -1/0. 1605 if (width == 1) 1606 conversion = PrintConversion::ZeroExt64; 1607 else if (width < 64) 1608 conversion = PrintConversion::SignExt64; 1609 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1610 } else { 1611 return failure(); 1612 } 1613 } 1614 } else { 1615 return failure(); 1616 } 1617 1618 switch (conversion) { 1619 case PrintConversion::ZeroExt64: 1620 value = rewriter.create<arith::ExtUIOp>( 1621 loc, IntegerType::get(rewriter.getContext(), 64), value); 1622 break; 1623 case PrintConversion::SignExt64: 1624 value = rewriter.create<arith::ExtSIOp>( 1625 loc, IntegerType::get(rewriter.getContext(), 64), value); 1626 break; 1627 case PrintConversion::Bitcast16: 1628 value = rewriter.create<LLVM::BitcastOp>( 1629 loc, IntegerType::get(rewriter.getContext(), 16), value); 1630 break; 1631 case PrintConversion::None: 1632 break; 1633 } 1634 emitCall(rewriter, loc, printer, value); 1635 return success(); 1636 } 1637 1638 // Helper to emit a call. 1639 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1640 Operation *ref, ValueRange params = ValueRange()) { 1641 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1642 params); 1643 } 1644 }; 1645 1646 /// The Splat operation is lowered to an insertelement + a shufflevector 1647 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1648 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1649 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1650 1651 LogicalResult 1652 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1653 ConversionPatternRewriter &rewriter) const override { 1654 VectorType resultType = cast<VectorType>(splatOp.getType()); 1655 if (resultType.getRank() > 1) 1656 return failure(); 1657 1658 // First insert it into an undef vector so we can shuffle it. 1659 auto vectorType = typeConverter->convertType(splatOp.getType()); 1660 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1661 auto zero = rewriter.create<LLVM::ConstantOp>( 1662 splatOp.getLoc(), 1663 typeConverter->convertType(rewriter.getIntegerType(32)), 1664 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1665 1666 // For 0-d vector, we simply do `insertelement`. 1667 if (resultType.getRank() == 0) { 1668 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1669 splatOp, vectorType, undef, adaptor.getInput(), zero); 1670 return success(); 1671 } 1672 1673 // For 1-d vector, we additionally do a `vectorshuffle`. 1674 auto v = rewriter.create<LLVM::InsertElementOp>( 1675 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1676 1677 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); 1678 SmallVector<int32_t> zeroValues(width, 0); 1679 1680 // Shuffle the value across the desired number of elements. 1681 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1682 zeroValues); 1683 return success(); 1684 } 1685 }; 1686 1687 /// The Splat operation is lowered to an insertelement + a shufflevector 1688 /// operation. Splat to only 2+-d vector result types are lowered by the 1689 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1690 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1691 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1692 1693 LogicalResult 1694 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1695 ConversionPatternRewriter &rewriter) const override { 1696 VectorType resultType = splatOp.getType(); 1697 if (resultType.getRank() <= 1) 1698 return failure(); 1699 1700 // First insert it into an undef vector so we can shuffle it. 1701 auto loc = splatOp.getLoc(); 1702 auto vectorTypeInfo = 1703 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1704 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1705 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1706 if (!llvmNDVectorTy || !llvm1DVectorTy) 1707 return failure(); 1708 1709 // Construct returned value. 1710 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1711 1712 // Construct a 1-D vector with the splatted value that we insert in all the 1713 // places within the returned descriptor. 1714 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1715 auto zero = rewriter.create<LLVM::ConstantOp>( 1716 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1717 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1718 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1719 adaptor.getInput(), zero); 1720 1721 // Shuffle the value across the desired number of elements. 1722 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1723 SmallVector<int32_t> zeroValues(width, 0); 1724 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1725 1726 // Iterate of linear index, convert to coords space and insert splatted 1-D 1727 // vector in each position. 1728 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1729 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1730 }); 1731 rewriter.replaceOp(splatOp, desc); 1732 return success(); 1733 } 1734 }; 1735 1736 /// Conversion pattern for a `vector.interleave`. 1737 /// This supports fixed-sized vectors and scalable vectors. 1738 struct VectorInterleaveOpLowering 1739 : public ConvertOpToLLVMPattern<vector::InterleaveOp> { 1740 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 1741 1742 LogicalResult 1743 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor, 1744 ConversionPatternRewriter &rewriter) const override { 1745 VectorType resultType = interleaveOp.getResultVectorType(); 1746 // n-D interleaves should have been lowered already. 1747 if (resultType.getRank() != 1) 1748 return rewriter.notifyMatchFailure(interleaveOp, 1749 "InterleaveOp not rank 1"); 1750 // If the result is rank 1, then this directly maps to LLVM. 1751 if (resultType.isScalable()) { 1752 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>( 1753 interleaveOp, typeConverter->convertType(resultType), 1754 adaptor.getLhs(), adaptor.getRhs()); 1755 return success(); 1756 } 1757 // Lower fixed-size interleaves to a shufflevector. While the 1758 // vector.interleave2 intrinsic supports fixed and scalable vectors, the 1759 // langref still recommends fixed-vectors use shufflevector, see: 1760 // https://llvm.org/docs/LangRef.html#id876. 1761 int64_t resultVectorSize = resultType.getNumElements(); 1762 SmallVector<int32_t> interleaveShuffleMask; 1763 interleaveShuffleMask.reserve(resultVectorSize); 1764 for (int i = 0, end = resultVectorSize / 2; i < end; ++i) { 1765 interleaveShuffleMask.push_back(i); 1766 interleaveShuffleMask.push_back((resultVectorSize / 2) + i); 1767 } 1768 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>( 1769 interleaveOp, adaptor.getLhs(), adaptor.getRhs(), 1770 interleaveShuffleMask); 1771 return success(); 1772 } 1773 }; 1774 1775 /// Conversion pattern for a `vector.deinterleave`. 1776 /// This supports fixed-sized vectors and scalable vectors. 1777 struct VectorDeinterleaveOpLowering 1778 : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> { 1779 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 1780 1781 LogicalResult 1782 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor, 1783 ConversionPatternRewriter &rewriter) const override { 1784 VectorType resultType = deinterleaveOp.getResultVectorType(); 1785 VectorType sourceType = deinterleaveOp.getSourceVectorType(); 1786 auto loc = deinterleaveOp.getLoc(); 1787 1788 // Note: n-D deinterleave operations should be lowered to the 1-D before 1789 // converting to LLVM. 1790 if (resultType.getRank() != 1) 1791 return rewriter.notifyMatchFailure(deinterleaveOp, 1792 "DeinterleaveOp not rank 1"); 1793 1794 if (resultType.isScalable()) { 1795 auto llvmTypeConverter = this->getTypeConverter(); 1796 auto deinterleaveResults = deinterleaveOp.getResultTypes(); 1797 auto packedOpResults = 1798 llvmTypeConverter->packOperationResults(deinterleaveResults); 1799 auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>( 1800 loc, packedOpResults, adaptor.getSource()); 1801 1802 auto evenResult = rewriter.create<LLVM::ExtractValueOp>( 1803 loc, intrinsic->getResult(0), 0); 1804 auto oddResult = rewriter.create<LLVM::ExtractValueOp>( 1805 loc, intrinsic->getResult(0), 1); 1806 1807 rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult}); 1808 return success(); 1809 } 1810 // Lower fixed-size deinterleave to two shufflevectors. While the 1811 // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the 1812 // langref still recommends fixed-vectors use shufflevector, see: 1813 // https://llvm.org/docs/LangRef.html#id889. 1814 int64_t resultVectorSize = resultType.getNumElements(); 1815 SmallVector<int32_t> evenShuffleMask; 1816 SmallVector<int32_t> oddShuffleMask; 1817 1818 evenShuffleMask.reserve(resultVectorSize); 1819 oddShuffleMask.reserve(resultVectorSize); 1820 1821 for (int i = 0; i < sourceType.getNumElements(); ++i) { 1822 if (i % 2 == 0) 1823 evenShuffleMask.push_back(i); 1824 else 1825 oddShuffleMask.push_back(i); 1826 } 1827 1828 auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType); 1829 auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>( 1830 loc, adaptor.getSource(), poison, evenShuffleMask); 1831 auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>( 1832 loc, adaptor.getSource(), poison, oddShuffleMask); 1833 1834 rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle}); 1835 return success(); 1836 } 1837 }; 1838 1839 } // namespace 1840 1841 /// Populate the given list with patterns that convert from Vector to LLVM. 1842 void mlir::populateVectorToLLVMConversionPatterns( 1843 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1844 bool reassociateFPReductions, bool force32BitVectorIndices) { 1845 MLIRContext *ctx = converter.getDialect()->getContext(); 1846 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1847 populateVectorInsertExtractStridedSliceTransforms(patterns); 1848 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1849 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1850 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1851 VectorExtractElementOpConversion, VectorExtractOpConversion, 1852 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1853 VectorInsertOpConversion, VectorPrintOpConversion, 1854 VectorTypeCastOpConversion, VectorScaleOpConversion, 1855 VectorLoadStoreConversion<vector::LoadOp>, 1856 VectorLoadStoreConversion<vector::MaskedLoadOp>, 1857 VectorLoadStoreConversion<vector::StoreOp>, 1858 VectorLoadStoreConversion<vector::MaskedStoreOp>, 1859 VectorGatherOpConversion, VectorScatterOpConversion, 1860 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1861 VectorSplatOpLowering, VectorSplatNdOpLowering, 1862 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1863 MaskedReductionOpConversion, VectorInterleaveOpLowering, 1864 VectorDeinterleaveOpLowering>(converter); 1865 // Transfer ops with rank > 1 are handled by VectorToSCF. 1866 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1867 } 1868 1869 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1870 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1871 patterns.add<VectorMatmulOpConversion>(converter); 1872 patterns.add<VectorFlatTransposeOpConversion>(converter); 1873 } 1874