1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Arith/Utils/Utils.h" 15 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 19 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include <optional> 25 26 using namespace mlir; 27 using namespace mlir::vector; 28 29 // Helper to reduce vector type by one rank at front. 30 static VectorType reducedVectorTypeFront(VectorType tp) { 31 assert((tp.getRank() > 1) && "unlowerable vector type"); 32 unsigned numScalableDims = tp.getNumScalableDims(); 33 if (tp.getShape().size() == numScalableDims) 34 --numScalableDims; 35 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 36 numScalableDims); 37 } 38 39 // Helper to reduce vector type by *all* but one rank at back. 40 static VectorType reducedVectorTypeBack(VectorType tp) { 41 assert((tp.getRank() > 1) && "unlowerable vector type"); 42 unsigned numScalableDims = tp.getNumScalableDims(); 43 if (numScalableDims > 0) 44 --numScalableDims; 45 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 46 numScalableDims); 47 } 48 49 // Helper that picks the proper sequence for inserting. 50 static Value insertOne(ConversionPatternRewriter &rewriter, 51 LLVMTypeConverter &typeConverter, Location loc, 52 Value val1, Value val2, Type llvmType, int64_t rank, 53 int64_t pos) { 54 assert(rank > 0 && "0-D vector corner case should have been handled already"); 55 if (rank == 1) { 56 auto idxType = rewriter.getIndexType(); 57 auto constant = rewriter.create<LLVM::ConstantOp>( 58 loc, typeConverter.convertType(idxType), 59 rewriter.getIntegerAttr(idxType, pos)); 60 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 61 constant); 62 } 63 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 64 } 65 66 // Helper that picks the proper sequence for extracting. 67 static Value extractOne(ConversionPatternRewriter &rewriter, 68 LLVMTypeConverter &typeConverter, Location loc, 69 Value val, Type llvmType, int64_t rank, int64_t pos) { 70 if (rank <= 1) { 71 auto idxType = rewriter.getIndexType(); 72 auto constant = rewriter.create<LLVM::ConstantOp>( 73 loc, typeConverter.convertType(idxType), 74 rewriter.getIntegerAttr(idxType, pos)); 75 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 76 constant); 77 } 78 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 79 } 80 81 // Helper that returns data layout alignment of a memref. 82 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 83 MemRefType memrefType, unsigned &align) { 84 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 85 if (!elementTy) 86 return failure(); 87 88 // TODO: this should use the MLIR data layout when it becomes available and 89 // stop depending on translation. 90 llvm::LLVMContext llvmContext; 91 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 92 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 93 return success(); 94 } 95 96 // Check if the last stride is non-unit or the memory space is not zero. 97 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 98 LLVMTypeConverter &converter) { 99 int64_t offset; 100 SmallVector<int64_t, 4> strides; 101 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 102 FailureOr<unsigned> addressSpace = 103 converter.getMemRefAddressSpace(memRefType); 104 if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) || 105 *addressSpace != 0) 106 return failure(); 107 return success(); 108 } 109 110 // Add an index vector component to a base pointer. 111 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 112 LLVMTypeConverter &typeConverter, 113 MemRefType memRefType, Value llvmMemref, Value base, 114 Value index, uint64_t vLen) { 115 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 116 "unsupported memref type"); 117 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 118 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 119 return rewriter.create<LLVM::GEPOp>( 120 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 121 base, index); 122 } 123 124 // Casts a strided element pointer to a vector pointer. The vector pointer 125 // will be in the same address space as the incoming memref type. 126 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 127 Value ptr, MemRefType memRefType, Type vt, 128 LLVMTypeConverter &converter) { 129 if (converter.useOpaquePointers()) 130 return ptr; 131 132 unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); 133 auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); 134 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 135 } 136 137 namespace { 138 139 /// Trivial Vector to LLVM conversions 140 using VectorScaleOpConversion = 141 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 142 143 /// Conversion pattern for a vector.bitcast. 144 class VectorBitCastOpConversion 145 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 146 public: 147 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 148 149 LogicalResult 150 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 151 ConversionPatternRewriter &rewriter) const override { 152 // Only 0-D and 1-D vectors can be lowered to LLVM. 153 VectorType resultTy = bitCastOp.getResultVectorType(); 154 if (resultTy.getRank() > 1) 155 return failure(); 156 Type newResultTy = typeConverter->convertType(resultTy); 157 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 158 adaptor.getOperands()[0]); 159 return success(); 160 } 161 }; 162 163 /// Conversion pattern for a vector.matrix_multiply. 164 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 165 class VectorMatmulOpConversion 166 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 167 public: 168 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 169 170 LogicalResult 171 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 172 ConversionPatternRewriter &rewriter) const override { 173 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 174 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 175 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 176 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 177 return success(); 178 } 179 }; 180 181 /// Conversion pattern for a vector.flat_transpose. 182 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 183 class VectorFlatTransposeOpConversion 184 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 185 public: 186 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 187 188 LogicalResult 189 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 190 ConversionPatternRewriter &rewriter) const override { 191 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 192 transOp, typeConverter->convertType(transOp.getRes().getType()), 193 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 194 return success(); 195 } 196 }; 197 198 /// Overloaded utility that replaces a vector.load, vector.store, 199 /// vector.maskedload and vector.maskedstore with their respective LLVM 200 /// couterparts. 201 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 202 vector::LoadOpAdaptor adaptor, 203 VectorType vectorTy, Value ptr, unsigned align, 204 ConversionPatternRewriter &rewriter) { 205 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 206 } 207 208 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 209 vector::MaskedLoadOpAdaptor adaptor, 210 VectorType vectorTy, Value ptr, unsigned align, 211 ConversionPatternRewriter &rewriter) { 212 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 213 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 214 } 215 216 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 217 vector::StoreOpAdaptor adaptor, 218 VectorType vectorTy, Value ptr, unsigned align, 219 ConversionPatternRewriter &rewriter) { 220 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 221 ptr, align); 222 } 223 224 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 225 vector::MaskedStoreOpAdaptor adaptor, 226 VectorType vectorTy, Value ptr, unsigned align, 227 ConversionPatternRewriter &rewriter) { 228 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 229 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 230 } 231 232 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 233 /// vector.maskedstore. 234 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 235 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 236 public: 237 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 238 239 LogicalResult 240 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 241 typename LoadOrStoreOp::Adaptor adaptor, 242 ConversionPatternRewriter &rewriter) const override { 243 // Only 1-D vectors can be lowered to LLVM. 244 VectorType vectorTy = loadOrStoreOp.getVectorType(); 245 if (vectorTy.getRank() > 1) 246 return failure(); 247 248 auto loc = loadOrStoreOp->getLoc(); 249 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 250 251 // Resolve alignment. 252 unsigned align; 253 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 254 return failure(); 255 256 // Resolve address. 257 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 258 .template cast<VectorType>(); 259 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 260 adaptor.getIndices(), rewriter); 261 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, 262 *this->getTypeConverter()); 263 264 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 265 return success(); 266 } 267 }; 268 269 /// Conversion pattern for a vector.gather. 270 class VectorGatherOpConversion 271 : public ConvertOpToLLVMPattern<vector::GatherOp> { 272 public: 273 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 274 275 LogicalResult 276 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 277 ConversionPatternRewriter &rewriter) const override { 278 MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>(); 279 assert(memRefType && "The base should be bufferized"); 280 281 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 282 return failure(); 283 284 auto loc = gather->getLoc(); 285 286 // Resolve alignment. 287 unsigned align; 288 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 289 return failure(); 290 291 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 292 adaptor.getIndices(), rewriter); 293 Value base = adaptor.getBase(); 294 295 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 296 // Handle the simple case of 1-D vector. 297 if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) { 298 auto vType = gather.getVectorType(); 299 // Resolve address. 300 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 301 memRefType, base, ptr, adaptor.getIndexVec(), 302 /*vLen=*/vType.getDimSize(0)); 303 // Replace with the gather intrinsic. 304 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 305 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 306 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 307 return success(); 308 } 309 310 LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 311 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 312 &typeConverter](Type llvm1DVectorTy, 313 ValueRange vectorOperands) { 314 // Resolve address. 315 Value ptrs = getIndexedPtrs( 316 rewriter, loc, typeConverter, memRefType, base, ptr, 317 /*index=*/vectorOperands[0], 318 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 319 // Create the gather intrinsic. 320 return rewriter.create<LLVM::masked_gather>( 321 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 322 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 323 }; 324 SmallVector<Value> vectorOperands = { 325 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 326 return LLVM::detail::handleMultidimensionalVectors( 327 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 328 } 329 }; 330 331 /// Conversion pattern for a vector.scatter. 332 class VectorScatterOpConversion 333 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 334 public: 335 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 336 337 LogicalResult 338 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 339 ConversionPatternRewriter &rewriter) const override { 340 auto loc = scatter->getLoc(); 341 MemRefType memRefType = scatter.getMemRefType(); 342 343 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 344 return failure(); 345 346 // Resolve alignment. 347 unsigned align; 348 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 349 return failure(); 350 351 // Resolve address. 352 VectorType vType = scatter.getVectorType(); 353 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 354 adaptor.getIndices(), rewriter); 355 Value ptrs = getIndexedPtrs( 356 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 357 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 358 359 // Replace with the scatter intrinsic. 360 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 361 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 362 rewriter.getI32IntegerAttr(align)); 363 return success(); 364 } 365 }; 366 367 /// Conversion pattern for a vector.expandload. 368 class VectorExpandLoadOpConversion 369 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 370 public: 371 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 372 373 LogicalResult 374 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 375 ConversionPatternRewriter &rewriter) const override { 376 auto loc = expand->getLoc(); 377 MemRefType memRefType = expand.getMemRefType(); 378 379 // Resolve address. 380 auto vtype = typeConverter->convertType(expand.getVectorType()); 381 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 382 adaptor.getIndices(), rewriter); 383 384 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 385 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 386 return success(); 387 } 388 }; 389 390 /// Conversion pattern for a vector.compressstore. 391 class VectorCompressStoreOpConversion 392 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 393 public: 394 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 395 396 LogicalResult 397 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 398 ConversionPatternRewriter &rewriter) const override { 399 auto loc = compress->getLoc(); 400 MemRefType memRefType = compress.getMemRefType(); 401 402 // Resolve address. 403 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 404 adaptor.getIndices(), rewriter); 405 406 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 407 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 408 return success(); 409 } 410 }; 411 412 /// Reduction neutral classes for overloading. 413 class ReductionNeutralZero {}; 414 class ReductionNeutralIntOne {}; 415 class ReductionNeutralFPOne {}; 416 class ReductionNeutralAllOnes {}; 417 class ReductionNeutralSIntMin {}; 418 class ReductionNeutralUIntMin {}; 419 class ReductionNeutralSIntMax {}; 420 class ReductionNeutralUIntMax {}; 421 class ReductionNeutralFPMin {}; 422 class ReductionNeutralFPMax {}; 423 424 /// Create the reduction neutral zero value. 425 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 426 ConversionPatternRewriter &rewriter, 427 Location loc, Type llvmType) { 428 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 429 rewriter.getZeroAttr(llvmType)); 430 } 431 432 /// Create the reduction neutral integer one value. 433 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 434 ConversionPatternRewriter &rewriter, 435 Location loc, Type llvmType) { 436 return rewriter.create<LLVM::ConstantOp>( 437 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 438 } 439 440 /// Create the reduction neutral fp one value. 441 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 442 ConversionPatternRewriter &rewriter, 443 Location loc, Type llvmType) { 444 return rewriter.create<LLVM::ConstantOp>( 445 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 446 } 447 448 /// Create the reduction neutral all-ones value. 449 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 450 ConversionPatternRewriter &rewriter, 451 Location loc, Type llvmType) { 452 return rewriter.create<LLVM::ConstantOp>( 453 loc, llvmType, 454 rewriter.getIntegerAttr( 455 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 456 } 457 458 /// Create the reduction neutral signed int minimum value. 459 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 460 ConversionPatternRewriter &rewriter, 461 Location loc, Type llvmType) { 462 return rewriter.create<LLVM::ConstantOp>( 463 loc, llvmType, 464 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 465 llvmType.getIntOrFloatBitWidth()))); 466 } 467 468 /// Create the reduction neutral unsigned int minimum value. 469 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 470 ConversionPatternRewriter &rewriter, 471 Location loc, Type llvmType) { 472 return rewriter.create<LLVM::ConstantOp>( 473 loc, llvmType, 474 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 475 llvmType.getIntOrFloatBitWidth()))); 476 } 477 478 /// Create the reduction neutral signed int maximum value. 479 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 480 ConversionPatternRewriter &rewriter, 481 Location loc, Type llvmType) { 482 return rewriter.create<LLVM::ConstantOp>( 483 loc, llvmType, 484 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 485 llvmType.getIntOrFloatBitWidth()))); 486 } 487 488 /// Create the reduction neutral unsigned int maximum value. 489 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 490 ConversionPatternRewriter &rewriter, 491 Location loc, Type llvmType) { 492 return rewriter.create<LLVM::ConstantOp>( 493 loc, llvmType, 494 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 495 llvmType.getIntOrFloatBitWidth()))); 496 } 497 498 /// Create the reduction neutral fp minimum value. 499 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 500 ConversionPatternRewriter &rewriter, 501 Location loc, Type llvmType) { 502 auto floatType = llvmType.cast<FloatType>(); 503 return rewriter.create<LLVM::ConstantOp>( 504 loc, llvmType, 505 rewriter.getFloatAttr( 506 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 507 /*Negative=*/false))); 508 } 509 510 /// Create the reduction neutral fp maximum value. 511 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 512 ConversionPatternRewriter &rewriter, 513 Location loc, Type llvmType) { 514 auto floatType = llvmType.cast<FloatType>(); 515 return rewriter.create<LLVM::ConstantOp>( 516 loc, llvmType, 517 rewriter.getFloatAttr( 518 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 519 /*Negative=*/true))); 520 } 521 522 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 523 /// returns a new accumulator value using `ReductionNeutral`. 524 template <class ReductionNeutral> 525 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 526 Location loc, Type llvmType, 527 Value accumulator) { 528 if (accumulator) 529 return accumulator; 530 531 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 532 llvmType); 533 } 534 535 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 536 /// This is used as effective vector length by some intrinsics supporting 537 /// dynamic vector lengths at runtime. 538 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 539 Location loc, Type llvmType) { 540 VectorType vType = cast<VectorType>(llvmType); 541 auto vShape = vType.getShape(); 542 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 543 544 return rewriter.create<LLVM::ConstantOp>( 545 loc, rewriter.getI32Type(), 546 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 547 } 548 549 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 550 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 551 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 552 /// non-null. 553 template <class LLVMRedIntrinOp, class ScalarOp> 554 static Value createIntegerReductionArithmeticOpLowering( 555 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 556 Value vectorOperand, Value accumulator) { 557 558 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 559 560 if (accumulator) 561 result = rewriter.create<ScalarOp>(loc, accumulator, result); 562 return result; 563 } 564 565 /// Helper method to lower a `vector.reduction` operation that performs 566 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 567 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 568 /// the accumulator value if non-null. 569 template <class LLVMRedIntrinOp> 570 static Value createIntegerReductionComparisonOpLowering( 571 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 572 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 573 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 574 if (accumulator) { 575 Value cmp = 576 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 577 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 578 } 579 return result; 580 } 581 582 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum 583 /// with vector types. 584 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, 585 Value rhs, bool isMin) { 586 auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); 587 Type i1Type = builder.getI1Type(); 588 if (auto vecType = lhs.getType().dyn_cast<VectorType>()) 589 i1Type = VectorType::get(vecType.getShape(), i1Type); 590 Value cmp = builder.create<LLVM::FCmpOp>( 591 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, 592 lhs, rhs); 593 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); 594 Value isNan = builder.create<LLVM::FCmpOp>( 595 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); 596 Value nan = builder.create<LLVM::ConstantOp>( 597 loc, lhs.getType(), 598 builder.getFloatAttr(floatType, 599 APFloat::getQNaN(floatType.getFloatSemantics()))); 600 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); 601 } 602 603 template <class LLVMRedIntrinOp> 604 static Value createFPReductionComparisonOpLowering( 605 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 606 Value vectorOperand, Value accumulator, bool isMin) { 607 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 608 609 if (accumulator) 610 result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin); 611 612 return result; 613 } 614 615 /// Overloaded methods to lower a reduction to an llvm instrinsic that requires 616 /// a start value. This start value format spans across fp reductions without 617 /// mask and all the masked reduction intrinsics. 618 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 619 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 620 Location loc, Type llvmType, 621 Value vectorOperand, 622 Value accumulator) { 623 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 624 llvmType, accumulator); 625 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 626 /*startValue=*/accumulator, 627 vectorOperand); 628 } 629 630 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 631 static Value 632 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 633 Type llvmType, Value vectorOperand, 634 Value accumulator, bool reassociateFPReds) { 635 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 636 llvmType, accumulator); 637 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 638 /*startValue=*/accumulator, 639 vectorOperand, reassociateFPReds); 640 } 641 642 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 643 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 644 Location loc, Type llvmType, 645 Value vectorOperand, 646 Value accumulator, Value mask) { 647 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 648 llvmType, accumulator); 649 Value vectorLength = 650 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 651 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 652 /*startValue=*/accumulator, 653 vectorOperand, mask, vectorLength); 654 } 655 656 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 657 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 658 Location loc, Type llvmType, 659 Value vectorOperand, 660 Value accumulator, Value mask, 661 bool reassociateFPReds) { 662 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 663 llvmType, accumulator); 664 Value vectorLength = 665 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 666 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 667 /*startValue=*/accumulator, 668 vectorOperand, mask, vectorLength, 669 reassociateFPReds); 670 } 671 672 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 673 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 674 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 675 Location loc, Type llvmType, 676 Value vectorOperand, 677 Value accumulator, Value mask) { 678 if (llvmType.isIntOrIndex()) 679 return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp, 680 IntReductionNeutral>( 681 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 682 683 // FP dispatch. 684 return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>( 685 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 686 } 687 688 /// Conversion pattern for all vector reductions. 689 class VectorReductionOpConversion 690 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 691 public: 692 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 693 bool reassociateFPRed) 694 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 695 reassociateFPReductions(reassociateFPRed) {} 696 697 LogicalResult 698 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 699 ConversionPatternRewriter &rewriter) const override { 700 auto kind = reductionOp.getKind(); 701 Type eltType = reductionOp.getDest().getType(); 702 Type llvmType = typeConverter->convertType(eltType); 703 Value operand = adaptor.getVector(); 704 Value acc = adaptor.getAcc(); 705 Location loc = reductionOp.getLoc(); 706 707 // Masked reductions are lowered separately. 708 auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation()); 709 if (maskableOp.isMasked()) 710 return failure(); 711 712 if (eltType.isIntOrIndex()) { 713 // Integer reductions: add/mul/min/max/and/or/xor. 714 Value result; 715 switch (kind) { 716 case vector::CombiningKind::ADD: 717 result = 718 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 719 LLVM::AddOp>( 720 rewriter, loc, llvmType, operand, acc); 721 break; 722 case vector::CombiningKind::MUL: 723 result = 724 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 725 LLVM::MulOp>( 726 rewriter, loc, llvmType, operand, acc); 727 break; 728 case vector::CombiningKind::MINUI: 729 result = createIntegerReductionComparisonOpLowering< 730 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 731 LLVM::ICmpPredicate::ule); 732 break; 733 case vector::CombiningKind::MINSI: 734 result = createIntegerReductionComparisonOpLowering< 735 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 736 LLVM::ICmpPredicate::sle); 737 break; 738 case vector::CombiningKind::MAXUI: 739 result = createIntegerReductionComparisonOpLowering< 740 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 741 LLVM::ICmpPredicate::uge); 742 break; 743 case vector::CombiningKind::MAXSI: 744 result = createIntegerReductionComparisonOpLowering< 745 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 746 LLVM::ICmpPredicate::sge); 747 break; 748 case vector::CombiningKind::AND: 749 result = 750 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 751 LLVM::AndOp>( 752 rewriter, loc, llvmType, operand, acc); 753 break; 754 case vector::CombiningKind::OR: 755 result = 756 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 757 LLVM::OrOp>( 758 rewriter, loc, llvmType, operand, acc); 759 break; 760 case vector::CombiningKind::XOR: 761 result = 762 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 763 LLVM::XOrOp>( 764 rewriter, loc, llvmType, operand, acc); 765 break; 766 default: 767 return failure(); 768 } 769 rewriter.replaceOp(reductionOp, result); 770 771 return success(); 772 } 773 774 if (!eltType.isa<FloatType>()) 775 return failure(); 776 777 // Floating-point reductions: add/mul/min/max 778 Value result; 779 if (kind == vector::CombiningKind::ADD) { 780 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 781 ReductionNeutralZero>( 782 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 783 } else if (kind == vector::CombiningKind::MUL) { 784 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 785 ReductionNeutralFPOne>( 786 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 787 } else if (kind == vector::CombiningKind::MINF) { 788 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 789 // NaNs/-0.0/+0.0 in the same way. 790 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 791 rewriter, loc, llvmType, operand, acc, 792 /*isMin=*/true); 793 } else if (kind == vector::CombiningKind::MAXF) { 794 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 795 // NaNs/-0.0/+0.0 in the same way. 796 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 797 rewriter, loc, llvmType, operand, acc, 798 /*isMin=*/false); 799 } else 800 return failure(); 801 802 rewriter.replaceOp(reductionOp, result); 803 return success(); 804 } 805 806 private: 807 const bool reassociateFPReductions; 808 }; 809 810 /// Base class to convert a `vector.mask` operation while matching traits 811 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 812 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 813 /// method performs a second match against the maskable operation `MaskedOp`. 814 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 815 /// implemented by the concrete conversion classes. This method can match 816 /// against specific traits of the `vector.mask` and the maskable operation. It 817 /// must replace the `vector.mask` operation. 818 template <class MaskedOp> 819 class VectorMaskOpConversionBase 820 : public ConvertOpToLLVMPattern<vector::MaskOp> { 821 public: 822 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 823 824 LogicalResult 825 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 826 ConversionPatternRewriter &rewriter) const override final { 827 // Match against the maskable operation kind. 828 Operation *maskableOp = maskOp.getMaskableOp(); 829 if (!isa<MaskedOp>(maskableOp)) 830 return failure(); 831 return matchAndRewriteMaskableOp( 832 maskOp, cast<MaskedOp>(maskOp.getMaskableOp()), rewriter); 833 } 834 835 protected: 836 virtual LogicalResult 837 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 838 vector::MaskableOpInterface maskableOp, 839 ConversionPatternRewriter &rewriter) const = 0; 840 }; 841 842 class MaskedReductionOpConversion 843 : public VectorMaskOpConversionBase<vector::ReductionOp> { 844 845 public: 846 using VectorMaskOpConversionBase< 847 vector::ReductionOp>::VectorMaskOpConversionBase; 848 849 virtual LogicalResult matchAndRewriteMaskableOp( 850 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 851 ConversionPatternRewriter &rewriter) const override { 852 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 853 auto kind = reductionOp.getKind(); 854 Type eltType = reductionOp.getDest().getType(); 855 Type llvmType = typeConverter->convertType(eltType); 856 Value operand = reductionOp.getVector(); 857 Value acc = reductionOp.getAcc(); 858 Location loc = reductionOp.getLoc(); 859 860 Value result; 861 switch (kind) { 862 case vector::CombiningKind::ADD: 863 result = lowerReductionWithStartValue< 864 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 865 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 866 maskOp.getMask()); 867 break; 868 case vector::CombiningKind::MUL: 869 result = lowerReductionWithStartValue< 870 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 871 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 872 maskOp.getMask()); 873 break; 874 case vector::CombiningKind::MINUI: 875 result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp, 876 ReductionNeutralUIntMax>( 877 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 878 break; 879 case vector::CombiningKind::MINSI: 880 result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp, 881 ReductionNeutralSIntMax>( 882 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 883 break; 884 case vector::CombiningKind::MAXUI: 885 result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp, 886 ReductionNeutralUIntMin>( 887 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 888 break; 889 case vector::CombiningKind::MAXSI: 890 result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp, 891 ReductionNeutralSIntMin>( 892 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 893 break; 894 case vector::CombiningKind::AND: 895 result = lowerReductionWithStartValue<LLVM::VPReduceAndOp, 896 ReductionNeutralAllOnes>( 897 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 898 break; 899 case vector::CombiningKind::OR: 900 result = lowerReductionWithStartValue<LLVM::VPReduceOrOp, 901 ReductionNeutralZero>( 902 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 903 break; 904 case vector::CombiningKind::XOR: 905 result = lowerReductionWithStartValue<LLVM::VPReduceXorOp, 906 ReductionNeutralZero>( 907 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 908 break; 909 case vector::CombiningKind::MINF: 910 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 911 // NaNs/-0.0/+0.0 in the same way. 912 result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp, 913 ReductionNeutralFPMax>( 914 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 915 break; 916 case vector::CombiningKind::MAXF: 917 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 918 // NaNs/-0.0/+0.0 in the same way. 919 result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp, 920 ReductionNeutralFPMin>( 921 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 922 break; 923 } 924 925 // Replace `vector.mask` operation altogether. 926 rewriter.replaceOp(maskOp, result); 927 return success(); 928 } 929 }; 930 931 class VectorShuffleOpConversion 932 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 933 public: 934 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 935 936 LogicalResult 937 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 938 ConversionPatternRewriter &rewriter) const override { 939 auto loc = shuffleOp->getLoc(); 940 auto v1Type = shuffleOp.getV1VectorType(); 941 auto v2Type = shuffleOp.getV2VectorType(); 942 auto vectorType = shuffleOp.getResultVectorType(); 943 Type llvmType = typeConverter->convertType(vectorType); 944 auto maskArrayAttr = shuffleOp.getMask(); 945 946 // Bail if result type cannot be lowered. 947 if (!llvmType) 948 return failure(); 949 950 // Get rank and dimension sizes. 951 int64_t rank = vectorType.getRank(); 952 #ifndef NDEBUG 953 bool wellFormed0DCase = 954 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 955 bool wellFormedNDCase = 956 v1Type.getRank() == rank && v2Type.getRank() == rank; 957 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 958 #endif 959 960 // For rank 0 and 1, where both operands have *exactly* the same vector 961 // type, there is direct shuffle support in LLVM. Use it! 962 if (rank <= 1 && v1Type == v2Type) { 963 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 964 loc, adaptor.getV1(), adaptor.getV2(), 965 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 966 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 967 return success(); 968 } 969 970 // For all other cases, insert the individual values individually. 971 int64_t v1Dim = v1Type.getDimSize(0); 972 Type eltType; 973 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 974 eltType = arrayType.getElementType(); 975 else 976 eltType = llvmType.cast<VectorType>().getElementType(); 977 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 978 int64_t insPos = 0; 979 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 980 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 981 Value value = adaptor.getV1(); 982 if (extPos >= v1Dim) { 983 extPos -= v1Dim; 984 value = adaptor.getV2(); 985 } 986 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 987 eltType, rank, extPos); 988 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 989 llvmType, rank, insPos++); 990 } 991 rewriter.replaceOp(shuffleOp, insert); 992 return success(); 993 } 994 }; 995 996 class VectorExtractElementOpConversion 997 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 998 public: 999 using ConvertOpToLLVMPattern< 1000 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 1001 1002 LogicalResult 1003 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 1004 ConversionPatternRewriter &rewriter) const override { 1005 auto vectorType = extractEltOp.getSourceVectorType(); 1006 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 1007 1008 // Bail if result type cannot be lowered. 1009 if (!llvmType) 1010 return failure(); 1011 1012 if (vectorType.getRank() == 0) { 1013 Location loc = extractEltOp.getLoc(); 1014 auto idxType = rewriter.getIndexType(); 1015 auto zero = rewriter.create<LLVM::ConstantOp>( 1016 loc, typeConverter->convertType(idxType), 1017 rewriter.getIntegerAttr(idxType, 0)); 1018 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1019 extractEltOp, llvmType, adaptor.getVector(), zero); 1020 return success(); 1021 } 1022 1023 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1024 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1025 return success(); 1026 } 1027 }; 1028 1029 class VectorExtractOpConversion 1030 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1031 public: 1032 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1033 1034 LogicalResult 1035 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1036 ConversionPatternRewriter &rewriter) const override { 1037 auto loc = extractOp->getLoc(); 1038 auto resultType = extractOp.getResult().getType(); 1039 auto llvmResultType = typeConverter->convertType(resultType); 1040 auto positionArrayAttr = extractOp.getPosition(); 1041 1042 // Bail if result type cannot be lowered. 1043 if (!llvmResultType) 1044 return failure(); 1045 1046 // Extract entire vector. Should be handled by folder, but just to be safe. 1047 if (positionArrayAttr.empty()) { 1048 rewriter.replaceOp(extractOp, adaptor.getVector()); 1049 return success(); 1050 } 1051 1052 // One-shot extraction of vector from array (only requires extractvalue). 1053 if (resultType.isa<VectorType>()) { 1054 SmallVector<int64_t> indices; 1055 for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>()) 1056 indices.push_back(idx.getInt()); 1057 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1058 loc, adaptor.getVector(), indices); 1059 rewriter.replaceOp(extractOp, extracted); 1060 return success(); 1061 } 1062 1063 // Potential extraction of 1-D vector from array. 1064 Value extracted = adaptor.getVector(); 1065 auto positionAttrs = positionArrayAttr.getValue(); 1066 if (positionAttrs.size() > 1) { 1067 SmallVector<int64_t> nMinusOnePosition; 1068 for (auto idx : positionAttrs.drop_back()) 1069 nMinusOnePosition.push_back(idx.cast<IntegerAttr>().getInt()); 1070 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 1071 nMinusOnePosition); 1072 } 1073 1074 // Remaining extraction of element from 1-D LLVM vector 1075 auto position = positionAttrs.back().cast<IntegerAttr>(); 1076 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1077 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 1078 extracted = 1079 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 1080 rewriter.replaceOp(extractOp, extracted); 1081 1082 return success(); 1083 } 1084 }; 1085 1086 /// Conversion pattern that turns a vector.fma on a 1-D vector 1087 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1088 /// This does not match vectors of n >= 2 rank. 1089 /// 1090 /// Example: 1091 /// ``` 1092 /// vector.fma %a, %a, %a : vector<8xf32> 1093 /// ``` 1094 /// is converted to: 1095 /// ``` 1096 /// llvm.intr.fmuladd %va, %va, %va: 1097 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1098 /// -> !llvm."<8 x f32>"> 1099 /// ``` 1100 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1101 public: 1102 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1103 1104 LogicalResult 1105 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1106 ConversionPatternRewriter &rewriter) const override { 1107 VectorType vType = fmaOp.getVectorType(); 1108 if (vType.getRank() > 1) 1109 return failure(); 1110 1111 // Masked fmas are lowered separately. 1112 auto maskableOp = cast<MaskableOpInterface>(fmaOp.getOperation()); 1113 if (maskableOp.isMasked()) 1114 return failure(); 1115 1116 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1117 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1118 return success(); 1119 } 1120 }; 1121 1122 /// Conversion pattern that turns a masked vector.fma on a 1-D vector into their 1123 /// LLVM counterpart representation. Non side effecting VP intrinsics are not 1124 /// fully supported by some backends, including x86, and they don't support 1125 /// pass-through values either. For these reasons, we generate an unmasked 1126 /// fma followed by a select instrution to emulate the masking behavior. 1127 /// This pattern is peepholed by some backends with support for masked fma 1128 /// instructions. This pattern does not match vectors of n >= 2 rank. 1129 class MaskedFMAOp1DConversion 1130 : public VectorMaskOpConversionBase<vector::FMAOp> { 1131 public: 1132 using VectorMaskOpConversionBase<vector::FMAOp>::VectorMaskOpConversionBase; 1133 1134 MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr) 1135 : VectorMaskOpConversionBase<vector::FMAOp>(converter) {} 1136 1137 virtual LogicalResult matchAndRewriteMaskableOp( 1138 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 1139 ConversionPatternRewriter &rewriter) const override { 1140 auto fmaOp = cast<FMAOp>(maskableOp.getOperation()); 1141 Type llvmType = typeConverter->convertType(fmaOp.getVectorType()); 1142 1143 Value fmulAddOp = rewriter.create<LLVM::FMulAddOp>( 1144 fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(), 1145 fmaOp.getAcc()); 1146 rewriter.replaceOpWithNewOp<LLVM::SelectOp>( 1147 maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc()); 1148 return success(); 1149 } 1150 }; 1151 1152 class VectorInsertElementOpConversion 1153 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1154 public: 1155 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1156 1157 LogicalResult 1158 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1159 ConversionPatternRewriter &rewriter) const override { 1160 auto vectorType = insertEltOp.getDestVectorType(); 1161 auto llvmType = typeConverter->convertType(vectorType); 1162 1163 // Bail if result type cannot be lowered. 1164 if (!llvmType) 1165 return failure(); 1166 1167 if (vectorType.getRank() == 0) { 1168 Location loc = insertEltOp.getLoc(); 1169 auto idxType = rewriter.getIndexType(); 1170 auto zero = rewriter.create<LLVM::ConstantOp>( 1171 loc, typeConverter->convertType(idxType), 1172 rewriter.getIntegerAttr(idxType, 0)); 1173 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1174 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1175 return success(); 1176 } 1177 1178 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1179 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1180 adaptor.getPosition()); 1181 return success(); 1182 } 1183 }; 1184 1185 class VectorInsertOpConversion 1186 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1187 public: 1188 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1189 1190 LogicalResult 1191 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1192 ConversionPatternRewriter &rewriter) const override { 1193 auto loc = insertOp->getLoc(); 1194 auto sourceType = insertOp.getSourceType(); 1195 auto destVectorType = insertOp.getDestVectorType(); 1196 auto llvmResultType = typeConverter->convertType(destVectorType); 1197 auto positionArrayAttr = insertOp.getPosition(); 1198 1199 // Bail if result type cannot be lowered. 1200 if (!llvmResultType) 1201 return failure(); 1202 1203 // Overwrite entire vector with value. Should be handled by folder, but 1204 // just to be safe. 1205 if (positionArrayAttr.empty()) { 1206 rewriter.replaceOp(insertOp, adaptor.getSource()); 1207 return success(); 1208 } 1209 1210 // One-shot insertion of a vector into an array (only requires insertvalue). 1211 if (sourceType.isa<VectorType>()) { 1212 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1213 loc, adaptor.getDest(), adaptor.getSource(), 1214 LLVM::convertArrayToIndices(positionArrayAttr)); 1215 rewriter.replaceOp(insertOp, inserted); 1216 return success(); 1217 } 1218 1219 // Potential extraction of 1-D vector from array. 1220 Value extracted = adaptor.getDest(); 1221 auto positionAttrs = positionArrayAttr.getValue(); 1222 auto position = positionAttrs.back().cast<IntegerAttr>(); 1223 auto oneDVectorType = destVectorType; 1224 if (positionAttrs.size() > 1) { 1225 oneDVectorType = reducedVectorTypeBack(destVectorType); 1226 extracted = rewriter.create<LLVM::ExtractValueOp>( 1227 loc, extracted, 1228 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 1229 } 1230 1231 // Insertion of an element into a 1-D LLVM vector. 1232 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1233 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 1234 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1235 loc, typeConverter->convertType(oneDVectorType), extracted, 1236 adaptor.getSource(), constant); 1237 1238 // Potential insertion of resulting 1-D vector into array. 1239 if (positionAttrs.size() > 1) { 1240 inserted = rewriter.create<LLVM::InsertValueOp>( 1241 loc, adaptor.getDest(), inserted, 1242 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 1243 } 1244 1245 rewriter.replaceOp(insertOp, inserted); 1246 return success(); 1247 } 1248 }; 1249 1250 /// Lower vector.scalable.insert ops to LLVM vector.insert 1251 struct VectorScalableInsertOpLowering 1252 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1253 using ConvertOpToLLVMPattern< 1254 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1255 1256 LogicalResult 1257 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1258 ConversionPatternRewriter &rewriter) const override { 1259 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1260 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1261 return success(); 1262 } 1263 }; 1264 1265 /// Lower vector.scalable.extract ops to LLVM vector.extract 1266 struct VectorScalableExtractOpLowering 1267 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1268 using ConvertOpToLLVMPattern< 1269 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1270 1271 LogicalResult 1272 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1273 ConversionPatternRewriter &rewriter) const override { 1274 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1275 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1276 adaptor.getSource(), adaptor.getPos()); 1277 return success(); 1278 } 1279 }; 1280 1281 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1282 /// 1283 /// Example: 1284 /// ``` 1285 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1286 /// ``` 1287 /// is rewritten into: 1288 /// ``` 1289 /// %r = splat %f0: vector<2x4xf32> 1290 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1291 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1292 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1293 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1294 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1295 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1296 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1297 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1298 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1299 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1300 /// // %r3 holds the final value. 1301 /// ``` 1302 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1303 public: 1304 using OpRewritePattern<FMAOp>::OpRewritePattern; 1305 1306 void initialize() { 1307 // This pattern recursively unpacks one dimension at a time. The recursion 1308 // bounded as the rank is strictly decreasing. 1309 setHasBoundedRewriteRecursion(); 1310 } 1311 1312 LogicalResult matchAndRewrite(FMAOp op, 1313 PatternRewriter &rewriter) const override { 1314 auto vType = op.getVectorType(); 1315 if (vType.getRank() < 2) 1316 return failure(); 1317 1318 // Masked fmas are lowered separately. 1319 auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 1320 if (maskableOp.isMasked()) 1321 return failure(); 1322 1323 auto loc = op.getLoc(); 1324 auto elemType = vType.getElementType(); 1325 Value zero = rewriter.create<arith::ConstantOp>( 1326 loc, elemType, rewriter.getZeroAttr(elemType)); 1327 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1328 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1329 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1330 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1331 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1332 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1333 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1334 } 1335 rewriter.replaceOp(op, desc); 1336 return success(); 1337 } 1338 }; 1339 1340 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1341 /// static layout. 1342 static std::optional<SmallVector<int64_t, 4>> 1343 computeContiguousStrides(MemRefType memRefType) { 1344 int64_t offset; 1345 SmallVector<int64_t, 4> strides; 1346 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1347 return std::nullopt; 1348 if (!strides.empty() && strides.back() != 1) 1349 return std::nullopt; 1350 // If no layout or identity layout, this is contiguous by definition. 1351 if (memRefType.getLayout().isIdentity()) 1352 return strides; 1353 1354 // Otherwise, we must determine contiguity form shapes. This can only ever 1355 // work in static cases because MemRefType is underspecified to represent 1356 // contiguous dynamic shapes in other ways than with just empty/identity 1357 // layout. 1358 auto sizes = memRefType.getShape(); 1359 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1360 if (ShapedType::isDynamic(sizes[index + 1]) || 1361 ShapedType::isDynamic(strides[index]) || 1362 ShapedType::isDynamic(strides[index + 1])) 1363 return std::nullopt; 1364 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1365 return std::nullopt; 1366 } 1367 return strides; 1368 } 1369 1370 class VectorTypeCastOpConversion 1371 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1372 public: 1373 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1374 1375 LogicalResult 1376 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1377 ConversionPatternRewriter &rewriter) const override { 1378 auto loc = castOp->getLoc(); 1379 MemRefType sourceMemRefType = 1380 castOp.getOperand().getType().cast<MemRefType>(); 1381 MemRefType targetMemRefType = castOp.getType(); 1382 1383 // Only static shape casts supported atm. 1384 if (!sourceMemRefType.hasStaticShape() || 1385 !targetMemRefType.hasStaticShape()) 1386 return failure(); 1387 1388 auto llvmSourceDescriptorTy = 1389 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 1390 if (!llvmSourceDescriptorTy) 1391 return failure(); 1392 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1393 1394 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1395 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1396 if (!llvmTargetDescriptorTy) 1397 return failure(); 1398 1399 // Only contiguous source buffers supported atm. 1400 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1401 if (!sourceStrides) 1402 return failure(); 1403 auto targetStrides = computeContiguousStrides(targetMemRefType); 1404 if (!targetStrides) 1405 return failure(); 1406 // Only support static strides for now, regardless of contiguity. 1407 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1408 return failure(); 1409 1410 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1411 1412 // Create descriptor. 1413 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1414 Type llvmTargetElementTy = desc.getElementPtrType(); 1415 // Set allocated ptr. 1416 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1417 if (!getTypeConverter()->useOpaquePointers()) 1418 allocated = 1419 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1420 desc.setAllocatedPtr(rewriter, loc, allocated); 1421 1422 // Set aligned ptr. 1423 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1424 if (!getTypeConverter()->useOpaquePointers()) 1425 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1426 1427 desc.setAlignedPtr(rewriter, loc, ptr); 1428 // Fill offset 0. 1429 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1430 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1431 desc.setOffset(rewriter, loc, zero); 1432 1433 // Fill size and stride descriptors in memref. 1434 for (const auto &indexedSize : 1435 llvm::enumerate(targetMemRefType.getShape())) { 1436 int64_t index = indexedSize.index(); 1437 auto sizeAttr = 1438 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1439 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1440 desc.setSize(rewriter, loc, index, size); 1441 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1442 (*targetStrides)[index]); 1443 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1444 desc.setStride(rewriter, loc, index, stride); 1445 } 1446 1447 rewriter.replaceOp(castOp, {desc}); 1448 return success(); 1449 } 1450 }; 1451 1452 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1453 /// Non-scalable versions of this operation are handled in Vector Transforms. 1454 class VectorCreateMaskOpRewritePattern 1455 : public OpRewritePattern<vector::CreateMaskOp> { 1456 public: 1457 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1458 bool enableIndexOpt) 1459 : OpRewritePattern<vector::CreateMaskOp>(context), 1460 force32BitVectorIndices(enableIndexOpt) {} 1461 1462 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1463 PatternRewriter &rewriter) const override { 1464 auto dstType = op.getType(); 1465 if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable()) 1466 return failure(); 1467 IntegerType idxType = 1468 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1469 auto loc = op->getLoc(); 1470 Value indices = rewriter.create<LLVM::StepVectorOp>( 1471 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1472 /*isScalable=*/true)); 1473 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1474 op.getOperand(0)); 1475 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1476 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1477 indices, bounds); 1478 rewriter.replaceOp(op, comp); 1479 return success(); 1480 } 1481 1482 private: 1483 const bool force32BitVectorIndices; 1484 }; 1485 1486 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1487 public: 1488 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1489 1490 // Proof-of-concept lowering implementation that relies on a small 1491 // runtime support library, which only needs to provide a few 1492 // printing methods (single value for all data types, opening/closing 1493 // bracket, comma, newline). The lowering fully unrolls a vector 1494 // in terms of these elementary printing operations. The advantage 1495 // of this approach is that the library can remain unaware of all 1496 // low-level implementation details of vectors while still supporting 1497 // output of any shaped and dimensioned vector. Due to full unrolling, 1498 // this approach is less suited for very large vectors though. 1499 // 1500 // TODO: rely solely on libc in future? something else? 1501 // 1502 LogicalResult 1503 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1504 ConversionPatternRewriter &rewriter) const override { 1505 Type printType = printOp.getPrintType(); 1506 1507 if (typeConverter->convertType(printType) == nullptr) 1508 return failure(); 1509 1510 // Make sure element type has runtime support. 1511 PrintConversion conversion = PrintConversion::None; 1512 VectorType vectorType = printType.dyn_cast<VectorType>(); 1513 Type eltType = vectorType ? vectorType.getElementType() : printType; 1514 Operation *printer; 1515 if (eltType.isF32()) { 1516 printer = 1517 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 1518 } else if (eltType.isF64()) { 1519 printer = 1520 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 1521 } else if (eltType.isIndex()) { 1522 printer = 1523 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 1524 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1525 // Integers need a zero or sign extension on the operand 1526 // (depending on the source type) as well as a signed or 1527 // unsigned print method. Up to 64-bit is supported. 1528 unsigned width = intTy.getWidth(); 1529 if (intTy.isUnsigned()) { 1530 if (width <= 64) { 1531 if (width < 64) 1532 conversion = PrintConversion::ZeroExt64; 1533 printer = LLVM::lookupOrCreatePrintU64Fn( 1534 printOp->getParentOfType<ModuleOp>()); 1535 } else { 1536 return failure(); 1537 } 1538 } else { 1539 assert(intTy.isSignless() || intTy.isSigned()); 1540 if (width <= 64) { 1541 // Note that we *always* zero extend booleans (1-bit integers), 1542 // so that true/false is printed as 1/0 rather than -1/0. 1543 if (width == 1) 1544 conversion = PrintConversion::ZeroExt64; 1545 else if (width < 64) 1546 conversion = PrintConversion::SignExt64; 1547 printer = LLVM::lookupOrCreatePrintI64Fn( 1548 printOp->getParentOfType<ModuleOp>()); 1549 } else { 1550 return failure(); 1551 } 1552 } 1553 } else { 1554 return failure(); 1555 } 1556 1557 // Unroll vector into elementary print calls. 1558 int64_t rank = vectorType ? vectorType.getRank() : 0; 1559 Type type = vectorType ? vectorType : eltType; 1560 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1561 conversion); 1562 emitCall(rewriter, printOp->getLoc(), 1563 LLVM::lookupOrCreatePrintNewlineFn( 1564 printOp->getParentOfType<ModuleOp>())); 1565 rewriter.eraseOp(printOp); 1566 return success(); 1567 } 1568 1569 private: 1570 enum class PrintConversion { 1571 // clang-format off 1572 None, 1573 ZeroExt64, 1574 SignExt64 1575 // clang-format on 1576 }; 1577 1578 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1579 Value value, Type type, Operation *printer, int64_t rank, 1580 PrintConversion conversion) const { 1581 VectorType vectorType = type.dyn_cast<VectorType>(); 1582 Location loc = op->getLoc(); 1583 if (!vectorType) { 1584 assert(rank == 0 && "The scalar case expects rank == 0"); 1585 switch (conversion) { 1586 case PrintConversion::ZeroExt64: 1587 value = rewriter.create<arith::ExtUIOp>( 1588 loc, IntegerType::get(rewriter.getContext(), 64), value); 1589 break; 1590 case PrintConversion::SignExt64: 1591 value = rewriter.create<arith::ExtSIOp>( 1592 loc, IntegerType::get(rewriter.getContext(), 64), value); 1593 break; 1594 case PrintConversion::None: 1595 break; 1596 } 1597 emitCall(rewriter, loc, printer, value); 1598 return; 1599 } 1600 1601 emitCall(rewriter, loc, 1602 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1603 Operation *printComma = 1604 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1605 1606 if (rank <= 1) { 1607 auto reducedType = vectorType.getElementType(); 1608 auto llvmType = typeConverter->convertType(reducedType); 1609 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1610 for (int64_t d = 0; d < dim; ++d) { 1611 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1612 llvmType, /*rank=*/0, /*pos=*/d); 1613 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1614 conversion); 1615 if (d != dim - 1) 1616 emitCall(rewriter, loc, printComma); 1617 } 1618 emitCall( 1619 rewriter, loc, 1620 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1621 return; 1622 } 1623 1624 int64_t dim = vectorType.getDimSize(0); 1625 for (int64_t d = 0; d < dim; ++d) { 1626 auto reducedType = reducedVectorTypeFront(vectorType); 1627 auto llvmType = typeConverter->convertType(reducedType); 1628 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1629 llvmType, rank, d); 1630 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1631 conversion); 1632 if (d != dim - 1) 1633 emitCall(rewriter, loc, printComma); 1634 } 1635 emitCall(rewriter, loc, 1636 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1637 } 1638 1639 // Helper to emit a call. 1640 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1641 Operation *ref, ValueRange params = ValueRange()) { 1642 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1643 params); 1644 } 1645 }; 1646 1647 /// The Splat operation is lowered to an insertelement + a shufflevector 1648 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1649 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1650 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1651 1652 LogicalResult 1653 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1654 ConversionPatternRewriter &rewriter) const override { 1655 VectorType resultType = splatOp.getType().cast<VectorType>(); 1656 if (resultType.getRank() > 1) 1657 return failure(); 1658 1659 // First insert it into an undef vector so we can shuffle it. 1660 auto vectorType = typeConverter->convertType(splatOp.getType()); 1661 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1662 auto zero = rewriter.create<LLVM::ConstantOp>( 1663 splatOp.getLoc(), 1664 typeConverter->convertType(rewriter.getIntegerType(32)), 1665 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1666 1667 // For 0-d vector, we simply do `insertelement`. 1668 if (resultType.getRank() == 0) { 1669 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1670 splatOp, vectorType, undef, adaptor.getInput(), zero); 1671 return success(); 1672 } 1673 1674 // For 1-d vector, we additionally do a `vectorshuffle`. 1675 auto v = rewriter.create<LLVM::InsertElementOp>( 1676 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1677 1678 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); 1679 SmallVector<int32_t> zeroValues(width, 0); 1680 1681 // Shuffle the value across the desired number of elements. 1682 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1683 zeroValues); 1684 return success(); 1685 } 1686 }; 1687 1688 /// The Splat operation is lowered to an insertelement + a shufflevector 1689 /// operation. Splat to only 2+-d vector result types are lowered by the 1690 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1691 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1692 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1693 1694 LogicalResult 1695 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1696 ConversionPatternRewriter &rewriter) const override { 1697 VectorType resultType = splatOp.getType(); 1698 if (resultType.getRank() <= 1) 1699 return failure(); 1700 1701 // First insert it into an undef vector so we can shuffle it. 1702 auto loc = splatOp.getLoc(); 1703 auto vectorTypeInfo = 1704 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1705 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1706 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1707 if (!llvmNDVectorTy || !llvm1DVectorTy) 1708 return failure(); 1709 1710 // Construct returned value. 1711 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1712 1713 // Construct a 1-D vector with the splatted value that we insert in all the 1714 // places within the returned descriptor. 1715 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1716 auto zero = rewriter.create<LLVM::ConstantOp>( 1717 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1718 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1719 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1720 adaptor.getInput(), zero); 1721 1722 // Shuffle the value across the desired number of elements. 1723 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1724 SmallVector<int32_t> zeroValues(width, 0); 1725 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1726 1727 // Iterate of linear index, convert to coords space and insert splatted 1-D 1728 // vector in each position. 1729 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1730 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1731 }); 1732 rewriter.replaceOp(splatOp, desc); 1733 return success(); 1734 } 1735 }; 1736 1737 } // namespace 1738 1739 /// Populate the given list with patterns that convert from Vector to LLVM. 1740 void mlir::populateVectorToLLVMConversionPatterns( 1741 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1742 bool reassociateFPReductions, bool force32BitVectorIndices) { 1743 MLIRContext *ctx = converter.getDialect()->getContext(); 1744 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1745 populateVectorInsertExtractStridedSliceTransforms(patterns); 1746 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1747 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1748 patterns 1749 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1750 VectorExtractElementOpConversion, VectorExtractOpConversion, 1751 VectorFMAOp1DConversion, MaskedFMAOp1DConversion, 1752 VectorInsertElementOpConversion, VectorInsertOpConversion, 1753 VectorPrintOpConversion, VectorTypeCastOpConversion, 1754 VectorScaleOpConversion, 1755 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1756 VectorLoadStoreConversion<vector::MaskedLoadOp, 1757 vector::MaskedLoadOpAdaptor>, 1758 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1759 VectorLoadStoreConversion<vector::MaskedStoreOp, 1760 vector::MaskedStoreOpAdaptor>, 1761 VectorGatherOpConversion, VectorScatterOpConversion, 1762 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1763 VectorSplatOpLowering, VectorSplatNdOpLowering, 1764 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1765 MaskedReductionOpConversion>(converter); 1766 // Transfer ops with rank > 1 are handled by VectorToSCF. 1767 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1768 } 1769 1770 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1771 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1772 patterns.add<VectorMatmulOpConversion>(converter); 1773 patterns.add<VectorFlatTransposeOpConversion>(converter); 1774 } 1775