1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" 12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Arith/Utils/Utils.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/Vector/IR/VectorOps.h" 20 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 21 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 23 #include "mlir/IR/BuiltinAttributes.h" 24 #include "mlir/IR/BuiltinTypeInterfaces.h" 25 #include "mlir/IR/BuiltinTypes.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "llvm/ADT/APFloat.h" 30 #include "llvm/Support/Casting.h" 31 #include <optional> 32 33 using namespace mlir; 34 using namespace mlir::vector; 35 36 // Helper to reduce vector type by *all* but one rank at back. 37 static VectorType reducedVectorTypeBack(VectorType tp) { 38 assert((tp.getRank() > 1) && "unlowerable vector type"); 39 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 40 tp.getScalableDims().take_back()); 41 } 42 43 // Helper that picks the proper sequence for inserting. 44 static Value insertOne(ConversionPatternRewriter &rewriter, 45 const LLVMTypeConverter &typeConverter, Location loc, 46 Value val1, Value val2, Type llvmType, int64_t rank, 47 int64_t pos) { 48 assert(rank > 0 && "0-D vector corner case should have been handled already"); 49 if (rank == 1) { 50 auto idxType = rewriter.getIndexType(); 51 auto constant = rewriter.create<LLVM::ConstantOp>( 52 loc, typeConverter.convertType(idxType), 53 rewriter.getIntegerAttr(idxType, pos)); 54 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 55 constant); 56 } 57 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 58 } 59 60 // Helper that picks the proper sequence for extracting. 61 static Value extractOne(ConversionPatternRewriter &rewriter, 62 const LLVMTypeConverter &typeConverter, Location loc, 63 Value val, Type llvmType, int64_t rank, int64_t pos) { 64 if (rank <= 1) { 65 auto idxType = rewriter.getIndexType(); 66 auto constant = rewriter.create<LLVM::ConstantOp>( 67 loc, typeConverter.convertType(idxType), 68 rewriter.getIntegerAttr(idxType, pos)); 69 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 70 constant); 71 } 72 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 73 } 74 75 // Helper that returns data layout alignment of a memref. 76 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, 77 MemRefType memrefType, unsigned &align) { 78 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 79 if (!elementTy) 80 return failure(); 81 82 // TODO: this should use the MLIR data layout when it becomes available and 83 // stop depending on translation. 84 llvm::LLVMContext llvmContext; 85 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 86 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 87 return success(); 88 } 89 90 // Check if the last stride is non-unit and has a valid memory space. 91 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 92 const LLVMTypeConverter &converter) { 93 if (!isLastMemrefDimUnitStride(memRefType)) 94 return failure(); 95 if (failed(converter.getMemRefAddressSpace(memRefType))) 96 return failure(); 97 return success(); 98 } 99 100 // Add an index vector component to a base pointer. 101 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 102 const LLVMTypeConverter &typeConverter, 103 MemRefType memRefType, Value llvmMemref, Value base, 104 Value index, uint64_t vLen) { 105 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 106 "unsupported memref type"); 107 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 108 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 109 return rewriter.create<LLVM::GEPOp>( 110 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 111 base, index); 112 } 113 114 // Casts a strided element pointer to a vector pointer. The vector pointer 115 // will be in the same address space as the incoming memref type. 116 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 117 Value ptr, MemRefType memRefType, Type vt, 118 const LLVMTypeConverter &converter) { 119 if (converter.useOpaquePointers()) 120 return ptr; 121 122 unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); 123 auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); 124 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 125 } 126 127 /// Convert `foldResult` into a Value. Integer attribute is converted to 128 /// an LLVM constant op. 129 static Value getAsLLVMValue(OpBuilder &builder, Location loc, 130 OpFoldResult foldResult) { 131 if (auto attr = foldResult.dyn_cast<Attribute>()) { 132 auto intAttr = cast<IntegerAttr>(attr); 133 return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult(); 134 } 135 136 return foldResult.get<Value>(); 137 } 138 139 namespace { 140 141 /// Trivial Vector to LLVM conversions 142 using VectorScaleOpConversion = 143 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 144 145 /// Conversion pattern for a vector.bitcast. 146 class VectorBitCastOpConversion 147 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 148 public: 149 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 150 151 LogicalResult 152 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 153 ConversionPatternRewriter &rewriter) const override { 154 // Only 0-D and 1-D vectors can be lowered to LLVM. 155 VectorType resultTy = bitCastOp.getResultVectorType(); 156 if (resultTy.getRank() > 1) 157 return failure(); 158 Type newResultTy = typeConverter->convertType(resultTy); 159 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 160 adaptor.getOperands()[0]); 161 return success(); 162 } 163 }; 164 165 /// Conversion pattern for a vector.matrix_multiply. 166 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 167 class VectorMatmulOpConversion 168 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 169 public: 170 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 171 172 LogicalResult 173 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 174 ConversionPatternRewriter &rewriter) const override { 175 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 176 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 177 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 178 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 179 return success(); 180 } 181 }; 182 183 /// Conversion pattern for a vector.flat_transpose. 184 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 185 class VectorFlatTransposeOpConversion 186 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 187 public: 188 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 189 190 LogicalResult 191 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 192 ConversionPatternRewriter &rewriter) const override { 193 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 194 transOp, typeConverter->convertType(transOp.getRes().getType()), 195 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 196 return success(); 197 } 198 }; 199 200 /// Overloaded utility that replaces a vector.load, vector.store, 201 /// vector.maskedload and vector.maskedstore with their respective LLVM 202 /// couterparts. 203 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 204 vector::LoadOpAdaptor adaptor, 205 VectorType vectorTy, Value ptr, unsigned align, 206 ConversionPatternRewriter &rewriter) { 207 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 208 } 209 210 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 211 vector::MaskedLoadOpAdaptor adaptor, 212 VectorType vectorTy, Value ptr, unsigned align, 213 ConversionPatternRewriter &rewriter) { 214 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 215 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 216 } 217 218 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 219 vector::StoreOpAdaptor adaptor, 220 VectorType vectorTy, Value ptr, unsigned align, 221 ConversionPatternRewriter &rewriter) { 222 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 223 ptr, align); 224 } 225 226 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 227 vector::MaskedStoreOpAdaptor adaptor, 228 VectorType vectorTy, Value ptr, unsigned align, 229 ConversionPatternRewriter &rewriter) { 230 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 231 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 232 } 233 234 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 235 /// vector.maskedstore. 236 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 237 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 238 public: 239 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 240 241 LogicalResult 242 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 243 typename LoadOrStoreOp::Adaptor adaptor, 244 ConversionPatternRewriter &rewriter) const override { 245 // Only 1-D vectors can be lowered to LLVM. 246 VectorType vectorTy = loadOrStoreOp.getVectorType(); 247 if (vectorTy.getRank() > 1) 248 return failure(); 249 250 auto loc = loadOrStoreOp->getLoc(); 251 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 252 253 // Resolve alignment. 254 unsigned align; 255 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 256 return failure(); 257 258 // Resolve address. 259 auto vtype = cast<VectorType>( 260 this->typeConverter->convertType(loadOrStoreOp.getVectorType())); 261 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 262 adaptor.getIndices(), rewriter); 263 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, 264 *this->getTypeConverter()); 265 266 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 267 return success(); 268 } 269 }; 270 271 /// Conversion pattern for a vector.gather. 272 class VectorGatherOpConversion 273 : public ConvertOpToLLVMPattern<vector::GatherOp> { 274 public: 275 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 276 277 LogicalResult 278 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 279 ConversionPatternRewriter &rewriter) const override { 280 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); 281 assert(memRefType && "The base should be bufferized"); 282 283 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 284 return failure(); 285 286 auto loc = gather->getLoc(); 287 288 // Resolve alignment. 289 unsigned align; 290 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 291 return failure(); 292 293 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 294 adaptor.getIndices(), rewriter); 295 Value base = adaptor.getBase(); 296 297 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 298 // Handle the simple case of 1-D vector. 299 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) { 300 auto vType = gather.getVectorType(); 301 // Resolve address. 302 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 303 memRefType, base, ptr, adaptor.getIndexVec(), 304 /*vLen=*/vType.getDimSize(0)); 305 // Replace with the gather intrinsic. 306 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 307 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 308 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 309 return success(); 310 } 311 312 const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 313 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 314 &typeConverter](Type llvm1DVectorTy, 315 ValueRange vectorOperands) { 316 // Resolve address. 317 Value ptrs = getIndexedPtrs( 318 rewriter, loc, typeConverter, memRefType, base, ptr, 319 /*index=*/vectorOperands[0], 320 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 321 // Create the gather intrinsic. 322 return rewriter.create<LLVM::masked_gather>( 323 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 324 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 325 }; 326 SmallVector<Value> vectorOperands = { 327 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 328 return LLVM::detail::handleMultidimensionalVectors( 329 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 330 } 331 }; 332 333 /// Conversion pattern for a vector.scatter. 334 class VectorScatterOpConversion 335 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 336 public: 337 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 338 339 LogicalResult 340 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 341 ConversionPatternRewriter &rewriter) const override { 342 auto loc = scatter->getLoc(); 343 MemRefType memRefType = scatter.getMemRefType(); 344 345 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 346 return failure(); 347 348 // Resolve alignment. 349 unsigned align; 350 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 351 return failure(); 352 353 // Resolve address. 354 VectorType vType = scatter.getVectorType(); 355 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 356 adaptor.getIndices(), rewriter); 357 Value ptrs = getIndexedPtrs( 358 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 359 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 360 361 // Replace with the scatter intrinsic. 362 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 363 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 364 rewriter.getI32IntegerAttr(align)); 365 return success(); 366 } 367 }; 368 369 /// Conversion pattern for a vector.expandload. 370 class VectorExpandLoadOpConversion 371 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 372 public: 373 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 374 375 LogicalResult 376 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 377 ConversionPatternRewriter &rewriter) const override { 378 auto loc = expand->getLoc(); 379 MemRefType memRefType = expand.getMemRefType(); 380 381 // Resolve address. 382 auto vtype = typeConverter->convertType(expand.getVectorType()); 383 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 384 adaptor.getIndices(), rewriter); 385 386 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 387 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 388 return success(); 389 } 390 }; 391 392 /// Conversion pattern for a vector.compressstore. 393 class VectorCompressStoreOpConversion 394 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 395 public: 396 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 397 398 LogicalResult 399 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 400 ConversionPatternRewriter &rewriter) const override { 401 auto loc = compress->getLoc(); 402 MemRefType memRefType = compress.getMemRefType(); 403 404 // Resolve address. 405 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 406 adaptor.getIndices(), rewriter); 407 408 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 409 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 410 return success(); 411 } 412 }; 413 414 /// Reduction neutral classes for overloading. 415 class ReductionNeutralZero {}; 416 class ReductionNeutralIntOne {}; 417 class ReductionNeutralFPOne {}; 418 class ReductionNeutralAllOnes {}; 419 class ReductionNeutralSIntMin {}; 420 class ReductionNeutralUIntMin {}; 421 class ReductionNeutralSIntMax {}; 422 class ReductionNeutralUIntMax {}; 423 class ReductionNeutralFPMin {}; 424 class ReductionNeutralFPMax {}; 425 426 /// Create the reduction neutral zero value. 427 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 428 ConversionPatternRewriter &rewriter, 429 Location loc, Type llvmType) { 430 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 431 rewriter.getZeroAttr(llvmType)); 432 } 433 434 /// Create the reduction neutral integer one value. 435 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 436 ConversionPatternRewriter &rewriter, 437 Location loc, Type llvmType) { 438 return rewriter.create<LLVM::ConstantOp>( 439 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 440 } 441 442 /// Create the reduction neutral fp one value. 443 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 444 ConversionPatternRewriter &rewriter, 445 Location loc, Type llvmType) { 446 return rewriter.create<LLVM::ConstantOp>( 447 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 448 } 449 450 /// Create the reduction neutral all-ones value. 451 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 452 ConversionPatternRewriter &rewriter, 453 Location loc, Type llvmType) { 454 return rewriter.create<LLVM::ConstantOp>( 455 loc, llvmType, 456 rewriter.getIntegerAttr( 457 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 458 } 459 460 /// Create the reduction neutral signed int minimum value. 461 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 462 ConversionPatternRewriter &rewriter, 463 Location loc, Type llvmType) { 464 return rewriter.create<LLVM::ConstantOp>( 465 loc, llvmType, 466 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 467 llvmType.getIntOrFloatBitWidth()))); 468 } 469 470 /// Create the reduction neutral unsigned int minimum value. 471 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 472 ConversionPatternRewriter &rewriter, 473 Location loc, Type llvmType) { 474 return rewriter.create<LLVM::ConstantOp>( 475 loc, llvmType, 476 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 477 llvmType.getIntOrFloatBitWidth()))); 478 } 479 480 /// Create the reduction neutral signed int maximum value. 481 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 482 ConversionPatternRewriter &rewriter, 483 Location loc, Type llvmType) { 484 return rewriter.create<LLVM::ConstantOp>( 485 loc, llvmType, 486 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 487 llvmType.getIntOrFloatBitWidth()))); 488 } 489 490 /// Create the reduction neutral unsigned int maximum value. 491 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 492 ConversionPatternRewriter &rewriter, 493 Location loc, Type llvmType) { 494 return rewriter.create<LLVM::ConstantOp>( 495 loc, llvmType, 496 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 497 llvmType.getIntOrFloatBitWidth()))); 498 } 499 500 /// Create the reduction neutral fp minimum value. 501 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 502 ConversionPatternRewriter &rewriter, 503 Location loc, Type llvmType) { 504 auto floatType = cast<FloatType>(llvmType); 505 return rewriter.create<LLVM::ConstantOp>( 506 loc, llvmType, 507 rewriter.getFloatAttr( 508 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 509 /*Negative=*/false))); 510 } 511 512 /// Create the reduction neutral fp maximum value. 513 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 514 ConversionPatternRewriter &rewriter, 515 Location loc, Type llvmType) { 516 auto floatType = cast<FloatType>(llvmType); 517 return rewriter.create<LLVM::ConstantOp>( 518 loc, llvmType, 519 rewriter.getFloatAttr( 520 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 521 /*Negative=*/true))); 522 } 523 524 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 525 /// returns a new accumulator value using `ReductionNeutral`. 526 template <class ReductionNeutral> 527 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 528 Location loc, Type llvmType, 529 Value accumulator) { 530 if (accumulator) 531 return accumulator; 532 533 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 534 llvmType); 535 } 536 537 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 538 /// This is used as effective vector length by some intrinsics supporting 539 /// dynamic vector lengths at runtime. 540 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 541 Location loc, Type llvmType) { 542 VectorType vType = cast<VectorType>(llvmType); 543 auto vShape = vType.getShape(); 544 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 545 546 return rewriter.create<LLVM::ConstantOp>( 547 loc, rewriter.getI32Type(), 548 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 549 } 550 551 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 552 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 553 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 554 /// non-null. 555 template <class LLVMRedIntrinOp, class ScalarOp> 556 static Value createIntegerReductionArithmeticOpLowering( 557 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 558 Value vectorOperand, Value accumulator) { 559 560 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 561 562 if (accumulator) 563 result = rewriter.create<ScalarOp>(loc, accumulator, result); 564 return result; 565 } 566 567 /// Helper method to lower a `vector.reduction` operation that performs 568 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 569 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 570 /// the accumulator value if non-null. 571 template <class LLVMRedIntrinOp> 572 static Value createIntegerReductionComparisonOpLowering( 573 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 574 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 575 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 576 if (accumulator) { 577 Value cmp = 578 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 579 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 580 } 581 return result; 582 } 583 584 namespace { 585 template <typename Source> 586 struct VectorToScalarMapper; 587 template <> 588 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> { 589 using Type = LLVM::MaximumOp; 590 }; 591 template <> 592 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> { 593 using Type = LLVM::MinimumOp; 594 }; 595 template <> 596 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> { 597 using Type = LLVM::MaxNumOp; 598 }; 599 template <> 600 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> { 601 using Type = LLVM::MinNumOp; 602 }; 603 } // namespace 604 605 template <class LLVMRedIntrinOp> 606 static Value createFPReductionComparisonOpLowering( 607 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 608 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { 609 Value result = 610 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf); 611 612 if (accumulator) { 613 result = 614 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>( 615 loc, result, accumulator); 616 } 617 618 return result; 619 } 620 621 /// Reduction neutral classes for overloading 622 class MaskNeutralFMaximum {}; 623 class MaskNeutralFMinimum {}; 624 625 /// Get the mask neutral floating point maximum value 626 static llvm::APFloat 627 getMaskNeutralValue(MaskNeutralFMaximum, 628 const llvm::fltSemantics &floatSemantics) { 629 return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true); 630 } 631 /// Get the mask neutral floating point minimum value 632 static llvm::APFloat 633 getMaskNeutralValue(MaskNeutralFMinimum, 634 const llvm::fltSemantics &floatSemantics) { 635 return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false); 636 } 637 638 /// Create the mask neutral floating point MLIR vector constant 639 template <typename MaskNeutral> 640 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, 641 Location loc, Type llvmType, 642 Type vectorType) { 643 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics(); 644 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); 645 auto denseValue = 646 DenseElementsAttr::get(vectorType.cast<ShapedType>(), value); 647 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue); 648 } 649 650 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked 651 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for 652 /// `fmaximum`/`fminimum`. 653 /// More information: https://github.com/llvm/llvm-project/issues/64940 654 template <class LLVMRedIntrinOp, class MaskNeutral> 655 static Value 656 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, 657 Location loc, Type llvmType, 658 Value vectorOperand, Value accumulator, 659 Value mask, LLVM::FastmathFlagsAttr fmf) { 660 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>( 661 rewriter, loc, llvmType, vectorOperand.getType()); 662 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>( 663 loc, mask, vectorOperand, vectorMaskNeutral); 664 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>( 665 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); 666 } 667 668 template <class LLVMRedIntrinOp, class ReductionNeutral> 669 static Value 670 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 671 Type llvmType, Value vectorOperand, 672 Value accumulator, LLVM::FastmathFlagsAttr fmf) { 673 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 674 llvmType, accumulator); 675 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType, 676 /*startValue=*/accumulator, 677 vectorOperand, fmf); 678 } 679 680 /// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic 681 /// that requires a start value. This start value format spans across fp 682 /// reductions without mask and all the masked reduction intrinsics. 683 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 684 static Value 685 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, 686 Location loc, Type llvmType, 687 Value vectorOperand, Value accumulator) { 688 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 689 llvmType, accumulator); 690 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 691 /*startValue=*/accumulator, 692 vectorOperand); 693 } 694 695 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 696 static Value lowerPredicatedReductionWithStartValue( 697 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 698 Value vectorOperand, Value accumulator, Value mask) { 699 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 700 llvmType, accumulator); 701 Value vectorLength = 702 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 703 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 704 /*startValue=*/accumulator, 705 vectorOperand, mask, vectorLength); 706 } 707 708 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 709 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 710 static Value lowerPredicatedReductionWithStartValue( 711 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 712 Value vectorOperand, Value accumulator, Value mask) { 713 if (llvmType.isIntOrIndex()) 714 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp, 715 IntReductionNeutral>( 716 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 717 718 // FP dispatch. 719 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp, 720 FPReductionNeutral>( 721 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 722 } 723 724 /// Conversion pattern for all vector reductions. 725 class VectorReductionOpConversion 726 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 727 public: 728 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv, 729 bool reassociateFPRed) 730 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 731 reassociateFPReductions(reassociateFPRed) {} 732 733 LogicalResult 734 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 735 ConversionPatternRewriter &rewriter) const override { 736 auto kind = reductionOp.getKind(); 737 Type eltType = reductionOp.getDest().getType(); 738 Type llvmType = typeConverter->convertType(eltType); 739 Value operand = adaptor.getVector(); 740 Value acc = adaptor.getAcc(); 741 Location loc = reductionOp.getLoc(); 742 743 if (eltType.isIntOrIndex()) { 744 // Integer reductions: add/mul/min/max/and/or/xor. 745 Value result; 746 switch (kind) { 747 case vector::CombiningKind::ADD: 748 result = 749 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 750 LLVM::AddOp>( 751 rewriter, loc, llvmType, operand, acc); 752 break; 753 case vector::CombiningKind::MUL: 754 result = 755 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 756 LLVM::MulOp>( 757 rewriter, loc, llvmType, operand, acc); 758 break; 759 case vector::CombiningKind::MINUI: 760 result = createIntegerReductionComparisonOpLowering< 761 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 762 LLVM::ICmpPredicate::ule); 763 break; 764 case vector::CombiningKind::MINSI: 765 result = createIntegerReductionComparisonOpLowering< 766 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 767 LLVM::ICmpPredicate::sle); 768 break; 769 case vector::CombiningKind::MAXUI: 770 result = createIntegerReductionComparisonOpLowering< 771 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 772 LLVM::ICmpPredicate::uge); 773 break; 774 case vector::CombiningKind::MAXSI: 775 result = createIntegerReductionComparisonOpLowering< 776 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 777 LLVM::ICmpPredicate::sge); 778 break; 779 case vector::CombiningKind::AND: 780 result = 781 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 782 LLVM::AndOp>( 783 rewriter, loc, llvmType, operand, acc); 784 break; 785 case vector::CombiningKind::OR: 786 result = 787 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 788 LLVM::OrOp>( 789 rewriter, loc, llvmType, operand, acc); 790 break; 791 case vector::CombiningKind::XOR: 792 result = 793 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 794 LLVM::XOrOp>( 795 rewriter, loc, llvmType, operand, acc); 796 break; 797 default: 798 return failure(); 799 } 800 rewriter.replaceOp(reductionOp, result); 801 802 return success(); 803 } 804 805 if (!isa<FloatType>(eltType)) 806 return failure(); 807 808 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 809 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 810 reductionOp.getContext(), 811 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 812 fmf = LLVM::FastmathFlagsAttr::get( 813 reductionOp.getContext(), 814 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc 815 : LLVM::FastmathFlags::none)); 816 817 // Floating-point reductions: add/mul/min/max 818 Value result; 819 if (kind == vector::CombiningKind::ADD) { 820 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 821 ReductionNeutralZero>( 822 rewriter, loc, llvmType, operand, acc, fmf); 823 } else if (kind == vector::CombiningKind::MUL) { 824 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 825 ReductionNeutralFPOne>( 826 rewriter, loc, llvmType, operand, acc, fmf); 827 } else if (kind == vector::CombiningKind::MINIMUMF) { 828 result = 829 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>( 830 rewriter, loc, llvmType, operand, acc, fmf); 831 } else if (kind == vector::CombiningKind::MAXIMUMF) { 832 result = 833 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>( 834 rewriter, loc, llvmType, operand, acc, fmf); 835 } else if (kind == vector::CombiningKind::MINF) { 836 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 837 rewriter, loc, llvmType, operand, acc, fmf); 838 } else if (kind == vector::CombiningKind::MAXF) { 839 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 840 rewriter, loc, llvmType, operand, acc, fmf); 841 } else 842 return failure(); 843 844 rewriter.replaceOp(reductionOp, result); 845 return success(); 846 } 847 848 private: 849 const bool reassociateFPReductions; 850 }; 851 852 /// Base class to convert a `vector.mask` operation while matching traits 853 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 854 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 855 /// method performs a second match against the maskable operation `MaskedOp`. 856 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 857 /// implemented by the concrete conversion classes. This method can match 858 /// against specific traits of the `vector.mask` and the maskable operation. It 859 /// must replace the `vector.mask` operation. 860 template <class MaskedOp> 861 class VectorMaskOpConversionBase 862 : public ConvertOpToLLVMPattern<vector::MaskOp> { 863 public: 864 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 865 866 LogicalResult 867 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 868 ConversionPatternRewriter &rewriter) const final { 869 // Match against the maskable operation kind. 870 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 871 if (!maskedOp) 872 return failure(); 873 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 874 } 875 876 protected: 877 virtual LogicalResult 878 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 879 vector::MaskableOpInterface maskableOp, 880 ConversionPatternRewriter &rewriter) const = 0; 881 }; 882 883 class MaskedReductionOpConversion 884 : public VectorMaskOpConversionBase<vector::ReductionOp> { 885 886 public: 887 using VectorMaskOpConversionBase< 888 vector::ReductionOp>::VectorMaskOpConversionBase; 889 890 LogicalResult matchAndRewriteMaskableOp( 891 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 892 ConversionPatternRewriter &rewriter) const override { 893 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 894 auto kind = reductionOp.getKind(); 895 Type eltType = reductionOp.getDest().getType(); 896 Type llvmType = typeConverter->convertType(eltType); 897 Value operand = reductionOp.getVector(); 898 Value acc = reductionOp.getAcc(); 899 Location loc = reductionOp.getLoc(); 900 901 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 902 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 903 reductionOp.getContext(), 904 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 905 906 Value result; 907 switch (kind) { 908 case vector::CombiningKind::ADD: 909 result = lowerPredicatedReductionWithStartValue< 910 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 911 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 912 maskOp.getMask()); 913 break; 914 case vector::CombiningKind::MUL: 915 result = lowerPredicatedReductionWithStartValue< 916 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 917 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 918 maskOp.getMask()); 919 break; 920 case vector::CombiningKind::MINUI: 921 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp, 922 ReductionNeutralUIntMax>( 923 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 924 break; 925 case vector::CombiningKind::MINSI: 926 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp, 927 ReductionNeutralSIntMax>( 928 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 929 break; 930 case vector::CombiningKind::MAXUI: 931 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp, 932 ReductionNeutralUIntMin>( 933 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 934 break; 935 case vector::CombiningKind::MAXSI: 936 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp, 937 ReductionNeutralSIntMin>( 938 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 939 break; 940 case vector::CombiningKind::AND: 941 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp, 942 ReductionNeutralAllOnes>( 943 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 944 break; 945 case vector::CombiningKind::OR: 946 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp, 947 ReductionNeutralZero>( 948 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 949 break; 950 case vector::CombiningKind::XOR: 951 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp, 952 ReductionNeutralZero>( 953 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 954 break; 955 case vector::CombiningKind::MINF: 956 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp, 957 ReductionNeutralFPMax>( 958 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 959 break; 960 case vector::CombiningKind::MAXF: 961 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp, 962 ReductionNeutralFPMin>( 963 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 964 break; 965 case CombiningKind::MAXIMUMF: 966 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum, 967 MaskNeutralFMaximum>( 968 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 969 break; 970 case CombiningKind::MINIMUMF: 971 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum, 972 MaskNeutralFMinimum>( 973 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 974 break; 975 } 976 977 // Replace `vector.mask` operation altogether. 978 rewriter.replaceOp(maskOp, result); 979 return success(); 980 } 981 }; 982 983 class VectorShuffleOpConversion 984 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 985 public: 986 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 987 988 LogicalResult 989 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 990 ConversionPatternRewriter &rewriter) const override { 991 auto loc = shuffleOp->getLoc(); 992 auto v1Type = shuffleOp.getV1VectorType(); 993 auto v2Type = shuffleOp.getV2VectorType(); 994 auto vectorType = shuffleOp.getResultVectorType(); 995 Type llvmType = typeConverter->convertType(vectorType); 996 auto maskArrayAttr = shuffleOp.getMask(); 997 998 // Bail if result type cannot be lowered. 999 if (!llvmType) 1000 return failure(); 1001 1002 // Get rank and dimension sizes. 1003 int64_t rank = vectorType.getRank(); 1004 #ifndef NDEBUG 1005 bool wellFormed0DCase = 1006 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 1007 bool wellFormedNDCase = 1008 v1Type.getRank() == rank && v2Type.getRank() == rank; 1009 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 1010 #endif 1011 1012 // For rank 0 and 1, where both operands have *exactly* the same vector 1013 // type, there is direct shuffle support in LLVM. Use it! 1014 if (rank <= 1 && v1Type == v2Type) { 1015 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 1016 loc, adaptor.getV1(), adaptor.getV2(), 1017 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 1018 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 1019 return success(); 1020 } 1021 1022 // For all other cases, insert the individual values individually. 1023 int64_t v1Dim = v1Type.getDimSize(0); 1024 Type eltType; 1025 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType)) 1026 eltType = arrayType.getElementType(); 1027 else 1028 eltType = cast<VectorType>(llvmType).getElementType(); 1029 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 1030 int64_t insPos = 0; 1031 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 1032 int64_t extPos = cast<IntegerAttr>(en.value()).getInt(); 1033 Value value = adaptor.getV1(); 1034 if (extPos >= v1Dim) { 1035 extPos -= v1Dim; 1036 value = adaptor.getV2(); 1037 } 1038 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 1039 eltType, rank, extPos); 1040 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 1041 llvmType, rank, insPos++); 1042 } 1043 rewriter.replaceOp(shuffleOp, insert); 1044 return success(); 1045 } 1046 }; 1047 1048 class VectorExtractElementOpConversion 1049 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 1050 public: 1051 using ConvertOpToLLVMPattern< 1052 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 1053 1054 LogicalResult 1055 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 1056 ConversionPatternRewriter &rewriter) const override { 1057 auto vectorType = extractEltOp.getSourceVectorType(); 1058 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 1059 1060 // Bail if result type cannot be lowered. 1061 if (!llvmType) 1062 return failure(); 1063 1064 if (vectorType.getRank() == 0) { 1065 Location loc = extractEltOp.getLoc(); 1066 auto idxType = rewriter.getIndexType(); 1067 auto zero = rewriter.create<LLVM::ConstantOp>( 1068 loc, typeConverter->convertType(idxType), 1069 rewriter.getIntegerAttr(idxType, 0)); 1070 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1071 extractEltOp, llvmType, adaptor.getVector(), zero); 1072 return success(); 1073 } 1074 1075 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1076 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1077 return success(); 1078 } 1079 }; 1080 1081 class VectorExtractOpConversion 1082 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1083 public: 1084 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1085 1086 LogicalResult 1087 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1088 ConversionPatternRewriter &rewriter) const override { 1089 auto loc = extractOp->getLoc(); 1090 auto resultType = extractOp.getResult().getType(); 1091 auto llvmResultType = typeConverter->convertType(resultType); 1092 // Bail if result type cannot be lowered. 1093 if (!llvmResultType) 1094 return failure(); 1095 1096 SmallVector<OpFoldResult> positionVec; 1097 for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) { 1098 if (pos.is<Value>()) 1099 // Make sure we use the value that has been already converted to LLVM. 1100 positionVec.push_back(adaptor.getDynamicPosition()[idx]); 1101 else 1102 positionVec.push_back(pos); 1103 } 1104 1105 // Extract entire vector. Should be handled by folder, but just to be safe. 1106 ArrayRef<OpFoldResult> position(positionVec); 1107 if (position.empty()) { 1108 rewriter.replaceOp(extractOp, adaptor.getVector()); 1109 return success(); 1110 } 1111 1112 // One-shot extraction of vector from array (only requires extractvalue). 1113 if (isa<VectorType>(resultType)) { 1114 if (extractOp.hasDynamicPosition()) 1115 return failure(); 1116 1117 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1118 loc, adaptor.getVector(), getAsIntegers(position)); 1119 rewriter.replaceOp(extractOp, extracted); 1120 return success(); 1121 } 1122 1123 // Potential extraction of 1-D vector from array. 1124 Value extracted = adaptor.getVector(); 1125 if (position.size() > 1) { 1126 if (extractOp.hasDynamicPosition()) 1127 return failure(); 1128 1129 SmallVector<int64_t> nMinusOnePosition = 1130 getAsIntegers(position.drop_back()); 1131 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 1132 nMinusOnePosition); 1133 } 1134 1135 Value lastPosition = getAsLLVMValue(rewriter, loc, position.back()); 1136 // Remaining extraction of element from 1-D LLVM vector. 1137 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted, 1138 lastPosition); 1139 return success(); 1140 } 1141 }; 1142 1143 /// Conversion pattern that turns a vector.fma on a 1-D vector 1144 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1145 /// This does not match vectors of n >= 2 rank. 1146 /// 1147 /// Example: 1148 /// ``` 1149 /// vector.fma %a, %a, %a : vector<8xf32> 1150 /// ``` 1151 /// is converted to: 1152 /// ``` 1153 /// llvm.intr.fmuladd %va, %va, %va: 1154 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1155 /// -> !llvm."<8 x f32>"> 1156 /// ``` 1157 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1158 public: 1159 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1160 1161 LogicalResult 1162 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1163 ConversionPatternRewriter &rewriter) const override { 1164 VectorType vType = fmaOp.getVectorType(); 1165 if (vType.getRank() > 1) 1166 return failure(); 1167 1168 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1169 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1170 return success(); 1171 } 1172 }; 1173 1174 class VectorInsertElementOpConversion 1175 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1176 public: 1177 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1178 1179 LogicalResult 1180 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1181 ConversionPatternRewriter &rewriter) const override { 1182 auto vectorType = insertEltOp.getDestVectorType(); 1183 auto llvmType = typeConverter->convertType(vectorType); 1184 1185 // Bail if result type cannot be lowered. 1186 if (!llvmType) 1187 return failure(); 1188 1189 if (vectorType.getRank() == 0) { 1190 Location loc = insertEltOp.getLoc(); 1191 auto idxType = rewriter.getIndexType(); 1192 auto zero = rewriter.create<LLVM::ConstantOp>( 1193 loc, typeConverter->convertType(idxType), 1194 rewriter.getIntegerAttr(idxType, 0)); 1195 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1196 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1197 return success(); 1198 } 1199 1200 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1201 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1202 adaptor.getPosition()); 1203 return success(); 1204 } 1205 }; 1206 1207 class VectorInsertOpConversion 1208 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1209 public: 1210 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1211 1212 LogicalResult 1213 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1214 ConversionPatternRewriter &rewriter) const override { 1215 auto loc = insertOp->getLoc(); 1216 auto sourceType = insertOp.getSourceType(); 1217 auto destVectorType = insertOp.getDestVectorType(); 1218 auto llvmResultType = typeConverter->convertType(destVectorType); 1219 // Bail if result type cannot be lowered. 1220 if (!llvmResultType) 1221 return failure(); 1222 1223 SmallVector<OpFoldResult> positionVec; 1224 for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) { 1225 if (pos.is<Value>()) 1226 // Make sure we use the value that has been already converted to LLVM. 1227 positionVec.push_back(adaptor.getDynamicPosition()[idx]); 1228 else 1229 positionVec.push_back(pos); 1230 } 1231 1232 // Overwrite entire vector with value. Should be handled by folder, but 1233 // just to be safe. 1234 ArrayRef<OpFoldResult> position(positionVec); 1235 if (position.empty()) { 1236 rewriter.replaceOp(insertOp, adaptor.getSource()); 1237 return success(); 1238 } 1239 1240 // One-shot insertion of a vector into an array (only requires insertvalue). 1241 if (isa<VectorType>(sourceType)) { 1242 if (insertOp.hasDynamicPosition()) 1243 return failure(); 1244 1245 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1246 loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position)); 1247 rewriter.replaceOp(insertOp, inserted); 1248 return success(); 1249 } 1250 1251 // Potential extraction of 1-D vector from array. 1252 Value extracted = adaptor.getDest(); 1253 auto oneDVectorType = destVectorType; 1254 if (position.size() > 1) { 1255 if (insertOp.hasDynamicPosition()) 1256 return failure(); 1257 1258 oneDVectorType = reducedVectorTypeBack(destVectorType); 1259 extracted = rewriter.create<LLVM::ExtractValueOp>( 1260 loc, extracted, getAsIntegers(position.drop_back())); 1261 } 1262 1263 // Insertion of an element into a 1-D LLVM vector. 1264 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1265 loc, typeConverter->convertType(oneDVectorType), extracted, 1266 adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back())); 1267 1268 // Potential insertion of resulting 1-D vector into array. 1269 if (position.size() > 1) { 1270 if (insertOp.hasDynamicPosition()) 1271 return failure(); 1272 1273 inserted = rewriter.create<LLVM::InsertValueOp>( 1274 loc, adaptor.getDest(), inserted, 1275 getAsIntegers(position.drop_back())); 1276 } 1277 1278 rewriter.replaceOp(insertOp, inserted); 1279 return success(); 1280 } 1281 }; 1282 1283 /// Lower vector.scalable.insert ops to LLVM vector.insert 1284 struct VectorScalableInsertOpLowering 1285 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1286 using ConvertOpToLLVMPattern< 1287 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1288 1289 LogicalResult 1290 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1291 ConversionPatternRewriter &rewriter) const override { 1292 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1293 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1294 return success(); 1295 } 1296 }; 1297 1298 /// Lower vector.scalable.extract ops to LLVM vector.extract 1299 struct VectorScalableExtractOpLowering 1300 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1301 using ConvertOpToLLVMPattern< 1302 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1303 1304 LogicalResult 1305 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1306 ConversionPatternRewriter &rewriter) const override { 1307 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1308 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1309 adaptor.getSource(), adaptor.getPos()); 1310 return success(); 1311 } 1312 }; 1313 1314 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1315 /// 1316 /// Example: 1317 /// ``` 1318 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1319 /// ``` 1320 /// is rewritten into: 1321 /// ``` 1322 /// %r = splat %f0: vector<2x4xf32> 1323 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1324 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1325 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1326 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1327 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1328 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1329 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1330 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1331 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1332 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1333 /// // %r3 holds the final value. 1334 /// ``` 1335 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1336 public: 1337 using OpRewritePattern<FMAOp>::OpRewritePattern; 1338 1339 void initialize() { 1340 // This pattern recursively unpacks one dimension at a time. The recursion 1341 // bounded as the rank is strictly decreasing. 1342 setHasBoundedRewriteRecursion(); 1343 } 1344 1345 LogicalResult matchAndRewrite(FMAOp op, 1346 PatternRewriter &rewriter) const override { 1347 auto vType = op.getVectorType(); 1348 if (vType.getRank() < 2) 1349 return failure(); 1350 1351 auto loc = op.getLoc(); 1352 auto elemType = vType.getElementType(); 1353 Value zero = rewriter.create<arith::ConstantOp>( 1354 loc, elemType, rewriter.getZeroAttr(elemType)); 1355 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1356 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1357 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1358 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1359 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1360 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1361 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1362 } 1363 rewriter.replaceOp(op, desc); 1364 return success(); 1365 } 1366 }; 1367 1368 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1369 /// static layout. 1370 static std::optional<SmallVector<int64_t, 4>> 1371 computeContiguousStrides(MemRefType memRefType) { 1372 int64_t offset; 1373 SmallVector<int64_t, 4> strides; 1374 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1375 return std::nullopt; 1376 if (!strides.empty() && strides.back() != 1) 1377 return std::nullopt; 1378 // If no layout or identity layout, this is contiguous by definition. 1379 if (memRefType.getLayout().isIdentity()) 1380 return strides; 1381 1382 // Otherwise, we must determine contiguity form shapes. This can only ever 1383 // work in static cases because MemRefType is underspecified to represent 1384 // contiguous dynamic shapes in other ways than with just empty/identity 1385 // layout. 1386 auto sizes = memRefType.getShape(); 1387 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1388 if (ShapedType::isDynamic(sizes[index + 1]) || 1389 ShapedType::isDynamic(strides[index]) || 1390 ShapedType::isDynamic(strides[index + 1])) 1391 return std::nullopt; 1392 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1393 return std::nullopt; 1394 } 1395 return strides; 1396 } 1397 1398 class VectorTypeCastOpConversion 1399 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1400 public: 1401 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1402 1403 LogicalResult 1404 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1405 ConversionPatternRewriter &rewriter) const override { 1406 auto loc = castOp->getLoc(); 1407 MemRefType sourceMemRefType = 1408 cast<MemRefType>(castOp.getOperand().getType()); 1409 MemRefType targetMemRefType = castOp.getType(); 1410 1411 // Only static shape casts supported atm. 1412 if (!sourceMemRefType.hasStaticShape() || 1413 !targetMemRefType.hasStaticShape()) 1414 return failure(); 1415 1416 auto llvmSourceDescriptorTy = 1417 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); 1418 if (!llvmSourceDescriptorTy) 1419 return failure(); 1420 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1421 1422 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1423 typeConverter->convertType(targetMemRefType)); 1424 if (!llvmTargetDescriptorTy) 1425 return failure(); 1426 1427 // Only contiguous source buffers supported atm. 1428 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1429 if (!sourceStrides) 1430 return failure(); 1431 auto targetStrides = computeContiguousStrides(targetMemRefType); 1432 if (!targetStrides) 1433 return failure(); 1434 // Only support static strides for now, regardless of contiguity. 1435 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1436 return failure(); 1437 1438 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1439 1440 // Create descriptor. 1441 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1442 Type llvmTargetElementTy = desc.getElementPtrType(); 1443 // Set allocated ptr. 1444 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1445 if (!getTypeConverter()->useOpaquePointers()) 1446 allocated = 1447 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1448 desc.setAllocatedPtr(rewriter, loc, allocated); 1449 1450 // Set aligned ptr. 1451 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1452 if (!getTypeConverter()->useOpaquePointers()) 1453 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1454 1455 desc.setAlignedPtr(rewriter, loc, ptr); 1456 // Fill offset 0. 1457 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1458 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1459 desc.setOffset(rewriter, loc, zero); 1460 1461 // Fill size and stride descriptors in memref. 1462 for (const auto &indexedSize : 1463 llvm::enumerate(targetMemRefType.getShape())) { 1464 int64_t index = indexedSize.index(); 1465 auto sizeAttr = 1466 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1467 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1468 desc.setSize(rewriter, loc, index, size); 1469 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1470 (*targetStrides)[index]); 1471 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1472 desc.setStride(rewriter, loc, index, stride); 1473 } 1474 1475 rewriter.replaceOp(castOp, {desc}); 1476 return success(); 1477 } 1478 }; 1479 1480 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1481 /// Non-scalable versions of this operation are handled in Vector Transforms. 1482 class VectorCreateMaskOpRewritePattern 1483 : public OpRewritePattern<vector::CreateMaskOp> { 1484 public: 1485 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1486 bool enableIndexOpt) 1487 : OpRewritePattern<vector::CreateMaskOp>(context), 1488 force32BitVectorIndices(enableIndexOpt) {} 1489 1490 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1491 PatternRewriter &rewriter) const override { 1492 auto dstType = op.getType(); 1493 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) 1494 return failure(); 1495 IntegerType idxType = 1496 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1497 auto loc = op->getLoc(); 1498 Value indices = rewriter.create<LLVM::StepVectorOp>( 1499 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1500 /*isScalable=*/true)); 1501 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1502 op.getOperand(0)); 1503 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1504 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1505 indices, bounds); 1506 rewriter.replaceOp(op, comp); 1507 return success(); 1508 } 1509 1510 private: 1511 const bool force32BitVectorIndices; 1512 }; 1513 1514 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1515 public: 1516 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1517 1518 // Lowering implementation that relies on a small runtime support library, 1519 // which only needs to provide a few printing methods (single value for all 1520 // data types, opening/closing bracket, comma, newline). The lowering splits 1521 // the vector into elementary printing operations. The advantage of this 1522 // approach is that the library can remain unaware of all low-level 1523 // implementation details of vectors while still supporting output of any 1524 // shaped and dimensioned vector. 1525 // 1526 // Note: This lowering only handles scalars, n-D vectors are broken into 1527 // printing scalars in loops in VectorToSCF. 1528 // 1529 // TODO: rely solely on libc in future? something else? 1530 // 1531 LogicalResult 1532 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1533 ConversionPatternRewriter &rewriter) const override { 1534 auto parent = printOp->getParentOfType<ModuleOp>(); 1535 if (!parent) 1536 return failure(); 1537 1538 auto loc = printOp->getLoc(); 1539 1540 if (auto value = adaptor.getSource()) { 1541 Type printType = printOp.getPrintType(); 1542 if (isa<VectorType>(printType)) { 1543 // Vectors should be broken into elementary print ops in VectorToSCF. 1544 return failure(); 1545 } 1546 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value))) 1547 return failure(); 1548 } 1549 1550 auto punct = printOp.getPunctuation(); 1551 if (punct != PrintPunctuation::NoPunctuation) { 1552 emitCall(rewriter, printOp->getLoc(), [&] { 1553 switch (punct) { 1554 case PrintPunctuation::Close: 1555 return LLVM::lookupOrCreatePrintCloseFn(parent); 1556 case PrintPunctuation::Open: 1557 return LLVM::lookupOrCreatePrintOpenFn(parent); 1558 case PrintPunctuation::Comma: 1559 return LLVM::lookupOrCreatePrintCommaFn(parent); 1560 case PrintPunctuation::NewLine: 1561 return LLVM::lookupOrCreatePrintNewlineFn(parent); 1562 default: 1563 llvm_unreachable("unexpected punctuation"); 1564 } 1565 }()); 1566 } 1567 1568 rewriter.eraseOp(printOp); 1569 return success(); 1570 } 1571 1572 private: 1573 enum class PrintConversion { 1574 // clang-format off 1575 None, 1576 ZeroExt64, 1577 SignExt64, 1578 Bitcast16 1579 // clang-format on 1580 }; 1581 1582 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter, 1583 ModuleOp parent, Location loc, Type printType, 1584 Value value) const { 1585 if (typeConverter->convertType(printType) == nullptr) 1586 return failure(); 1587 1588 // Make sure element type has runtime support. 1589 PrintConversion conversion = PrintConversion::None; 1590 Operation *printer; 1591 if (printType.isF32()) { 1592 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1593 } else if (printType.isF64()) { 1594 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1595 } else if (printType.isF16()) { 1596 conversion = PrintConversion::Bitcast16; // bits! 1597 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1598 } else if (printType.isBF16()) { 1599 conversion = PrintConversion::Bitcast16; // bits! 1600 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1601 } else if (printType.isIndex()) { 1602 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1603 } else if (auto intTy = dyn_cast<IntegerType>(printType)) { 1604 // Integers need a zero or sign extension on the operand 1605 // (depending on the source type) as well as a signed or 1606 // unsigned print method. Up to 64-bit is supported. 1607 unsigned width = intTy.getWidth(); 1608 if (intTy.isUnsigned()) { 1609 if (width <= 64) { 1610 if (width < 64) 1611 conversion = PrintConversion::ZeroExt64; 1612 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1613 } else { 1614 return failure(); 1615 } 1616 } else { 1617 assert(intTy.isSignless() || intTy.isSigned()); 1618 if (width <= 64) { 1619 // Note that we *always* zero extend booleans (1-bit integers), 1620 // so that true/false is printed as 1/0 rather than -1/0. 1621 if (width == 1) 1622 conversion = PrintConversion::ZeroExt64; 1623 else if (width < 64) 1624 conversion = PrintConversion::SignExt64; 1625 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1626 } else { 1627 return failure(); 1628 } 1629 } 1630 } else { 1631 return failure(); 1632 } 1633 1634 switch (conversion) { 1635 case PrintConversion::ZeroExt64: 1636 value = rewriter.create<arith::ExtUIOp>( 1637 loc, IntegerType::get(rewriter.getContext(), 64), value); 1638 break; 1639 case PrintConversion::SignExt64: 1640 value = rewriter.create<arith::ExtSIOp>( 1641 loc, IntegerType::get(rewriter.getContext(), 64), value); 1642 break; 1643 case PrintConversion::Bitcast16: 1644 value = rewriter.create<LLVM::BitcastOp>( 1645 loc, IntegerType::get(rewriter.getContext(), 16), value); 1646 break; 1647 case PrintConversion::None: 1648 break; 1649 } 1650 emitCall(rewriter, loc, printer, value); 1651 return success(); 1652 } 1653 1654 // Helper to emit a call. 1655 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1656 Operation *ref, ValueRange params = ValueRange()) { 1657 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1658 params); 1659 } 1660 }; 1661 1662 /// The Splat operation is lowered to an insertelement + a shufflevector 1663 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1664 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1665 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1666 1667 LogicalResult 1668 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1669 ConversionPatternRewriter &rewriter) const override { 1670 VectorType resultType = cast<VectorType>(splatOp.getType()); 1671 if (resultType.getRank() > 1) 1672 return failure(); 1673 1674 // First insert it into an undef vector so we can shuffle it. 1675 auto vectorType = typeConverter->convertType(splatOp.getType()); 1676 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1677 auto zero = rewriter.create<LLVM::ConstantOp>( 1678 splatOp.getLoc(), 1679 typeConverter->convertType(rewriter.getIntegerType(32)), 1680 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1681 1682 // For 0-d vector, we simply do `insertelement`. 1683 if (resultType.getRank() == 0) { 1684 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1685 splatOp, vectorType, undef, adaptor.getInput(), zero); 1686 return success(); 1687 } 1688 1689 // For 1-d vector, we additionally do a `vectorshuffle`. 1690 auto v = rewriter.create<LLVM::InsertElementOp>( 1691 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1692 1693 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); 1694 SmallVector<int32_t> zeroValues(width, 0); 1695 1696 // Shuffle the value across the desired number of elements. 1697 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1698 zeroValues); 1699 return success(); 1700 } 1701 }; 1702 1703 /// The Splat operation is lowered to an insertelement + a shufflevector 1704 /// operation. Splat to only 2+-d vector result types are lowered by the 1705 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1706 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1707 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1708 1709 LogicalResult 1710 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1711 ConversionPatternRewriter &rewriter) const override { 1712 VectorType resultType = splatOp.getType(); 1713 if (resultType.getRank() <= 1) 1714 return failure(); 1715 1716 // First insert it into an undef vector so we can shuffle it. 1717 auto loc = splatOp.getLoc(); 1718 auto vectorTypeInfo = 1719 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1720 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1721 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1722 if (!llvmNDVectorTy || !llvm1DVectorTy) 1723 return failure(); 1724 1725 // Construct returned value. 1726 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1727 1728 // Construct a 1-D vector with the splatted value that we insert in all the 1729 // places within the returned descriptor. 1730 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1731 auto zero = rewriter.create<LLVM::ConstantOp>( 1732 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1733 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1734 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1735 adaptor.getInput(), zero); 1736 1737 // Shuffle the value across the desired number of elements. 1738 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1739 SmallVector<int32_t> zeroValues(width, 0); 1740 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1741 1742 // Iterate of linear index, convert to coords space and insert splatted 1-D 1743 // vector in each position. 1744 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1745 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1746 }); 1747 rewriter.replaceOp(splatOp, desc); 1748 return success(); 1749 } 1750 }; 1751 1752 } // namespace 1753 1754 /// Populate the given list with patterns that convert from Vector to LLVM. 1755 void mlir::populateVectorToLLVMConversionPatterns( 1756 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1757 bool reassociateFPReductions, bool force32BitVectorIndices) { 1758 MLIRContext *ctx = converter.getDialect()->getContext(); 1759 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1760 populateVectorInsertExtractStridedSliceTransforms(patterns); 1761 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1762 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1763 patterns 1764 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1765 VectorExtractElementOpConversion, VectorExtractOpConversion, 1766 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1767 VectorInsertOpConversion, VectorPrintOpConversion, 1768 VectorTypeCastOpConversion, VectorScaleOpConversion, 1769 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1770 VectorLoadStoreConversion<vector::MaskedLoadOp, 1771 vector::MaskedLoadOpAdaptor>, 1772 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1773 VectorLoadStoreConversion<vector::MaskedStoreOp, 1774 vector::MaskedStoreOpAdaptor>, 1775 VectorGatherOpConversion, VectorScatterOpConversion, 1776 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1777 VectorSplatOpLowering, VectorSplatNdOpLowering, 1778 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1779 MaskedReductionOpConversion>(converter); 1780 // Transfer ops with rank > 1 are handled by VectorToSCF. 1781 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1782 } 1783 1784 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1785 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1786 patterns.add<VectorMatmulOpConversion>(converter); 1787 patterns.add<VectorFlatTransposeOpConversion>(converter); 1788 } 1789