1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" 12 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Utils/Utils.h" 17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 22 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 24 #include "mlir/IR/BuiltinAttributes.h" 25 #include "mlir/IR/BuiltinTypeInterfaces.h" 26 #include "mlir/IR/BuiltinTypes.h" 27 #include "mlir/IR/TypeUtilities.h" 28 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 #include "llvm/ADT/APFloat.h" 31 #include "llvm/Support/Casting.h" 32 #include <optional> 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 // Helper to reduce vector type by *all* but one rank at back. 38 static VectorType reducedVectorTypeBack(VectorType tp) { 39 assert((tp.getRank() > 1) && "unlowerable vector type"); 40 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 41 tp.getScalableDims().take_back()); 42 } 43 44 // Helper that picks the proper sequence for inserting. 45 static Value insertOne(ConversionPatternRewriter &rewriter, 46 const LLVMTypeConverter &typeConverter, Location loc, 47 Value val1, Value val2, Type llvmType, int64_t rank, 48 int64_t pos) { 49 assert(rank > 0 && "0-D vector corner case should have been handled already"); 50 if (rank == 1) { 51 auto idxType = rewriter.getIndexType(); 52 auto constant = rewriter.create<LLVM::ConstantOp>( 53 loc, typeConverter.convertType(idxType), 54 rewriter.getIntegerAttr(idxType, pos)); 55 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 56 constant); 57 } 58 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 59 } 60 61 // Helper that picks the proper sequence for extracting. 62 static Value extractOne(ConversionPatternRewriter &rewriter, 63 const LLVMTypeConverter &typeConverter, Location loc, 64 Value val, Type llvmType, int64_t rank, int64_t pos) { 65 if (rank <= 1) { 66 auto idxType = rewriter.getIndexType(); 67 auto constant = rewriter.create<LLVM::ConstantOp>( 68 loc, typeConverter.convertType(idxType), 69 rewriter.getIntegerAttr(idxType, pos)); 70 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 71 constant); 72 } 73 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 74 } 75 76 // Helper that returns data layout alignment of a memref. 77 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, 78 MemRefType memrefType, unsigned &align) { 79 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 80 if (!elementTy) 81 return failure(); 82 83 // TODO: this should use the MLIR data layout when it becomes available and 84 // stop depending on translation. 85 llvm::LLVMContext llvmContext; 86 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 87 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 88 return success(); 89 } 90 91 // Check if the last stride is non-unit and has a valid memory space. 92 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 93 const LLVMTypeConverter &converter) { 94 if (!isLastMemrefDimUnitStride(memRefType)) 95 return failure(); 96 if (failed(converter.getMemRefAddressSpace(memRefType))) 97 return failure(); 98 return success(); 99 } 100 101 // Add an index vector component to a base pointer. 102 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 103 const LLVMTypeConverter &typeConverter, 104 MemRefType memRefType, Value llvmMemref, Value base, 105 Value index, uint64_t vLen) { 106 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 107 "unsupported memref type"); 108 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 109 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 110 return rewriter.create<LLVM::GEPOp>( 111 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 112 base, index); 113 } 114 115 // Casts a strided element pointer to a vector pointer. The vector pointer 116 // will be in the same address space as the incoming memref type. 117 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 118 Value ptr, MemRefType memRefType, Type vt, 119 const LLVMTypeConverter &converter) { 120 if (converter.useOpaquePointers()) 121 return ptr; 122 123 unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); 124 auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); 125 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 126 } 127 128 /// Convert `foldResult` into a Value. Integer attribute is converted to 129 /// an LLVM constant op. 130 static Value getAsLLVMValue(OpBuilder &builder, Location loc, 131 OpFoldResult foldResult) { 132 if (auto attr = foldResult.dyn_cast<Attribute>()) { 133 auto intAttr = cast<IntegerAttr>(attr); 134 return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult(); 135 } 136 137 return foldResult.get<Value>(); 138 } 139 140 namespace { 141 142 /// Trivial Vector to LLVM conversions 143 using VectorScaleOpConversion = 144 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 145 146 /// Conversion pattern for a vector.bitcast. 147 class VectorBitCastOpConversion 148 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 149 public: 150 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 151 152 LogicalResult 153 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 154 ConversionPatternRewriter &rewriter) const override { 155 // Only 0-D and 1-D vectors can be lowered to LLVM. 156 VectorType resultTy = bitCastOp.getResultVectorType(); 157 if (resultTy.getRank() > 1) 158 return failure(); 159 Type newResultTy = typeConverter->convertType(resultTy); 160 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 161 adaptor.getOperands()[0]); 162 return success(); 163 } 164 }; 165 166 /// Conversion pattern for a vector.matrix_multiply. 167 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 168 class VectorMatmulOpConversion 169 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 170 public: 171 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 172 173 LogicalResult 174 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 175 ConversionPatternRewriter &rewriter) const override { 176 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 177 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 178 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 179 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 180 return success(); 181 } 182 }; 183 184 /// Conversion pattern for a vector.flat_transpose. 185 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 186 class VectorFlatTransposeOpConversion 187 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 188 public: 189 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 190 191 LogicalResult 192 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 193 ConversionPatternRewriter &rewriter) const override { 194 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 195 transOp, typeConverter->convertType(transOp.getRes().getType()), 196 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 197 return success(); 198 } 199 }; 200 201 /// Overloaded utility that replaces a vector.load, vector.store, 202 /// vector.maskedload and vector.maskedstore with their respective LLVM 203 /// couterparts. 204 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 205 vector::LoadOpAdaptor adaptor, 206 VectorType vectorTy, Value ptr, unsigned align, 207 ConversionPatternRewriter &rewriter) { 208 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 209 } 210 211 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 212 vector::MaskedLoadOpAdaptor adaptor, 213 VectorType vectorTy, Value ptr, unsigned align, 214 ConversionPatternRewriter &rewriter) { 215 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 216 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 217 } 218 219 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 220 vector::StoreOpAdaptor adaptor, 221 VectorType vectorTy, Value ptr, unsigned align, 222 ConversionPatternRewriter &rewriter) { 223 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 224 ptr, align); 225 } 226 227 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 228 vector::MaskedStoreOpAdaptor adaptor, 229 VectorType vectorTy, Value ptr, unsigned align, 230 ConversionPatternRewriter &rewriter) { 231 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 232 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 233 } 234 235 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 236 /// vector.maskedstore. 237 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 238 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 239 public: 240 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 241 242 LogicalResult 243 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 244 typename LoadOrStoreOp::Adaptor adaptor, 245 ConversionPatternRewriter &rewriter) const override { 246 // Only 1-D vectors can be lowered to LLVM. 247 VectorType vectorTy = loadOrStoreOp.getVectorType(); 248 if (vectorTy.getRank() > 1) 249 return failure(); 250 251 auto loc = loadOrStoreOp->getLoc(); 252 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 253 254 // Resolve alignment. 255 unsigned align; 256 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 257 return failure(); 258 259 // Resolve address. 260 auto vtype = cast<VectorType>( 261 this->typeConverter->convertType(loadOrStoreOp.getVectorType())); 262 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 263 adaptor.getIndices(), rewriter); 264 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, 265 *this->getTypeConverter()); 266 267 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 268 return success(); 269 } 270 }; 271 272 /// Conversion pattern for a vector.gather. 273 class VectorGatherOpConversion 274 : public ConvertOpToLLVMPattern<vector::GatherOp> { 275 public: 276 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 277 278 LogicalResult 279 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 280 ConversionPatternRewriter &rewriter) const override { 281 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); 282 assert(memRefType && "The base should be bufferized"); 283 284 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 285 return failure(); 286 287 auto loc = gather->getLoc(); 288 289 // Resolve alignment. 290 unsigned align; 291 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 292 return failure(); 293 294 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 295 adaptor.getIndices(), rewriter); 296 Value base = adaptor.getBase(); 297 298 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 299 // Handle the simple case of 1-D vector. 300 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) { 301 auto vType = gather.getVectorType(); 302 // Resolve address. 303 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 304 memRefType, base, ptr, adaptor.getIndexVec(), 305 /*vLen=*/vType.getDimSize(0)); 306 // Replace with the gather intrinsic. 307 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 308 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 309 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 310 return success(); 311 } 312 313 const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 314 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 315 &typeConverter](Type llvm1DVectorTy, 316 ValueRange vectorOperands) { 317 // Resolve address. 318 Value ptrs = getIndexedPtrs( 319 rewriter, loc, typeConverter, memRefType, base, ptr, 320 /*index=*/vectorOperands[0], 321 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 322 // Create the gather intrinsic. 323 return rewriter.create<LLVM::masked_gather>( 324 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 325 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 326 }; 327 SmallVector<Value> vectorOperands = { 328 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 329 return LLVM::detail::handleMultidimensionalVectors( 330 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 331 } 332 }; 333 334 /// Conversion pattern for a vector.scatter. 335 class VectorScatterOpConversion 336 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 337 public: 338 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 339 340 LogicalResult 341 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 342 ConversionPatternRewriter &rewriter) const override { 343 auto loc = scatter->getLoc(); 344 MemRefType memRefType = scatter.getMemRefType(); 345 346 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 347 return failure(); 348 349 // Resolve alignment. 350 unsigned align; 351 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 352 return failure(); 353 354 // Resolve address. 355 VectorType vType = scatter.getVectorType(); 356 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 357 adaptor.getIndices(), rewriter); 358 Value ptrs = getIndexedPtrs( 359 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 360 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 361 362 // Replace with the scatter intrinsic. 363 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 364 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 365 rewriter.getI32IntegerAttr(align)); 366 return success(); 367 } 368 }; 369 370 /// Conversion pattern for a vector.expandload. 371 class VectorExpandLoadOpConversion 372 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 373 public: 374 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 375 376 LogicalResult 377 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 378 ConversionPatternRewriter &rewriter) const override { 379 auto loc = expand->getLoc(); 380 MemRefType memRefType = expand.getMemRefType(); 381 382 // Resolve address. 383 auto vtype = typeConverter->convertType(expand.getVectorType()); 384 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 385 adaptor.getIndices(), rewriter); 386 387 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 388 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 389 return success(); 390 } 391 }; 392 393 /// Conversion pattern for a vector.compressstore. 394 class VectorCompressStoreOpConversion 395 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 396 public: 397 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 398 399 LogicalResult 400 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 401 ConversionPatternRewriter &rewriter) const override { 402 auto loc = compress->getLoc(); 403 MemRefType memRefType = compress.getMemRefType(); 404 405 // Resolve address. 406 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 407 adaptor.getIndices(), rewriter); 408 409 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 410 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 411 return success(); 412 } 413 }; 414 415 /// Reduction neutral classes for overloading. 416 class ReductionNeutralZero {}; 417 class ReductionNeutralIntOne {}; 418 class ReductionNeutralFPOne {}; 419 class ReductionNeutralAllOnes {}; 420 class ReductionNeutralSIntMin {}; 421 class ReductionNeutralUIntMin {}; 422 class ReductionNeutralSIntMax {}; 423 class ReductionNeutralUIntMax {}; 424 class ReductionNeutralFPMin {}; 425 class ReductionNeutralFPMax {}; 426 427 /// Create the reduction neutral zero value. 428 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 429 ConversionPatternRewriter &rewriter, 430 Location loc, Type llvmType) { 431 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 432 rewriter.getZeroAttr(llvmType)); 433 } 434 435 /// Create the reduction neutral integer one value. 436 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 437 ConversionPatternRewriter &rewriter, 438 Location loc, Type llvmType) { 439 return rewriter.create<LLVM::ConstantOp>( 440 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 441 } 442 443 /// Create the reduction neutral fp one value. 444 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 445 ConversionPatternRewriter &rewriter, 446 Location loc, Type llvmType) { 447 return rewriter.create<LLVM::ConstantOp>( 448 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 449 } 450 451 /// Create the reduction neutral all-ones value. 452 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 453 ConversionPatternRewriter &rewriter, 454 Location loc, Type llvmType) { 455 return rewriter.create<LLVM::ConstantOp>( 456 loc, llvmType, 457 rewriter.getIntegerAttr( 458 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 459 } 460 461 /// Create the reduction neutral signed int minimum value. 462 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 463 ConversionPatternRewriter &rewriter, 464 Location loc, Type llvmType) { 465 return rewriter.create<LLVM::ConstantOp>( 466 loc, llvmType, 467 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 468 llvmType.getIntOrFloatBitWidth()))); 469 } 470 471 /// Create the reduction neutral unsigned int minimum value. 472 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 473 ConversionPatternRewriter &rewriter, 474 Location loc, Type llvmType) { 475 return rewriter.create<LLVM::ConstantOp>( 476 loc, llvmType, 477 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 478 llvmType.getIntOrFloatBitWidth()))); 479 } 480 481 /// Create the reduction neutral signed int maximum value. 482 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 483 ConversionPatternRewriter &rewriter, 484 Location loc, Type llvmType) { 485 return rewriter.create<LLVM::ConstantOp>( 486 loc, llvmType, 487 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 488 llvmType.getIntOrFloatBitWidth()))); 489 } 490 491 /// Create the reduction neutral unsigned int maximum value. 492 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 493 ConversionPatternRewriter &rewriter, 494 Location loc, Type llvmType) { 495 return rewriter.create<LLVM::ConstantOp>( 496 loc, llvmType, 497 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 498 llvmType.getIntOrFloatBitWidth()))); 499 } 500 501 /// Create the reduction neutral fp minimum value. 502 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 503 ConversionPatternRewriter &rewriter, 504 Location loc, Type llvmType) { 505 auto floatType = cast<FloatType>(llvmType); 506 return rewriter.create<LLVM::ConstantOp>( 507 loc, llvmType, 508 rewriter.getFloatAttr( 509 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 510 /*Negative=*/false))); 511 } 512 513 /// Create the reduction neutral fp maximum value. 514 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 515 ConversionPatternRewriter &rewriter, 516 Location loc, Type llvmType) { 517 auto floatType = cast<FloatType>(llvmType); 518 return rewriter.create<LLVM::ConstantOp>( 519 loc, llvmType, 520 rewriter.getFloatAttr( 521 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 522 /*Negative=*/true))); 523 } 524 525 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 526 /// returns a new accumulator value using `ReductionNeutral`. 527 template <class ReductionNeutral> 528 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 529 Location loc, Type llvmType, 530 Value accumulator) { 531 if (accumulator) 532 return accumulator; 533 534 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 535 llvmType); 536 } 537 538 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 539 /// This is used as effective vector length by some intrinsics supporting 540 /// dynamic vector lengths at runtime. 541 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 542 Location loc, Type llvmType) { 543 VectorType vType = cast<VectorType>(llvmType); 544 auto vShape = vType.getShape(); 545 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 546 547 return rewriter.create<LLVM::ConstantOp>( 548 loc, rewriter.getI32Type(), 549 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 550 } 551 552 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 553 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 554 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 555 /// non-null. 556 template <class LLVMRedIntrinOp, class ScalarOp> 557 static Value createIntegerReductionArithmeticOpLowering( 558 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 559 Value vectorOperand, Value accumulator) { 560 561 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 562 563 if (accumulator) 564 result = rewriter.create<ScalarOp>(loc, accumulator, result); 565 return result; 566 } 567 568 /// Helper method to lower a `vector.reduction` operation that performs 569 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 570 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 571 /// the accumulator value if non-null. 572 template <class LLVMRedIntrinOp> 573 static Value createIntegerReductionComparisonOpLowering( 574 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 575 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 576 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 577 if (accumulator) { 578 Value cmp = 579 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 580 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 581 } 582 return result; 583 } 584 585 namespace { 586 template <typename Source> 587 struct VectorToScalarMapper; 588 template <> 589 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> { 590 using Type = LLVM::MaximumOp; 591 }; 592 template <> 593 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> { 594 using Type = LLVM::MinimumOp; 595 }; 596 template <> 597 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> { 598 using Type = LLVM::MaxNumOp; 599 }; 600 template <> 601 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> { 602 using Type = LLVM::MinNumOp; 603 }; 604 } // namespace 605 606 template <class LLVMRedIntrinOp> 607 static Value createFPReductionComparisonOpLowering( 608 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 609 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { 610 Value result = 611 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf); 612 613 if (accumulator) { 614 result = 615 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>( 616 loc, result, accumulator); 617 } 618 619 return result; 620 } 621 622 /// Reduction neutral classes for overloading 623 class MaskNeutralFMaximum {}; 624 class MaskNeutralFMinimum {}; 625 626 /// Get the mask neutral floating point maximum value 627 static llvm::APFloat 628 getMaskNeutralValue(MaskNeutralFMaximum, 629 const llvm::fltSemantics &floatSemantics) { 630 return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true); 631 } 632 /// Get the mask neutral floating point minimum value 633 static llvm::APFloat 634 getMaskNeutralValue(MaskNeutralFMinimum, 635 const llvm::fltSemantics &floatSemantics) { 636 return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false); 637 } 638 639 /// Create the mask neutral floating point MLIR vector constant 640 template <typename MaskNeutral> 641 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, 642 Location loc, Type llvmType, 643 Type vectorType) { 644 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics(); 645 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); 646 auto denseValue = 647 DenseElementsAttr::get(vectorType.cast<ShapedType>(), value); 648 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue); 649 } 650 651 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked 652 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for 653 /// `fmaximum`/`fminimum`. 654 /// More information: https://github.com/llvm/llvm-project/issues/64940 655 template <class LLVMRedIntrinOp, class MaskNeutral> 656 static Value 657 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, 658 Location loc, Type llvmType, 659 Value vectorOperand, Value accumulator, 660 Value mask, LLVM::FastmathFlagsAttr fmf) { 661 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>( 662 rewriter, loc, llvmType, vectorOperand.getType()); 663 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>( 664 loc, mask, vectorOperand, vectorMaskNeutral); 665 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>( 666 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); 667 } 668 669 template <class LLVMRedIntrinOp, class ReductionNeutral> 670 static Value 671 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 672 Type llvmType, Value vectorOperand, 673 Value accumulator, LLVM::FastmathFlagsAttr fmf) { 674 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 675 llvmType, accumulator); 676 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType, 677 /*startValue=*/accumulator, 678 vectorOperand, fmf); 679 } 680 681 /// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic 682 /// that requires a start value. This start value format spans across fp 683 /// reductions without mask and all the masked reduction intrinsics. 684 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 685 static Value 686 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, 687 Location loc, Type llvmType, 688 Value vectorOperand, Value accumulator) { 689 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 690 llvmType, accumulator); 691 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 692 /*startValue=*/accumulator, 693 vectorOperand); 694 } 695 696 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 697 static Value lowerPredicatedReductionWithStartValue( 698 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 699 Value vectorOperand, Value accumulator, Value mask) { 700 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 701 llvmType, accumulator); 702 Value vectorLength = 703 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 704 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 705 /*startValue=*/accumulator, 706 vectorOperand, mask, vectorLength); 707 } 708 709 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 710 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 711 static Value lowerPredicatedReductionWithStartValue( 712 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 713 Value vectorOperand, Value accumulator, Value mask) { 714 if (llvmType.isIntOrIndex()) 715 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp, 716 IntReductionNeutral>( 717 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 718 719 // FP dispatch. 720 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp, 721 FPReductionNeutral>( 722 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 723 } 724 725 /// Conversion pattern for all vector reductions. 726 class VectorReductionOpConversion 727 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 728 public: 729 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv, 730 bool reassociateFPRed) 731 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 732 reassociateFPReductions(reassociateFPRed) {} 733 734 LogicalResult 735 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 736 ConversionPatternRewriter &rewriter) const override { 737 auto kind = reductionOp.getKind(); 738 Type eltType = reductionOp.getDest().getType(); 739 Type llvmType = typeConverter->convertType(eltType); 740 Value operand = adaptor.getVector(); 741 Value acc = adaptor.getAcc(); 742 Location loc = reductionOp.getLoc(); 743 744 if (eltType.isIntOrIndex()) { 745 // Integer reductions: add/mul/min/max/and/or/xor. 746 Value result; 747 switch (kind) { 748 case vector::CombiningKind::ADD: 749 result = 750 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 751 LLVM::AddOp>( 752 rewriter, loc, llvmType, operand, acc); 753 break; 754 case vector::CombiningKind::MUL: 755 result = 756 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 757 LLVM::MulOp>( 758 rewriter, loc, llvmType, operand, acc); 759 break; 760 case vector::CombiningKind::MINUI: 761 result = createIntegerReductionComparisonOpLowering< 762 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 763 LLVM::ICmpPredicate::ule); 764 break; 765 case vector::CombiningKind::MINSI: 766 result = createIntegerReductionComparisonOpLowering< 767 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 768 LLVM::ICmpPredicate::sle); 769 break; 770 case vector::CombiningKind::MAXUI: 771 result = createIntegerReductionComparisonOpLowering< 772 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 773 LLVM::ICmpPredicate::uge); 774 break; 775 case vector::CombiningKind::MAXSI: 776 result = createIntegerReductionComparisonOpLowering< 777 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 778 LLVM::ICmpPredicate::sge); 779 break; 780 case vector::CombiningKind::AND: 781 result = 782 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 783 LLVM::AndOp>( 784 rewriter, loc, llvmType, operand, acc); 785 break; 786 case vector::CombiningKind::OR: 787 result = 788 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 789 LLVM::OrOp>( 790 rewriter, loc, llvmType, operand, acc); 791 break; 792 case vector::CombiningKind::XOR: 793 result = 794 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 795 LLVM::XOrOp>( 796 rewriter, loc, llvmType, operand, acc); 797 break; 798 default: 799 return failure(); 800 } 801 rewriter.replaceOp(reductionOp, result); 802 803 return success(); 804 } 805 806 if (!isa<FloatType>(eltType)) 807 return failure(); 808 809 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 810 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 811 reductionOp.getContext(), 812 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 813 fmf = LLVM::FastmathFlagsAttr::get( 814 reductionOp.getContext(), 815 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc 816 : LLVM::FastmathFlags::none)); 817 818 // Floating-point reductions: add/mul/min/max 819 Value result; 820 if (kind == vector::CombiningKind::ADD) { 821 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 822 ReductionNeutralZero>( 823 rewriter, loc, llvmType, operand, acc, fmf); 824 } else if (kind == vector::CombiningKind::MUL) { 825 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 826 ReductionNeutralFPOne>( 827 rewriter, loc, llvmType, operand, acc, fmf); 828 } else if (kind == vector::CombiningKind::MINIMUMF) { 829 result = 830 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>( 831 rewriter, loc, llvmType, operand, acc, fmf); 832 } else if (kind == vector::CombiningKind::MAXIMUMF) { 833 result = 834 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>( 835 rewriter, loc, llvmType, operand, acc, fmf); 836 } else if (kind == vector::CombiningKind::MINF) { 837 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 838 rewriter, loc, llvmType, operand, acc, fmf); 839 } else if (kind == vector::CombiningKind::MAXF) { 840 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 841 rewriter, loc, llvmType, operand, acc, fmf); 842 } else 843 return failure(); 844 845 rewriter.replaceOp(reductionOp, result); 846 return success(); 847 } 848 849 private: 850 const bool reassociateFPReductions; 851 }; 852 853 /// Base class to convert a `vector.mask` operation while matching traits 854 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 855 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 856 /// method performs a second match against the maskable operation `MaskedOp`. 857 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 858 /// implemented by the concrete conversion classes. This method can match 859 /// against specific traits of the `vector.mask` and the maskable operation. It 860 /// must replace the `vector.mask` operation. 861 template <class MaskedOp> 862 class VectorMaskOpConversionBase 863 : public ConvertOpToLLVMPattern<vector::MaskOp> { 864 public: 865 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 866 867 LogicalResult 868 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 869 ConversionPatternRewriter &rewriter) const final { 870 // Match against the maskable operation kind. 871 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 872 if (!maskedOp) 873 return failure(); 874 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 875 } 876 877 protected: 878 virtual LogicalResult 879 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 880 vector::MaskableOpInterface maskableOp, 881 ConversionPatternRewriter &rewriter) const = 0; 882 }; 883 884 class MaskedReductionOpConversion 885 : public VectorMaskOpConversionBase<vector::ReductionOp> { 886 887 public: 888 using VectorMaskOpConversionBase< 889 vector::ReductionOp>::VectorMaskOpConversionBase; 890 891 LogicalResult matchAndRewriteMaskableOp( 892 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 893 ConversionPatternRewriter &rewriter) const override { 894 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 895 auto kind = reductionOp.getKind(); 896 Type eltType = reductionOp.getDest().getType(); 897 Type llvmType = typeConverter->convertType(eltType); 898 Value operand = reductionOp.getVector(); 899 Value acc = reductionOp.getAcc(); 900 Location loc = reductionOp.getLoc(); 901 902 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 903 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 904 reductionOp.getContext(), 905 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 906 907 Value result; 908 switch (kind) { 909 case vector::CombiningKind::ADD: 910 result = lowerPredicatedReductionWithStartValue< 911 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 912 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 913 maskOp.getMask()); 914 break; 915 case vector::CombiningKind::MUL: 916 result = lowerPredicatedReductionWithStartValue< 917 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 918 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 919 maskOp.getMask()); 920 break; 921 case vector::CombiningKind::MINUI: 922 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp, 923 ReductionNeutralUIntMax>( 924 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 925 break; 926 case vector::CombiningKind::MINSI: 927 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp, 928 ReductionNeutralSIntMax>( 929 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 930 break; 931 case vector::CombiningKind::MAXUI: 932 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp, 933 ReductionNeutralUIntMin>( 934 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 935 break; 936 case vector::CombiningKind::MAXSI: 937 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp, 938 ReductionNeutralSIntMin>( 939 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 940 break; 941 case vector::CombiningKind::AND: 942 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp, 943 ReductionNeutralAllOnes>( 944 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 945 break; 946 case vector::CombiningKind::OR: 947 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp, 948 ReductionNeutralZero>( 949 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 950 break; 951 case vector::CombiningKind::XOR: 952 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp, 953 ReductionNeutralZero>( 954 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 955 break; 956 case vector::CombiningKind::MINF: 957 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp, 958 ReductionNeutralFPMax>( 959 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 960 break; 961 case vector::CombiningKind::MAXF: 962 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp, 963 ReductionNeutralFPMin>( 964 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 965 break; 966 case CombiningKind::MAXIMUMF: 967 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum, 968 MaskNeutralFMaximum>( 969 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 970 break; 971 case CombiningKind::MINIMUMF: 972 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum, 973 MaskNeutralFMinimum>( 974 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 975 break; 976 } 977 978 // Replace `vector.mask` operation altogether. 979 rewriter.replaceOp(maskOp, result); 980 return success(); 981 } 982 }; 983 984 class VectorShuffleOpConversion 985 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 986 public: 987 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 988 989 LogicalResult 990 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 991 ConversionPatternRewriter &rewriter) const override { 992 auto loc = shuffleOp->getLoc(); 993 auto v1Type = shuffleOp.getV1VectorType(); 994 auto v2Type = shuffleOp.getV2VectorType(); 995 auto vectorType = shuffleOp.getResultVectorType(); 996 Type llvmType = typeConverter->convertType(vectorType); 997 auto maskArrayAttr = shuffleOp.getMask(); 998 999 // Bail if result type cannot be lowered. 1000 if (!llvmType) 1001 return failure(); 1002 1003 // Get rank and dimension sizes. 1004 int64_t rank = vectorType.getRank(); 1005 #ifndef NDEBUG 1006 bool wellFormed0DCase = 1007 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 1008 bool wellFormedNDCase = 1009 v1Type.getRank() == rank && v2Type.getRank() == rank; 1010 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 1011 #endif 1012 1013 // For rank 0 and 1, where both operands have *exactly* the same vector 1014 // type, there is direct shuffle support in LLVM. Use it! 1015 if (rank <= 1 && v1Type == v2Type) { 1016 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 1017 loc, adaptor.getV1(), adaptor.getV2(), 1018 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 1019 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 1020 return success(); 1021 } 1022 1023 // For all other cases, insert the individual values individually. 1024 int64_t v1Dim = v1Type.getDimSize(0); 1025 Type eltType; 1026 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType)) 1027 eltType = arrayType.getElementType(); 1028 else 1029 eltType = cast<VectorType>(llvmType).getElementType(); 1030 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 1031 int64_t insPos = 0; 1032 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 1033 int64_t extPos = cast<IntegerAttr>(en.value()).getInt(); 1034 Value value = adaptor.getV1(); 1035 if (extPos >= v1Dim) { 1036 extPos -= v1Dim; 1037 value = adaptor.getV2(); 1038 } 1039 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 1040 eltType, rank, extPos); 1041 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 1042 llvmType, rank, insPos++); 1043 } 1044 rewriter.replaceOp(shuffleOp, insert); 1045 return success(); 1046 } 1047 }; 1048 1049 class VectorExtractElementOpConversion 1050 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 1051 public: 1052 using ConvertOpToLLVMPattern< 1053 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 1054 1055 LogicalResult 1056 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 1057 ConversionPatternRewriter &rewriter) const override { 1058 auto vectorType = extractEltOp.getSourceVectorType(); 1059 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 1060 1061 // Bail if result type cannot be lowered. 1062 if (!llvmType) 1063 return failure(); 1064 1065 if (vectorType.getRank() == 0) { 1066 Location loc = extractEltOp.getLoc(); 1067 auto idxType = rewriter.getIndexType(); 1068 auto zero = rewriter.create<LLVM::ConstantOp>( 1069 loc, typeConverter->convertType(idxType), 1070 rewriter.getIntegerAttr(idxType, 0)); 1071 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1072 extractEltOp, llvmType, adaptor.getVector(), zero); 1073 return success(); 1074 } 1075 1076 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1077 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1078 return success(); 1079 } 1080 }; 1081 1082 class VectorExtractOpConversion 1083 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1084 public: 1085 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1086 1087 LogicalResult 1088 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1089 ConversionPatternRewriter &rewriter) const override { 1090 auto loc = extractOp->getLoc(); 1091 auto resultType = extractOp.getResult().getType(); 1092 auto llvmResultType = typeConverter->convertType(resultType); 1093 // Bail if result type cannot be lowered. 1094 if (!llvmResultType) 1095 return failure(); 1096 1097 SmallVector<OpFoldResult> positionVec; 1098 for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) { 1099 if (pos.is<Value>()) 1100 // Make sure we use the value that has been already converted to LLVM. 1101 positionVec.push_back(adaptor.getDynamicPosition()[idx]); 1102 else 1103 positionVec.push_back(pos); 1104 } 1105 1106 // Extract entire vector. Should be handled by folder, but just to be safe. 1107 ArrayRef<OpFoldResult> position(positionVec); 1108 if (position.empty()) { 1109 rewriter.replaceOp(extractOp, adaptor.getVector()); 1110 return success(); 1111 } 1112 1113 // One-shot extraction of vector from array (only requires extractvalue). 1114 if (isa<VectorType>(resultType)) { 1115 if (extractOp.hasDynamicPosition()) 1116 return failure(); 1117 1118 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1119 loc, adaptor.getVector(), getAsIntegers(position)); 1120 rewriter.replaceOp(extractOp, extracted); 1121 return success(); 1122 } 1123 1124 // Potential extraction of 1-D vector from array. 1125 Value extracted = adaptor.getVector(); 1126 if (position.size() > 1) { 1127 if (extractOp.hasDynamicPosition()) 1128 return failure(); 1129 1130 SmallVector<int64_t> nMinusOnePosition = 1131 getAsIntegers(position.drop_back()); 1132 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 1133 nMinusOnePosition); 1134 } 1135 1136 Value lastPosition = getAsLLVMValue(rewriter, loc, position.back()); 1137 // Remaining extraction of element from 1-D LLVM vector. 1138 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted, 1139 lastPosition); 1140 return success(); 1141 } 1142 }; 1143 1144 /// Conversion pattern that turns a vector.fma on a 1-D vector 1145 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1146 /// This does not match vectors of n >= 2 rank. 1147 /// 1148 /// Example: 1149 /// ``` 1150 /// vector.fma %a, %a, %a : vector<8xf32> 1151 /// ``` 1152 /// is converted to: 1153 /// ``` 1154 /// llvm.intr.fmuladd %va, %va, %va: 1155 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1156 /// -> !llvm."<8 x f32>"> 1157 /// ``` 1158 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1159 public: 1160 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1161 1162 LogicalResult 1163 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1164 ConversionPatternRewriter &rewriter) const override { 1165 VectorType vType = fmaOp.getVectorType(); 1166 if (vType.getRank() > 1) 1167 return failure(); 1168 1169 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1170 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1171 return success(); 1172 } 1173 }; 1174 1175 class VectorInsertElementOpConversion 1176 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1177 public: 1178 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1179 1180 LogicalResult 1181 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1182 ConversionPatternRewriter &rewriter) const override { 1183 auto vectorType = insertEltOp.getDestVectorType(); 1184 auto llvmType = typeConverter->convertType(vectorType); 1185 1186 // Bail if result type cannot be lowered. 1187 if (!llvmType) 1188 return failure(); 1189 1190 if (vectorType.getRank() == 0) { 1191 Location loc = insertEltOp.getLoc(); 1192 auto idxType = rewriter.getIndexType(); 1193 auto zero = rewriter.create<LLVM::ConstantOp>( 1194 loc, typeConverter->convertType(idxType), 1195 rewriter.getIntegerAttr(idxType, 0)); 1196 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1197 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1198 return success(); 1199 } 1200 1201 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1202 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1203 adaptor.getPosition()); 1204 return success(); 1205 } 1206 }; 1207 1208 class VectorInsertOpConversion 1209 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1210 public: 1211 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1212 1213 LogicalResult 1214 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1215 ConversionPatternRewriter &rewriter) const override { 1216 auto loc = insertOp->getLoc(); 1217 auto sourceType = insertOp.getSourceType(); 1218 auto destVectorType = insertOp.getDestVectorType(); 1219 auto llvmResultType = typeConverter->convertType(destVectorType); 1220 // Bail if result type cannot be lowered. 1221 if (!llvmResultType) 1222 return failure(); 1223 1224 SmallVector<OpFoldResult> positionVec; 1225 for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) { 1226 if (pos.is<Value>()) 1227 // Make sure we use the value that has been already converted to LLVM. 1228 positionVec.push_back(adaptor.getDynamicPosition()[idx]); 1229 else 1230 positionVec.push_back(pos); 1231 } 1232 1233 // Overwrite entire vector with value. Should be handled by folder, but 1234 // just to be safe. 1235 ArrayRef<OpFoldResult> position(positionVec); 1236 if (position.empty()) { 1237 rewriter.replaceOp(insertOp, adaptor.getSource()); 1238 return success(); 1239 } 1240 1241 // One-shot insertion of a vector into an array (only requires insertvalue). 1242 if (isa<VectorType>(sourceType)) { 1243 if (insertOp.hasDynamicPosition()) 1244 return failure(); 1245 1246 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1247 loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position)); 1248 rewriter.replaceOp(insertOp, inserted); 1249 return success(); 1250 } 1251 1252 // Potential extraction of 1-D vector from array. 1253 Value extracted = adaptor.getDest(); 1254 auto oneDVectorType = destVectorType; 1255 if (position.size() > 1) { 1256 if (insertOp.hasDynamicPosition()) 1257 return failure(); 1258 1259 oneDVectorType = reducedVectorTypeBack(destVectorType); 1260 extracted = rewriter.create<LLVM::ExtractValueOp>( 1261 loc, extracted, getAsIntegers(position.drop_back())); 1262 } 1263 1264 // Insertion of an element into a 1-D LLVM vector. 1265 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1266 loc, typeConverter->convertType(oneDVectorType), extracted, 1267 adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back())); 1268 1269 // Potential insertion of resulting 1-D vector into array. 1270 if (position.size() > 1) { 1271 if (insertOp.hasDynamicPosition()) 1272 return failure(); 1273 1274 inserted = rewriter.create<LLVM::InsertValueOp>( 1275 loc, adaptor.getDest(), inserted, 1276 getAsIntegers(position.drop_back())); 1277 } 1278 1279 rewriter.replaceOp(insertOp, inserted); 1280 return success(); 1281 } 1282 }; 1283 1284 /// Lower vector.scalable.insert ops to LLVM vector.insert 1285 struct VectorScalableInsertOpLowering 1286 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1287 using ConvertOpToLLVMPattern< 1288 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1289 1290 LogicalResult 1291 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1292 ConversionPatternRewriter &rewriter) const override { 1293 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1294 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1295 return success(); 1296 } 1297 }; 1298 1299 /// Lower vector.scalable.extract ops to LLVM vector.extract 1300 struct VectorScalableExtractOpLowering 1301 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1302 using ConvertOpToLLVMPattern< 1303 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1304 1305 LogicalResult 1306 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1307 ConversionPatternRewriter &rewriter) const override { 1308 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1309 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1310 adaptor.getSource(), adaptor.getPos()); 1311 return success(); 1312 } 1313 }; 1314 1315 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1316 /// 1317 /// Example: 1318 /// ``` 1319 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1320 /// ``` 1321 /// is rewritten into: 1322 /// ``` 1323 /// %r = splat %f0: vector<2x4xf32> 1324 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1325 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1326 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1327 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1328 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1329 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1330 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1331 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1332 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1333 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1334 /// // %r3 holds the final value. 1335 /// ``` 1336 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1337 public: 1338 using OpRewritePattern<FMAOp>::OpRewritePattern; 1339 1340 void initialize() { 1341 // This pattern recursively unpacks one dimension at a time. The recursion 1342 // bounded as the rank is strictly decreasing. 1343 setHasBoundedRewriteRecursion(); 1344 } 1345 1346 LogicalResult matchAndRewrite(FMAOp op, 1347 PatternRewriter &rewriter) const override { 1348 auto vType = op.getVectorType(); 1349 if (vType.getRank() < 2) 1350 return failure(); 1351 1352 auto loc = op.getLoc(); 1353 auto elemType = vType.getElementType(); 1354 Value zero = rewriter.create<arith::ConstantOp>( 1355 loc, elemType, rewriter.getZeroAttr(elemType)); 1356 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1357 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1358 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1359 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1360 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1361 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1362 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1363 } 1364 rewriter.replaceOp(op, desc); 1365 return success(); 1366 } 1367 }; 1368 1369 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1370 /// static layout. 1371 static std::optional<SmallVector<int64_t, 4>> 1372 computeContiguousStrides(MemRefType memRefType) { 1373 int64_t offset; 1374 SmallVector<int64_t, 4> strides; 1375 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1376 return std::nullopt; 1377 if (!strides.empty() && strides.back() != 1) 1378 return std::nullopt; 1379 // If no layout or identity layout, this is contiguous by definition. 1380 if (memRefType.getLayout().isIdentity()) 1381 return strides; 1382 1383 // Otherwise, we must determine contiguity form shapes. This can only ever 1384 // work in static cases because MemRefType is underspecified to represent 1385 // contiguous dynamic shapes in other ways than with just empty/identity 1386 // layout. 1387 auto sizes = memRefType.getShape(); 1388 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1389 if (ShapedType::isDynamic(sizes[index + 1]) || 1390 ShapedType::isDynamic(strides[index]) || 1391 ShapedType::isDynamic(strides[index + 1])) 1392 return std::nullopt; 1393 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1394 return std::nullopt; 1395 } 1396 return strides; 1397 } 1398 1399 class VectorTypeCastOpConversion 1400 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1401 public: 1402 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1403 1404 LogicalResult 1405 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1406 ConversionPatternRewriter &rewriter) const override { 1407 auto loc = castOp->getLoc(); 1408 MemRefType sourceMemRefType = 1409 cast<MemRefType>(castOp.getOperand().getType()); 1410 MemRefType targetMemRefType = castOp.getType(); 1411 1412 // Only static shape casts supported atm. 1413 if (!sourceMemRefType.hasStaticShape() || 1414 !targetMemRefType.hasStaticShape()) 1415 return failure(); 1416 1417 auto llvmSourceDescriptorTy = 1418 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); 1419 if (!llvmSourceDescriptorTy) 1420 return failure(); 1421 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1422 1423 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1424 typeConverter->convertType(targetMemRefType)); 1425 if (!llvmTargetDescriptorTy) 1426 return failure(); 1427 1428 // Only contiguous source buffers supported atm. 1429 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1430 if (!sourceStrides) 1431 return failure(); 1432 auto targetStrides = computeContiguousStrides(targetMemRefType); 1433 if (!targetStrides) 1434 return failure(); 1435 // Only support static strides for now, regardless of contiguity. 1436 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1437 return failure(); 1438 1439 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1440 1441 // Create descriptor. 1442 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1443 Type llvmTargetElementTy = desc.getElementPtrType(); 1444 // Set allocated ptr. 1445 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1446 if (!getTypeConverter()->useOpaquePointers()) 1447 allocated = 1448 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1449 desc.setAllocatedPtr(rewriter, loc, allocated); 1450 1451 // Set aligned ptr. 1452 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1453 if (!getTypeConverter()->useOpaquePointers()) 1454 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1455 1456 desc.setAlignedPtr(rewriter, loc, ptr); 1457 // Fill offset 0. 1458 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1459 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1460 desc.setOffset(rewriter, loc, zero); 1461 1462 // Fill size and stride descriptors in memref. 1463 for (const auto &indexedSize : 1464 llvm::enumerate(targetMemRefType.getShape())) { 1465 int64_t index = indexedSize.index(); 1466 auto sizeAttr = 1467 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1468 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1469 desc.setSize(rewriter, loc, index, size); 1470 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1471 (*targetStrides)[index]); 1472 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1473 desc.setStride(rewriter, loc, index, stride); 1474 } 1475 1476 rewriter.replaceOp(castOp, {desc}); 1477 return success(); 1478 } 1479 }; 1480 1481 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1482 /// Non-scalable versions of this operation are handled in Vector Transforms. 1483 class VectorCreateMaskOpRewritePattern 1484 : public OpRewritePattern<vector::CreateMaskOp> { 1485 public: 1486 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1487 bool enableIndexOpt) 1488 : OpRewritePattern<vector::CreateMaskOp>(context), 1489 force32BitVectorIndices(enableIndexOpt) {} 1490 1491 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1492 PatternRewriter &rewriter) const override { 1493 auto dstType = op.getType(); 1494 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) 1495 return failure(); 1496 IntegerType idxType = 1497 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1498 auto loc = op->getLoc(); 1499 Value indices = rewriter.create<LLVM::StepVectorOp>( 1500 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1501 /*isScalable=*/true)); 1502 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1503 op.getOperand(0)); 1504 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1505 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1506 indices, bounds); 1507 rewriter.replaceOp(op, comp); 1508 return success(); 1509 } 1510 1511 private: 1512 const bool force32BitVectorIndices; 1513 }; 1514 1515 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1516 public: 1517 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1518 1519 // Lowering implementation that relies on a small runtime support library, 1520 // which only needs to provide a few printing methods (single value for all 1521 // data types, opening/closing bracket, comma, newline). The lowering splits 1522 // the vector into elementary printing operations. The advantage of this 1523 // approach is that the library can remain unaware of all low-level 1524 // implementation details of vectors while still supporting output of any 1525 // shaped and dimensioned vector. 1526 // 1527 // Note: This lowering only handles scalars, n-D vectors are broken into 1528 // printing scalars in loops in VectorToSCF. 1529 // 1530 // TODO: rely solely on libc in future? something else? 1531 // 1532 LogicalResult 1533 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1534 ConversionPatternRewriter &rewriter) const override { 1535 auto parent = printOp->getParentOfType<ModuleOp>(); 1536 if (!parent) 1537 return failure(); 1538 1539 auto loc = printOp->getLoc(); 1540 1541 if (auto value = adaptor.getSource()) { 1542 Type printType = printOp.getPrintType(); 1543 if (isa<VectorType>(printType)) { 1544 // Vectors should be broken into elementary print ops in VectorToSCF. 1545 return failure(); 1546 } 1547 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value))) 1548 return failure(); 1549 } 1550 1551 auto punct = printOp.getPunctuation(); 1552 if (auto stringLiteral = printOp.getStringLiteral()) { 1553 LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", 1554 *stringLiteral, *getTypeConverter()); 1555 } else if (punct != PrintPunctuation::NoPunctuation) { 1556 emitCall(rewriter, printOp->getLoc(), [&] { 1557 switch (punct) { 1558 case PrintPunctuation::Close: 1559 return LLVM::lookupOrCreatePrintCloseFn(parent); 1560 case PrintPunctuation::Open: 1561 return LLVM::lookupOrCreatePrintOpenFn(parent); 1562 case PrintPunctuation::Comma: 1563 return LLVM::lookupOrCreatePrintCommaFn(parent); 1564 case PrintPunctuation::NewLine: 1565 return LLVM::lookupOrCreatePrintNewlineFn(parent); 1566 default: 1567 llvm_unreachable("unexpected punctuation"); 1568 } 1569 }()); 1570 } 1571 1572 rewriter.eraseOp(printOp); 1573 return success(); 1574 } 1575 1576 private: 1577 enum class PrintConversion { 1578 // clang-format off 1579 None, 1580 ZeroExt64, 1581 SignExt64, 1582 Bitcast16 1583 // clang-format on 1584 }; 1585 1586 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter, 1587 ModuleOp parent, Location loc, Type printType, 1588 Value value) const { 1589 if (typeConverter->convertType(printType) == nullptr) 1590 return failure(); 1591 1592 // Make sure element type has runtime support. 1593 PrintConversion conversion = PrintConversion::None; 1594 Operation *printer; 1595 if (printType.isF32()) { 1596 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1597 } else if (printType.isF64()) { 1598 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1599 } else if (printType.isF16()) { 1600 conversion = PrintConversion::Bitcast16; // bits! 1601 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1602 } else if (printType.isBF16()) { 1603 conversion = PrintConversion::Bitcast16; // bits! 1604 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1605 } else if (printType.isIndex()) { 1606 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1607 } else if (auto intTy = dyn_cast<IntegerType>(printType)) { 1608 // Integers need a zero or sign extension on the operand 1609 // (depending on the source type) as well as a signed or 1610 // unsigned print method. Up to 64-bit is supported. 1611 unsigned width = intTy.getWidth(); 1612 if (intTy.isUnsigned()) { 1613 if (width <= 64) { 1614 if (width < 64) 1615 conversion = PrintConversion::ZeroExt64; 1616 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1617 } else { 1618 return failure(); 1619 } 1620 } else { 1621 assert(intTy.isSignless() || intTy.isSigned()); 1622 if (width <= 64) { 1623 // Note that we *always* zero extend booleans (1-bit integers), 1624 // so that true/false is printed as 1/0 rather than -1/0. 1625 if (width == 1) 1626 conversion = PrintConversion::ZeroExt64; 1627 else if (width < 64) 1628 conversion = PrintConversion::SignExt64; 1629 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1630 } else { 1631 return failure(); 1632 } 1633 } 1634 } else { 1635 return failure(); 1636 } 1637 1638 switch (conversion) { 1639 case PrintConversion::ZeroExt64: 1640 value = rewriter.create<arith::ExtUIOp>( 1641 loc, IntegerType::get(rewriter.getContext(), 64), value); 1642 break; 1643 case PrintConversion::SignExt64: 1644 value = rewriter.create<arith::ExtSIOp>( 1645 loc, IntegerType::get(rewriter.getContext(), 64), value); 1646 break; 1647 case PrintConversion::Bitcast16: 1648 value = rewriter.create<LLVM::BitcastOp>( 1649 loc, IntegerType::get(rewriter.getContext(), 16), value); 1650 break; 1651 case PrintConversion::None: 1652 break; 1653 } 1654 emitCall(rewriter, loc, printer, value); 1655 return success(); 1656 } 1657 1658 // Helper to emit a call. 1659 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1660 Operation *ref, ValueRange params = ValueRange()) { 1661 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1662 params); 1663 } 1664 }; 1665 1666 /// The Splat operation is lowered to an insertelement + a shufflevector 1667 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1668 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1669 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1670 1671 LogicalResult 1672 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1673 ConversionPatternRewriter &rewriter) const override { 1674 VectorType resultType = cast<VectorType>(splatOp.getType()); 1675 if (resultType.getRank() > 1) 1676 return failure(); 1677 1678 // First insert it into an undef vector so we can shuffle it. 1679 auto vectorType = typeConverter->convertType(splatOp.getType()); 1680 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1681 auto zero = rewriter.create<LLVM::ConstantOp>( 1682 splatOp.getLoc(), 1683 typeConverter->convertType(rewriter.getIntegerType(32)), 1684 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1685 1686 // For 0-d vector, we simply do `insertelement`. 1687 if (resultType.getRank() == 0) { 1688 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1689 splatOp, vectorType, undef, adaptor.getInput(), zero); 1690 return success(); 1691 } 1692 1693 // For 1-d vector, we additionally do a `vectorshuffle`. 1694 auto v = rewriter.create<LLVM::InsertElementOp>( 1695 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1696 1697 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); 1698 SmallVector<int32_t> zeroValues(width, 0); 1699 1700 // Shuffle the value across the desired number of elements. 1701 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1702 zeroValues); 1703 return success(); 1704 } 1705 }; 1706 1707 /// The Splat operation is lowered to an insertelement + a shufflevector 1708 /// operation. Splat to only 2+-d vector result types are lowered by the 1709 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1710 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1711 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1712 1713 LogicalResult 1714 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1715 ConversionPatternRewriter &rewriter) const override { 1716 VectorType resultType = splatOp.getType(); 1717 if (resultType.getRank() <= 1) 1718 return failure(); 1719 1720 // First insert it into an undef vector so we can shuffle it. 1721 auto loc = splatOp.getLoc(); 1722 auto vectorTypeInfo = 1723 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1724 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1725 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1726 if (!llvmNDVectorTy || !llvm1DVectorTy) 1727 return failure(); 1728 1729 // Construct returned value. 1730 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1731 1732 // Construct a 1-D vector with the splatted value that we insert in all the 1733 // places within the returned descriptor. 1734 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1735 auto zero = rewriter.create<LLVM::ConstantOp>( 1736 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1737 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1738 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1739 adaptor.getInput(), zero); 1740 1741 // Shuffle the value across the desired number of elements. 1742 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1743 SmallVector<int32_t> zeroValues(width, 0); 1744 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1745 1746 // Iterate of linear index, convert to coords space and insert splatted 1-D 1747 // vector in each position. 1748 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1749 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1750 }); 1751 rewriter.replaceOp(splatOp, desc); 1752 return success(); 1753 } 1754 }; 1755 1756 } // namespace 1757 1758 /// Populate the given list with patterns that convert from Vector to LLVM. 1759 void mlir::populateVectorToLLVMConversionPatterns( 1760 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1761 bool reassociateFPReductions, bool force32BitVectorIndices) { 1762 MLIRContext *ctx = converter.getDialect()->getContext(); 1763 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1764 populateVectorInsertExtractStridedSliceTransforms(patterns); 1765 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1766 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1767 patterns 1768 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1769 VectorExtractElementOpConversion, VectorExtractOpConversion, 1770 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1771 VectorInsertOpConversion, VectorPrintOpConversion, 1772 VectorTypeCastOpConversion, VectorScaleOpConversion, 1773 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1774 VectorLoadStoreConversion<vector::MaskedLoadOp, 1775 vector::MaskedLoadOpAdaptor>, 1776 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1777 VectorLoadStoreConversion<vector::MaskedStoreOp, 1778 vector::MaskedStoreOpAdaptor>, 1779 VectorGatherOpConversion, VectorScatterOpConversion, 1780 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1781 VectorSplatOpLowering, VectorSplatNdOpLowering, 1782 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1783 MaskedReductionOpConversion>(converter); 1784 // Transfer ops with rank > 1 are handled by VectorToSCF. 1785 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1786 } 1787 1788 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1789 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1790 patterns.add<VectorMatmulOpConversion>(converter); 1791 patterns.add<VectorFlatTransposeOpConversion>(converter); 1792 } 1793