1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" 12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Arith/Utils/Utils.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/Vector/IR/VectorOps.h" 20 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 21 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 23 #include "mlir/IR/BuiltinAttributes.h" 24 #include "mlir/IR/BuiltinTypeInterfaces.h" 25 #include "mlir/IR/BuiltinTypes.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "llvm/ADT/APFloat.h" 30 #include "llvm/Support/Casting.h" 31 #include <optional> 32 33 using namespace mlir; 34 using namespace mlir::vector; 35 36 // Helper to reduce vector type by *all* but one rank at back. 37 static VectorType reducedVectorTypeBack(VectorType tp) { 38 assert((tp.getRank() > 1) && "unlowerable vector type"); 39 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 40 tp.getScalableDims().take_back()); 41 } 42 43 // Helper that picks the proper sequence for inserting. 44 static Value insertOne(ConversionPatternRewriter &rewriter, 45 const LLVMTypeConverter &typeConverter, Location loc, 46 Value val1, Value val2, Type llvmType, int64_t rank, 47 int64_t pos) { 48 assert(rank > 0 && "0-D vector corner case should have been handled already"); 49 if (rank == 1) { 50 auto idxType = rewriter.getIndexType(); 51 auto constant = rewriter.create<LLVM::ConstantOp>( 52 loc, typeConverter.convertType(idxType), 53 rewriter.getIntegerAttr(idxType, pos)); 54 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 55 constant); 56 } 57 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 58 } 59 60 // Helper that picks the proper sequence for extracting. 61 static Value extractOne(ConversionPatternRewriter &rewriter, 62 const LLVMTypeConverter &typeConverter, Location loc, 63 Value val, Type llvmType, int64_t rank, int64_t pos) { 64 if (rank <= 1) { 65 auto idxType = rewriter.getIndexType(); 66 auto constant = rewriter.create<LLVM::ConstantOp>( 67 loc, typeConverter.convertType(idxType), 68 rewriter.getIntegerAttr(idxType, pos)); 69 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 70 constant); 71 } 72 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 73 } 74 75 // Helper that returns data layout alignment of a memref. 76 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, 77 MemRefType memrefType, unsigned &align) { 78 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 79 if (!elementTy) 80 return failure(); 81 82 // TODO: this should use the MLIR data layout when it becomes available and 83 // stop depending on translation. 84 llvm::LLVMContext llvmContext; 85 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 86 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 87 return success(); 88 } 89 90 // Check if the last stride is non-unit or the memory space is not zero. 91 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 92 const LLVMTypeConverter &converter) { 93 if (!isLastMemrefDimUnitStride(memRefType)) 94 return failure(); 95 FailureOr<unsigned> addressSpace = 96 converter.getMemRefAddressSpace(memRefType); 97 if (failed(addressSpace) || *addressSpace != 0) 98 return failure(); 99 return success(); 100 } 101 102 // Add an index vector component to a base pointer. 103 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 104 const LLVMTypeConverter &typeConverter, 105 MemRefType memRefType, Value llvmMemref, Value base, 106 Value index, uint64_t vLen) { 107 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 108 "unsupported memref type"); 109 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 110 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 111 return rewriter.create<LLVM::GEPOp>( 112 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 113 base, index); 114 } 115 116 // Casts a strided element pointer to a vector pointer. The vector pointer 117 // will be in the same address space as the incoming memref type. 118 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 119 Value ptr, MemRefType memRefType, Type vt, 120 const LLVMTypeConverter &converter) { 121 if (converter.useOpaquePointers()) 122 return ptr; 123 124 unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); 125 auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); 126 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 127 } 128 129 /// Convert `foldResult` into a Value. Integer attribute is converted to 130 /// an LLVM constant op. 131 static Value getAsLLVMValue(OpBuilder &builder, Location loc, 132 OpFoldResult foldResult) { 133 if (auto attr = foldResult.dyn_cast<Attribute>()) { 134 auto intAttr = cast<IntegerAttr>(attr); 135 return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult(); 136 } 137 138 return foldResult.get<Value>(); 139 } 140 141 namespace { 142 143 /// Trivial Vector to LLVM conversions 144 using VectorScaleOpConversion = 145 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 146 147 /// Conversion pattern for a vector.bitcast. 148 class VectorBitCastOpConversion 149 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 150 public: 151 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 152 153 LogicalResult 154 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 155 ConversionPatternRewriter &rewriter) const override { 156 // Only 0-D and 1-D vectors can be lowered to LLVM. 157 VectorType resultTy = bitCastOp.getResultVectorType(); 158 if (resultTy.getRank() > 1) 159 return failure(); 160 Type newResultTy = typeConverter->convertType(resultTy); 161 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 162 adaptor.getOperands()[0]); 163 return success(); 164 } 165 }; 166 167 /// Conversion pattern for a vector.matrix_multiply. 168 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 169 class VectorMatmulOpConversion 170 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 171 public: 172 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 173 174 LogicalResult 175 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 176 ConversionPatternRewriter &rewriter) const override { 177 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 178 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 179 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 180 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 181 return success(); 182 } 183 }; 184 185 /// Conversion pattern for a vector.flat_transpose. 186 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 187 class VectorFlatTransposeOpConversion 188 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 189 public: 190 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 191 192 LogicalResult 193 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 194 ConversionPatternRewriter &rewriter) const override { 195 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 196 transOp, typeConverter->convertType(transOp.getRes().getType()), 197 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 198 return success(); 199 } 200 }; 201 202 /// Overloaded utility that replaces a vector.load, vector.store, 203 /// vector.maskedload and vector.maskedstore with their respective LLVM 204 /// couterparts. 205 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 206 vector::LoadOpAdaptor adaptor, 207 VectorType vectorTy, Value ptr, unsigned align, 208 ConversionPatternRewriter &rewriter) { 209 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 210 } 211 212 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 213 vector::MaskedLoadOpAdaptor adaptor, 214 VectorType vectorTy, Value ptr, unsigned align, 215 ConversionPatternRewriter &rewriter) { 216 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 217 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 218 } 219 220 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 221 vector::StoreOpAdaptor adaptor, 222 VectorType vectorTy, Value ptr, unsigned align, 223 ConversionPatternRewriter &rewriter) { 224 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 225 ptr, align); 226 } 227 228 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 229 vector::MaskedStoreOpAdaptor adaptor, 230 VectorType vectorTy, Value ptr, unsigned align, 231 ConversionPatternRewriter &rewriter) { 232 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 233 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 234 } 235 236 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 237 /// vector.maskedstore. 238 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 239 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 240 public: 241 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 242 243 LogicalResult 244 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 245 typename LoadOrStoreOp::Adaptor adaptor, 246 ConversionPatternRewriter &rewriter) const override { 247 // Only 1-D vectors can be lowered to LLVM. 248 VectorType vectorTy = loadOrStoreOp.getVectorType(); 249 if (vectorTy.getRank() > 1) 250 return failure(); 251 252 auto loc = loadOrStoreOp->getLoc(); 253 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 254 255 // Resolve alignment. 256 unsigned align; 257 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 258 return failure(); 259 260 // Resolve address. 261 auto vtype = cast<VectorType>( 262 this->typeConverter->convertType(loadOrStoreOp.getVectorType())); 263 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 264 adaptor.getIndices(), rewriter); 265 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, 266 *this->getTypeConverter()); 267 268 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 269 return success(); 270 } 271 }; 272 273 /// Conversion pattern for a vector.gather. 274 class VectorGatherOpConversion 275 : public ConvertOpToLLVMPattern<vector::GatherOp> { 276 public: 277 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 278 279 LogicalResult 280 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 281 ConversionPatternRewriter &rewriter) const override { 282 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); 283 assert(memRefType && "The base should be bufferized"); 284 285 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 286 return failure(); 287 288 auto loc = gather->getLoc(); 289 290 // Resolve alignment. 291 unsigned align; 292 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 293 return failure(); 294 295 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 296 adaptor.getIndices(), rewriter); 297 Value base = adaptor.getBase(); 298 299 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 300 // Handle the simple case of 1-D vector. 301 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) { 302 auto vType = gather.getVectorType(); 303 // Resolve address. 304 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 305 memRefType, base, ptr, adaptor.getIndexVec(), 306 /*vLen=*/vType.getDimSize(0)); 307 // Replace with the gather intrinsic. 308 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 309 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 310 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 311 return success(); 312 } 313 314 const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 315 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 316 &typeConverter](Type llvm1DVectorTy, 317 ValueRange vectorOperands) { 318 // Resolve address. 319 Value ptrs = getIndexedPtrs( 320 rewriter, loc, typeConverter, memRefType, base, ptr, 321 /*index=*/vectorOperands[0], 322 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 323 // Create the gather intrinsic. 324 return rewriter.create<LLVM::masked_gather>( 325 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 326 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 327 }; 328 SmallVector<Value> vectorOperands = { 329 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 330 return LLVM::detail::handleMultidimensionalVectors( 331 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 332 } 333 }; 334 335 /// Conversion pattern for a vector.scatter. 336 class VectorScatterOpConversion 337 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 338 public: 339 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 340 341 LogicalResult 342 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 343 ConversionPatternRewriter &rewriter) const override { 344 auto loc = scatter->getLoc(); 345 MemRefType memRefType = scatter.getMemRefType(); 346 347 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 348 return failure(); 349 350 // Resolve alignment. 351 unsigned align; 352 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 353 return failure(); 354 355 // Resolve address. 356 VectorType vType = scatter.getVectorType(); 357 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 358 adaptor.getIndices(), rewriter); 359 Value ptrs = getIndexedPtrs( 360 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 361 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 362 363 // Replace with the scatter intrinsic. 364 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 365 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 366 rewriter.getI32IntegerAttr(align)); 367 return success(); 368 } 369 }; 370 371 /// Conversion pattern for a vector.expandload. 372 class VectorExpandLoadOpConversion 373 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 374 public: 375 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 376 377 LogicalResult 378 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 379 ConversionPatternRewriter &rewriter) const override { 380 auto loc = expand->getLoc(); 381 MemRefType memRefType = expand.getMemRefType(); 382 383 // Resolve address. 384 auto vtype = typeConverter->convertType(expand.getVectorType()); 385 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 386 adaptor.getIndices(), rewriter); 387 388 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 389 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 390 return success(); 391 } 392 }; 393 394 /// Conversion pattern for a vector.compressstore. 395 class VectorCompressStoreOpConversion 396 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 397 public: 398 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 399 400 LogicalResult 401 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 402 ConversionPatternRewriter &rewriter) const override { 403 auto loc = compress->getLoc(); 404 MemRefType memRefType = compress.getMemRefType(); 405 406 // Resolve address. 407 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 408 adaptor.getIndices(), rewriter); 409 410 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 411 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 412 return success(); 413 } 414 }; 415 416 /// Reduction neutral classes for overloading. 417 class ReductionNeutralZero {}; 418 class ReductionNeutralIntOne {}; 419 class ReductionNeutralFPOne {}; 420 class ReductionNeutralAllOnes {}; 421 class ReductionNeutralSIntMin {}; 422 class ReductionNeutralUIntMin {}; 423 class ReductionNeutralSIntMax {}; 424 class ReductionNeutralUIntMax {}; 425 class ReductionNeutralFPMin {}; 426 class ReductionNeutralFPMax {}; 427 428 /// Create the reduction neutral zero value. 429 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 430 ConversionPatternRewriter &rewriter, 431 Location loc, Type llvmType) { 432 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 433 rewriter.getZeroAttr(llvmType)); 434 } 435 436 /// Create the reduction neutral integer one value. 437 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 438 ConversionPatternRewriter &rewriter, 439 Location loc, Type llvmType) { 440 return rewriter.create<LLVM::ConstantOp>( 441 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 442 } 443 444 /// Create the reduction neutral fp one value. 445 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 446 ConversionPatternRewriter &rewriter, 447 Location loc, Type llvmType) { 448 return rewriter.create<LLVM::ConstantOp>( 449 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 450 } 451 452 /// Create the reduction neutral all-ones value. 453 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 454 ConversionPatternRewriter &rewriter, 455 Location loc, Type llvmType) { 456 return rewriter.create<LLVM::ConstantOp>( 457 loc, llvmType, 458 rewriter.getIntegerAttr( 459 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 460 } 461 462 /// Create the reduction neutral signed int minimum value. 463 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 464 ConversionPatternRewriter &rewriter, 465 Location loc, Type llvmType) { 466 return rewriter.create<LLVM::ConstantOp>( 467 loc, llvmType, 468 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 469 llvmType.getIntOrFloatBitWidth()))); 470 } 471 472 /// Create the reduction neutral unsigned int minimum value. 473 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 474 ConversionPatternRewriter &rewriter, 475 Location loc, Type llvmType) { 476 return rewriter.create<LLVM::ConstantOp>( 477 loc, llvmType, 478 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 479 llvmType.getIntOrFloatBitWidth()))); 480 } 481 482 /// Create the reduction neutral signed int maximum value. 483 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 484 ConversionPatternRewriter &rewriter, 485 Location loc, Type llvmType) { 486 return rewriter.create<LLVM::ConstantOp>( 487 loc, llvmType, 488 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 489 llvmType.getIntOrFloatBitWidth()))); 490 } 491 492 /// Create the reduction neutral unsigned int maximum value. 493 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 494 ConversionPatternRewriter &rewriter, 495 Location loc, Type llvmType) { 496 return rewriter.create<LLVM::ConstantOp>( 497 loc, llvmType, 498 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 499 llvmType.getIntOrFloatBitWidth()))); 500 } 501 502 /// Create the reduction neutral fp minimum value. 503 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 504 ConversionPatternRewriter &rewriter, 505 Location loc, Type llvmType) { 506 auto floatType = cast<FloatType>(llvmType); 507 return rewriter.create<LLVM::ConstantOp>( 508 loc, llvmType, 509 rewriter.getFloatAttr( 510 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 511 /*Negative=*/false))); 512 } 513 514 /// Create the reduction neutral fp maximum value. 515 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 516 ConversionPatternRewriter &rewriter, 517 Location loc, Type llvmType) { 518 auto floatType = cast<FloatType>(llvmType); 519 return rewriter.create<LLVM::ConstantOp>( 520 loc, llvmType, 521 rewriter.getFloatAttr( 522 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 523 /*Negative=*/true))); 524 } 525 526 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 527 /// returns a new accumulator value using `ReductionNeutral`. 528 template <class ReductionNeutral> 529 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 530 Location loc, Type llvmType, 531 Value accumulator) { 532 if (accumulator) 533 return accumulator; 534 535 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 536 llvmType); 537 } 538 539 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 540 /// This is used as effective vector length by some intrinsics supporting 541 /// dynamic vector lengths at runtime. 542 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 543 Location loc, Type llvmType) { 544 VectorType vType = cast<VectorType>(llvmType); 545 auto vShape = vType.getShape(); 546 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 547 548 return rewriter.create<LLVM::ConstantOp>( 549 loc, rewriter.getI32Type(), 550 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 551 } 552 553 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 554 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 555 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 556 /// non-null. 557 template <class LLVMRedIntrinOp, class ScalarOp> 558 static Value createIntegerReductionArithmeticOpLowering( 559 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 560 Value vectorOperand, Value accumulator) { 561 562 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 563 564 if (accumulator) 565 result = rewriter.create<ScalarOp>(loc, accumulator, result); 566 return result; 567 } 568 569 /// Helper method to lower a `vector.reduction` operation that performs 570 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 571 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 572 /// the accumulator value if non-null. 573 template <class LLVMRedIntrinOp> 574 static Value createIntegerReductionComparisonOpLowering( 575 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 576 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 577 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 578 if (accumulator) { 579 Value cmp = 580 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 581 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 582 } 583 return result; 584 } 585 586 namespace { 587 template <typename Source> 588 struct VectorToScalarMapper; 589 template <> 590 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> { 591 using Type = LLVM::MaximumOp; 592 }; 593 template <> 594 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> { 595 using Type = LLVM::MinimumOp; 596 }; 597 template <> 598 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> { 599 using Type = LLVM::MaxNumOp; 600 }; 601 template <> 602 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> { 603 using Type = LLVM::MinNumOp; 604 }; 605 } // namespace 606 607 template <class LLVMRedIntrinOp> 608 static Value createFPReductionComparisonOpLowering( 609 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 610 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { 611 Value result = 612 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf); 613 614 if (accumulator) { 615 result = 616 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>( 617 loc, result, accumulator); 618 } 619 620 return result; 621 } 622 623 /// Reduction neutral classes for overloading 624 class MaskNeutralFMaximum {}; 625 class MaskNeutralFMinimum {}; 626 627 /// Get the mask neutral floating point maximum value 628 static llvm::APFloat 629 getMaskNeutralValue(MaskNeutralFMaximum, 630 const llvm::fltSemantics &floatSemantics) { 631 return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true); 632 } 633 /// Get the mask neutral floating point minimum value 634 static llvm::APFloat 635 getMaskNeutralValue(MaskNeutralFMinimum, 636 const llvm::fltSemantics &floatSemantics) { 637 return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false); 638 } 639 640 /// Create the mask neutral floating point MLIR vector constant 641 template <typename MaskNeutral> 642 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, 643 Location loc, Type llvmType, 644 Type vectorType) { 645 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics(); 646 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); 647 auto denseValue = 648 DenseElementsAttr::get(vectorType.cast<ShapedType>(), value); 649 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue); 650 } 651 652 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked 653 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for 654 /// `fmaximum`/`fminimum`. 655 /// More information: https://github.com/llvm/llvm-project/issues/64940 656 template <class LLVMRedIntrinOp, class MaskNeutral> 657 static Value 658 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, 659 Location loc, Type llvmType, 660 Value vectorOperand, Value accumulator, 661 Value mask, LLVM::FastmathFlagsAttr fmf) { 662 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>( 663 rewriter, loc, llvmType, vectorOperand.getType()); 664 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>( 665 loc, mask, vectorOperand, vectorMaskNeutral); 666 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>( 667 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); 668 } 669 670 template <class LLVMRedIntrinOp, class ReductionNeutral> 671 static Value 672 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 673 Type llvmType, Value vectorOperand, 674 Value accumulator, LLVM::FastmathFlagsAttr fmf) { 675 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 676 llvmType, accumulator); 677 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType, 678 /*startValue=*/accumulator, 679 vectorOperand, fmf); 680 } 681 682 /// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic 683 /// that requires a start value. This start value format spans across fp 684 /// reductions without mask and all the masked reduction intrinsics. 685 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 686 static Value 687 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, 688 Location loc, Type llvmType, 689 Value vectorOperand, Value accumulator) { 690 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 691 llvmType, accumulator); 692 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 693 /*startValue=*/accumulator, 694 vectorOperand); 695 } 696 697 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 698 static Value lowerPredicatedReductionWithStartValue( 699 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 700 Value vectorOperand, Value accumulator, Value mask) { 701 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 702 llvmType, accumulator); 703 Value vectorLength = 704 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 705 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 706 /*startValue=*/accumulator, 707 vectorOperand, mask, vectorLength); 708 } 709 710 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 711 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 712 static Value lowerPredicatedReductionWithStartValue( 713 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 714 Value vectorOperand, Value accumulator, Value mask) { 715 if (llvmType.isIntOrIndex()) 716 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp, 717 IntReductionNeutral>( 718 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 719 720 // FP dispatch. 721 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp, 722 FPReductionNeutral>( 723 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 724 } 725 726 /// Conversion pattern for all vector reductions. 727 class VectorReductionOpConversion 728 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 729 public: 730 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv, 731 bool reassociateFPRed) 732 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 733 reassociateFPReductions(reassociateFPRed) {} 734 735 LogicalResult 736 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 737 ConversionPatternRewriter &rewriter) const override { 738 auto kind = reductionOp.getKind(); 739 Type eltType = reductionOp.getDest().getType(); 740 Type llvmType = typeConverter->convertType(eltType); 741 Value operand = adaptor.getVector(); 742 Value acc = adaptor.getAcc(); 743 Location loc = reductionOp.getLoc(); 744 745 if (eltType.isIntOrIndex()) { 746 // Integer reductions: add/mul/min/max/and/or/xor. 747 Value result; 748 switch (kind) { 749 case vector::CombiningKind::ADD: 750 result = 751 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 752 LLVM::AddOp>( 753 rewriter, loc, llvmType, operand, acc); 754 break; 755 case vector::CombiningKind::MUL: 756 result = 757 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 758 LLVM::MulOp>( 759 rewriter, loc, llvmType, operand, acc); 760 break; 761 case vector::CombiningKind::MINUI: 762 result = createIntegerReductionComparisonOpLowering< 763 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 764 LLVM::ICmpPredicate::ule); 765 break; 766 case vector::CombiningKind::MINSI: 767 result = createIntegerReductionComparisonOpLowering< 768 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 769 LLVM::ICmpPredicate::sle); 770 break; 771 case vector::CombiningKind::MAXUI: 772 result = createIntegerReductionComparisonOpLowering< 773 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 774 LLVM::ICmpPredicate::uge); 775 break; 776 case vector::CombiningKind::MAXSI: 777 result = createIntegerReductionComparisonOpLowering< 778 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 779 LLVM::ICmpPredicate::sge); 780 break; 781 case vector::CombiningKind::AND: 782 result = 783 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 784 LLVM::AndOp>( 785 rewriter, loc, llvmType, operand, acc); 786 break; 787 case vector::CombiningKind::OR: 788 result = 789 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 790 LLVM::OrOp>( 791 rewriter, loc, llvmType, operand, acc); 792 break; 793 case vector::CombiningKind::XOR: 794 result = 795 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 796 LLVM::XOrOp>( 797 rewriter, loc, llvmType, operand, acc); 798 break; 799 default: 800 return failure(); 801 } 802 rewriter.replaceOp(reductionOp, result); 803 804 return success(); 805 } 806 807 if (!isa<FloatType>(eltType)) 808 return failure(); 809 810 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 811 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 812 reductionOp.getContext(), 813 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 814 fmf = LLVM::FastmathFlagsAttr::get( 815 reductionOp.getContext(), 816 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc 817 : LLVM::FastmathFlags::none)); 818 819 // Floating-point reductions: add/mul/min/max 820 Value result; 821 if (kind == vector::CombiningKind::ADD) { 822 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 823 ReductionNeutralZero>( 824 rewriter, loc, llvmType, operand, acc, fmf); 825 } else if (kind == vector::CombiningKind::MUL) { 826 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 827 ReductionNeutralFPOne>( 828 rewriter, loc, llvmType, operand, acc, fmf); 829 } else if (kind == vector::CombiningKind::MINIMUMF) { 830 result = 831 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>( 832 rewriter, loc, llvmType, operand, acc, fmf); 833 } else if (kind == vector::CombiningKind::MAXIMUMF) { 834 result = 835 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>( 836 rewriter, loc, llvmType, operand, acc, fmf); 837 } else if (kind == vector::CombiningKind::MINF) { 838 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 839 rewriter, loc, llvmType, operand, acc, fmf); 840 } else if (kind == vector::CombiningKind::MAXF) { 841 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 842 rewriter, loc, llvmType, operand, acc, fmf); 843 } else 844 return failure(); 845 846 rewriter.replaceOp(reductionOp, result); 847 return success(); 848 } 849 850 private: 851 const bool reassociateFPReductions; 852 }; 853 854 /// Base class to convert a `vector.mask` operation while matching traits 855 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 856 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 857 /// method performs a second match against the maskable operation `MaskedOp`. 858 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 859 /// implemented by the concrete conversion classes. This method can match 860 /// against specific traits of the `vector.mask` and the maskable operation. It 861 /// must replace the `vector.mask` operation. 862 template <class MaskedOp> 863 class VectorMaskOpConversionBase 864 : public ConvertOpToLLVMPattern<vector::MaskOp> { 865 public: 866 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 867 868 LogicalResult 869 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 870 ConversionPatternRewriter &rewriter) const final { 871 // Match against the maskable operation kind. 872 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 873 if (!maskedOp) 874 return failure(); 875 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 876 } 877 878 protected: 879 virtual LogicalResult 880 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 881 vector::MaskableOpInterface maskableOp, 882 ConversionPatternRewriter &rewriter) const = 0; 883 }; 884 885 class MaskedReductionOpConversion 886 : public VectorMaskOpConversionBase<vector::ReductionOp> { 887 888 public: 889 using VectorMaskOpConversionBase< 890 vector::ReductionOp>::VectorMaskOpConversionBase; 891 892 LogicalResult matchAndRewriteMaskableOp( 893 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 894 ConversionPatternRewriter &rewriter) const override { 895 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 896 auto kind = reductionOp.getKind(); 897 Type eltType = reductionOp.getDest().getType(); 898 Type llvmType = typeConverter->convertType(eltType); 899 Value operand = reductionOp.getVector(); 900 Value acc = reductionOp.getAcc(); 901 Location loc = reductionOp.getLoc(); 902 903 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 904 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 905 reductionOp.getContext(), 906 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 907 908 Value result; 909 switch (kind) { 910 case vector::CombiningKind::ADD: 911 result = lowerPredicatedReductionWithStartValue< 912 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 913 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 914 maskOp.getMask()); 915 break; 916 case vector::CombiningKind::MUL: 917 result = lowerPredicatedReductionWithStartValue< 918 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 919 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 920 maskOp.getMask()); 921 break; 922 case vector::CombiningKind::MINUI: 923 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp, 924 ReductionNeutralUIntMax>( 925 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 926 break; 927 case vector::CombiningKind::MINSI: 928 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp, 929 ReductionNeutralSIntMax>( 930 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 931 break; 932 case vector::CombiningKind::MAXUI: 933 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp, 934 ReductionNeutralUIntMin>( 935 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 936 break; 937 case vector::CombiningKind::MAXSI: 938 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp, 939 ReductionNeutralSIntMin>( 940 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 941 break; 942 case vector::CombiningKind::AND: 943 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp, 944 ReductionNeutralAllOnes>( 945 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 946 break; 947 case vector::CombiningKind::OR: 948 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp, 949 ReductionNeutralZero>( 950 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 951 break; 952 case vector::CombiningKind::XOR: 953 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp, 954 ReductionNeutralZero>( 955 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 956 break; 957 case vector::CombiningKind::MINF: 958 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp, 959 ReductionNeutralFPMax>( 960 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 961 break; 962 case vector::CombiningKind::MAXF: 963 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp, 964 ReductionNeutralFPMin>( 965 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 966 break; 967 case CombiningKind::MAXIMUMF: 968 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum, 969 MaskNeutralFMaximum>( 970 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 971 break; 972 case CombiningKind::MINIMUMF: 973 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum, 974 MaskNeutralFMinimum>( 975 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 976 break; 977 } 978 979 // Replace `vector.mask` operation altogether. 980 rewriter.replaceOp(maskOp, result); 981 return success(); 982 } 983 }; 984 985 class VectorShuffleOpConversion 986 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 987 public: 988 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 989 990 LogicalResult 991 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 992 ConversionPatternRewriter &rewriter) const override { 993 auto loc = shuffleOp->getLoc(); 994 auto v1Type = shuffleOp.getV1VectorType(); 995 auto v2Type = shuffleOp.getV2VectorType(); 996 auto vectorType = shuffleOp.getResultVectorType(); 997 Type llvmType = typeConverter->convertType(vectorType); 998 auto maskArrayAttr = shuffleOp.getMask(); 999 1000 // Bail if result type cannot be lowered. 1001 if (!llvmType) 1002 return failure(); 1003 1004 // Get rank and dimension sizes. 1005 int64_t rank = vectorType.getRank(); 1006 #ifndef NDEBUG 1007 bool wellFormed0DCase = 1008 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 1009 bool wellFormedNDCase = 1010 v1Type.getRank() == rank && v2Type.getRank() == rank; 1011 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 1012 #endif 1013 1014 // For rank 0 and 1, where both operands have *exactly* the same vector 1015 // type, there is direct shuffle support in LLVM. Use it! 1016 if (rank <= 1 && v1Type == v2Type) { 1017 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 1018 loc, adaptor.getV1(), adaptor.getV2(), 1019 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 1020 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 1021 return success(); 1022 } 1023 1024 // For all other cases, insert the individual values individually. 1025 int64_t v1Dim = v1Type.getDimSize(0); 1026 Type eltType; 1027 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType)) 1028 eltType = arrayType.getElementType(); 1029 else 1030 eltType = cast<VectorType>(llvmType).getElementType(); 1031 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 1032 int64_t insPos = 0; 1033 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 1034 int64_t extPos = cast<IntegerAttr>(en.value()).getInt(); 1035 Value value = adaptor.getV1(); 1036 if (extPos >= v1Dim) { 1037 extPos -= v1Dim; 1038 value = adaptor.getV2(); 1039 } 1040 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 1041 eltType, rank, extPos); 1042 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 1043 llvmType, rank, insPos++); 1044 } 1045 rewriter.replaceOp(shuffleOp, insert); 1046 return success(); 1047 } 1048 }; 1049 1050 class VectorExtractElementOpConversion 1051 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 1052 public: 1053 using ConvertOpToLLVMPattern< 1054 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 1055 1056 LogicalResult 1057 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 1058 ConversionPatternRewriter &rewriter) const override { 1059 auto vectorType = extractEltOp.getSourceVectorType(); 1060 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 1061 1062 // Bail if result type cannot be lowered. 1063 if (!llvmType) 1064 return failure(); 1065 1066 if (vectorType.getRank() == 0) { 1067 Location loc = extractEltOp.getLoc(); 1068 auto idxType = rewriter.getIndexType(); 1069 auto zero = rewriter.create<LLVM::ConstantOp>( 1070 loc, typeConverter->convertType(idxType), 1071 rewriter.getIntegerAttr(idxType, 0)); 1072 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1073 extractEltOp, llvmType, adaptor.getVector(), zero); 1074 return success(); 1075 } 1076 1077 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1078 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1079 return success(); 1080 } 1081 }; 1082 1083 class VectorExtractOpConversion 1084 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1085 public: 1086 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1087 1088 LogicalResult 1089 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1090 ConversionPatternRewriter &rewriter) const override { 1091 auto loc = extractOp->getLoc(); 1092 auto resultType = extractOp.getResult().getType(); 1093 auto llvmResultType = typeConverter->convertType(resultType); 1094 // Bail if result type cannot be lowered. 1095 if (!llvmResultType) 1096 return failure(); 1097 1098 SmallVector<OpFoldResult> positionVec; 1099 for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) { 1100 if (pos.is<Value>()) 1101 // Make sure we use the value that has been already converted to LLVM. 1102 positionVec.push_back(adaptor.getDynamicPosition()[idx]); 1103 else 1104 positionVec.push_back(pos); 1105 } 1106 1107 // Extract entire vector. Should be handled by folder, but just to be safe. 1108 ArrayRef<OpFoldResult> position(positionVec); 1109 if (position.empty()) { 1110 rewriter.replaceOp(extractOp, adaptor.getVector()); 1111 return success(); 1112 } 1113 1114 // One-shot extraction of vector from array (only requires extractvalue). 1115 if (isa<VectorType>(resultType)) { 1116 if (extractOp.hasDynamicPosition()) 1117 return failure(); 1118 1119 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1120 loc, adaptor.getVector(), getAsIntegers(position)); 1121 rewriter.replaceOp(extractOp, extracted); 1122 return success(); 1123 } 1124 1125 // Potential extraction of 1-D vector from array. 1126 Value extracted = adaptor.getVector(); 1127 if (position.size() > 1) { 1128 if (extractOp.hasDynamicPosition()) 1129 return failure(); 1130 1131 SmallVector<int64_t> nMinusOnePosition = 1132 getAsIntegers(position.drop_back()); 1133 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 1134 nMinusOnePosition); 1135 } 1136 1137 Value lastPosition = getAsLLVMValue(rewriter, loc, position.back()); 1138 // Remaining extraction of element from 1-D LLVM vector. 1139 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted, 1140 lastPosition); 1141 return success(); 1142 } 1143 }; 1144 1145 /// Conversion pattern that turns a vector.fma on a 1-D vector 1146 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1147 /// This does not match vectors of n >= 2 rank. 1148 /// 1149 /// Example: 1150 /// ``` 1151 /// vector.fma %a, %a, %a : vector<8xf32> 1152 /// ``` 1153 /// is converted to: 1154 /// ``` 1155 /// llvm.intr.fmuladd %va, %va, %va: 1156 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1157 /// -> !llvm."<8 x f32>"> 1158 /// ``` 1159 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1160 public: 1161 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1162 1163 LogicalResult 1164 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1165 ConversionPatternRewriter &rewriter) const override { 1166 VectorType vType = fmaOp.getVectorType(); 1167 if (vType.getRank() > 1) 1168 return failure(); 1169 1170 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1171 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1172 return success(); 1173 } 1174 }; 1175 1176 class VectorInsertElementOpConversion 1177 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1178 public: 1179 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1180 1181 LogicalResult 1182 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1183 ConversionPatternRewriter &rewriter) const override { 1184 auto vectorType = insertEltOp.getDestVectorType(); 1185 auto llvmType = typeConverter->convertType(vectorType); 1186 1187 // Bail if result type cannot be lowered. 1188 if (!llvmType) 1189 return failure(); 1190 1191 if (vectorType.getRank() == 0) { 1192 Location loc = insertEltOp.getLoc(); 1193 auto idxType = rewriter.getIndexType(); 1194 auto zero = rewriter.create<LLVM::ConstantOp>( 1195 loc, typeConverter->convertType(idxType), 1196 rewriter.getIntegerAttr(idxType, 0)); 1197 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1198 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1199 return success(); 1200 } 1201 1202 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1203 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1204 adaptor.getPosition()); 1205 return success(); 1206 } 1207 }; 1208 1209 class VectorInsertOpConversion 1210 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1211 public: 1212 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1213 1214 LogicalResult 1215 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1216 ConversionPatternRewriter &rewriter) const override { 1217 auto loc = insertOp->getLoc(); 1218 auto sourceType = insertOp.getSourceType(); 1219 auto destVectorType = insertOp.getDestVectorType(); 1220 auto llvmResultType = typeConverter->convertType(destVectorType); 1221 // Bail if result type cannot be lowered. 1222 if (!llvmResultType) 1223 return failure(); 1224 1225 SmallVector<OpFoldResult> positionVec; 1226 for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) { 1227 if (pos.is<Value>()) 1228 // Make sure we use the value that has been already converted to LLVM. 1229 positionVec.push_back(adaptor.getDynamicPosition()[idx]); 1230 else 1231 positionVec.push_back(pos); 1232 } 1233 1234 // Overwrite entire vector with value. Should be handled by folder, but 1235 // just to be safe. 1236 ArrayRef<OpFoldResult> position(positionVec); 1237 if (position.empty()) { 1238 rewriter.replaceOp(insertOp, adaptor.getSource()); 1239 return success(); 1240 } 1241 1242 // One-shot insertion of a vector into an array (only requires insertvalue). 1243 if (isa<VectorType>(sourceType)) { 1244 if (insertOp.hasDynamicPosition()) 1245 return failure(); 1246 1247 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1248 loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position)); 1249 rewriter.replaceOp(insertOp, inserted); 1250 return success(); 1251 } 1252 1253 // Potential extraction of 1-D vector from array. 1254 Value extracted = adaptor.getDest(); 1255 auto oneDVectorType = destVectorType; 1256 if (position.size() > 1) { 1257 if (insertOp.hasDynamicPosition()) 1258 return failure(); 1259 1260 oneDVectorType = reducedVectorTypeBack(destVectorType); 1261 extracted = rewriter.create<LLVM::ExtractValueOp>( 1262 loc, extracted, getAsIntegers(position.drop_back())); 1263 } 1264 1265 // Insertion of an element into a 1-D LLVM vector. 1266 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1267 loc, typeConverter->convertType(oneDVectorType), extracted, 1268 adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back())); 1269 1270 // Potential insertion of resulting 1-D vector into array. 1271 if (position.size() > 1) { 1272 if (insertOp.hasDynamicPosition()) 1273 return failure(); 1274 1275 inserted = rewriter.create<LLVM::InsertValueOp>( 1276 loc, adaptor.getDest(), inserted, 1277 getAsIntegers(position.drop_back())); 1278 } 1279 1280 rewriter.replaceOp(insertOp, inserted); 1281 return success(); 1282 } 1283 }; 1284 1285 /// Lower vector.scalable.insert ops to LLVM vector.insert 1286 struct VectorScalableInsertOpLowering 1287 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1288 using ConvertOpToLLVMPattern< 1289 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1290 1291 LogicalResult 1292 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1293 ConversionPatternRewriter &rewriter) const override { 1294 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1295 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1296 return success(); 1297 } 1298 }; 1299 1300 /// Lower vector.scalable.extract ops to LLVM vector.extract 1301 struct VectorScalableExtractOpLowering 1302 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1303 using ConvertOpToLLVMPattern< 1304 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1305 1306 LogicalResult 1307 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1308 ConversionPatternRewriter &rewriter) const override { 1309 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1310 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1311 adaptor.getSource(), adaptor.getPos()); 1312 return success(); 1313 } 1314 }; 1315 1316 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1317 /// 1318 /// Example: 1319 /// ``` 1320 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1321 /// ``` 1322 /// is rewritten into: 1323 /// ``` 1324 /// %r = splat %f0: vector<2x4xf32> 1325 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1326 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1327 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1328 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1329 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1330 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1331 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1332 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1333 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1334 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1335 /// // %r3 holds the final value. 1336 /// ``` 1337 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1338 public: 1339 using OpRewritePattern<FMAOp>::OpRewritePattern; 1340 1341 void initialize() { 1342 // This pattern recursively unpacks one dimension at a time. The recursion 1343 // bounded as the rank is strictly decreasing. 1344 setHasBoundedRewriteRecursion(); 1345 } 1346 1347 LogicalResult matchAndRewrite(FMAOp op, 1348 PatternRewriter &rewriter) const override { 1349 auto vType = op.getVectorType(); 1350 if (vType.getRank() < 2) 1351 return failure(); 1352 1353 auto loc = op.getLoc(); 1354 auto elemType = vType.getElementType(); 1355 Value zero = rewriter.create<arith::ConstantOp>( 1356 loc, elemType, rewriter.getZeroAttr(elemType)); 1357 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1358 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1359 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1360 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1361 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1362 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1363 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1364 } 1365 rewriter.replaceOp(op, desc); 1366 return success(); 1367 } 1368 }; 1369 1370 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1371 /// static layout. 1372 static std::optional<SmallVector<int64_t, 4>> 1373 computeContiguousStrides(MemRefType memRefType) { 1374 int64_t offset; 1375 SmallVector<int64_t, 4> strides; 1376 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1377 return std::nullopt; 1378 if (!strides.empty() && strides.back() != 1) 1379 return std::nullopt; 1380 // If no layout or identity layout, this is contiguous by definition. 1381 if (memRefType.getLayout().isIdentity()) 1382 return strides; 1383 1384 // Otherwise, we must determine contiguity form shapes. This can only ever 1385 // work in static cases because MemRefType is underspecified to represent 1386 // contiguous dynamic shapes in other ways than with just empty/identity 1387 // layout. 1388 auto sizes = memRefType.getShape(); 1389 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1390 if (ShapedType::isDynamic(sizes[index + 1]) || 1391 ShapedType::isDynamic(strides[index]) || 1392 ShapedType::isDynamic(strides[index + 1])) 1393 return std::nullopt; 1394 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1395 return std::nullopt; 1396 } 1397 return strides; 1398 } 1399 1400 class VectorTypeCastOpConversion 1401 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1402 public: 1403 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1404 1405 LogicalResult 1406 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1407 ConversionPatternRewriter &rewriter) const override { 1408 auto loc = castOp->getLoc(); 1409 MemRefType sourceMemRefType = 1410 cast<MemRefType>(castOp.getOperand().getType()); 1411 MemRefType targetMemRefType = castOp.getType(); 1412 1413 // Only static shape casts supported atm. 1414 if (!sourceMemRefType.hasStaticShape() || 1415 !targetMemRefType.hasStaticShape()) 1416 return failure(); 1417 1418 auto llvmSourceDescriptorTy = 1419 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); 1420 if (!llvmSourceDescriptorTy) 1421 return failure(); 1422 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1423 1424 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1425 typeConverter->convertType(targetMemRefType)); 1426 if (!llvmTargetDescriptorTy) 1427 return failure(); 1428 1429 // Only contiguous source buffers supported atm. 1430 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1431 if (!sourceStrides) 1432 return failure(); 1433 auto targetStrides = computeContiguousStrides(targetMemRefType); 1434 if (!targetStrides) 1435 return failure(); 1436 // Only support static strides for now, regardless of contiguity. 1437 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1438 return failure(); 1439 1440 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1441 1442 // Create descriptor. 1443 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1444 Type llvmTargetElementTy = desc.getElementPtrType(); 1445 // Set allocated ptr. 1446 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1447 if (!getTypeConverter()->useOpaquePointers()) 1448 allocated = 1449 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1450 desc.setAllocatedPtr(rewriter, loc, allocated); 1451 1452 // Set aligned ptr. 1453 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1454 if (!getTypeConverter()->useOpaquePointers()) 1455 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1456 1457 desc.setAlignedPtr(rewriter, loc, ptr); 1458 // Fill offset 0. 1459 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1460 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1461 desc.setOffset(rewriter, loc, zero); 1462 1463 // Fill size and stride descriptors in memref. 1464 for (const auto &indexedSize : 1465 llvm::enumerate(targetMemRefType.getShape())) { 1466 int64_t index = indexedSize.index(); 1467 auto sizeAttr = 1468 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1469 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1470 desc.setSize(rewriter, loc, index, size); 1471 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1472 (*targetStrides)[index]); 1473 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1474 desc.setStride(rewriter, loc, index, stride); 1475 } 1476 1477 rewriter.replaceOp(castOp, {desc}); 1478 return success(); 1479 } 1480 }; 1481 1482 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1483 /// Non-scalable versions of this operation are handled in Vector Transforms. 1484 class VectorCreateMaskOpRewritePattern 1485 : public OpRewritePattern<vector::CreateMaskOp> { 1486 public: 1487 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1488 bool enableIndexOpt) 1489 : OpRewritePattern<vector::CreateMaskOp>(context), 1490 force32BitVectorIndices(enableIndexOpt) {} 1491 1492 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1493 PatternRewriter &rewriter) const override { 1494 auto dstType = op.getType(); 1495 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) 1496 return failure(); 1497 IntegerType idxType = 1498 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1499 auto loc = op->getLoc(); 1500 Value indices = rewriter.create<LLVM::StepVectorOp>( 1501 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1502 /*isScalable=*/true)); 1503 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1504 op.getOperand(0)); 1505 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1506 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1507 indices, bounds); 1508 rewriter.replaceOp(op, comp); 1509 return success(); 1510 } 1511 1512 private: 1513 const bool force32BitVectorIndices; 1514 }; 1515 1516 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1517 public: 1518 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1519 1520 // Lowering implementation that relies on a small runtime support library, 1521 // which only needs to provide a few printing methods (single value for all 1522 // data types, opening/closing bracket, comma, newline). The lowering splits 1523 // the vector into elementary printing operations. The advantage of this 1524 // approach is that the library can remain unaware of all low-level 1525 // implementation details of vectors while still supporting output of any 1526 // shaped and dimensioned vector. 1527 // 1528 // Note: This lowering only handles scalars, n-D vectors are broken into 1529 // printing scalars in loops in VectorToSCF. 1530 // 1531 // TODO: rely solely on libc in future? something else? 1532 // 1533 LogicalResult 1534 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1535 ConversionPatternRewriter &rewriter) const override { 1536 auto parent = printOp->getParentOfType<ModuleOp>(); 1537 if (!parent) 1538 return failure(); 1539 1540 auto loc = printOp->getLoc(); 1541 1542 if (auto value = adaptor.getSource()) { 1543 Type printType = printOp.getPrintType(); 1544 if (isa<VectorType>(printType)) { 1545 // Vectors should be broken into elementary print ops in VectorToSCF. 1546 return failure(); 1547 } 1548 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value))) 1549 return failure(); 1550 } 1551 1552 auto punct = printOp.getPunctuation(); 1553 if (punct != PrintPunctuation::NoPunctuation) { 1554 emitCall(rewriter, printOp->getLoc(), [&] { 1555 switch (punct) { 1556 case PrintPunctuation::Close: 1557 return LLVM::lookupOrCreatePrintCloseFn(parent); 1558 case PrintPunctuation::Open: 1559 return LLVM::lookupOrCreatePrintOpenFn(parent); 1560 case PrintPunctuation::Comma: 1561 return LLVM::lookupOrCreatePrintCommaFn(parent); 1562 case PrintPunctuation::NewLine: 1563 return LLVM::lookupOrCreatePrintNewlineFn(parent); 1564 default: 1565 llvm_unreachable("unexpected punctuation"); 1566 } 1567 }()); 1568 } 1569 1570 rewriter.eraseOp(printOp); 1571 return success(); 1572 } 1573 1574 private: 1575 enum class PrintConversion { 1576 // clang-format off 1577 None, 1578 ZeroExt64, 1579 SignExt64, 1580 Bitcast16 1581 // clang-format on 1582 }; 1583 1584 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter, 1585 ModuleOp parent, Location loc, Type printType, 1586 Value value) const { 1587 if (typeConverter->convertType(printType) == nullptr) 1588 return failure(); 1589 1590 // Make sure element type has runtime support. 1591 PrintConversion conversion = PrintConversion::None; 1592 Operation *printer; 1593 if (printType.isF32()) { 1594 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1595 } else if (printType.isF64()) { 1596 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1597 } else if (printType.isF16()) { 1598 conversion = PrintConversion::Bitcast16; // bits! 1599 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1600 } else if (printType.isBF16()) { 1601 conversion = PrintConversion::Bitcast16; // bits! 1602 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1603 } else if (printType.isIndex()) { 1604 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1605 } else if (auto intTy = dyn_cast<IntegerType>(printType)) { 1606 // Integers need a zero or sign extension on the operand 1607 // (depending on the source type) as well as a signed or 1608 // unsigned print method. Up to 64-bit is supported. 1609 unsigned width = intTy.getWidth(); 1610 if (intTy.isUnsigned()) { 1611 if (width <= 64) { 1612 if (width < 64) 1613 conversion = PrintConversion::ZeroExt64; 1614 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1615 } else { 1616 return failure(); 1617 } 1618 } else { 1619 assert(intTy.isSignless() || intTy.isSigned()); 1620 if (width <= 64) { 1621 // Note that we *always* zero extend booleans (1-bit integers), 1622 // so that true/false is printed as 1/0 rather than -1/0. 1623 if (width == 1) 1624 conversion = PrintConversion::ZeroExt64; 1625 else if (width < 64) 1626 conversion = PrintConversion::SignExt64; 1627 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1628 } else { 1629 return failure(); 1630 } 1631 } 1632 } else { 1633 return failure(); 1634 } 1635 1636 switch (conversion) { 1637 case PrintConversion::ZeroExt64: 1638 value = rewriter.create<arith::ExtUIOp>( 1639 loc, IntegerType::get(rewriter.getContext(), 64), value); 1640 break; 1641 case PrintConversion::SignExt64: 1642 value = rewriter.create<arith::ExtSIOp>( 1643 loc, IntegerType::get(rewriter.getContext(), 64), value); 1644 break; 1645 case PrintConversion::Bitcast16: 1646 value = rewriter.create<LLVM::BitcastOp>( 1647 loc, IntegerType::get(rewriter.getContext(), 16), value); 1648 break; 1649 case PrintConversion::None: 1650 break; 1651 } 1652 emitCall(rewriter, loc, printer, value); 1653 return success(); 1654 } 1655 1656 // Helper to emit a call. 1657 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1658 Operation *ref, ValueRange params = ValueRange()) { 1659 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1660 params); 1661 } 1662 }; 1663 1664 /// The Splat operation is lowered to an insertelement + a shufflevector 1665 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1666 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1667 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1668 1669 LogicalResult 1670 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1671 ConversionPatternRewriter &rewriter) const override { 1672 VectorType resultType = cast<VectorType>(splatOp.getType()); 1673 if (resultType.getRank() > 1) 1674 return failure(); 1675 1676 // First insert it into an undef vector so we can shuffle it. 1677 auto vectorType = typeConverter->convertType(splatOp.getType()); 1678 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1679 auto zero = rewriter.create<LLVM::ConstantOp>( 1680 splatOp.getLoc(), 1681 typeConverter->convertType(rewriter.getIntegerType(32)), 1682 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1683 1684 // For 0-d vector, we simply do `insertelement`. 1685 if (resultType.getRank() == 0) { 1686 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1687 splatOp, vectorType, undef, adaptor.getInput(), zero); 1688 return success(); 1689 } 1690 1691 // For 1-d vector, we additionally do a `vectorshuffle`. 1692 auto v = rewriter.create<LLVM::InsertElementOp>( 1693 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1694 1695 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); 1696 SmallVector<int32_t> zeroValues(width, 0); 1697 1698 // Shuffle the value across the desired number of elements. 1699 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1700 zeroValues); 1701 return success(); 1702 } 1703 }; 1704 1705 /// The Splat operation is lowered to an insertelement + a shufflevector 1706 /// operation. Splat to only 2+-d vector result types are lowered by the 1707 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1708 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1709 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1710 1711 LogicalResult 1712 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1713 ConversionPatternRewriter &rewriter) const override { 1714 VectorType resultType = splatOp.getType(); 1715 if (resultType.getRank() <= 1) 1716 return failure(); 1717 1718 // First insert it into an undef vector so we can shuffle it. 1719 auto loc = splatOp.getLoc(); 1720 auto vectorTypeInfo = 1721 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1722 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1723 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1724 if (!llvmNDVectorTy || !llvm1DVectorTy) 1725 return failure(); 1726 1727 // Construct returned value. 1728 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1729 1730 // Construct a 1-D vector with the splatted value that we insert in all the 1731 // places within the returned descriptor. 1732 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1733 auto zero = rewriter.create<LLVM::ConstantOp>( 1734 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1735 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1736 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1737 adaptor.getInput(), zero); 1738 1739 // Shuffle the value across the desired number of elements. 1740 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1741 SmallVector<int32_t> zeroValues(width, 0); 1742 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1743 1744 // Iterate of linear index, convert to coords space and insert splatted 1-D 1745 // vector in each position. 1746 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1747 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1748 }); 1749 rewriter.replaceOp(splatOp, desc); 1750 return success(); 1751 } 1752 }; 1753 1754 } // namespace 1755 1756 /// Populate the given list with patterns that convert from Vector to LLVM. 1757 void mlir::populateVectorToLLVMConversionPatterns( 1758 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1759 bool reassociateFPReductions, bool force32BitVectorIndices) { 1760 MLIRContext *ctx = converter.getDialect()->getContext(); 1761 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1762 populateVectorInsertExtractStridedSliceTransforms(patterns); 1763 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1764 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1765 patterns 1766 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1767 VectorExtractElementOpConversion, VectorExtractOpConversion, 1768 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1769 VectorInsertOpConversion, VectorPrintOpConversion, 1770 VectorTypeCastOpConversion, VectorScaleOpConversion, 1771 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1772 VectorLoadStoreConversion<vector::MaskedLoadOp, 1773 vector::MaskedLoadOpAdaptor>, 1774 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1775 VectorLoadStoreConversion<vector::MaskedStoreOp, 1776 vector::MaskedStoreOpAdaptor>, 1777 VectorGatherOpConversion, VectorScatterOpConversion, 1778 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1779 VectorSplatOpLowering, VectorSplatNdOpLowering, 1780 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1781 MaskedReductionOpConversion>(converter); 1782 // Transfer ops with rank > 1 are handled by VectorToSCF. 1783 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1784 } 1785 1786 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1787 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1788 patterns.add<VectorMatmulOpConversion>(converter); 1789 patterns.add<VectorFlatTransposeOpConversion>(converter); 1790 } 1791