1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Arith/Utils/Utils.h" 15 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 19 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "llvm/Support/Casting.h" 25 #include <optional> 26 27 using namespace mlir; 28 using namespace mlir::vector; 29 30 // Helper to reduce vector type by one rank at front. 31 static VectorType reducedVectorTypeFront(VectorType tp) { 32 assert((tp.getRank() > 1) && "unlowerable vector type"); 33 unsigned numScalableDims = tp.getNumScalableDims(); 34 if (tp.getShape().size() == numScalableDims) 35 --numScalableDims; 36 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 37 numScalableDims); 38 } 39 40 // Helper to reduce vector type by *all* but one rank at back. 41 static VectorType reducedVectorTypeBack(VectorType tp) { 42 assert((tp.getRank() > 1) && "unlowerable vector type"); 43 unsigned numScalableDims = tp.getNumScalableDims(); 44 if (numScalableDims > 0) 45 --numScalableDims; 46 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 47 numScalableDims); 48 } 49 50 // Helper that picks the proper sequence for inserting. 51 static Value insertOne(ConversionPatternRewriter &rewriter, 52 LLVMTypeConverter &typeConverter, Location loc, 53 Value val1, Value val2, Type llvmType, int64_t rank, 54 int64_t pos) { 55 assert(rank > 0 && "0-D vector corner case should have been handled already"); 56 if (rank == 1) { 57 auto idxType = rewriter.getIndexType(); 58 auto constant = rewriter.create<LLVM::ConstantOp>( 59 loc, typeConverter.convertType(idxType), 60 rewriter.getIntegerAttr(idxType, pos)); 61 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 62 constant); 63 } 64 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 65 } 66 67 // Helper that picks the proper sequence for extracting. 68 static Value extractOne(ConversionPatternRewriter &rewriter, 69 LLVMTypeConverter &typeConverter, Location loc, 70 Value val, Type llvmType, int64_t rank, int64_t pos) { 71 if (rank <= 1) { 72 auto idxType = rewriter.getIndexType(); 73 auto constant = rewriter.create<LLVM::ConstantOp>( 74 loc, typeConverter.convertType(idxType), 75 rewriter.getIntegerAttr(idxType, pos)); 76 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 77 constant); 78 } 79 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 80 } 81 82 // Helper that returns data layout alignment of a memref. 83 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 84 MemRefType memrefType, unsigned &align) { 85 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 86 if (!elementTy) 87 return failure(); 88 89 // TODO: this should use the MLIR data layout when it becomes available and 90 // stop depending on translation. 91 llvm::LLVMContext llvmContext; 92 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 93 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 94 return success(); 95 } 96 97 // Check if the last stride is non-unit or the memory space is not zero. 98 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 99 LLVMTypeConverter &converter) { 100 int64_t offset; 101 SmallVector<int64_t, 4> strides; 102 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 103 FailureOr<unsigned> addressSpace = 104 converter.getMemRefAddressSpace(memRefType); 105 if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) || 106 *addressSpace != 0) 107 return failure(); 108 return success(); 109 } 110 111 // Add an index vector component to a base pointer. 112 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 113 LLVMTypeConverter &typeConverter, 114 MemRefType memRefType, Value llvmMemref, Value base, 115 Value index, uint64_t vLen) { 116 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 117 "unsupported memref type"); 118 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 119 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 120 return rewriter.create<LLVM::GEPOp>( 121 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 122 base, index); 123 } 124 125 // Casts a strided element pointer to a vector pointer. The vector pointer 126 // will be in the same address space as the incoming memref type. 127 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 128 Value ptr, MemRefType memRefType, Type vt, 129 LLVMTypeConverter &converter) { 130 if (converter.useOpaquePointers()) 131 return ptr; 132 133 unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); 134 auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); 135 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 136 } 137 138 namespace { 139 140 /// Trivial Vector to LLVM conversions 141 using VectorScaleOpConversion = 142 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 143 144 /// Conversion pattern for a vector.bitcast. 145 class VectorBitCastOpConversion 146 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 147 public: 148 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 149 150 LogicalResult 151 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 152 ConversionPatternRewriter &rewriter) const override { 153 // Only 0-D and 1-D vectors can be lowered to LLVM. 154 VectorType resultTy = bitCastOp.getResultVectorType(); 155 if (resultTy.getRank() > 1) 156 return failure(); 157 Type newResultTy = typeConverter->convertType(resultTy); 158 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 159 adaptor.getOperands()[0]); 160 return success(); 161 } 162 }; 163 164 /// Conversion pattern for a vector.matrix_multiply. 165 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 166 class VectorMatmulOpConversion 167 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 168 public: 169 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 170 171 LogicalResult 172 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 173 ConversionPatternRewriter &rewriter) const override { 174 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 175 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 176 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 177 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 178 return success(); 179 } 180 }; 181 182 /// Conversion pattern for a vector.flat_transpose. 183 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 184 class VectorFlatTransposeOpConversion 185 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 186 public: 187 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 188 189 LogicalResult 190 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 191 ConversionPatternRewriter &rewriter) const override { 192 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 193 transOp, typeConverter->convertType(transOp.getRes().getType()), 194 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 195 return success(); 196 } 197 }; 198 199 /// Overloaded utility that replaces a vector.load, vector.store, 200 /// vector.maskedload and vector.maskedstore with their respective LLVM 201 /// couterparts. 202 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 203 vector::LoadOpAdaptor adaptor, 204 VectorType vectorTy, Value ptr, unsigned align, 205 ConversionPatternRewriter &rewriter) { 206 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 207 } 208 209 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 210 vector::MaskedLoadOpAdaptor adaptor, 211 VectorType vectorTy, Value ptr, unsigned align, 212 ConversionPatternRewriter &rewriter) { 213 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 214 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 215 } 216 217 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 218 vector::StoreOpAdaptor adaptor, 219 VectorType vectorTy, Value ptr, unsigned align, 220 ConversionPatternRewriter &rewriter) { 221 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 222 ptr, align); 223 } 224 225 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 226 vector::MaskedStoreOpAdaptor adaptor, 227 VectorType vectorTy, Value ptr, unsigned align, 228 ConversionPatternRewriter &rewriter) { 229 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 230 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 231 } 232 233 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 234 /// vector.maskedstore. 235 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 236 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 237 public: 238 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 239 240 LogicalResult 241 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 242 typename LoadOrStoreOp::Adaptor adaptor, 243 ConversionPatternRewriter &rewriter) const override { 244 // Only 1-D vectors can be lowered to LLVM. 245 VectorType vectorTy = loadOrStoreOp.getVectorType(); 246 if (vectorTy.getRank() > 1) 247 return failure(); 248 249 auto loc = loadOrStoreOp->getLoc(); 250 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 251 252 // Resolve alignment. 253 unsigned align; 254 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 255 return failure(); 256 257 // Resolve address. 258 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 259 .template cast<VectorType>(); 260 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 261 adaptor.getIndices(), rewriter); 262 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, 263 *this->getTypeConverter()); 264 265 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 266 return success(); 267 } 268 }; 269 270 /// Conversion pattern for a vector.gather. 271 class VectorGatherOpConversion 272 : public ConvertOpToLLVMPattern<vector::GatherOp> { 273 public: 274 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 275 276 LogicalResult 277 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 278 ConversionPatternRewriter &rewriter) const override { 279 MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>(); 280 assert(memRefType && "The base should be bufferized"); 281 282 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 283 return failure(); 284 285 auto loc = gather->getLoc(); 286 287 // Resolve alignment. 288 unsigned align; 289 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 290 return failure(); 291 292 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 293 adaptor.getIndices(), rewriter); 294 Value base = adaptor.getBase(); 295 296 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 297 // Handle the simple case of 1-D vector. 298 if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) { 299 auto vType = gather.getVectorType(); 300 // Resolve address. 301 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 302 memRefType, base, ptr, adaptor.getIndexVec(), 303 /*vLen=*/vType.getDimSize(0)); 304 // Replace with the gather intrinsic. 305 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 306 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 307 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 308 return success(); 309 } 310 311 LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 312 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 313 &typeConverter](Type llvm1DVectorTy, 314 ValueRange vectorOperands) { 315 // Resolve address. 316 Value ptrs = getIndexedPtrs( 317 rewriter, loc, typeConverter, memRefType, base, ptr, 318 /*index=*/vectorOperands[0], 319 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 320 // Create the gather intrinsic. 321 return rewriter.create<LLVM::masked_gather>( 322 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 323 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 324 }; 325 SmallVector<Value> vectorOperands = { 326 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 327 return LLVM::detail::handleMultidimensionalVectors( 328 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 329 } 330 }; 331 332 /// Conversion pattern for a vector.scatter. 333 class VectorScatterOpConversion 334 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 335 public: 336 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 337 338 LogicalResult 339 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 340 ConversionPatternRewriter &rewriter) const override { 341 auto loc = scatter->getLoc(); 342 MemRefType memRefType = scatter.getMemRefType(); 343 344 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 345 return failure(); 346 347 // Resolve alignment. 348 unsigned align; 349 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 350 return failure(); 351 352 // Resolve address. 353 VectorType vType = scatter.getVectorType(); 354 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 355 adaptor.getIndices(), rewriter); 356 Value ptrs = getIndexedPtrs( 357 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 358 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 359 360 // Replace with the scatter intrinsic. 361 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 362 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 363 rewriter.getI32IntegerAttr(align)); 364 return success(); 365 } 366 }; 367 368 /// Conversion pattern for a vector.expandload. 369 class VectorExpandLoadOpConversion 370 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 371 public: 372 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 373 374 LogicalResult 375 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 376 ConversionPatternRewriter &rewriter) const override { 377 auto loc = expand->getLoc(); 378 MemRefType memRefType = expand.getMemRefType(); 379 380 // Resolve address. 381 auto vtype = typeConverter->convertType(expand.getVectorType()); 382 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 383 adaptor.getIndices(), rewriter); 384 385 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 386 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 387 return success(); 388 } 389 }; 390 391 /// Conversion pattern for a vector.compressstore. 392 class VectorCompressStoreOpConversion 393 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 394 public: 395 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 396 397 LogicalResult 398 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 399 ConversionPatternRewriter &rewriter) const override { 400 auto loc = compress->getLoc(); 401 MemRefType memRefType = compress.getMemRefType(); 402 403 // Resolve address. 404 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 405 adaptor.getIndices(), rewriter); 406 407 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 408 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 409 return success(); 410 } 411 }; 412 413 /// Reduction neutral classes for overloading. 414 class ReductionNeutralZero {}; 415 class ReductionNeutralIntOne {}; 416 class ReductionNeutralFPOne {}; 417 class ReductionNeutralAllOnes {}; 418 class ReductionNeutralSIntMin {}; 419 class ReductionNeutralUIntMin {}; 420 class ReductionNeutralSIntMax {}; 421 class ReductionNeutralUIntMax {}; 422 class ReductionNeutralFPMin {}; 423 class ReductionNeutralFPMax {}; 424 425 /// Create the reduction neutral zero value. 426 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 427 ConversionPatternRewriter &rewriter, 428 Location loc, Type llvmType) { 429 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 430 rewriter.getZeroAttr(llvmType)); 431 } 432 433 /// Create the reduction neutral integer one value. 434 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 435 ConversionPatternRewriter &rewriter, 436 Location loc, Type llvmType) { 437 return rewriter.create<LLVM::ConstantOp>( 438 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 439 } 440 441 /// Create the reduction neutral fp one value. 442 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 443 ConversionPatternRewriter &rewriter, 444 Location loc, Type llvmType) { 445 return rewriter.create<LLVM::ConstantOp>( 446 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 447 } 448 449 /// Create the reduction neutral all-ones value. 450 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 451 ConversionPatternRewriter &rewriter, 452 Location loc, Type llvmType) { 453 return rewriter.create<LLVM::ConstantOp>( 454 loc, llvmType, 455 rewriter.getIntegerAttr( 456 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 457 } 458 459 /// Create the reduction neutral signed int minimum value. 460 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 461 ConversionPatternRewriter &rewriter, 462 Location loc, Type llvmType) { 463 return rewriter.create<LLVM::ConstantOp>( 464 loc, llvmType, 465 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 466 llvmType.getIntOrFloatBitWidth()))); 467 } 468 469 /// Create the reduction neutral unsigned int minimum value. 470 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 471 ConversionPatternRewriter &rewriter, 472 Location loc, Type llvmType) { 473 return rewriter.create<LLVM::ConstantOp>( 474 loc, llvmType, 475 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 476 llvmType.getIntOrFloatBitWidth()))); 477 } 478 479 /// Create the reduction neutral signed int maximum value. 480 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 481 ConversionPatternRewriter &rewriter, 482 Location loc, Type llvmType) { 483 return rewriter.create<LLVM::ConstantOp>( 484 loc, llvmType, 485 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 486 llvmType.getIntOrFloatBitWidth()))); 487 } 488 489 /// Create the reduction neutral unsigned int maximum value. 490 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 491 ConversionPatternRewriter &rewriter, 492 Location loc, Type llvmType) { 493 return rewriter.create<LLVM::ConstantOp>( 494 loc, llvmType, 495 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 496 llvmType.getIntOrFloatBitWidth()))); 497 } 498 499 /// Create the reduction neutral fp minimum value. 500 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 501 ConversionPatternRewriter &rewriter, 502 Location loc, Type llvmType) { 503 auto floatType = llvmType.cast<FloatType>(); 504 return rewriter.create<LLVM::ConstantOp>( 505 loc, llvmType, 506 rewriter.getFloatAttr( 507 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 508 /*Negative=*/false))); 509 } 510 511 /// Create the reduction neutral fp maximum value. 512 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 513 ConversionPatternRewriter &rewriter, 514 Location loc, Type llvmType) { 515 auto floatType = llvmType.cast<FloatType>(); 516 return rewriter.create<LLVM::ConstantOp>( 517 loc, llvmType, 518 rewriter.getFloatAttr( 519 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 520 /*Negative=*/true))); 521 } 522 523 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 524 /// returns a new accumulator value using `ReductionNeutral`. 525 template <class ReductionNeutral> 526 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 527 Location loc, Type llvmType, 528 Value accumulator) { 529 if (accumulator) 530 return accumulator; 531 532 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 533 llvmType); 534 } 535 536 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 537 /// This is used as effective vector length by some intrinsics supporting 538 /// dynamic vector lengths at runtime. 539 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 540 Location loc, Type llvmType) { 541 VectorType vType = cast<VectorType>(llvmType); 542 auto vShape = vType.getShape(); 543 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 544 545 return rewriter.create<LLVM::ConstantOp>( 546 loc, rewriter.getI32Type(), 547 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 548 } 549 550 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 551 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 552 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 553 /// non-null. 554 template <class LLVMRedIntrinOp, class ScalarOp> 555 static Value createIntegerReductionArithmeticOpLowering( 556 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 557 Value vectorOperand, Value accumulator) { 558 559 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 560 561 if (accumulator) 562 result = rewriter.create<ScalarOp>(loc, accumulator, result); 563 return result; 564 } 565 566 /// Helper method to lower a `vector.reduction` operation that performs 567 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 568 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 569 /// the accumulator value if non-null. 570 template <class LLVMRedIntrinOp> 571 static Value createIntegerReductionComparisonOpLowering( 572 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 573 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 574 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 575 if (accumulator) { 576 Value cmp = 577 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 578 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 579 } 580 return result; 581 } 582 583 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum 584 /// with vector types. 585 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, 586 Value rhs, bool isMin) { 587 auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); 588 Type i1Type = builder.getI1Type(); 589 if (auto vecType = lhs.getType().dyn_cast<VectorType>()) 590 i1Type = VectorType::get(vecType.getShape(), i1Type); 591 Value cmp = builder.create<LLVM::FCmpOp>( 592 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, 593 lhs, rhs); 594 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); 595 Value isNan = builder.create<LLVM::FCmpOp>( 596 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); 597 Value nan = builder.create<LLVM::ConstantOp>( 598 loc, lhs.getType(), 599 builder.getFloatAttr(floatType, 600 APFloat::getQNaN(floatType.getFloatSemantics()))); 601 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); 602 } 603 604 template <class LLVMRedIntrinOp> 605 static Value createFPReductionComparisonOpLowering( 606 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 607 Value vectorOperand, Value accumulator, bool isMin) { 608 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 609 610 if (accumulator) 611 result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin); 612 613 return result; 614 } 615 616 /// Overloaded methods to lower a reduction to an llvm instrinsic that requires 617 /// a start value. This start value format spans across fp reductions without 618 /// mask and all the masked reduction intrinsics. 619 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 620 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 621 Location loc, Type llvmType, 622 Value vectorOperand, 623 Value accumulator) { 624 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 625 llvmType, accumulator); 626 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 627 /*startValue=*/accumulator, 628 vectorOperand); 629 } 630 631 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 632 static Value 633 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 634 Type llvmType, Value vectorOperand, 635 Value accumulator, bool reassociateFPReds) { 636 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 637 llvmType, accumulator); 638 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 639 /*startValue=*/accumulator, 640 vectorOperand, reassociateFPReds); 641 } 642 643 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 644 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 645 Location loc, Type llvmType, 646 Value vectorOperand, 647 Value accumulator, Value mask) { 648 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 649 llvmType, accumulator); 650 Value vectorLength = 651 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 652 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 653 /*startValue=*/accumulator, 654 vectorOperand, mask, vectorLength); 655 } 656 657 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 658 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 659 Location loc, Type llvmType, 660 Value vectorOperand, 661 Value accumulator, Value mask, 662 bool reassociateFPReds) { 663 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 664 llvmType, accumulator); 665 Value vectorLength = 666 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 667 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 668 /*startValue=*/accumulator, 669 vectorOperand, mask, vectorLength, 670 reassociateFPReds); 671 } 672 673 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 674 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 675 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 676 Location loc, Type llvmType, 677 Value vectorOperand, 678 Value accumulator, Value mask) { 679 if (llvmType.isIntOrIndex()) 680 return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp, 681 IntReductionNeutral>( 682 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 683 684 // FP dispatch. 685 return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>( 686 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 687 } 688 689 /// Conversion pattern for all vector reductions. 690 class VectorReductionOpConversion 691 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 692 public: 693 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 694 bool reassociateFPRed) 695 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 696 reassociateFPReductions(reassociateFPRed) {} 697 698 LogicalResult 699 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 700 ConversionPatternRewriter &rewriter) const override { 701 auto kind = reductionOp.getKind(); 702 Type eltType = reductionOp.getDest().getType(); 703 Type llvmType = typeConverter->convertType(eltType); 704 Value operand = adaptor.getVector(); 705 Value acc = adaptor.getAcc(); 706 Location loc = reductionOp.getLoc(); 707 708 if (eltType.isIntOrIndex()) { 709 // Integer reductions: add/mul/min/max/and/or/xor. 710 Value result; 711 switch (kind) { 712 case vector::CombiningKind::ADD: 713 result = 714 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 715 LLVM::AddOp>( 716 rewriter, loc, llvmType, operand, acc); 717 break; 718 case vector::CombiningKind::MUL: 719 result = 720 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 721 LLVM::MulOp>( 722 rewriter, loc, llvmType, operand, acc); 723 break; 724 case vector::CombiningKind::MINUI: 725 result = createIntegerReductionComparisonOpLowering< 726 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 727 LLVM::ICmpPredicate::ule); 728 break; 729 case vector::CombiningKind::MINSI: 730 result = createIntegerReductionComparisonOpLowering< 731 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 732 LLVM::ICmpPredicate::sle); 733 break; 734 case vector::CombiningKind::MAXUI: 735 result = createIntegerReductionComparisonOpLowering< 736 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 737 LLVM::ICmpPredicate::uge); 738 break; 739 case vector::CombiningKind::MAXSI: 740 result = createIntegerReductionComparisonOpLowering< 741 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 742 LLVM::ICmpPredicate::sge); 743 break; 744 case vector::CombiningKind::AND: 745 result = 746 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 747 LLVM::AndOp>( 748 rewriter, loc, llvmType, operand, acc); 749 break; 750 case vector::CombiningKind::OR: 751 result = 752 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 753 LLVM::OrOp>( 754 rewriter, loc, llvmType, operand, acc); 755 break; 756 case vector::CombiningKind::XOR: 757 result = 758 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 759 LLVM::XOrOp>( 760 rewriter, loc, llvmType, operand, acc); 761 break; 762 default: 763 return failure(); 764 } 765 rewriter.replaceOp(reductionOp, result); 766 767 return success(); 768 } 769 770 if (!eltType.isa<FloatType>()) 771 return failure(); 772 773 // Floating-point reductions: add/mul/min/max 774 Value result; 775 if (kind == vector::CombiningKind::ADD) { 776 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 777 ReductionNeutralZero>( 778 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 779 } else if (kind == vector::CombiningKind::MUL) { 780 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 781 ReductionNeutralFPOne>( 782 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 783 } else if (kind == vector::CombiningKind::MINF) { 784 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 785 // NaNs/-0.0/+0.0 in the same way. 786 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 787 rewriter, loc, llvmType, operand, acc, 788 /*isMin=*/true); 789 } else if (kind == vector::CombiningKind::MAXF) { 790 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 791 // NaNs/-0.0/+0.0 in the same way. 792 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 793 rewriter, loc, llvmType, operand, acc, 794 /*isMin=*/false); 795 } else 796 return failure(); 797 798 rewriter.replaceOp(reductionOp, result); 799 return success(); 800 } 801 802 private: 803 const bool reassociateFPReductions; 804 }; 805 806 /// Base class to convert a `vector.mask` operation while matching traits 807 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 808 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 809 /// method performs a second match against the maskable operation `MaskedOp`. 810 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 811 /// implemented by the concrete conversion classes. This method can match 812 /// against specific traits of the `vector.mask` and the maskable operation. It 813 /// must replace the `vector.mask` operation. 814 template <class MaskedOp> 815 class VectorMaskOpConversionBase 816 : public ConvertOpToLLVMPattern<vector::MaskOp> { 817 public: 818 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 819 820 LogicalResult 821 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 822 ConversionPatternRewriter &rewriter) const final { 823 // Match against the maskable operation kind. 824 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 825 if (!maskedOp) 826 return failure(); 827 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 828 } 829 830 protected: 831 virtual LogicalResult 832 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 833 vector::MaskableOpInterface maskableOp, 834 ConversionPatternRewriter &rewriter) const = 0; 835 }; 836 837 class MaskedReductionOpConversion 838 : public VectorMaskOpConversionBase<vector::ReductionOp> { 839 840 public: 841 using VectorMaskOpConversionBase< 842 vector::ReductionOp>::VectorMaskOpConversionBase; 843 844 LogicalResult matchAndRewriteMaskableOp( 845 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 846 ConversionPatternRewriter &rewriter) const override { 847 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 848 auto kind = reductionOp.getKind(); 849 Type eltType = reductionOp.getDest().getType(); 850 Type llvmType = typeConverter->convertType(eltType); 851 Value operand = reductionOp.getVector(); 852 Value acc = reductionOp.getAcc(); 853 Location loc = reductionOp.getLoc(); 854 855 Value result; 856 switch (kind) { 857 case vector::CombiningKind::ADD: 858 result = lowerReductionWithStartValue< 859 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 860 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 861 maskOp.getMask()); 862 break; 863 case vector::CombiningKind::MUL: 864 result = lowerReductionWithStartValue< 865 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 866 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 867 maskOp.getMask()); 868 break; 869 case vector::CombiningKind::MINUI: 870 result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp, 871 ReductionNeutralUIntMax>( 872 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 873 break; 874 case vector::CombiningKind::MINSI: 875 result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp, 876 ReductionNeutralSIntMax>( 877 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 878 break; 879 case vector::CombiningKind::MAXUI: 880 result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp, 881 ReductionNeutralUIntMin>( 882 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 883 break; 884 case vector::CombiningKind::MAXSI: 885 result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp, 886 ReductionNeutralSIntMin>( 887 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 888 break; 889 case vector::CombiningKind::AND: 890 result = lowerReductionWithStartValue<LLVM::VPReduceAndOp, 891 ReductionNeutralAllOnes>( 892 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 893 break; 894 case vector::CombiningKind::OR: 895 result = lowerReductionWithStartValue<LLVM::VPReduceOrOp, 896 ReductionNeutralZero>( 897 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 898 break; 899 case vector::CombiningKind::XOR: 900 result = lowerReductionWithStartValue<LLVM::VPReduceXorOp, 901 ReductionNeutralZero>( 902 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 903 break; 904 case vector::CombiningKind::MINF: 905 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 906 // NaNs/-0.0/+0.0 in the same way. 907 result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp, 908 ReductionNeutralFPMax>( 909 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 910 break; 911 case vector::CombiningKind::MAXF: 912 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 913 // NaNs/-0.0/+0.0 in the same way. 914 result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp, 915 ReductionNeutralFPMin>( 916 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 917 break; 918 } 919 920 // Replace `vector.mask` operation altogether. 921 rewriter.replaceOp(maskOp, result); 922 return success(); 923 } 924 }; 925 926 class VectorShuffleOpConversion 927 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 928 public: 929 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 930 931 LogicalResult 932 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 933 ConversionPatternRewriter &rewriter) const override { 934 auto loc = shuffleOp->getLoc(); 935 auto v1Type = shuffleOp.getV1VectorType(); 936 auto v2Type = shuffleOp.getV2VectorType(); 937 auto vectorType = shuffleOp.getResultVectorType(); 938 Type llvmType = typeConverter->convertType(vectorType); 939 auto maskArrayAttr = shuffleOp.getMask(); 940 941 // Bail if result type cannot be lowered. 942 if (!llvmType) 943 return failure(); 944 945 // Get rank and dimension sizes. 946 int64_t rank = vectorType.getRank(); 947 #ifndef NDEBUG 948 bool wellFormed0DCase = 949 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 950 bool wellFormedNDCase = 951 v1Type.getRank() == rank && v2Type.getRank() == rank; 952 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 953 #endif 954 955 // For rank 0 and 1, where both operands have *exactly* the same vector 956 // type, there is direct shuffle support in LLVM. Use it! 957 if (rank <= 1 && v1Type == v2Type) { 958 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 959 loc, adaptor.getV1(), adaptor.getV2(), 960 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 961 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 962 return success(); 963 } 964 965 // For all other cases, insert the individual values individually. 966 int64_t v1Dim = v1Type.getDimSize(0); 967 Type eltType; 968 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 969 eltType = arrayType.getElementType(); 970 else 971 eltType = llvmType.cast<VectorType>().getElementType(); 972 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 973 int64_t insPos = 0; 974 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 975 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 976 Value value = adaptor.getV1(); 977 if (extPos >= v1Dim) { 978 extPos -= v1Dim; 979 value = adaptor.getV2(); 980 } 981 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 982 eltType, rank, extPos); 983 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 984 llvmType, rank, insPos++); 985 } 986 rewriter.replaceOp(shuffleOp, insert); 987 return success(); 988 } 989 }; 990 991 class VectorExtractElementOpConversion 992 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 993 public: 994 using ConvertOpToLLVMPattern< 995 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 996 997 LogicalResult 998 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 999 ConversionPatternRewriter &rewriter) const override { 1000 auto vectorType = extractEltOp.getSourceVectorType(); 1001 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 1002 1003 // Bail if result type cannot be lowered. 1004 if (!llvmType) 1005 return failure(); 1006 1007 if (vectorType.getRank() == 0) { 1008 Location loc = extractEltOp.getLoc(); 1009 auto idxType = rewriter.getIndexType(); 1010 auto zero = rewriter.create<LLVM::ConstantOp>( 1011 loc, typeConverter->convertType(idxType), 1012 rewriter.getIntegerAttr(idxType, 0)); 1013 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1014 extractEltOp, llvmType, adaptor.getVector(), zero); 1015 return success(); 1016 } 1017 1018 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1019 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1020 return success(); 1021 } 1022 }; 1023 1024 class VectorExtractOpConversion 1025 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1026 public: 1027 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1028 1029 LogicalResult 1030 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1031 ConversionPatternRewriter &rewriter) const override { 1032 auto loc = extractOp->getLoc(); 1033 auto resultType = extractOp.getResult().getType(); 1034 auto llvmResultType = typeConverter->convertType(resultType); 1035 auto positionArrayAttr = extractOp.getPosition(); 1036 1037 // Bail if result type cannot be lowered. 1038 if (!llvmResultType) 1039 return failure(); 1040 1041 // Extract entire vector. Should be handled by folder, but just to be safe. 1042 if (positionArrayAttr.empty()) { 1043 rewriter.replaceOp(extractOp, adaptor.getVector()); 1044 return success(); 1045 } 1046 1047 // One-shot extraction of vector from array (only requires extractvalue). 1048 if (resultType.isa<VectorType>()) { 1049 SmallVector<int64_t> indices; 1050 for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>()) 1051 indices.push_back(idx.getInt()); 1052 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1053 loc, adaptor.getVector(), indices); 1054 rewriter.replaceOp(extractOp, extracted); 1055 return success(); 1056 } 1057 1058 // Potential extraction of 1-D vector from array. 1059 Value extracted = adaptor.getVector(); 1060 auto positionAttrs = positionArrayAttr.getValue(); 1061 if (positionAttrs.size() > 1) { 1062 SmallVector<int64_t> nMinusOnePosition; 1063 for (auto idx : positionAttrs.drop_back()) 1064 nMinusOnePosition.push_back(idx.cast<IntegerAttr>().getInt()); 1065 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 1066 nMinusOnePosition); 1067 } 1068 1069 // Remaining extraction of element from 1-D LLVM vector 1070 auto position = positionAttrs.back().cast<IntegerAttr>(); 1071 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1072 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 1073 extracted = 1074 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 1075 rewriter.replaceOp(extractOp, extracted); 1076 1077 return success(); 1078 } 1079 }; 1080 1081 /// Conversion pattern that turns a vector.fma on a 1-D vector 1082 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1083 /// This does not match vectors of n >= 2 rank. 1084 /// 1085 /// Example: 1086 /// ``` 1087 /// vector.fma %a, %a, %a : vector<8xf32> 1088 /// ``` 1089 /// is converted to: 1090 /// ``` 1091 /// llvm.intr.fmuladd %va, %va, %va: 1092 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1093 /// -> !llvm."<8 x f32>"> 1094 /// ``` 1095 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1096 public: 1097 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1098 1099 LogicalResult 1100 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1101 ConversionPatternRewriter &rewriter) const override { 1102 VectorType vType = fmaOp.getVectorType(); 1103 if (vType.getRank() > 1) 1104 return failure(); 1105 1106 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1107 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1108 return success(); 1109 } 1110 }; 1111 1112 class VectorInsertElementOpConversion 1113 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1114 public: 1115 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1116 1117 LogicalResult 1118 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1119 ConversionPatternRewriter &rewriter) const override { 1120 auto vectorType = insertEltOp.getDestVectorType(); 1121 auto llvmType = typeConverter->convertType(vectorType); 1122 1123 // Bail if result type cannot be lowered. 1124 if (!llvmType) 1125 return failure(); 1126 1127 if (vectorType.getRank() == 0) { 1128 Location loc = insertEltOp.getLoc(); 1129 auto idxType = rewriter.getIndexType(); 1130 auto zero = rewriter.create<LLVM::ConstantOp>( 1131 loc, typeConverter->convertType(idxType), 1132 rewriter.getIntegerAttr(idxType, 0)); 1133 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1134 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1135 return success(); 1136 } 1137 1138 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1139 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1140 adaptor.getPosition()); 1141 return success(); 1142 } 1143 }; 1144 1145 class VectorInsertOpConversion 1146 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1147 public: 1148 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1149 1150 LogicalResult 1151 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1152 ConversionPatternRewriter &rewriter) const override { 1153 auto loc = insertOp->getLoc(); 1154 auto sourceType = insertOp.getSourceType(); 1155 auto destVectorType = insertOp.getDestVectorType(); 1156 auto llvmResultType = typeConverter->convertType(destVectorType); 1157 auto positionArrayAttr = insertOp.getPosition(); 1158 1159 // Bail if result type cannot be lowered. 1160 if (!llvmResultType) 1161 return failure(); 1162 1163 // Overwrite entire vector with value. Should be handled by folder, but 1164 // just to be safe. 1165 if (positionArrayAttr.empty()) { 1166 rewriter.replaceOp(insertOp, adaptor.getSource()); 1167 return success(); 1168 } 1169 1170 // One-shot insertion of a vector into an array (only requires insertvalue). 1171 if (sourceType.isa<VectorType>()) { 1172 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1173 loc, adaptor.getDest(), adaptor.getSource(), 1174 LLVM::convertArrayToIndices(positionArrayAttr)); 1175 rewriter.replaceOp(insertOp, inserted); 1176 return success(); 1177 } 1178 1179 // Potential extraction of 1-D vector from array. 1180 Value extracted = adaptor.getDest(); 1181 auto positionAttrs = positionArrayAttr.getValue(); 1182 auto position = positionAttrs.back().cast<IntegerAttr>(); 1183 auto oneDVectorType = destVectorType; 1184 if (positionAttrs.size() > 1) { 1185 oneDVectorType = reducedVectorTypeBack(destVectorType); 1186 extracted = rewriter.create<LLVM::ExtractValueOp>( 1187 loc, extracted, 1188 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 1189 } 1190 1191 // Insertion of an element into a 1-D LLVM vector. 1192 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1193 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 1194 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1195 loc, typeConverter->convertType(oneDVectorType), extracted, 1196 adaptor.getSource(), constant); 1197 1198 // Potential insertion of resulting 1-D vector into array. 1199 if (positionAttrs.size() > 1) { 1200 inserted = rewriter.create<LLVM::InsertValueOp>( 1201 loc, adaptor.getDest(), inserted, 1202 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 1203 } 1204 1205 rewriter.replaceOp(insertOp, inserted); 1206 return success(); 1207 } 1208 }; 1209 1210 /// Lower vector.scalable.insert ops to LLVM vector.insert 1211 struct VectorScalableInsertOpLowering 1212 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1213 using ConvertOpToLLVMPattern< 1214 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1215 1216 LogicalResult 1217 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1218 ConversionPatternRewriter &rewriter) const override { 1219 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1220 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1221 return success(); 1222 } 1223 }; 1224 1225 /// Lower vector.scalable.extract ops to LLVM vector.extract 1226 struct VectorScalableExtractOpLowering 1227 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1228 using ConvertOpToLLVMPattern< 1229 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1230 1231 LogicalResult 1232 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1233 ConversionPatternRewriter &rewriter) const override { 1234 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1235 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1236 adaptor.getSource(), adaptor.getPos()); 1237 return success(); 1238 } 1239 }; 1240 1241 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1242 /// 1243 /// Example: 1244 /// ``` 1245 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1246 /// ``` 1247 /// is rewritten into: 1248 /// ``` 1249 /// %r = splat %f0: vector<2x4xf32> 1250 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1251 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1252 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1253 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1254 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1255 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1256 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1257 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1258 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1259 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1260 /// // %r3 holds the final value. 1261 /// ``` 1262 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1263 public: 1264 using OpRewritePattern<FMAOp>::OpRewritePattern; 1265 1266 void initialize() { 1267 // This pattern recursively unpacks one dimension at a time. The recursion 1268 // bounded as the rank is strictly decreasing. 1269 setHasBoundedRewriteRecursion(); 1270 } 1271 1272 LogicalResult matchAndRewrite(FMAOp op, 1273 PatternRewriter &rewriter) const override { 1274 auto vType = op.getVectorType(); 1275 if (vType.getRank() < 2) 1276 return failure(); 1277 1278 auto loc = op.getLoc(); 1279 auto elemType = vType.getElementType(); 1280 Value zero = rewriter.create<arith::ConstantOp>( 1281 loc, elemType, rewriter.getZeroAttr(elemType)); 1282 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1283 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1284 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1285 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1286 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1287 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1288 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1289 } 1290 rewriter.replaceOp(op, desc); 1291 return success(); 1292 } 1293 }; 1294 1295 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1296 /// static layout. 1297 static std::optional<SmallVector<int64_t, 4>> 1298 computeContiguousStrides(MemRefType memRefType) { 1299 int64_t offset; 1300 SmallVector<int64_t, 4> strides; 1301 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1302 return std::nullopt; 1303 if (!strides.empty() && strides.back() != 1) 1304 return std::nullopt; 1305 // If no layout or identity layout, this is contiguous by definition. 1306 if (memRefType.getLayout().isIdentity()) 1307 return strides; 1308 1309 // Otherwise, we must determine contiguity form shapes. This can only ever 1310 // work in static cases because MemRefType is underspecified to represent 1311 // contiguous dynamic shapes in other ways than with just empty/identity 1312 // layout. 1313 auto sizes = memRefType.getShape(); 1314 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1315 if (ShapedType::isDynamic(sizes[index + 1]) || 1316 ShapedType::isDynamic(strides[index]) || 1317 ShapedType::isDynamic(strides[index + 1])) 1318 return std::nullopt; 1319 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1320 return std::nullopt; 1321 } 1322 return strides; 1323 } 1324 1325 class VectorTypeCastOpConversion 1326 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1327 public: 1328 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1329 1330 LogicalResult 1331 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1332 ConversionPatternRewriter &rewriter) const override { 1333 auto loc = castOp->getLoc(); 1334 MemRefType sourceMemRefType = 1335 castOp.getOperand().getType().cast<MemRefType>(); 1336 MemRefType targetMemRefType = castOp.getType(); 1337 1338 // Only static shape casts supported atm. 1339 if (!sourceMemRefType.hasStaticShape() || 1340 !targetMemRefType.hasStaticShape()) 1341 return failure(); 1342 1343 auto llvmSourceDescriptorTy = 1344 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 1345 if (!llvmSourceDescriptorTy) 1346 return failure(); 1347 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1348 1349 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1350 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1351 if (!llvmTargetDescriptorTy) 1352 return failure(); 1353 1354 // Only contiguous source buffers supported atm. 1355 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1356 if (!sourceStrides) 1357 return failure(); 1358 auto targetStrides = computeContiguousStrides(targetMemRefType); 1359 if (!targetStrides) 1360 return failure(); 1361 // Only support static strides for now, regardless of contiguity. 1362 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1363 return failure(); 1364 1365 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1366 1367 // Create descriptor. 1368 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1369 Type llvmTargetElementTy = desc.getElementPtrType(); 1370 // Set allocated ptr. 1371 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1372 if (!getTypeConverter()->useOpaquePointers()) 1373 allocated = 1374 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1375 desc.setAllocatedPtr(rewriter, loc, allocated); 1376 1377 // Set aligned ptr. 1378 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1379 if (!getTypeConverter()->useOpaquePointers()) 1380 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1381 1382 desc.setAlignedPtr(rewriter, loc, ptr); 1383 // Fill offset 0. 1384 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1385 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1386 desc.setOffset(rewriter, loc, zero); 1387 1388 // Fill size and stride descriptors in memref. 1389 for (const auto &indexedSize : 1390 llvm::enumerate(targetMemRefType.getShape())) { 1391 int64_t index = indexedSize.index(); 1392 auto sizeAttr = 1393 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1394 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1395 desc.setSize(rewriter, loc, index, size); 1396 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1397 (*targetStrides)[index]); 1398 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1399 desc.setStride(rewriter, loc, index, stride); 1400 } 1401 1402 rewriter.replaceOp(castOp, {desc}); 1403 return success(); 1404 } 1405 }; 1406 1407 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1408 /// Non-scalable versions of this operation are handled in Vector Transforms. 1409 class VectorCreateMaskOpRewritePattern 1410 : public OpRewritePattern<vector::CreateMaskOp> { 1411 public: 1412 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1413 bool enableIndexOpt) 1414 : OpRewritePattern<vector::CreateMaskOp>(context), 1415 force32BitVectorIndices(enableIndexOpt) {} 1416 1417 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1418 PatternRewriter &rewriter) const override { 1419 auto dstType = op.getType(); 1420 if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable()) 1421 return failure(); 1422 IntegerType idxType = 1423 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1424 auto loc = op->getLoc(); 1425 Value indices = rewriter.create<LLVM::StepVectorOp>( 1426 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1427 /*isScalable=*/true)); 1428 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1429 op.getOperand(0)); 1430 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1431 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1432 indices, bounds); 1433 rewriter.replaceOp(op, comp); 1434 return success(); 1435 } 1436 1437 private: 1438 const bool force32BitVectorIndices; 1439 }; 1440 1441 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1442 public: 1443 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1444 1445 // Proof-of-concept lowering implementation that relies on a small 1446 // runtime support library, which only needs to provide a few 1447 // printing methods (single value for all data types, opening/closing 1448 // bracket, comma, newline). The lowering fully unrolls a vector 1449 // in terms of these elementary printing operations. The advantage 1450 // of this approach is that the library can remain unaware of all 1451 // low-level implementation details of vectors while still supporting 1452 // output of any shaped and dimensioned vector. Due to full unrolling, 1453 // this approach is less suited for very large vectors though. 1454 // 1455 // TODO: rely solely on libc in future? something else? 1456 // 1457 LogicalResult 1458 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1459 ConversionPatternRewriter &rewriter) const override { 1460 Type printType = printOp.getPrintType(); 1461 1462 if (typeConverter->convertType(printType) == nullptr) 1463 return failure(); 1464 1465 // Make sure element type has runtime support. 1466 PrintConversion conversion = PrintConversion::None; 1467 VectorType vectorType = printType.dyn_cast<VectorType>(); 1468 Type eltType = vectorType ? vectorType.getElementType() : printType; 1469 auto parent = printOp->getParentOfType<ModuleOp>(); 1470 Operation *printer; 1471 if (eltType.isF32()) { 1472 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1473 } else if (eltType.isF64()) { 1474 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1475 } else if (eltType.isF16()) { 1476 conversion = PrintConversion::Bitcast16; // bits! 1477 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1478 } else if (eltType.isBF16()) { 1479 conversion = PrintConversion::Bitcast16; // bits! 1480 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1481 } else if (eltType.isIndex()) { 1482 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1483 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1484 // Integers need a zero or sign extension on the operand 1485 // (depending on the source type) as well as a signed or 1486 // unsigned print method. Up to 64-bit is supported. 1487 unsigned width = intTy.getWidth(); 1488 if (intTy.isUnsigned()) { 1489 if (width <= 64) { 1490 if (width < 64) 1491 conversion = PrintConversion::ZeroExt64; 1492 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1493 } else { 1494 return failure(); 1495 } 1496 } else { 1497 assert(intTy.isSignless() || intTy.isSigned()); 1498 if (width <= 64) { 1499 // Note that we *always* zero extend booleans (1-bit integers), 1500 // so that true/false is printed as 1/0 rather than -1/0. 1501 if (width == 1) 1502 conversion = PrintConversion::ZeroExt64; 1503 else if (width < 64) 1504 conversion = PrintConversion::SignExt64; 1505 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1506 } else { 1507 return failure(); 1508 } 1509 } 1510 } else { 1511 return failure(); 1512 } 1513 1514 // Unroll vector into elementary print calls. 1515 int64_t rank = vectorType ? vectorType.getRank() : 0; 1516 Type type = vectorType ? vectorType : eltType; 1517 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1518 conversion); 1519 emitCall(rewriter, printOp->getLoc(), 1520 LLVM::lookupOrCreatePrintNewlineFn(parent)); 1521 rewriter.eraseOp(printOp); 1522 return success(); 1523 } 1524 1525 private: 1526 enum class PrintConversion { 1527 // clang-format off 1528 None, 1529 ZeroExt64, 1530 SignExt64, 1531 Bitcast16 1532 // clang-format on 1533 }; 1534 1535 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1536 Value value, Type type, Operation *printer, int64_t rank, 1537 PrintConversion conversion) const { 1538 VectorType vectorType = type.dyn_cast<VectorType>(); 1539 Location loc = op->getLoc(); 1540 if (!vectorType) { 1541 assert(rank == 0 && "The scalar case expects rank == 0"); 1542 switch (conversion) { 1543 case PrintConversion::ZeroExt64: 1544 value = rewriter.create<arith::ExtUIOp>( 1545 loc, IntegerType::get(rewriter.getContext(), 64), value); 1546 break; 1547 case PrintConversion::SignExt64: 1548 value = rewriter.create<arith::ExtSIOp>( 1549 loc, IntegerType::get(rewriter.getContext(), 64), value); 1550 break; 1551 case PrintConversion::Bitcast16: 1552 value = rewriter.create<LLVM::BitcastOp>( 1553 loc, IntegerType::get(rewriter.getContext(), 16), value); 1554 break; 1555 case PrintConversion::None: 1556 break; 1557 } 1558 emitCall(rewriter, loc, printer, value); 1559 return; 1560 } 1561 1562 auto parent = op->getParentOfType<ModuleOp>(); 1563 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent)); 1564 Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent); 1565 1566 if (rank <= 1) { 1567 auto reducedType = vectorType.getElementType(); 1568 auto llvmType = typeConverter->convertType(reducedType); 1569 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1570 for (int64_t d = 0; d < dim; ++d) { 1571 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1572 llvmType, /*rank=*/0, /*pos=*/d); 1573 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1574 conversion); 1575 if (d != dim - 1) 1576 emitCall(rewriter, loc, printComma); 1577 } 1578 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); 1579 return; 1580 } 1581 1582 int64_t dim = vectorType.getDimSize(0); 1583 for (int64_t d = 0; d < dim; ++d) { 1584 auto reducedType = reducedVectorTypeFront(vectorType); 1585 auto llvmType = typeConverter->convertType(reducedType); 1586 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1587 llvmType, rank, d); 1588 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1589 conversion); 1590 if (d != dim - 1) 1591 emitCall(rewriter, loc, printComma); 1592 } 1593 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); 1594 } 1595 1596 // Helper to emit a call. 1597 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1598 Operation *ref, ValueRange params = ValueRange()) { 1599 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1600 params); 1601 } 1602 }; 1603 1604 /// The Splat operation is lowered to an insertelement + a shufflevector 1605 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1606 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1607 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1608 1609 LogicalResult 1610 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1611 ConversionPatternRewriter &rewriter) const override { 1612 VectorType resultType = splatOp.getType().cast<VectorType>(); 1613 if (resultType.getRank() > 1) 1614 return failure(); 1615 1616 // First insert it into an undef vector so we can shuffle it. 1617 auto vectorType = typeConverter->convertType(splatOp.getType()); 1618 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1619 auto zero = rewriter.create<LLVM::ConstantOp>( 1620 splatOp.getLoc(), 1621 typeConverter->convertType(rewriter.getIntegerType(32)), 1622 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1623 1624 // For 0-d vector, we simply do `insertelement`. 1625 if (resultType.getRank() == 0) { 1626 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1627 splatOp, vectorType, undef, adaptor.getInput(), zero); 1628 return success(); 1629 } 1630 1631 // For 1-d vector, we additionally do a `vectorshuffle`. 1632 auto v = rewriter.create<LLVM::InsertElementOp>( 1633 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1634 1635 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); 1636 SmallVector<int32_t> zeroValues(width, 0); 1637 1638 // Shuffle the value across the desired number of elements. 1639 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1640 zeroValues); 1641 return success(); 1642 } 1643 }; 1644 1645 /// The Splat operation is lowered to an insertelement + a shufflevector 1646 /// operation. Splat to only 2+-d vector result types are lowered by the 1647 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1648 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1649 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1650 1651 LogicalResult 1652 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1653 ConversionPatternRewriter &rewriter) const override { 1654 VectorType resultType = splatOp.getType(); 1655 if (resultType.getRank() <= 1) 1656 return failure(); 1657 1658 // First insert it into an undef vector so we can shuffle it. 1659 auto loc = splatOp.getLoc(); 1660 auto vectorTypeInfo = 1661 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1662 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1663 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1664 if (!llvmNDVectorTy || !llvm1DVectorTy) 1665 return failure(); 1666 1667 // Construct returned value. 1668 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1669 1670 // Construct a 1-D vector with the splatted value that we insert in all the 1671 // places within the returned descriptor. 1672 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1673 auto zero = rewriter.create<LLVM::ConstantOp>( 1674 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1675 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1676 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1677 adaptor.getInput(), zero); 1678 1679 // Shuffle the value across the desired number of elements. 1680 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1681 SmallVector<int32_t> zeroValues(width, 0); 1682 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1683 1684 // Iterate of linear index, convert to coords space and insert splatted 1-D 1685 // vector in each position. 1686 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1687 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1688 }); 1689 rewriter.replaceOp(splatOp, desc); 1690 return success(); 1691 } 1692 }; 1693 1694 } // namespace 1695 1696 /// Populate the given list with patterns that convert from Vector to LLVM. 1697 void mlir::populateVectorToLLVMConversionPatterns( 1698 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1699 bool reassociateFPReductions, bool force32BitVectorIndices) { 1700 MLIRContext *ctx = converter.getDialect()->getContext(); 1701 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1702 populateVectorInsertExtractStridedSliceTransforms(patterns); 1703 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1704 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1705 patterns 1706 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1707 VectorExtractElementOpConversion, VectorExtractOpConversion, 1708 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1709 VectorInsertOpConversion, VectorPrintOpConversion, 1710 VectorTypeCastOpConversion, VectorScaleOpConversion, 1711 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1712 VectorLoadStoreConversion<vector::MaskedLoadOp, 1713 vector::MaskedLoadOpAdaptor>, 1714 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1715 VectorLoadStoreConversion<vector::MaskedStoreOp, 1716 vector::MaskedStoreOpAdaptor>, 1717 VectorGatherOpConversion, VectorScatterOpConversion, 1718 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1719 VectorSplatOpLowering, VectorSplatNdOpLowering, 1720 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1721 MaskedReductionOpConversion>(converter); 1722 // Transfer ops with rank > 1 are handled by VectorToSCF. 1723 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1724 } 1725 1726 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1727 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1728 patterns.add<VectorMatmulOpConversion>(converter); 1729 patterns.add<VectorFlatTransposeOpConversion>(converter); 1730 } 1731