1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" 12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Arith/Utils/Utils.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/Vector/IR/VectorOps.h" 20 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 21 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 23 #include "mlir/IR/BuiltinAttributes.h" 24 #include "mlir/IR/BuiltinTypeInterfaces.h" 25 #include "mlir/IR/BuiltinTypes.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "llvm/ADT/APFloat.h" 30 #include "llvm/Support/Casting.h" 31 #include <optional> 32 33 using namespace mlir; 34 using namespace mlir::vector; 35 36 // Helper to reduce vector type by *all* but one rank at back. 37 static VectorType reducedVectorTypeBack(VectorType tp) { 38 assert((tp.getRank() > 1) && "unlowerable vector type"); 39 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 40 tp.getScalableDims().take_back()); 41 } 42 43 // Helper that picks the proper sequence for inserting. 44 static Value insertOne(ConversionPatternRewriter &rewriter, 45 const LLVMTypeConverter &typeConverter, Location loc, 46 Value val1, Value val2, Type llvmType, int64_t rank, 47 int64_t pos) { 48 assert(rank > 0 && "0-D vector corner case should have been handled already"); 49 if (rank == 1) { 50 auto idxType = rewriter.getIndexType(); 51 auto constant = rewriter.create<LLVM::ConstantOp>( 52 loc, typeConverter.convertType(idxType), 53 rewriter.getIntegerAttr(idxType, pos)); 54 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 55 constant); 56 } 57 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 58 } 59 60 // Helper that picks the proper sequence for extracting. 61 static Value extractOne(ConversionPatternRewriter &rewriter, 62 const LLVMTypeConverter &typeConverter, Location loc, 63 Value val, Type llvmType, int64_t rank, int64_t pos) { 64 if (rank <= 1) { 65 auto idxType = rewriter.getIndexType(); 66 auto constant = rewriter.create<LLVM::ConstantOp>( 67 loc, typeConverter.convertType(idxType), 68 rewriter.getIntegerAttr(idxType, pos)); 69 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 70 constant); 71 } 72 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 73 } 74 75 // Helper that returns data layout alignment of a memref. 76 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, 77 MemRefType memrefType, unsigned &align) { 78 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 79 if (!elementTy) 80 return failure(); 81 82 // TODO: this should use the MLIR data layout when it becomes available and 83 // stop depending on translation. 84 llvm::LLVMContext llvmContext; 85 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 86 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 87 return success(); 88 } 89 90 // Check if the last stride is non-unit or the memory space is not zero. 91 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 92 const LLVMTypeConverter &converter) { 93 if (!isLastMemrefDimUnitStride(memRefType)) 94 return failure(); 95 FailureOr<unsigned> addressSpace = 96 converter.getMemRefAddressSpace(memRefType); 97 if (failed(addressSpace) || *addressSpace != 0) 98 return failure(); 99 return success(); 100 } 101 102 // Add an index vector component to a base pointer. 103 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 104 const LLVMTypeConverter &typeConverter, 105 MemRefType memRefType, Value llvmMemref, Value base, 106 Value index, uint64_t vLen) { 107 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 108 "unsupported memref type"); 109 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 110 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 111 return rewriter.create<LLVM::GEPOp>( 112 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 113 base, index); 114 } 115 116 // Casts a strided element pointer to a vector pointer. The vector pointer 117 // will be in the same address space as the incoming memref type. 118 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 119 Value ptr, MemRefType memRefType, Type vt, 120 const LLVMTypeConverter &converter) { 121 if (converter.useOpaquePointers()) 122 return ptr; 123 124 unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); 125 auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); 126 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 127 } 128 129 namespace { 130 131 /// Trivial Vector to LLVM conversions 132 using VectorScaleOpConversion = 133 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 134 135 /// Conversion pattern for a vector.bitcast. 136 class VectorBitCastOpConversion 137 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 138 public: 139 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 140 141 LogicalResult 142 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 143 ConversionPatternRewriter &rewriter) const override { 144 // Only 0-D and 1-D vectors can be lowered to LLVM. 145 VectorType resultTy = bitCastOp.getResultVectorType(); 146 if (resultTy.getRank() > 1) 147 return failure(); 148 Type newResultTy = typeConverter->convertType(resultTy); 149 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 150 adaptor.getOperands()[0]); 151 return success(); 152 } 153 }; 154 155 /// Conversion pattern for a vector.matrix_multiply. 156 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 157 class VectorMatmulOpConversion 158 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 159 public: 160 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 161 162 LogicalResult 163 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 164 ConversionPatternRewriter &rewriter) const override { 165 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 166 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 167 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 168 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 169 return success(); 170 } 171 }; 172 173 /// Conversion pattern for a vector.flat_transpose. 174 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 175 class VectorFlatTransposeOpConversion 176 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 177 public: 178 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 179 180 LogicalResult 181 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 182 ConversionPatternRewriter &rewriter) const override { 183 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 184 transOp, typeConverter->convertType(transOp.getRes().getType()), 185 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 186 return success(); 187 } 188 }; 189 190 /// Overloaded utility that replaces a vector.load, vector.store, 191 /// vector.maskedload and vector.maskedstore with their respective LLVM 192 /// couterparts. 193 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 194 vector::LoadOpAdaptor adaptor, 195 VectorType vectorTy, Value ptr, unsigned align, 196 ConversionPatternRewriter &rewriter) { 197 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 198 } 199 200 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 201 vector::MaskedLoadOpAdaptor adaptor, 202 VectorType vectorTy, Value ptr, unsigned align, 203 ConversionPatternRewriter &rewriter) { 204 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 205 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 206 } 207 208 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 209 vector::StoreOpAdaptor adaptor, 210 VectorType vectorTy, Value ptr, unsigned align, 211 ConversionPatternRewriter &rewriter) { 212 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 213 ptr, align); 214 } 215 216 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 217 vector::MaskedStoreOpAdaptor adaptor, 218 VectorType vectorTy, Value ptr, unsigned align, 219 ConversionPatternRewriter &rewriter) { 220 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 221 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 222 } 223 224 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 225 /// vector.maskedstore. 226 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 227 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 228 public: 229 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 230 231 LogicalResult 232 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 233 typename LoadOrStoreOp::Adaptor adaptor, 234 ConversionPatternRewriter &rewriter) const override { 235 // Only 1-D vectors can be lowered to LLVM. 236 VectorType vectorTy = loadOrStoreOp.getVectorType(); 237 if (vectorTy.getRank() > 1) 238 return failure(); 239 240 auto loc = loadOrStoreOp->getLoc(); 241 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 242 243 // Resolve alignment. 244 unsigned align; 245 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 246 return failure(); 247 248 // Resolve address. 249 auto vtype = cast<VectorType>( 250 this->typeConverter->convertType(loadOrStoreOp.getVectorType())); 251 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 252 adaptor.getIndices(), rewriter); 253 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, 254 *this->getTypeConverter()); 255 256 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 257 return success(); 258 } 259 }; 260 261 /// Conversion pattern for a vector.gather. 262 class VectorGatherOpConversion 263 : public ConvertOpToLLVMPattern<vector::GatherOp> { 264 public: 265 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 266 267 LogicalResult 268 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 269 ConversionPatternRewriter &rewriter) const override { 270 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); 271 assert(memRefType && "The base should be bufferized"); 272 273 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 274 return failure(); 275 276 auto loc = gather->getLoc(); 277 278 // Resolve alignment. 279 unsigned align; 280 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 281 return failure(); 282 283 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 284 adaptor.getIndices(), rewriter); 285 Value base = adaptor.getBase(); 286 287 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 288 // Handle the simple case of 1-D vector. 289 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) { 290 auto vType = gather.getVectorType(); 291 // Resolve address. 292 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 293 memRefType, base, ptr, adaptor.getIndexVec(), 294 /*vLen=*/vType.getDimSize(0)); 295 // Replace with the gather intrinsic. 296 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 297 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 298 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 299 return success(); 300 } 301 302 const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 303 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 304 &typeConverter](Type llvm1DVectorTy, 305 ValueRange vectorOperands) { 306 // Resolve address. 307 Value ptrs = getIndexedPtrs( 308 rewriter, loc, typeConverter, memRefType, base, ptr, 309 /*index=*/vectorOperands[0], 310 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 311 // Create the gather intrinsic. 312 return rewriter.create<LLVM::masked_gather>( 313 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 314 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 315 }; 316 SmallVector<Value> vectorOperands = { 317 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 318 return LLVM::detail::handleMultidimensionalVectors( 319 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 320 } 321 }; 322 323 /// Conversion pattern for a vector.scatter. 324 class VectorScatterOpConversion 325 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 326 public: 327 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 328 329 LogicalResult 330 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 331 ConversionPatternRewriter &rewriter) const override { 332 auto loc = scatter->getLoc(); 333 MemRefType memRefType = scatter.getMemRefType(); 334 335 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 336 return failure(); 337 338 // Resolve alignment. 339 unsigned align; 340 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 341 return failure(); 342 343 // Resolve address. 344 VectorType vType = scatter.getVectorType(); 345 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 346 adaptor.getIndices(), rewriter); 347 Value ptrs = getIndexedPtrs( 348 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 349 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 350 351 // Replace with the scatter intrinsic. 352 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 353 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 354 rewriter.getI32IntegerAttr(align)); 355 return success(); 356 } 357 }; 358 359 /// Conversion pattern for a vector.expandload. 360 class VectorExpandLoadOpConversion 361 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 362 public: 363 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 364 365 LogicalResult 366 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 367 ConversionPatternRewriter &rewriter) const override { 368 auto loc = expand->getLoc(); 369 MemRefType memRefType = expand.getMemRefType(); 370 371 // Resolve address. 372 auto vtype = typeConverter->convertType(expand.getVectorType()); 373 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 374 adaptor.getIndices(), rewriter); 375 376 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 377 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 378 return success(); 379 } 380 }; 381 382 /// Conversion pattern for a vector.compressstore. 383 class VectorCompressStoreOpConversion 384 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 385 public: 386 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 387 388 LogicalResult 389 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 390 ConversionPatternRewriter &rewriter) const override { 391 auto loc = compress->getLoc(); 392 MemRefType memRefType = compress.getMemRefType(); 393 394 // Resolve address. 395 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 396 adaptor.getIndices(), rewriter); 397 398 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 399 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 400 return success(); 401 } 402 }; 403 404 /// Reduction neutral classes for overloading. 405 class ReductionNeutralZero {}; 406 class ReductionNeutralIntOne {}; 407 class ReductionNeutralFPOne {}; 408 class ReductionNeutralAllOnes {}; 409 class ReductionNeutralSIntMin {}; 410 class ReductionNeutralUIntMin {}; 411 class ReductionNeutralSIntMax {}; 412 class ReductionNeutralUIntMax {}; 413 class ReductionNeutralFPMin {}; 414 class ReductionNeutralFPMax {}; 415 416 /// Create the reduction neutral zero value. 417 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 418 ConversionPatternRewriter &rewriter, 419 Location loc, Type llvmType) { 420 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 421 rewriter.getZeroAttr(llvmType)); 422 } 423 424 /// Create the reduction neutral integer one value. 425 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 426 ConversionPatternRewriter &rewriter, 427 Location loc, Type llvmType) { 428 return rewriter.create<LLVM::ConstantOp>( 429 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 430 } 431 432 /// Create the reduction neutral fp one value. 433 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 434 ConversionPatternRewriter &rewriter, 435 Location loc, Type llvmType) { 436 return rewriter.create<LLVM::ConstantOp>( 437 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 438 } 439 440 /// Create the reduction neutral all-ones value. 441 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 442 ConversionPatternRewriter &rewriter, 443 Location loc, Type llvmType) { 444 return rewriter.create<LLVM::ConstantOp>( 445 loc, llvmType, 446 rewriter.getIntegerAttr( 447 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 448 } 449 450 /// Create the reduction neutral signed int minimum value. 451 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 452 ConversionPatternRewriter &rewriter, 453 Location loc, Type llvmType) { 454 return rewriter.create<LLVM::ConstantOp>( 455 loc, llvmType, 456 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 457 llvmType.getIntOrFloatBitWidth()))); 458 } 459 460 /// Create the reduction neutral unsigned int minimum value. 461 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 462 ConversionPatternRewriter &rewriter, 463 Location loc, Type llvmType) { 464 return rewriter.create<LLVM::ConstantOp>( 465 loc, llvmType, 466 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 467 llvmType.getIntOrFloatBitWidth()))); 468 } 469 470 /// Create the reduction neutral signed int maximum value. 471 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 472 ConversionPatternRewriter &rewriter, 473 Location loc, Type llvmType) { 474 return rewriter.create<LLVM::ConstantOp>( 475 loc, llvmType, 476 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 477 llvmType.getIntOrFloatBitWidth()))); 478 } 479 480 /// Create the reduction neutral unsigned int maximum value. 481 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 482 ConversionPatternRewriter &rewriter, 483 Location loc, Type llvmType) { 484 return rewriter.create<LLVM::ConstantOp>( 485 loc, llvmType, 486 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 487 llvmType.getIntOrFloatBitWidth()))); 488 } 489 490 /// Create the reduction neutral fp minimum value. 491 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 492 ConversionPatternRewriter &rewriter, 493 Location loc, Type llvmType) { 494 auto floatType = cast<FloatType>(llvmType); 495 return rewriter.create<LLVM::ConstantOp>( 496 loc, llvmType, 497 rewriter.getFloatAttr( 498 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 499 /*Negative=*/false))); 500 } 501 502 /// Create the reduction neutral fp maximum value. 503 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 504 ConversionPatternRewriter &rewriter, 505 Location loc, Type llvmType) { 506 auto floatType = cast<FloatType>(llvmType); 507 return rewriter.create<LLVM::ConstantOp>( 508 loc, llvmType, 509 rewriter.getFloatAttr( 510 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 511 /*Negative=*/true))); 512 } 513 514 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 515 /// returns a new accumulator value using `ReductionNeutral`. 516 template <class ReductionNeutral> 517 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 518 Location loc, Type llvmType, 519 Value accumulator) { 520 if (accumulator) 521 return accumulator; 522 523 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 524 llvmType); 525 } 526 527 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 528 /// This is used as effective vector length by some intrinsics supporting 529 /// dynamic vector lengths at runtime. 530 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 531 Location loc, Type llvmType) { 532 VectorType vType = cast<VectorType>(llvmType); 533 auto vShape = vType.getShape(); 534 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 535 536 return rewriter.create<LLVM::ConstantOp>( 537 loc, rewriter.getI32Type(), 538 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 539 } 540 541 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 542 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 543 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 544 /// non-null. 545 template <class LLVMRedIntrinOp, class ScalarOp> 546 static Value createIntegerReductionArithmeticOpLowering( 547 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 548 Value vectorOperand, Value accumulator) { 549 550 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 551 552 if (accumulator) 553 result = rewriter.create<ScalarOp>(loc, accumulator, result); 554 return result; 555 } 556 557 /// Helper method to lower a `vector.reduction` operation that performs 558 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 559 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 560 /// the accumulator value if non-null. 561 template <class LLVMRedIntrinOp> 562 static Value createIntegerReductionComparisonOpLowering( 563 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 564 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 565 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 566 if (accumulator) { 567 Value cmp = 568 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 569 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 570 } 571 return result; 572 } 573 574 namespace { 575 template <typename Source> 576 struct VectorToScalarMapper; 577 template <> 578 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> { 579 using Type = LLVM::MaximumOp; 580 }; 581 template <> 582 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> { 583 using Type = LLVM::MinimumOp; 584 }; 585 template <> 586 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> { 587 using Type = LLVM::MaxNumOp; 588 }; 589 template <> 590 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> { 591 using Type = LLVM::MinNumOp; 592 }; 593 } // namespace 594 595 template <class LLVMRedIntrinOp> 596 static Value createFPReductionComparisonOpLowering( 597 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 598 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { 599 Value result = 600 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf); 601 602 if (accumulator) { 603 result = 604 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>( 605 loc, result, accumulator); 606 } 607 608 return result; 609 } 610 611 /// Reduction neutral classes for overloading 612 class MaskNeutralFMaximum {}; 613 class MaskNeutralFMinimum {}; 614 615 /// Get the mask neutral floating point maximum value 616 static llvm::APFloat 617 getMaskNeutralValue(MaskNeutralFMaximum, 618 const llvm::fltSemantics &floatSemantics) { 619 return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true); 620 } 621 /// Get the mask neutral floating point minimum value 622 static llvm::APFloat 623 getMaskNeutralValue(MaskNeutralFMinimum, 624 const llvm::fltSemantics &floatSemantics) { 625 return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false); 626 } 627 628 /// Create the mask neutral floating point MLIR vector constant 629 template <typename MaskNeutral> 630 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, 631 Location loc, Type llvmType, 632 Type vectorType) { 633 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics(); 634 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); 635 auto denseValue = 636 DenseElementsAttr::get(vectorType.cast<ShapedType>(), value); 637 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue); 638 } 639 640 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked 641 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for 642 /// `fmaximum`/`fminimum`. 643 /// More information: https://github.com/llvm/llvm-project/issues/64940 644 template <class LLVMRedIntrinOp, class MaskNeutral> 645 static Value 646 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, 647 Location loc, Type llvmType, 648 Value vectorOperand, Value accumulator, 649 Value mask, LLVM::FastmathFlagsAttr fmf) { 650 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>( 651 rewriter, loc, llvmType, vectorOperand.getType()); 652 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>( 653 loc, mask, vectorOperand, vectorMaskNeutral); 654 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>( 655 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); 656 } 657 658 template <class LLVMRedIntrinOp, class ReductionNeutral> 659 static Value 660 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 661 Type llvmType, Value vectorOperand, 662 Value accumulator, LLVM::FastmathFlagsAttr fmf) { 663 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 664 llvmType, accumulator); 665 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType, 666 /*startValue=*/accumulator, 667 vectorOperand, fmf); 668 } 669 670 /// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic 671 /// that requires a start value. This start value format spans across fp 672 /// reductions without mask and all the masked reduction intrinsics. 673 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 674 static Value 675 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, 676 Location loc, Type llvmType, 677 Value vectorOperand, Value accumulator) { 678 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 679 llvmType, accumulator); 680 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 681 /*startValue=*/accumulator, 682 vectorOperand); 683 } 684 685 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 686 static Value lowerPredicatedReductionWithStartValue( 687 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 688 Value vectorOperand, Value accumulator, Value mask) { 689 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 690 llvmType, accumulator); 691 Value vectorLength = 692 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 693 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 694 /*startValue=*/accumulator, 695 vectorOperand, mask, vectorLength); 696 } 697 698 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 699 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 700 static Value lowerPredicatedReductionWithStartValue( 701 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 702 Value vectorOperand, Value accumulator, Value mask) { 703 if (llvmType.isIntOrIndex()) 704 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp, 705 IntReductionNeutral>( 706 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 707 708 // FP dispatch. 709 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp, 710 FPReductionNeutral>( 711 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 712 } 713 714 /// Conversion pattern for all vector reductions. 715 class VectorReductionOpConversion 716 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 717 public: 718 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv, 719 bool reassociateFPRed) 720 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 721 reassociateFPReductions(reassociateFPRed) {} 722 723 LogicalResult 724 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 725 ConversionPatternRewriter &rewriter) const override { 726 auto kind = reductionOp.getKind(); 727 Type eltType = reductionOp.getDest().getType(); 728 Type llvmType = typeConverter->convertType(eltType); 729 Value operand = adaptor.getVector(); 730 Value acc = adaptor.getAcc(); 731 Location loc = reductionOp.getLoc(); 732 733 if (eltType.isIntOrIndex()) { 734 // Integer reductions: add/mul/min/max/and/or/xor. 735 Value result; 736 switch (kind) { 737 case vector::CombiningKind::ADD: 738 result = 739 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 740 LLVM::AddOp>( 741 rewriter, loc, llvmType, operand, acc); 742 break; 743 case vector::CombiningKind::MUL: 744 result = 745 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 746 LLVM::MulOp>( 747 rewriter, loc, llvmType, operand, acc); 748 break; 749 case vector::CombiningKind::MINUI: 750 result = createIntegerReductionComparisonOpLowering< 751 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 752 LLVM::ICmpPredicate::ule); 753 break; 754 case vector::CombiningKind::MINSI: 755 result = createIntegerReductionComparisonOpLowering< 756 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 757 LLVM::ICmpPredicate::sle); 758 break; 759 case vector::CombiningKind::MAXUI: 760 result = createIntegerReductionComparisonOpLowering< 761 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 762 LLVM::ICmpPredicate::uge); 763 break; 764 case vector::CombiningKind::MAXSI: 765 result = createIntegerReductionComparisonOpLowering< 766 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 767 LLVM::ICmpPredicate::sge); 768 break; 769 case vector::CombiningKind::AND: 770 result = 771 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 772 LLVM::AndOp>( 773 rewriter, loc, llvmType, operand, acc); 774 break; 775 case vector::CombiningKind::OR: 776 result = 777 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 778 LLVM::OrOp>( 779 rewriter, loc, llvmType, operand, acc); 780 break; 781 case vector::CombiningKind::XOR: 782 result = 783 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 784 LLVM::XOrOp>( 785 rewriter, loc, llvmType, operand, acc); 786 break; 787 default: 788 return failure(); 789 } 790 rewriter.replaceOp(reductionOp, result); 791 792 return success(); 793 } 794 795 if (!isa<FloatType>(eltType)) 796 return failure(); 797 798 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 799 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 800 reductionOp.getContext(), 801 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 802 fmf = LLVM::FastmathFlagsAttr::get( 803 reductionOp.getContext(), 804 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc 805 : LLVM::FastmathFlags::none)); 806 807 // Floating-point reductions: add/mul/min/max 808 Value result; 809 if (kind == vector::CombiningKind::ADD) { 810 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 811 ReductionNeutralZero>( 812 rewriter, loc, llvmType, operand, acc, fmf); 813 } else if (kind == vector::CombiningKind::MUL) { 814 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 815 ReductionNeutralFPOne>( 816 rewriter, loc, llvmType, operand, acc, fmf); 817 } else if (kind == vector::CombiningKind::MINIMUMF) { 818 result = 819 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>( 820 rewriter, loc, llvmType, operand, acc, fmf); 821 } else if (kind == vector::CombiningKind::MAXIMUMF) { 822 result = 823 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>( 824 rewriter, loc, llvmType, operand, acc, fmf); 825 } else if (kind == vector::CombiningKind::MINF) { 826 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 827 rewriter, loc, llvmType, operand, acc, fmf); 828 } else if (kind == vector::CombiningKind::MAXF) { 829 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 830 rewriter, loc, llvmType, operand, acc, fmf); 831 } else 832 return failure(); 833 834 rewriter.replaceOp(reductionOp, result); 835 return success(); 836 } 837 838 private: 839 const bool reassociateFPReductions; 840 }; 841 842 /// Base class to convert a `vector.mask` operation while matching traits 843 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 844 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 845 /// method performs a second match against the maskable operation `MaskedOp`. 846 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 847 /// implemented by the concrete conversion classes. This method can match 848 /// against specific traits of the `vector.mask` and the maskable operation. It 849 /// must replace the `vector.mask` operation. 850 template <class MaskedOp> 851 class VectorMaskOpConversionBase 852 : public ConvertOpToLLVMPattern<vector::MaskOp> { 853 public: 854 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 855 856 LogicalResult 857 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 858 ConversionPatternRewriter &rewriter) const final { 859 // Match against the maskable operation kind. 860 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 861 if (!maskedOp) 862 return failure(); 863 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 864 } 865 866 protected: 867 virtual LogicalResult 868 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 869 vector::MaskableOpInterface maskableOp, 870 ConversionPatternRewriter &rewriter) const = 0; 871 }; 872 873 class MaskedReductionOpConversion 874 : public VectorMaskOpConversionBase<vector::ReductionOp> { 875 876 public: 877 using VectorMaskOpConversionBase< 878 vector::ReductionOp>::VectorMaskOpConversionBase; 879 880 LogicalResult matchAndRewriteMaskableOp( 881 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 882 ConversionPatternRewriter &rewriter) const override { 883 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 884 auto kind = reductionOp.getKind(); 885 Type eltType = reductionOp.getDest().getType(); 886 Type llvmType = typeConverter->convertType(eltType); 887 Value operand = reductionOp.getVector(); 888 Value acc = reductionOp.getAcc(); 889 Location loc = reductionOp.getLoc(); 890 891 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); 892 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( 893 reductionOp.getContext(), 894 convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); 895 896 Value result; 897 switch (kind) { 898 case vector::CombiningKind::ADD: 899 result = lowerPredicatedReductionWithStartValue< 900 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 901 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 902 maskOp.getMask()); 903 break; 904 case vector::CombiningKind::MUL: 905 result = lowerPredicatedReductionWithStartValue< 906 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 907 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 908 maskOp.getMask()); 909 break; 910 case vector::CombiningKind::MINUI: 911 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp, 912 ReductionNeutralUIntMax>( 913 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 914 break; 915 case vector::CombiningKind::MINSI: 916 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp, 917 ReductionNeutralSIntMax>( 918 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 919 break; 920 case vector::CombiningKind::MAXUI: 921 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp, 922 ReductionNeutralUIntMin>( 923 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 924 break; 925 case vector::CombiningKind::MAXSI: 926 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp, 927 ReductionNeutralSIntMin>( 928 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 929 break; 930 case vector::CombiningKind::AND: 931 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp, 932 ReductionNeutralAllOnes>( 933 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 934 break; 935 case vector::CombiningKind::OR: 936 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp, 937 ReductionNeutralZero>( 938 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 939 break; 940 case vector::CombiningKind::XOR: 941 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp, 942 ReductionNeutralZero>( 943 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 944 break; 945 case vector::CombiningKind::MINF: 946 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp, 947 ReductionNeutralFPMax>( 948 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 949 break; 950 case vector::CombiningKind::MAXF: 951 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp, 952 ReductionNeutralFPMin>( 953 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 954 break; 955 case CombiningKind::MAXIMUMF: 956 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum, 957 MaskNeutralFMaximum>( 958 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 959 break; 960 case CombiningKind::MINIMUMF: 961 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum, 962 MaskNeutralFMinimum>( 963 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); 964 break; 965 } 966 967 // Replace `vector.mask` operation altogether. 968 rewriter.replaceOp(maskOp, result); 969 return success(); 970 } 971 }; 972 973 class VectorShuffleOpConversion 974 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 975 public: 976 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 977 978 LogicalResult 979 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 980 ConversionPatternRewriter &rewriter) const override { 981 auto loc = shuffleOp->getLoc(); 982 auto v1Type = shuffleOp.getV1VectorType(); 983 auto v2Type = shuffleOp.getV2VectorType(); 984 auto vectorType = shuffleOp.getResultVectorType(); 985 Type llvmType = typeConverter->convertType(vectorType); 986 auto maskArrayAttr = shuffleOp.getMask(); 987 988 // Bail if result type cannot be lowered. 989 if (!llvmType) 990 return failure(); 991 992 // Get rank and dimension sizes. 993 int64_t rank = vectorType.getRank(); 994 #ifndef NDEBUG 995 bool wellFormed0DCase = 996 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 997 bool wellFormedNDCase = 998 v1Type.getRank() == rank && v2Type.getRank() == rank; 999 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 1000 #endif 1001 1002 // For rank 0 and 1, where both operands have *exactly* the same vector 1003 // type, there is direct shuffle support in LLVM. Use it! 1004 if (rank <= 1 && v1Type == v2Type) { 1005 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 1006 loc, adaptor.getV1(), adaptor.getV2(), 1007 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 1008 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 1009 return success(); 1010 } 1011 1012 // For all other cases, insert the individual values individually. 1013 int64_t v1Dim = v1Type.getDimSize(0); 1014 Type eltType; 1015 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType)) 1016 eltType = arrayType.getElementType(); 1017 else 1018 eltType = cast<VectorType>(llvmType).getElementType(); 1019 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 1020 int64_t insPos = 0; 1021 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 1022 int64_t extPos = cast<IntegerAttr>(en.value()).getInt(); 1023 Value value = adaptor.getV1(); 1024 if (extPos >= v1Dim) { 1025 extPos -= v1Dim; 1026 value = adaptor.getV2(); 1027 } 1028 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 1029 eltType, rank, extPos); 1030 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 1031 llvmType, rank, insPos++); 1032 } 1033 rewriter.replaceOp(shuffleOp, insert); 1034 return success(); 1035 } 1036 }; 1037 1038 class VectorExtractElementOpConversion 1039 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 1040 public: 1041 using ConvertOpToLLVMPattern< 1042 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 1043 1044 LogicalResult 1045 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 1046 ConversionPatternRewriter &rewriter) const override { 1047 auto vectorType = extractEltOp.getSourceVectorType(); 1048 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 1049 1050 // Bail if result type cannot be lowered. 1051 if (!llvmType) 1052 return failure(); 1053 1054 if (vectorType.getRank() == 0) { 1055 Location loc = extractEltOp.getLoc(); 1056 auto idxType = rewriter.getIndexType(); 1057 auto zero = rewriter.create<LLVM::ConstantOp>( 1058 loc, typeConverter->convertType(idxType), 1059 rewriter.getIntegerAttr(idxType, 0)); 1060 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1061 extractEltOp, llvmType, adaptor.getVector(), zero); 1062 return success(); 1063 } 1064 1065 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1066 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1067 return success(); 1068 } 1069 }; 1070 1071 class VectorExtractOpConversion 1072 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1073 public: 1074 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1075 1076 LogicalResult 1077 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1078 ConversionPatternRewriter &rewriter) const override { 1079 auto loc = extractOp->getLoc(); 1080 auto resultType = extractOp.getResult().getType(); 1081 auto llvmResultType = typeConverter->convertType(resultType); 1082 ArrayRef<int64_t> positionArray = extractOp.getPosition(); 1083 1084 // Bail if result type cannot be lowered. 1085 if (!llvmResultType) 1086 return failure(); 1087 1088 // Extract entire vector. Should be handled by folder, but just to be safe. 1089 if (positionArray.empty()) { 1090 rewriter.replaceOp(extractOp, adaptor.getVector()); 1091 return success(); 1092 } 1093 1094 // One-shot extraction of vector from array (only requires extractvalue). 1095 if (isa<VectorType>(resultType)) { 1096 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1097 loc, adaptor.getVector(), positionArray); 1098 rewriter.replaceOp(extractOp, extracted); 1099 return success(); 1100 } 1101 1102 // Potential extraction of 1-D vector from array. 1103 Value extracted = adaptor.getVector(); 1104 if (positionArray.size() > 1) { 1105 extracted = rewriter.create<LLVM::ExtractValueOp>( 1106 loc, extracted, positionArray.drop_back()); 1107 } 1108 1109 // Remaining extraction of element from 1-D LLVM vector 1110 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1111 auto constant = 1112 rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back()); 1113 extracted = 1114 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 1115 rewriter.replaceOp(extractOp, extracted); 1116 1117 return success(); 1118 } 1119 }; 1120 1121 /// Conversion pattern that turns a vector.fma on a 1-D vector 1122 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1123 /// This does not match vectors of n >= 2 rank. 1124 /// 1125 /// Example: 1126 /// ``` 1127 /// vector.fma %a, %a, %a : vector<8xf32> 1128 /// ``` 1129 /// is converted to: 1130 /// ``` 1131 /// llvm.intr.fmuladd %va, %va, %va: 1132 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1133 /// -> !llvm."<8 x f32>"> 1134 /// ``` 1135 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1136 public: 1137 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1138 1139 LogicalResult 1140 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1141 ConversionPatternRewriter &rewriter) const override { 1142 VectorType vType = fmaOp.getVectorType(); 1143 if (vType.getRank() > 1) 1144 return failure(); 1145 1146 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1147 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1148 return success(); 1149 } 1150 }; 1151 1152 class VectorInsertElementOpConversion 1153 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1154 public: 1155 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1156 1157 LogicalResult 1158 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1159 ConversionPatternRewriter &rewriter) const override { 1160 auto vectorType = insertEltOp.getDestVectorType(); 1161 auto llvmType = typeConverter->convertType(vectorType); 1162 1163 // Bail if result type cannot be lowered. 1164 if (!llvmType) 1165 return failure(); 1166 1167 if (vectorType.getRank() == 0) { 1168 Location loc = insertEltOp.getLoc(); 1169 auto idxType = rewriter.getIndexType(); 1170 auto zero = rewriter.create<LLVM::ConstantOp>( 1171 loc, typeConverter->convertType(idxType), 1172 rewriter.getIntegerAttr(idxType, 0)); 1173 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1174 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1175 return success(); 1176 } 1177 1178 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1179 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1180 adaptor.getPosition()); 1181 return success(); 1182 } 1183 }; 1184 1185 class VectorInsertOpConversion 1186 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1187 public: 1188 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1189 1190 LogicalResult 1191 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1192 ConversionPatternRewriter &rewriter) const override { 1193 auto loc = insertOp->getLoc(); 1194 auto sourceType = insertOp.getSourceType(); 1195 auto destVectorType = insertOp.getDestVectorType(); 1196 auto llvmResultType = typeConverter->convertType(destVectorType); 1197 ArrayRef<int64_t> positionArray = insertOp.getPosition(); 1198 1199 // Bail if result type cannot be lowered. 1200 if (!llvmResultType) 1201 return failure(); 1202 1203 // Overwrite entire vector with value. Should be handled by folder, but 1204 // just to be safe. 1205 if (positionArray.empty()) { 1206 rewriter.replaceOp(insertOp, adaptor.getSource()); 1207 return success(); 1208 } 1209 1210 // One-shot insertion of a vector into an array (only requires insertvalue). 1211 if (isa<VectorType>(sourceType)) { 1212 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1213 loc, adaptor.getDest(), adaptor.getSource(), positionArray); 1214 rewriter.replaceOp(insertOp, inserted); 1215 return success(); 1216 } 1217 1218 // Potential extraction of 1-D vector from array. 1219 Value extracted = adaptor.getDest(); 1220 auto oneDVectorType = destVectorType; 1221 if (positionArray.size() > 1) { 1222 oneDVectorType = reducedVectorTypeBack(destVectorType); 1223 extracted = rewriter.create<LLVM::ExtractValueOp>( 1224 loc, extracted, positionArray.drop_back()); 1225 } 1226 1227 // Insertion of an element into a 1-D LLVM vector. 1228 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1229 auto constant = 1230 rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back()); 1231 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1232 loc, typeConverter->convertType(oneDVectorType), extracted, 1233 adaptor.getSource(), constant); 1234 1235 // Potential insertion of resulting 1-D vector into array. 1236 if (positionArray.size() > 1) { 1237 inserted = rewriter.create<LLVM::InsertValueOp>( 1238 loc, adaptor.getDest(), inserted, positionArray.drop_back()); 1239 } 1240 1241 rewriter.replaceOp(insertOp, inserted); 1242 return success(); 1243 } 1244 }; 1245 1246 /// Lower vector.scalable.insert ops to LLVM vector.insert 1247 struct VectorScalableInsertOpLowering 1248 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1249 using ConvertOpToLLVMPattern< 1250 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1251 1252 LogicalResult 1253 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1254 ConversionPatternRewriter &rewriter) const override { 1255 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1256 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1257 return success(); 1258 } 1259 }; 1260 1261 /// Lower vector.scalable.extract ops to LLVM vector.extract 1262 struct VectorScalableExtractOpLowering 1263 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1264 using ConvertOpToLLVMPattern< 1265 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1266 1267 LogicalResult 1268 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1269 ConversionPatternRewriter &rewriter) const override { 1270 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1271 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1272 adaptor.getSource(), adaptor.getPos()); 1273 return success(); 1274 } 1275 }; 1276 1277 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1278 /// 1279 /// Example: 1280 /// ``` 1281 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1282 /// ``` 1283 /// is rewritten into: 1284 /// ``` 1285 /// %r = splat %f0: vector<2x4xf32> 1286 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1287 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1288 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1289 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1290 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1291 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1292 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1293 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1294 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1295 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1296 /// // %r3 holds the final value. 1297 /// ``` 1298 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1299 public: 1300 using OpRewritePattern<FMAOp>::OpRewritePattern; 1301 1302 void initialize() { 1303 // This pattern recursively unpacks one dimension at a time. The recursion 1304 // bounded as the rank is strictly decreasing. 1305 setHasBoundedRewriteRecursion(); 1306 } 1307 1308 LogicalResult matchAndRewrite(FMAOp op, 1309 PatternRewriter &rewriter) const override { 1310 auto vType = op.getVectorType(); 1311 if (vType.getRank() < 2) 1312 return failure(); 1313 1314 auto loc = op.getLoc(); 1315 auto elemType = vType.getElementType(); 1316 Value zero = rewriter.create<arith::ConstantOp>( 1317 loc, elemType, rewriter.getZeroAttr(elemType)); 1318 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1319 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1320 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1321 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1322 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1323 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1324 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1325 } 1326 rewriter.replaceOp(op, desc); 1327 return success(); 1328 } 1329 }; 1330 1331 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1332 /// static layout. 1333 static std::optional<SmallVector<int64_t, 4>> 1334 computeContiguousStrides(MemRefType memRefType) { 1335 int64_t offset; 1336 SmallVector<int64_t, 4> strides; 1337 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1338 return std::nullopt; 1339 if (!strides.empty() && strides.back() != 1) 1340 return std::nullopt; 1341 // If no layout or identity layout, this is contiguous by definition. 1342 if (memRefType.getLayout().isIdentity()) 1343 return strides; 1344 1345 // Otherwise, we must determine contiguity form shapes. This can only ever 1346 // work in static cases because MemRefType is underspecified to represent 1347 // contiguous dynamic shapes in other ways than with just empty/identity 1348 // layout. 1349 auto sizes = memRefType.getShape(); 1350 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1351 if (ShapedType::isDynamic(sizes[index + 1]) || 1352 ShapedType::isDynamic(strides[index]) || 1353 ShapedType::isDynamic(strides[index + 1])) 1354 return std::nullopt; 1355 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1356 return std::nullopt; 1357 } 1358 return strides; 1359 } 1360 1361 class VectorTypeCastOpConversion 1362 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1363 public: 1364 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1365 1366 LogicalResult 1367 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1368 ConversionPatternRewriter &rewriter) const override { 1369 auto loc = castOp->getLoc(); 1370 MemRefType sourceMemRefType = 1371 cast<MemRefType>(castOp.getOperand().getType()); 1372 MemRefType targetMemRefType = castOp.getType(); 1373 1374 // Only static shape casts supported atm. 1375 if (!sourceMemRefType.hasStaticShape() || 1376 !targetMemRefType.hasStaticShape()) 1377 return failure(); 1378 1379 auto llvmSourceDescriptorTy = 1380 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); 1381 if (!llvmSourceDescriptorTy) 1382 return failure(); 1383 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1384 1385 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1386 typeConverter->convertType(targetMemRefType)); 1387 if (!llvmTargetDescriptorTy) 1388 return failure(); 1389 1390 // Only contiguous source buffers supported atm. 1391 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1392 if (!sourceStrides) 1393 return failure(); 1394 auto targetStrides = computeContiguousStrides(targetMemRefType); 1395 if (!targetStrides) 1396 return failure(); 1397 // Only support static strides for now, regardless of contiguity. 1398 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1399 return failure(); 1400 1401 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1402 1403 // Create descriptor. 1404 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1405 Type llvmTargetElementTy = desc.getElementPtrType(); 1406 // Set allocated ptr. 1407 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1408 if (!getTypeConverter()->useOpaquePointers()) 1409 allocated = 1410 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1411 desc.setAllocatedPtr(rewriter, loc, allocated); 1412 1413 // Set aligned ptr. 1414 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1415 if (!getTypeConverter()->useOpaquePointers()) 1416 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1417 1418 desc.setAlignedPtr(rewriter, loc, ptr); 1419 // Fill offset 0. 1420 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1421 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1422 desc.setOffset(rewriter, loc, zero); 1423 1424 // Fill size and stride descriptors in memref. 1425 for (const auto &indexedSize : 1426 llvm::enumerate(targetMemRefType.getShape())) { 1427 int64_t index = indexedSize.index(); 1428 auto sizeAttr = 1429 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1430 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1431 desc.setSize(rewriter, loc, index, size); 1432 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1433 (*targetStrides)[index]); 1434 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1435 desc.setStride(rewriter, loc, index, stride); 1436 } 1437 1438 rewriter.replaceOp(castOp, {desc}); 1439 return success(); 1440 } 1441 }; 1442 1443 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1444 /// Non-scalable versions of this operation are handled in Vector Transforms. 1445 class VectorCreateMaskOpRewritePattern 1446 : public OpRewritePattern<vector::CreateMaskOp> { 1447 public: 1448 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1449 bool enableIndexOpt) 1450 : OpRewritePattern<vector::CreateMaskOp>(context), 1451 force32BitVectorIndices(enableIndexOpt) {} 1452 1453 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1454 PatternRewriter &rewriter) const override { 1455 auto dstType = op.getType(); 1456 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) 1457 return failure(); 1458 IntegerType idxType = 1459 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1460 auto loc = op->getLoc(); 1461 Value indices = rewriter.create<LLVM::StepVectorOp>( 1462 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1463 /*isScalable=*/true)); 1464 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1465 op.getOperand(0)); 1466 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1467 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1468 indices, bounds); 1469 rewriter.replaceOp(op, comp); 1470 return success(); 1471 } 1472 1473 private: 1474 const bool force32BitVectorIndices; 1475 }; 1476 1477 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1478 public: 1479 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1480 1481 // Lowering implementation that relies on a small runtime support library, 1482 // which only needs to provide a few printing methods (single value for all 1483 // data types, opening/closing bracket, comma, newline). The lowering splits 1484 // the vector into elementary printing operations. The advantage of this 1485 // approach is that the library can remain unaware of all low-level 1486 // implementation details of vectors while still supporting output of any 1487 // shaped and dimensioned vector. 1488 // 1489 // Note: This lowering only handles scalars, n-D vectors are broken into 1490 // printing scalars in loops in VectorToSCF. 1491 // 1492 // TODO: rely solely on libc in future? something else? 1493 // 1494 LogicalResult 1495 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1496 ConversionPatternRewriter &rewriter) const override { 1497 auto parent = printOp->getParentOfType<ModuleOp>(); 1498 if (!parent) 1499 return failure(); 1500 1501 auto loc = printOp->getLoc(); 1502 1503 if (auto value = adaptor.getSource()) { 1504 Type printType = printOp.getPrintType(); 1505 if (isa<VectorType>(printType)) { 1506 // Vectors should be broken into elementary print ops in VectorToSCF. 1507 return failure(); 1508 } 1509 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value))) 1510 return failure(); 1511 } 1512 1513 auto punct = printOp.getPunctuation(); 1514 if (punct != PrintPunctuation::NoPunctuation) { 1515 emitCall(rewriter, printOp->getLoc(), [&] { 1516 switch (punct) { 1517 case PrintPunctuation::Close: 1518 return LLVM::lookupOrCreatePrintCloseFn(parent); 1519 case PrintPunctuation::Open: 1520 return LLVM::lookupOrCreatePrintOpenFn(parent); 1521 case PrintPunctuation::Comma: 1522 return LLVM::lookupOrCreatePrintCommaFn(parent); 1523 case PrintPunctuation::NewLine: 1524 return LLVM::lookupOrCreatePrintNewlineFn(parent); 1525 default: 1526 llvm_unreachable("unexpected punctuation"); 1527 } 1528 }()); 1529 } 1530 1531 rewriter.eraseOp(printOp); 1532 return success(); 1533 } 1534 1535 private: 1536 enum class PrintConversion { 1537 // clang-format off 1538 None, 1539 ZeroExt64, 1540 SignExt64, 1541 Bitcast16 1542 // clang-format on 1543 }; 1544 1545 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter, 1546 ModuleOp parent, Location loc, Type printType, 1547 Value value) const { 1548 if (typeConverter->convertType(printType) == nullptr) 1549 return failure(); 1550 1551 // Make sure element type has runtime support. 1552 PrintConversion conversion = PrintConversion::None; 1553 Operation *printer; 1554 if (printType.isF32()) { 1555 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1556 } else if (printType.isF64()) { 1557 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1558 } else if (printType.isF16()) { 1559 conversion = PrintConversion::Bitcast16; // bits! 1560 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1561 } else if (printType.isBF16()) { 1562 conversion = PrintConversion::Bitcast16; // bits! 1563 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1564 } else if (printType.isIndex()) { 1565 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1566 } else if (auto intTy = dyn_cast<IntegerType>(printType)) { 1567 // Integers need a zero or sign extension on the operand 1568 // (depending on the source type) as well as a signed or 1569 // unsigned print method. Up to 64-bit is supported. 1570 unsigned width = intTy.getWidth(); 1571 if (intTy.isUnsigned()) { 1572 if (width <= 64) { 1573 if (width < 64) 1574 conversion = PrintConversion::ZeroExt64; 1575 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1576 } else { 1577 return failure(); 1578 } 1579 } else { 1580 assert(intTy.isSignless() || intTy.isSigned()); 1581 if (width <= 64) { 1582 // Note that we *always* zero extend booleans (1-bit integers), 1583 // so that true/false is printed as 1/0 rather than -1/0. 1584 if (width == 1) 1585 conversion = PrintConversion::ZeroExt64; 1586 else if (width < 64) 1587 conversion = PrintConversion::SignExt64; 1588 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1589 } else { 1590 return failure(); 1591 } 1592 } 1593 } else { 1594 return failure(); 1595 } 1596 1597 switch (conversion) { 1598 case PrintConversion::ZeroExt64: 1599 value = rewriter.create<arith::ExtUIOp>( 1600 loc, IntegerType::get(rewriter.getContext(), 64), value); 1601 break; 1602 case PrintConversion::SignExt64: 1603 value = rewriter.create<arith::ExtSIOp>( 1604 loc, IntegerType::get(rewriter.getContext(), 64), value); 1605 break; 1606 case PrintConversion::Bitcast16: 1607 value = rewriter.create<LLVM::BitcastOp>( 1608 loc, IntegerType::get(rewriter.getContext(), 16), value); 1609 break; 1610 case PrintConversion::None: 1611 break; 1612 } 1613 emitCall(rewriter, loc, printer, value); 1614 return success(); 1615 } 1616 1617 // Helper to emit a call. 1618 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1619 Operation *ref, ValueRange params = ValueRange()) { 1620 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1621 params); 1622 } 1623 }; 1624 1625 /// The Splat operation is lowered to an insertelement + a shufflevector 1626 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1627 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1628 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1629 1630 LogicalResult 1631 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1632 ConversionPatternRewriter &rewriter) const override { 1633 VectorType resultType = cast<VectorType>(splatOp.getType()); 1634 if (resultType.getRank() > 1) 1635 return failure(); 1636 1637 // First insert it into an undef vector so we can shuffle it. 1638 auto vectorType = typeConverter->convertType(splatOp.getType()); 1639 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1640 auto zero = rewriter.create<LLVM::ConstantOp>( 1641 splatOp.getLoc(), 1642 typeConverter->convertType(rewriter.getIntegerType(32)), 1643 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1644 1645 // For 0-d vector, we simply do `insertelement`. 1646 if (resultType.getRank() == 0) { 1647 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1648 splatOp, vectorType, undef, adaptor.getInput(), zero); 1649 return success(); 1650 } 1651 1652 // For 1-d vector, we additionally do a `vectorshuffle`. 1653 auto v = rewriter.create<LLVM::InsertElementOp>( 1654 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1655 1656 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); 1657 SmallVector<int32_t> zeroValues(width, 0); 1658 1659 // Shuffle the value across the desired number of elements. 1660 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1661 zeroValues); 1662 return success(); 1663 } 1664 }; 1665 1666 /// The Splat operation is lowered to an insertelement + a shufflevector 1667 /// operation. Splat to only 2+-d vector result types are lowered by the 1668 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1669 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1670 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1671 1672 LogicalResult 1673 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1674 ConversionPatternRewriter &rewriter) const override { 1675 VectorType resultType = splatOp.getType(); 1676 if (resultType.getRank() <= 1) 1677 return failure(); 1678 1679 // First insert it into an undef vector so we can shuffle it. 1680 auto loc = splatOp.getLoc(); 1681 auto vectorTypeInfo = 1682 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1683 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1684 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1685 if (!llvmNDVectorTy || !llvm1DVectorTy) 1686 return failure(); 1687 1688 // Construct returned value. 1689 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1690 1691 // Construct a 1-D vector with the splatted value that we insert in all the 1692 // places within the returned descriptor. 1693 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1694 auto zero = rewriter.create<LLVM::ConstantOp>( 1695 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1696 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1697 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1698 adaptor.getInput(), zero); 1699 1700 // Shuffle the value across the desired number of elements. 1701 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1702 SmallVector<int32_t> zeroValues(width, 0); 1703 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1704 1705 // Iterate of linear index, convert to coords space and insert splatted 1-D 1706 // vector in each position. 1707 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1708 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1709 }); 1710 rewriter.replaceOp(splatOp, desc); 1711 return success(); 1712 } 1713 }; 1714 1715 } // namespace 1716 1717 /// Populate the given list with patterns that convert from Vector to LLVM. 1718 void mlir::populateVectorToLLVMConversionPatterns( 1719 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1720 bool reassociateFPReductions, bool force32BitVectorIndices) { 1721 MLIRContext *ctx = converter.getDialect()->getContext(); 1722 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1723 populateVectorInsertExtractStridedSliceTransforms(patterns); 1724 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1725 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1726 patterns 1727 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1728 VectorExtractElementOpConversion, VectorExtractOpConversion, 1729 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1730 VectorInsertOpConversion, VectorPrintOpConversion, 1731 VectorTypeCastOpConversion, VectorScaleOpConversion, 1732 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1733 VectorLoadStoreConversion<vector::MaskedLoadOp, 1734 vector::MaskedLoadOpAdaptor>, 1735 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1736 VectorLoadStoreConversion<vector::MaskedStoreOp, 1737 vector::MaskedStoreOpAdaptor>, 1738 VectorGatherOpConversion, VectorScatterOpConversion, 1739 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1740 VectorSplatOpLowering, VectorSplatNdOpLowering, 1741 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1742 MaskedReductionOpConversion>(converter); 1743 // Transfer ops with rank > 1 are handled by VectorToSCF. 1744 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1745 } 1746 1747 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1748 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1749 patterns.add<VectorMatmulOpConversion>(converter); 1750 patterns.add<VectorFlatTransposeOpConversion>(converter); 1751 } 1752