1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" 12 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Utils/Utils.h" 17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 22 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 24 #include "mlir/IR/BuiltinAttributes.h" 25 #include "mlir/IR/BuiltinTypeInterfaces.h" 26 #include "mlir/IR/BuiltinTypes.h" 27 #include "mlir/IR/TypeUtilities.h" 28 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 #include "llvm/ADT/APFloat.h" 31 #include "llvm/Support/Casting.h" 32 #include <optional> 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 // Helper to reduce vector type by *all* but one rank at back. 38 static VectorType reducedVectorTypeBack(VectorType tp) { 39 assert((tp.getRank() > 1) && "unlowerable vector type"); 40 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 41 tp.getScalableDims().take_back()); 42 } 43 44 // Helper that picks the proper sequence for inserting. 45 static Value insertOne(ConversionPatternRewriter &rewriter, 46 const LLVMTypeConverter &typeConverter, Location loc, 47 Value val1, Value val2, Type llvmType, int64_t rank, 48 int64_t pos) { 49 assert(rank > 0 && "0-D vector corner case should have been handled already"); 50 if (rank == 1) { 51 auto idxType = rewriter.getIndexType(); 52 auto constant = rewriter.create<LLVM::ConstantOp>( 53 loc, typeConverter.convertType(idxType), 54 rewriter.getIntegerAttr(idxType, pos)); 55 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 56 constant); 57 } 58 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 59 } 60 61 // Helper that picks the proper sequence for extracting. 62 static Value extractOne(ConversionPatternRewriter &rewriter, 63 const LLVMTypeConverter &typeConverter, Location loc, 64 Value val, Type llvmType, int64_t rank, int64_t pos) { 65 if (rank <= 1) { 66 auto idxType = rewriter.getIndexType(); 67 auto constant = rewriter.create<LLVM::ConstantOp>( 68 loc, typeConverter.convertType(idxType), 69 rewriter.getIntegerAttr(idxType, pos)); 70 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 71 constant); 72 } 73 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 74 } 75 76 // Helper that returns data layout alignment of a memref. 77 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, 78 MemRefType memrefType, unsigned &align) { 79 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 80 if (!elementTy) 81 return failure(); 82 83 // TODO: this should use the MLIR data layout when it becomes available and 84 // stop depending on translation. 85 llvm::LLVMContext llvmContext; 86 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 87 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 88 return success(); 89 } 90 91 // Check if the last stride is non-unit and has a valid memory space. 92 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 93 const LLVMTypeConverter &converter) { 94 if (!isLastMemrefDimUnitStride(memRefType)) 95 return failure(); 96 if (failed(converter.getMemRefAddressSpace(memRefType))) 97 return failure(); 98 return success(); 99 } 100 101 // Add an index vector component to a base pointer. 102 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 103 const LLVMTypeConverter &typeConverter, 104 MemRefType memRefType, Value llvmMemref, Value base, 105 Value index, uint64_t vLen) { 106 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 107 "unsupported memref type"); 108 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 109 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 110 return rewriter.create<LLVM::GEPOp>( 111 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 112 base, index); 113 } 114 115 /// Convert `foldResult` into a Value. Integer attribute is converted to 116 /// an LLVM constant op. 117 static Value getAsLLVMValue(OpBuilder &builder, Location loc, 118 OpFoldResult foldResult) { 119 if (auto attr = foldResult.dyn_cast<Attribute>()) { 120 auto intAttr = cast<IntegerAttr>(attr); 121 return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult(); 122 } 123 124 return foldResult.get<Value>(); 125 } 126 127 namespace { 128 129 /// Trivial Vector to LLVM conversions 130 using VectorScaleOpConversion = 131 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 132 133 /// Conversion pattern for a vector.bitcast. 134 class VectorBitCastOpConversion 135 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 136 public: 137 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 138 139 LogicalResult 140 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 141 ConversionPatternRewriter &rewriter) const override { 142 // Only 0-D and 1-D vectors can be lowered to LLVM. 143 VectorType resultTy = bitCastOp.getResultVectorType(); 144 if (resultTy.getRank() > 1) 145 return failure(); 146 Type newResultTy = typeConverter->convertType(resultTy); 147 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 148 adaptor.getOperands()[0]); 149 return success(); 150 } 151 }; 152 153 /// Conversion pattern for a vector.matrix_multiply. 154 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 155 class VectorMatmulOpConversion 156 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 157 public: 158 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 159 160 LogicalResult 161 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 162 ConversionPatternRewriter &rewriter) const override { 163 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 164 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 165 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 166 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 167 return success(); 168 } 169 }; 170 171 /// Conversion pattern for a vector.flat_transpose. 172 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 173 class VectorFlatTransposeOpConversion 174 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 175 public: 176 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 177 178 LogicalResult 179 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 180 ConversionPatternRewriter &rewriter) const override { 181 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 182 transOp, typeConverter->convertType(transOp.getRes().getType()), 183 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 184 return success(); 185 } 186 }; 187 188 /// Overloaded utility that replaces a vector.load, vector.store, 189 /// vector.maskedload and vector.maskedstore with their respective LLVM 190 /// couterparts. 191 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 192 vector::LoadOpAdaptor adaptor, 193 VectorType vectorTy, Value ptr, unsigned align, 194 ConversionPatternRewriter &rewriter) { 195 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 196 } 197 198 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 199 vector::MaskedLoadOpAdaptor adaptor, 200 VectorType vectorTy, Value ptr, unsigned align, 201 ConversionPatternRewriter &rewriter) { 202 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 203 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 204 } 205 206 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 207 vector::StoreOpAdaptor adaptor, 208 VectorType vectorTy, Value ptr, unsigned align, 209 ConversionPatternRewriter &rewriter) { 210 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 211 ptr, align); 212 } 213 214 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 215 vector::MaskedStoreOpAdaptor adaptor, 216 VectorType vectorTy, Value ptr, unsigned align, 217 ConversionPatternRewriter &rewriter) { 218 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 219 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 220 } 221 222 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 223 /// vector.maskedstore. 224 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 225 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 226 public: 227 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 228 229 LogicalResult 230 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 231 typename LoadOrStoreOp::Adaptor adaptor, 232 ConversionPatternRewriter &rewriter) const override { 233 // Only 1-D vectors can be lowered to LLVM. 234 VectorType vectorTy = loadOrStoreOp.getVectorType(); 235 if (vectorTy.getRank() > 1) 236 return failure(); 237 238 auto loc = loadOrStoreOp->getLoc(); 239 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 240 241 // Resolve alignment. 242 unsigned align; 243 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 244 return failure(); 245 246 // Resolve address. 247 auto vtype = cast<VectorType>( 248 this->typeConverter->convertType(loadOrStoreOp.getVectorType())); 249 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 250 adaptor.getIndices(), rewriter); 251 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align, 252 rewriter); 253 return success(); 254 } 255 }; 256 257 /// Conversion pattern for a vector.gather. 258 class VectorGatherOpConversion 259 : public ConvertOpToLLVMPattern<vector::GatherOp> { 260 public: 261 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 262 263 LogicalResult 264 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 265 ConversionPatternRewriter &rewriter) const override { 266 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); 267 assert(memRefType && "The base should be bufferized"); 268 269 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 270 return failure(); 271 272 auto loc = gather->getLoc(); 273 274 // Resolve alignment. 275 unsigned align; 276 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 277 return failure(); 278 279 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 280 adaptor.getIndices(), rewriter); 281 Value base = adaptor.getBase(); 282 283 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 284 // Handle the simple case of 1-D vector. 285 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) { 286 auto vType = gather.getVectorType(); 287 // Resolve address. 288 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 289 memRefType, base, ptr, adaptor.getIndexVec(), 290 /*vLen=*/vType.getDimSize(0)); 291 // Replace with the gather intrinsic. 292 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 293 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 294 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 295 return success(); 296 } 297 298 const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 299 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 300 &typeConverter](Type llvm1DVectorTy, 301 ValueRange vectorOperands) { 302 // Resolve address. 303 Value ptrs = getIndexedPtrs( 304 rewriter, loc, typeConverter, memRefType, base, ptr, 305 /*index=*/vectorOperands[0], 306 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 307 // Create the gather intrinsic. 308 return rewriter.create<LLVM::masked_gather>( 309 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 310 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 311 }; 312 SmallVector<Value> vectorOperands = { 313 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 314 return LLVM::detail::handleMultidimensionalVectors( 315 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 316 } 317 }; 318 319 /// Conversion pattern for a vector.scatter. 320 class VectorScatterOpConversion 321 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 322 public: 323 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 324 325 LogicalResult 326 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 327 ConversionPatternRewriter &rewriter) const override { 328 auto loc = scatter->getLoc(); 329 MemRefType memRefType = scatter.getMemRefType(); 330 331 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 332 return failure(); 333 334 // Resolve alignment. 335 unsigned align; 336 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 337 return failure(); 338 339 // Resolve address. 340 VectorType vType = scatter.getVectorType(); 341 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 342 adaptor.getIndices(), rewriter); 343 Value ptrs = getIndexedPtrs( 344 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 345 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 346 347 // Replace with the scatter intrinsic. 348 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 349 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 350 rewriter.getI32IntegerAttr(align)); 351 return success(); 352 } 353 }; 354 355 /// Conversion pattern for a vector.expandload. 356 class VectorExpandLoadOpConversion 357 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 358 public: 359 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 360 361 LogicalResult 362 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 363 ConversionPatternRewriter &rewriter) const override { 364 auto loc = expand->getLoc(); 365 MemRefType memRefType = expand.getMemRefType(); 366 367 // Resolve address. 368 auto vtype = typeConverter->convertType(expand.getVectorType()); 369 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 370 adaptor.getIndices(), rewriter); 371 372 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 373 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 374 return success(); 375 } 376 }; 377 378 /// Conversion pattern for a vector.compressstore. 379 class VectorCompressStoreOpConversion 380 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 381 public: 382 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 383 384 LogicalResult 385 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 386 ConversionPatternRewriter &rewriter) const override { 387 auto loc = compress->getLoc(); 388 MemRefType memRefType = compress.getMemRefType(); 389 390 // Resolve address. 391 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 392 adaptor.getIndices(), rewriter); 393 394 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 395 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 396 return success(); 397 } 398 }; 399 400 /// Reduction neutral classes for overloading. 401 class ReductionNeutralZero {}; 402 class ReductionNeutralIntOne {}; 403 class ReductionNeutralFPOne {}; 404 class ReductionNeutralAllOnes {}; 405 class ReductionNeutralSIntMin {}; 406 class ReductionNeutralUIntMin {}; 407 class ReductionNeutralSIntMax {}; 408 class ReductionNeutralUIntMax {}; 409 class ReductionNeutralFPMin {}; 410 class ReductionNeutralFPMax {}; 411 412 /// Create the reduction neutral zero value. 413 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 414 ConversionPatternRewriter &rewriter, 415 Location loc, Type llvmType) { 416 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 417 rewriter.getZeroAttr(llvmType)); 418 } 419 420 /// Create the reduction neutral integer one value. 421 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 422 ConversionPatternRewriter &rewriter, 423 Location loc, Type llvmType) { 424 return rewriter.create<LLVM::ConstantOp>( 425 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 426 } 427 428 /// Create the reduction neutral fp one value. 429 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 430 ConversionPatternRewriter &rewriter, 431 Location loc, Type llvmType) { 432 return rewriter.create<LLVM::ConstantOp>( 433 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 434 } 435 436 /// Create the reduction neutral all-ones value. 437 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 438 ConversionPatternRewriter &rewriter, 439 Location loc, Type llvmType) { 440 return rewriter.create<LLVM::ConstantOp>( 441 loc, llvmType, 442 rewriter.getIntegerAttr( 443 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 444 } 445 446 /// Create the reduction neutral signed int minimum value. 447 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 448 ConversionPatternRewriter &rewriter, 449 Location loc, Type llvmType) { 450 return rewriter.create<LLVM::ConstantOp>( 451 loc, llvmType, 452 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 453 llvmType.getIntOrFloatBitWidth()))); 454 } 455 456 /// Create the reduction neutral unsigned int minimum value. 457 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 458 ConversionPatternRewriter &rewriter, 459 Location loc, Type llvmType) { 460 return rewriter.create<LLVM::ConstantOp>( 461 loc, llvmType, 462 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 463 llvmType.getIntOrFloatBitWidth()))); 464 } 465 466 /// Create the reduction neutral signed int maximum value. 467 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 468 ConversionPatternRewriter &rewriter, 469 Location loc, Type llvmType) { 470 return rewriter.create<LLVM::ConstantOp>( 471 loc, llvmType, 472 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 473 llvmType.getIntOrFloatBitWidth()))); 474 } 475 476 /// Create the reduction neutral unsigned int maximum value. 477 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 478 ConversionPatternRewriter &rewriter, 479 Location loc, Type llvmType) { 480 return rewriter.create<LLVM::ConstantOp>( 481 loc, llvmType, 482 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 483 llvmType.getIntOrFloatBitWidth()))); 484 } 485 486 /// Create the reduction neutral fp minimum value. 487 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 488 ConversionPatternRewriter &rewriter, 489 Location loc, Type llvmType) { 490 auto floatType = cast<FloatType>(llvmType); 491 return rewriter.create<LLVM::ConstantOp>( 492 loc, llvmType, 493 rewriter.getFloatAttr( 494 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 495 /*Negative=*/false))); 496 } 497 498 /// Create the reduction neutral fp maximum value. 499 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 500 ConversionPatternRewriter &rewriter, 501 Location loc, Type llvmType) { 502 auto floatType = cast<FloatType>(llvmType); 503 return rewriter.create<LLVM::ConstantOp>( 504 loc, llvmType, 505 rewriter.getFloatAttr( 506 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 507 /*Negative=*/true))); 508 } 509 510 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 511 /// returns a new accumulator value using `ReductionNeutral`. 512 template <class ReductionNeutral> 513 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 514 Location loc, Type llvmType, 515 Value accumulator) { 516 if (accumulator) 517 return accumulator; 518 519 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 520 llvmType); 521 } 522 523 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 524 /// This is used as effective vector length by some intrinsics supporting 525 /// dynamic vector lengths at runtime. 526 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 527 Location loc, Type llvmType) { 528 VectorType vType = cast<VectorType>(llvmType); 529 auto vShape = vType.getShape(); 530 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 531 532 return rewriter.create<LLVM::ConstantOp>( 533 loc, rewriter.getI32Type(), 534 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 535 } 536 537 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 538 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 539 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 540 /// non-null. 541 template <class LLVMRedIntrinOp, class ScalarOp> 542 static Value createIntegerReductionArithmeticOpLowering( 543 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 544 Value vectorOperand, Value accumulator) { 545 546 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 547 548 if (accumulator) 549 result = rewriter.create<ScalarOp>(loc, accumulator, result); 550 return result; 551 } 552 553 /// Helper method to lower a `vector.reduction` operation that performs 554 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 555 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 556 /// the accumulator value if non-null. 557 template <class LLVMRedIntrinOp> 558 static Value createIntegerReductionComparisonOpLowering( 559 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 560 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 561 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 562 if (accumulator) { 563 Value cmp = 564 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 565 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 566 } 567 return result; 568 } 569 570 namespace { 571 template <typename Source> 572 struct VectorToScalarMapper; 573 template <> 574 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> { 575 using Type = LLVM::MaximumOp; 576 }; 577 template <> 578 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> { 579 using Type = LLVM::MinimumOp; 580 }; 581 template <> 582 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> { 583 using Type = LLVM::MaxNumOp; 584 }; 585 template <> 586 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> { 587 using Type = LLVM::MinNumOp; 588 }; 589 } // namespace 590 591 template <class LLVMRedIntrinOp> 592 static Value createFPReductionComparisonOpLowering( 593 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 594 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { 595 Value result = 596 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf); 597 598 if (accumulator) { 599 result = 600 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>( 601 loc, result, accumulator); 602 } 603 604 return result; 605 } 606 607 /// Reduction neutral classes for overloading 608 class MaskNeutralFMaximum {}; 609 class MaskNeutralFMinimum {}; 610 611 /// Get the mask neutral floating point maximum value 612 static llvm::APFloat 613 getMaskNeutralValue(MaskNeutralFMaximum, 614 const llvm::fltSemantics &floatSemantics) { 615 return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true); 616 } 617 /// Get the mask neutral floating point minimum value 618 static llvm::APFloat 619 getMaskNeutralValue(MaskNeutralFMinimum, 620 const llvm::fltSemantics &floatSemantics) { 621 return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false); 622 } 623 624 /// Create the mask neutral floating point MLIR vector constant 625 template <typename MaskNeutral> 626 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, 627 Location loc, Type llvmType, 628 Type vectorType) { 629 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics(); 630 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); 631 auto denseValue = 632 DenseElementsAttr::get(vectorType.cast<ShapedType>(), value); 633 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue); 634 } 635 636 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked 637 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for 638 /// `fmaximum`/`fminimum`. 639 /// More information: https://github.com/llvm/llvm-project/issues/64940 640 template <class LLVMRedIntrinOp, class MaskNeutral> 641 static Value 642 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, 643 Location loc, Type llvmType, 644 Value vectorOperand, Value accumulator, 645 Value mask, LLVM::FastmathFlagsAttr fmf) { 646 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>( 647 rewriter, loc, llvmType, vectorOperand.getType()); 648 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>( 649 loc, mask, vectorOperand, vectorMaskNeutral); 650 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>( 651 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); 652 } 653 654 template <class LLVMRedIntrinOp, class ReductionNeutral> 655 static Value 656 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 657 Type llvmType, Value vectorOperand, 658 Value accumulator, LLVM::FastmathFlagsAttr fmf) { 659 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 660 llvmType, accumulator); 661 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType, 662 /*startValue=*/accumulator, 663 vectorOperand, fmf); 664 } 665 666 /// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic 667 /// that requires a start value. This start value format spans across fp 668 /// reductions without mask and all the masked reduction intrinsics. 669 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 670 static Value 671 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, 672 Location loc, Type llvmType, 673 Value vectorOperand, Value accumulator) { 674 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 675 llvmType, accumulator); 676 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 677 /*startValue=*/accumulator, 678 vectorOperand); 679 } 680 681 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 682 static Value lowerPredicatedReductionWithStartValue( 683 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 684 Value vectorOperand, Value accumulator, Value mask) { 685 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 686 llvmType, accumulator); 687 Value vectorLength = 688 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 689 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 690 /*startValue=*/accumulator, 691 vectorOperand, mask, vectorLength); 692 } 693 694 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 695 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 696 static Value lowerPredicatedReductionWithStartValue( 697 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 698 Value vectorOperand, Value accumulator, Value mask) { 699 if (llvmType.isIntOrIndex()) 700 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp, 701 IntReductionNeutral>( 702 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 703 704 // FP dispatch. 705 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp, 706 FPReductionNeutral>( 707 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 708 } 709 710 /// Conversion pattern for all vector reductions. 711 class VectorReductionOpConversion 712 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 713 public: 714 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv, 715 bool reassociateFPRed) 716 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 717 reassociateFPReductions(reassociateFPRed) {} 718 719 LogicalResult 720 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 721 ConversionPatternRewriter &rewriter) const override { 722 auto kind = reductionOp.getKind(); 723 Type eltType = reductionOp.getDest().getType(); 724 Type llvmType = typeConverter->convertType(eltType); 725 Value operand = adaptor.getVector(); 726 Value acc = adaptor.getAcc(); 727 Location loc = reductionOp.getLoc(); 728 729 if (eltType.isIntOrIndex()) { 730 // Integer reductions: add/mul/min/max/and/or/xor. 731 Value result; 732 switch (kind) { 733 case vector::CombiningKind::ADD: 734 result = 735 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 736 LLVM::AddOp>( 737 rewriter, loc, llvmType, operand, acc); 738 break; 739 case vector::CombiningKind::MUL: 740 result = 741 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 742 LLVM::MulOp>( 743 rewriter, loc, llvmType, operand, acc); 744 break; 745 case vector::CombiningKind::MINUI: 746 result = createIntegerReductionComparisonOpLowering< 747 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 748 LLVM::ICmpPredicate::ule); 749 break; 750 case vector::CombiningKind::MINSI: 751 result = createIntegerReductionComparisonOpLowering< 752 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 753 LLVM::ICmpPredicate::sle); 754 break; 755 case vector::CombiningKind::MAXUI: 756 result = createIntegerReductionComparisonOpLowering< 757 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 758 LLVM::ICmpPredicate::uge); 759 break; 760 case vector::CombiningKind::MAXSI: 761 result = createIntegerReductionComparisonOpLowering< 762 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 763 LLVM::ICmpPredicate::sge); 764 break; 765 case vector::CombiningKind::AND: 766 result = 767 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 768 LLVM::AndOp>( 769 rewriter, loc, llvmType, operand, acc); 770 break; 771 case vector::CombiningKind::OR: 772 result = 773 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 774 LLVM::OrOp>( 775 rewriter, loc, llvmType, operand, acc); 776 break; 777 case vector::CombiningKind::XOR: 778 result = 779 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 780 LLVM::XOrOp>( 781 rewriter, loc, llvmType, operand, acc); 782 break; 783 default: 784 return failure(); 785 } 786 rewriter.replaceOp(reductionOp, result); 787 788 return success(); 789 } 790 791 if (!isa<FloatType>(eltType)) 792 return failure(); 793 794 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 795 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 796 reductionOp.getContext(), 797 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 798 fmf = LLVM::FastmathFlagsAttr::get( 799 reductionOp.getContext(), 800 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc 801 : LLVM::FastmathFlags::none)); 802 803 // Floating-point reductions: add/mul/min/max 804 Value result; 805 if (kind == vector::CombiningKind::ADD) { 806 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 807 ReductionNeutralZero>( 808 rewriter, loc, llvmType, operand, acc, fmf); 809 } else if (kind == vector::CombiningKind::MUL) { 810 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 811 ReductionNeutralFPOne>( 812 rewriter, loc, llvmType, operand, acc, fmf); 813 } else if (kind == vector::CombiningKind::MINIMUMF) { 814 result = 815 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>( 816 rewriter, loc, llvmType, operand, acc, fmf); 817 } else if (kind == vector::CombiningKind::MAXIMUMF) { 818 result = 819 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>( 820 rewriter, loc, llvmType, operand, acc, fmf); 821 } else if (kind == vector::CombiningKind::MINF) { 822 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 823 rewriter, loc, llvmType, operand, acc, fmf); 824 } else if (kind == vector::CombiningKind::MAXF) { 825 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 826 rewriter, loc, llvmType, operand, acc, fmf); 827 } else 828 return failure(); 829 830 rewriter.replaceOp(reductionOp, result); 831 return success(); 832 } 833 834 private: 835 const bool reassociateFPReductions; 836 }; 837 838 /// Base class to convert a `vector.mask` operation while matching traits 839 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 840 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 841 /// method performs a second match against the maskable operation `MaskedOp`. 842 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 843 /// implemented by the concrete conversion classes. This method can match 844 /// against specific traits of the `vector.mask` and the maskable operation. It 845 /// must replace the `vector.mask` operation. 846 template <class MaskedOp> 847 class VectorMaskOpConversionBase 848 : public ConvertOpToLLVMPattern<vector::MaskOp> { 849 public: 850 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 851 852 LogicalResult 853 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 854 ConversionPatternRewriter &rewriter) const final { 855 // Match against the maskable operation kind. 856 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 857 if (!maskedOp) 858 return failure(); 859 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 860 } 861 862 protected: 863 virtual LogicalResult 864 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 865 vector::MaskableOpInterface maskableOp, 866 ConversionPatternRewriter &rewriter) const = 0; 867 }; 868 869 class MaskedReductionOpConversion 870 : public VectorMaskOpConversionBase<vector::ReductionOp> { 871 872 public: 873 using VectorMaskOpConversionBase< 874 vector::ReductionOp>::VectorMaskOpConversionBase; 875 876 LogicalResult matchAndRewriteMaskableOp( 877 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 878 ConversionPatternRewriter &rewriter) const override { 879 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 880 auto kind = reductionOp.getKind(); 881 Type eltType = reductionOp.getDest().getType(); 882 Type llvmType = typeConverter->convertType(eltType); 883 Value operand = reductionOp.getVector(); 884 Value acc = reductionOp.getAcc(); 885 Location loc = reductionOp.getLoc(); 886 887 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 888 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 889 reductionOp.getContext(), 890 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 891 892 Value result; 893 switch (kind) { 894 case vector::CombiningKind::ADD: 895 result = lowerPredicatedReductionWithStartValue< 896 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 897 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 898 maskOp.getMask()); 899 break; 900 case vector::CombiningKind::MUL: 901 result = lowerPredicatedReductionWithStartValue< 902 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 903 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 904 maskOp.getMask()); 905 break; 906 case vector::CombiningKind::MINUI: 907 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp, 908 ReductionNeutralUIntMax>( 909 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 910 break; 911 case vector::CombiningKind::MINSI: 912 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp, 913 ReductionNeutralSIntMax>( 914 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 915 break; 916 case vector::CombiningKind::MAXUI: 917 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp, 918 ReductionNeutralUIntMin>( 919 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 920 break; 921 case vector::CombiningKind::MAXSI: 922 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp, 923 ReductionNeutralSIntMin>( 924 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 925 break; 926 case vector::CombiningKind::AND: 927 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp, 928 ReductionNeutralAllOnes>( 929 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 930 break; 931 case vector::CombiningKind::OR: 932 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp, 933 ReductionNeutralZero>( 934 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 935 break; 936 case vector::CombiningKind::XOR: 937 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp, 938 ReductionNeutralZero>( 939 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 940 break; 941 case vector::CombiningKind::MINF: 942 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp, 943 ReductionNeutralFPMax>( 944 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 945 break; 946 case vector::CombiningKind::MAXF: 947 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp, 948 ReductionNeutralFPMin>( 949 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 950 break; 951 case CombiningKind::MAXIMUMF: 952 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum, 953 MaskNeutralFMaximum>( 954 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 955 break; 956 case CombiningKind::MINIMUMF: 957 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum, 958 MaskNeutralFMinimum>( 959 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 960 break; 961 } 962 963 // Replace `vector.mask` operation altogether. 964 rewriter.replaceOp(maskOp, result); 965 return success(); 966 } 967 }; 968 969 class VectorShuffleOpConversion 970 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 971 public: 972 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 973 974 LogicalResult 975 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 976 ConversionPatternRewriter &rewriter) const override { 977 auto loc = shuffleOp->getLoc(); 978 auto v1Type = shuffleOp.getV1VectorType(); 979 auto v2Type = shuffleOp.getV2VectorType(); 980 auto vectorType = shuffleOp.getResultVectorType(); 981 Type llvmType = typeConverter->convertType(vectorType); 982 auto maskArrayAttr = shuffleOp.getMask(); 983 984 // Bail if result type cannot be lowered. 985 if (!llvmType) 986 return failure(); 987 988 // Get rank and dimension sizes. 989 int64_t rank = vectorType.getRank(); 990 #ifndef NDEBUG 991 bool wellFormed0DCase = 992 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 993 bool wellFormedNDCase = 994 v1Type.getRank() == rank && v2Type.getRank() == rank; 995 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 996 #endif 997 998 // For rank 0 and 1, where both operands have *exactly* the same vector 999 // type, there is direct shuffle support in LLVM. Use it! 1000 if (rank <= 1 && v1Type == v2Type) { 1001 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 1002 loc, adaptor.getV1(), adaptor.getV2(), 1003 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 1004 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 1005 return success(); 1006 } 1007 1008 // For all other cases, insert the individual values individually. 1009 int64_t v1Dim = v1Type.getDimSize(0); 1010 Type eltType; 1011 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType)) 1012 eltType = arrayType.getElementType(); 1013 else 1014 eltType = cast<VectorType>(llvmType).getElementType(); 1015 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 1016 int64_t insPos = 0; 1017 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 1018 int64_t extPos = cast<IntegerAttr>(en.value()).getInt(); 1019 Value value = adaptor.getV1(); 1020 if (extPos >= v1Dim) { 1021 extPos -= v1Dim; 1022 value = adaptor.getV2(); 1023 } 1024 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 1025 eltType, rank, extPos); 1026 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 1027 llvmType, rank, insPos++); 1028 } 1029 rewriter.replaceOp(shuffleOp, insert); 1030 return success(); 1031 } 1032 }; 1033 1034 class VectorExtractElementOpConversion 1035 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 1036 public: 1037 using ConvertOpToLLVMPattern< 1038 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 1039 1040 LogicalResult 1041 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 1042 ConversionPatternRewriter &rewriter) const override { 1043 auto vectorType = extractEltOp.getSourceVectorType(); 1044 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 1045 1046 // Bail if result type cannot be lowered. 1047 if (!llvmType) 1048 return failure(); 1049 1050 if (vectorType.getRank() == 0) { 1051 Location loc = extractEltOp.getLoc(); 1052 auto idxType = rewriter.getIndexType(); 1053 auto zero = rewriter.create<LLVM::ConstantOp>( 1054 loc, typeConverter->convertType(idxType), 1055 rewriter.getIntegerAttr(idxType, 0)); 1056 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1057 extractEltOp, llvmType, adaptor.getVector(), zero); 1058 return success(); 1059 } 1060 1061 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1062 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1063 return success(); 1064 } 1065 }; 1066 1067 class VectorExtractOpConversion 1068 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1069 public: 1070 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1071 1072 LogicalResult 1073 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1074 ConversionPatternRewriter &rewriter) const override { 1075 auto loc = extractOp->getLoc(); 1076 auto resultType = extractOp.getResult().getType(); 1077 auto llvmResultType = typeConverter->convertType(resultType); 1078 // Bail if result type cannot be lowered. 1079 if (!llvmResultType) 1080 return failure(); 1081 1082 SmallVector<OpFoldResult> positionVec; 1083 for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) { 1084 if (pos.is<Value>()) 1085 // Make sure we use the value that has been already converted to LLVM. 1086 positionVec.push_back(adaptor.getDynamicPosition()[idx]); 1087 else 1088 positionVec.push_back(pos); 1089 } 1090 1091 // Extract entire vector. Should be handled by folder, but just to be safe. 1092 ArrayRef<OpFoldResult> position(positionVec); 1093 if (position.empty()) { 1094 rewriter.replaceOp(extractOp, adaptor.getVector()); 1095 return success(); 1096 } 1097 1098 // One-shot extraction of vector from array (only requires extractvalue). 1099 if (isa<VectorType>(resultType)) { 1100 if (extractOp.hasDynamicPosition()) 1101 return failure(); 1102 1103 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1104 loc, adaptor.getVector(), getAsIntegers(position)); 1105 rewriter.replaceOp(extractOp, extracted); 1106 return success(); 1107 } 1108 1109 // Potential extraction of 1-D vector from array. 1110 Value extracted = adaptor.getVector(); 1111 if (position.size() > 1) { 1112 if (extractOp.hasDynamicPosition()) 1113 return failure(); 1114 1115 SmallVector<int64_t> nMinusOnePosition = 1116 getAsIntegers(position.drop_back()); 1117 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 1118 nMinusOnePosition); 1119 } 1120 1121 Value lastPosition = getAsLLVMValue(rewriter, loc, position.back()); 1122 // Remaining extraction of element from 1-D LLVM vector. 1123 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted, 1124 lastPosition); 1125 return success(); 1126 } 1127 }; 1128 1129 /// Conversion pattern that turns a vector.fma on a 1-D vector 1130 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1131 /// This does not match vectors of n >= 2 rank. 1132 /// 1133 /// Example: 1134 /// ``` 1135 /// vector.fma %a, %a, %a : vector<8xf32> 1136 /// ``` 1137 /// is converted to: 1138 /// ``` 1139 /// llvm.intr.fmuladd %va, %va, %va: 1140 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1141 /// -> !llvm."<8 x f32>"> 1142 /// ``` 1143 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1144 public: 1145 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1146 1147 LogicalResult 1148 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1149 ConversionPatternRewriter &rewriter) const override { 1150 VectorType vType = fmaOp.getVectorType(); 1151 if (vType.getRank() > 1) 1152 return failure(); 1153 1154 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1155 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1156 return success(); 1157 } 1158 }; 1159 1160 class VectorInsertElementOpConversion 1161 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1162 public: 1163 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1164 1165 LogicalResult 1166 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1167 ConversionPatternRewriter &rewriter) const override { 1168 auto vectorType = insertEltOp.getDestVectorType(); 1169 auto llvmType = typeConverter->convertType(vectorType); 1170 1171 // Bail if result type cannot be lowered. 1172 if (!llvmType) 1173 return failure(); 1174 1175 if (vectorType.getRank() == 0) { 1176 Location loc = insertEltOp.getLoc(); 1177 auto idxType = rewriter.getIndexType(); 1178 auto zero = rewriter.create<LLVM::ConstantOp>( 1179 loc, typeConverter->convertType(idxType), 1180 rewriter.getIntegerAttr(idxType, 0)); 1181 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1182 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1183 return success(); 1184 } 1185 1186 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1187 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1188 adaptor.getPosition()); 1189 return success(); 1190 } 1191 }; 1192 1193 class VectorInsertOpConversion 1194 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1195 public: 1196 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1197 1198 LogicalResult 1199 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1200 ConversionPatternRewriter &rewriter) const override { 1201 auto loc = insertOp->getLoc(); 1202 auto sourceType = insertOp.getSourceType(); 1203 auto destVectorType = insertOp.getDestVectorType(); 1204 auto llvmResultType = typeConverter->convertType(destVectorType); 1205 // Bail if result type cannot be lowered. 1206 if (!llvmResultType) 1207 return failure(); 1208 1209 SmallVector<OpFoldResult> positionVec; 1210 for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) { 1211 if (pos.is<Value>()) 1212 // Make sure we use the value that has been already converted to LLVM. 1213 positionVec.push_back(adaptor.getDynamicPosition()[idx]); 1214 else 1215 positionVec.push_back(pos); 1216 } 1217 1218 // Overwrite entire vector with value. Should be handled by folder, but 1219 // just to be safe. 1220 ArrayRef<OpFoldResult> position(positionVec); 1221 if (position.empty()) { 1222 rewriter.replaceOp(insertOp, adaptor.getSource()); 1223 return success(); 1224 } 1225 1226 // One-shot insertion of a vector into an array (only requires insertvalue). 1227 if (isa<VectorType>(sourceType)) { 1228 if (insertOp.hasDynamicPosition()) 1229 return failure(); 1230 1231 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1232 loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position)); 1233 rewriter.replaceOp(insertOp, inserted); 1234 return success(); 1235 } 1236 1237 // Potential extraction of 1-D vector from array. 1238 Value extracted = adaptor.getDest(); 1239 auto oneDVectorType = destVectorType; 1240 if (position.size() > 1) { 1241 if (insertOp.hasDynamicPosition()) 1242 return failure(); 1243 1244 oneDVectorType = reducedVectorTypeBack(destVectorType); 1245 extracted = rewriter.create<LLVM::ExtractValueOp>( 1246 loc, extracted, getAsIntegers(position.drop_back())); 1247 } 1248 1249 // Insertion of an element into a 1-D LLVM vector. 1250 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1251 loc, typeConverter->convertType(oneDVectorType), extracted, 1252 adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back())); 1253 1254 // Potential insertion of resulting 1-D vector into array. 1255 if (position.size() > 1) { 1256 if (insertOp.hasDynamicPosition()) 1257 return failure(); 1258 1259 inserted = rewriter.create<LLVM::InsertValueOp>( 1260 loc, adaptor.getDest(), inserted, 1261 getAsIntegers(position.drop_back())); 1262 } 1263 1264 rewriter.replaceOp(insertOp, inserted); 1265 return success(); 1266 } 1267 }; 1268 1269 /// Lower vector.scalable.insert ops to LLVM vector.insert 1270 struct VectorScalableInsertOpLowering 1271 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1272 using ConvertOpToLLVMPattern< 1273 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1274 1275 LogicalResult 1276 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1277 ConversionPatternRewriter &rewriter) const override { 1278 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1279 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1280 return success(); 1281 } 1282 }; 1283 1284 /// Lower vector.scalable.extract ops to LLVM vector.extract 1285 struct VectorScalableExtractOpLowering 1286 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1287 using ConvertOpToLLVMPattern< 1288 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1289 1290 LogicalResult 1291 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1292 ConversionPatternRewriter &rewriter) const override { 1293 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1294 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1295 adaptor.getSource(), adaptor.getPos()); 1296 return success(); 1297 } 1298 }; 1299 1300 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1301 /// 1302 /// Example: 1303 /// ``` 1304 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1305 /// ``` 1306 /// is rewritten into: 1307 /// ``` 1308 /// %r = splat %f0: vector<2x4xf32> 1309 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1310 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1311 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1312 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1313 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1314 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1315 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1316 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1317 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1318 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1319 /// // %r3 holds the final value. 1320 /// ``` 1321 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1322 public: 1323 using OpRewritePattern<FMAOp>::OpRewritePattern; 1324 1325 void initialize() { 1326 // This pattern recursively unpacks one dimension at a time. The recursion 1327 // bounded as the rank is strictly decreasing. 1328 setHasBoundedRewriteRecursion(); 1329 } 1330 1331 LogicalResult matchAndRewrite(FMAOp op, 1332 PatternRewriter &rewriter) const override { 1333 auto vType = op.getVectorType(); 1334 if (vType.getRank() < 2) 1335 return failure(); 1336 1337 auto loc = op.getLoc(); 1338 auto elemType = vType.getElementType(); 1339 Value zero = rewriter.create<arith::ConstantOp>( 1340 loc, elemType, rewriter.getZeroAttr(elemType)); 1341 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1342 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1343 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1344 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1345 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1346 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1347 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1348 } 1349 rewriter.replaceOp(op, desc); 1350 return success(); 1351 } 1352 }; 1353 1354 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1355 /// static layout. 1356 static std::optional<SmallVector<int64_t, 4>> 1357 computeContiguousStrides(MemRefType memRefType) { 1358 int64_t offset; 1359 SmallVector<int64_t, 4> strides; 1360 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1361 return std::nullopt; 1362 if (!strides.empty() && strides.back() != 1) 1363 return std::nullopt; 1364 // If no layout or identity layout, this is contiguous by definition. 1365 if (memRefType.getLayout().isIdentity()) 1366 return strides; 1367 1368 // Otherwise, we must determine contiguity form shapes. This can only ever 1369 // work in static cases because MemRefType is underspecified to represent 1370 // contiguous dynamic shapes in other ways than with just empty/identity 1371 // layout. 1372 auto sizes = memRefType.getShape(); 1373 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1374 if (ShapedType::isDynamic(sizes[index + 1]) || 1375 ShapedType::isDynamic(strides[index]) || 1376 ShapedType::isDynamic(strides[index + 1])) 1377 return std::nullopt; 1378 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1379 return std::nullopt; 1380 } 1381 return strides; 1382 } 1383 1384 class VectorTypeCastOpConversion 1385 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1386 public: 1387 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1388 1389 LogicalResult 1390 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1391 ConversionPatternRewriter &rewriter) const override { 1392 auto loc = castOp->getLoc(); 1393 MemRefType sourceMemRefType = 1394 cast<MemRefType>(castOp.getOperand().getType()); 1395 MemRefType targetMemRefType = castOp.getType(); 1396 1397 // Only static shape casts supported atm. 1398 if (!sourceMemRefType.hasStaticShape() || 1399 !targetMemRefType.hasStaticShape()) 1400 return failure(); 1401 1402 auto llvmSourceDescriptorTy = 1403 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); 1404 if (!llvmSourceDescriptorTy) 1405 return failure(); 1406 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1407 1408 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1409 typeConverter->convertType(targetMemRefType)); 1410 if (!llvmTargetDescriptorTy) 1411 return failure(); 1412 1413 // Only contiguous source buffers supported atm. 1414 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1415 if (!sourceStrides) 1416 return failure(); 1417 auto targetStrides = computeContiguousStrides(targetMemRefType); 1418 if (!targetStrides) 1419 return failure(); 1420 // Only support static strides for now, regardless of contiguity. 1421 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1422 return failure(); 1423 1424 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1425 1426 // Create descriptor. 1427 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1428 // Set allocated ptr. 1429 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1430 desc.setAllocatedPtr(rewriter, loc, allocated); 1431 1432 // Set aligned ptr. 1433 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1434 desc.setAlignedPtr(rewriter, loc, ptr); 1435 // Fill offset 0. 1436 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1437 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1438 desc.setOffset(rewriter, loc, zero); 1439 1440 // Fill size and stride descriptors in memref. 1441 for (const auto &indexedSize : 1442 llvm::enumerate(targetMemRefType.getShape())) { 1443 int64_t index = indexedSize.index(); 1444 auto sizeAttr = 1445 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1446 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1447 desc.setSize(rewriter, loc, index, size); 1448 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1449 (*targetStrides)[index]); 1450 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1451 desc.setStride(rewriter, loc, index, stride); 1452 } 1453 1454 rewriter.replaceOp(castOp, {desc}); 1455 return success(); 1456 } 1457 }; 1458 1459 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1460 /// Non-scalable versions of this operation are handled in Vector Transforms. 1461 class VectorCreateMaskOpRewritePattern 1462 : public OpRewritePattern<vector::CreateMaskOp> { 1463 public: 1464 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1465 bool enableIndexOpt) 1466 : OpRewritePattern<vector::CreateMaskOp>(context), 1467 force32BitVectorIndices(enableIndexOpt) {} 1468 1469 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1470 PatternRewriter &rewriter) const override { 1471 auto dstType = op.getType(); 1472 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) 1473 return failure(); 1474 IntegerType idxType = 1475 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1476 auto loc = op->getLoc(); 1477 Value indices = rewriter.create<LLVM::StepVectorOp>( 1478 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1479 /*isScalable=*/true)); 1480 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1481 op.getOperand(0)); 1482 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1483 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1484 indices, bounds); 1485 rewriter.replaceOp(op, comp); 1486 return success(); 1487 } 1488 1489 private: 1490 const bool force32BitVectorIndices; 1491 }; 1492 1493 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1494 public: 1495 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1496 1497 // Lowering implementation that relies on a small runtime support library, 1498 // which only needs to provide a few printing methods (single value for all 1499 // data types, opening/closing bracket, comma, newline). The lowering splits 1500 // the vector into elementary printing operations. The advantage of this 1501 // approach is that the library can remain unaware of all low-level 1502 // implementation details of vectors while still supporting output of any 1503 // shaped and dimensioned vector. 1504 // 1505 // Note: This lowering only handles scalars, n-D vectors are broken into 1506 // printing scalars in loops in VectorToSCF. 1507 // 1508 // TODO: rely solely on libc in future? something else? 1509 // 1510 LogicalResult 1511 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1512 ConversionPatternRewriter &rewriter) const override { 1513 auto parent = printOp->getParentOfType<ModuleOp>(); 1514 if (!parent) 1515 return failure(); 1516 1517 auto loc = printOp->getLoc(); 1518 1519 if (auto value = adaptor.getSource()) { 1520 Type printType = printOp.getPrintType(); 1521 if (isa<VectorType>(printType)) { 1522 // Vectors should be broken into elementary print ops in VectorToSCF. 1523 return failure(); 1524 } 1525 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value))) 1526 return failure(); 1527 } 1528 1529 auto punct = printOp.getPunctuation(); 1530 if (auto stringLiteral = printOp.getStringLiteral()) { 1531 LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", 1532 *stringLiteral, *getTypeConverter()); 1533 } else if (punct != PrintPunctuation::NoPunctuation) { 1534 emitCall(rewriter, printOp->getLoc(), [&] { 1535 switch (punct) { 1536 case PrintPunctuation::Close: 1537 return LLVM::lookupOrCreatePrintCloseFn(parent); 1538 case PrintPunctuation::Open: 1539 return LLVM::lookupOrCreatePrintOpenFn(parent); 1540 case PrintPunctuation::Comma: 1541 return LLVM::lookupOrCreatePrintCommaFn(parent); 1542 case PrintPunctuation::NewLine: 1543 return LLVM::lookupOrCreatePrintNewlineFn(parent); 1544 default: 1545 llvm_unreachable("unexpected punctuation"); 1546 } 1547 }()); 1548 } 1549 1550 rewriter.eraseOp(printOp); 1551 return success(); 1552 } 1553 1554 private: 1555 enum class PrintConversion { 1556 // clang-format off 1557 None, 1558 ZeroExt64, 1559 SignExt64, 1560 Bitcast16 1561 // clang-format on 1562 }; 1563 1564 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter, 1565 ModuleOp parent, Location loc, Type printType, 1566 Value value) const { 1567 if (typeConverter->convertType(printType) == nullptr) 1568 return failure(); 1569 1570 // Make sure element type has runtime support. 1571 PrintConversion conversion = PrintConversion::None; 1572 Operation *printer; 1573 if (printType.isF32()) { 1574 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1575 } else if (printType.isF64()) { 1576 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1577 } else if (printType.isF16()) { 1578 conversion = PrintConversion::Bitcast16; // bits! 1579 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1580 } else if (printType.isBF16()) { 1581 conversion = PrintConversion::Bitcast16; // bits! 1582 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1583 } else if (printType.isIndex()) { 1584 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1585 } else if (auto intTy = dyn_cast<IntegerType>(printType)) { 1586 // Integers need a zero or sign extension on the operand 1587 // (depending on the source type) as well as a signed or 1588 // unsigned print method. Up to 64-bit is supported. 1589 unsigned width = intTy.getWidth(); 1590 if (intTy.isUnsigned()) { 1591 if (width <= 64) { 1592 if (width < 64) 1593 conversion = PrintConversion::ZeroExt64; 1594 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1595 } else { 1596 return failure(); 1597 } 1598 } else { 1599 assert(intTy.isSignless() || intTy.isSigned()); 1600 if (width <= 64) { 1601 // Note that we *always* zero extend booleans (1-bit integers), 1602 // so that true/false is printed as 1/0 rather than -1/0. 1603 if (width == 1) 1604 conversion = PrintConversion::ZeroExt64; 1605 else if (width < 64) 1606 conversion = PrintConversion::SignExt64; 1607 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1608 } else { 1609 return failure(); 1610 } 1611 } 1612 } else { 1613 return failure(); 1614 } 1615 1616 switch (conversion) { 1617 case PrintConversion::ZeroExt64: 1618 value = rewriter.create<arith::ExtUIOp>( 1619 loc, IntegerType::get(rewriter.getContext(), 64), value); 1620 break; 1621 case PrintConversion::SignExt64: 1622 value = rewriter.create<arith::ExtSIOp>( 1623 loc, IntegerType::get(rewriter.getContext(), 64), value); 1624 break; 1625 case PrintConversion::Bitcast16: 1626 value = rewriter.create<LLVM::BitcastOp>( 1627 loc, IntegerType::get(rewriter.getContext(), 16), value); 1628 break; 1629 case PrintConversion::None: 1630 break; 1631 } 1632 emitCall(rewriter, loc, printer, value); 1633 return success(); 1634 } 1635 1636 // Helper to emit a call. 1637 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1638 Operation *ref, ValueRange params = ValueRange()) { 1639 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1640 params); 1641 } 1642 }; 1643 1644 /// The Splat operation is lowered to an insertelement + a shufflevector 1645 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1646 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1647 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1648 1649 LogicalResult 1650 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1651 ConversionPatternRewriter &rewriter) const override { 1652 VectorType resultType = cast<VectorType>(splatOp.getType()); 1653 if (resultType.getRank() > 1) 1654 return failure(); 1655 1656 // First insert it into an undef vector so we can shuffle it. 1657 auto vectorType = typeConverter->convertType(splatOp.getType()); 1658 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1659 auto zero = rewriter.create<LLVM::ConstantOp>( 1660 splatOp.getLoc(), 1661 typeConverter->convertType(rewriter.getIntegerType(32)), 1662 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1663 1664 // For 0-d vector, we simply do `insertelement`. 1665 if (resultType.getRank() == 0) { 1666 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1667 splatOp, vectorType, undef, adaptor.getInput(), zero); 1668 return success(); 1669 } 1670 1671 // For 1-d vector, we additionally do a `vectorshuffle`. 1672 auto v = rewriter.create<LLVM::InsertElementOp>( 1673 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1674 1675 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); 1676 SmallVector<int32_t> zeroValues(width, 0); 1677 1678 // Shuffle the value across the desired number of elements. 1679 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1680 zeroValues); 1681 return success(); 1682 } 1683 }; 1684 1685 /// The Splat operation is lowered to an insertelement + a shufflevector 1686 /// operation. Splat to only 2+-d vector result types are lowered by the 1687 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1688 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1689 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1690 1691 LogicalResult 1692 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1693 ConversionPatternRewriter &rewriter) const override { 1694 VectorType resultType = splatOp.getType(); 1695 if (resultType.getRank() <= 1) 1696 return failure(); 1697 1698 // First insert it into an undef vector so we can shuffle it. 1699 auto loc = splatOp.getLoc(); 1700 auto vectorTypeInfo = 1701 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1702 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1703 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1704 if (!llvmNDVectorTy || !llvm1DVectorTy) 1705 return failure(); 1706 1707 // Construct returned value. 1708 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1709 1710 // Construct a 1-D vector with the splatted value that we insert in all the 1711 // places within the returned descriptor. 1712 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1713 auto zero = rewriter.create<LLVM::ConstantOp>( 1714 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1715 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1716 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1717 adaptor.getInput(), zero); 1718 1719 // Shuffle the value across the desired number of elements. 1720 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1721 SmallVector<int32_t> zeroValues(width, 0); 1722 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1723 1724 // Iterate of linear index, convert to coords space and insert splatted 1-D 1725 // vector in each position. 1726 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1727 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1728 }); 1729 rewriter.replaceOp(splatOp, desc); 1730 return success(); 1731 } 1732 }; 1733 1734 } // namespace 1735 1736 /// Populate the given list with patterns that convert from Vector to LLVM. 1737 void mlir::populateVectorToLLVMConversionPatterns( 1738 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1739 bool reassociateFPReductions, bool force32BitVectorIndices) { 1740 MLIRContext *ctx = converter.getDialect()->getContext(); 1741 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1742 populateVectorInsertExtractStridedSliceTransforms(patterns); 1743 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1744 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1745 patterns 1746 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1747 VectorExtractElementOpConversion, VectorExtractOpConversion, 1748 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1749 VectorInsertOpConversion, VectorPrintOpConversion, 1750 VectorTypeCastOpConversion, VectorScaleOpConversion, 1751 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1752 VectorLoadStoreConversion<vector::MaskedLoadOp, 1753 vector::MaskedLoadOpAdaptor>, 1754 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1755 VectorLoadStoreConversion<vector::MaskedStoreOp, 1756 vector::MaskedStoreOpAdaptor>, 1757 VectorGatherOpConversion, VectorScatterOpConversion, 1758 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1759 VectorSplatOpLowering, VectorSplatNdOpLowering, 1760 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1761 MaskedReductionOpConversion>(converter); 1762 // Transfer ops with rank > 1 are handled by VectorToSCF. 1763 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1764 } 1765 1766 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1767 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1768 patterns.add<VectorMatmulOpConversion>(converter); 1769 patterns.add<VectorFlatTransposeOpConversion>(converter); 1770 } 1771