1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Arith/Utils/Utils.h" 15 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/Support/Casting.h" 26 #include <optional> 27 28 using namespace mlir; 29 using namespace mlir::vector; 30 31 // Helper to reduce vector type by one rank at front. 32 static VectorType reducedVectorTypeFront(VectorType tp) { 33 assert((tp.getRank() > 1) && "unlowerable vector type"); 34 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 35 tp.getScalableDims().drop_front()); 36 } 37 38 // Helper to reduce vector type by *all* but one rank at back. 39 static VectorType reducedVectorTypeBack(VectorType tp) { 40 assert((tp.getRank() > 1) && "unlowerable vector type"); 41 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 42 tp.getScalableDims().take_back()); 43 } 44 45 // Helper that picks the proper sequence for inserting. 46 static Value insertOne(ConversionPatternRewriter &rewriter, 47 LLVMTypeConverter &typeConverter, Location loc, 48 Value val1, Value val2, Type llvmType, int64_t rank, 49 int64_t pos) { 50 assert(rank > 0 && "0-D vector corner case should have been handled already"); 51 if (rank == 1) { 52 auto idxType = rewriter.getIndexType(); 53 auto constant = rewriter.create<LLVM::ConstantOp>( 54 loc, typeConverter.convertType(idxType), 55 rewriter.getIntegerAttr(idxType, pos)); 56 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 57 constant); 58 } 59 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 60 } 61 62 // Helper that picks the proper sequence for extracting. 63 static Value extractOne(ConversionPatternRewriter &rewriter, 64 LLVMTypeConverter &typeConverter, Location loc, 65 Value val, Type llvmType, int64_t rank, int64_t pos) { 66 if (rank <= 1) { 67 auto idxType = rewriter.getIndexType(); 68 auto constant = rewriter.create<LLVM::ConstantOp>( 69 loc, typeConverter.convertType(idxType), 70 rewriter.getIntegerAttr(idxType, pos)); 71 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 72 constant); 73 } 74 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 75 } 76 77 // Helper that returns data layout alignment of a memref. 78 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 79 MemRefType memrefType, unsigned &align) { 80 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 81 if (!elementTy) 82 return failure(); 83 84 // TODO: this should use the MLIR data layout when it becomes available and 85 // stop depending on translation. 86 llvm::LLVMContext llvmContext; 87 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 88 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 89 return success(); 90 } 91 92 // Check if the last stride is non-unit or the memory space is not zero. 93 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 94 LLVMTypeConverter &converter) { 95 if (!isLastMemrefDimUnitStride(memRefType)) 96 return failure(); 97 FailureOr<unsigned> addressSpace = 98 converter.getMemRefAddressSpace(memRefType); 99 if (failed(addressSpace) || *addressSpace != 0) 100 return failure(); 101 return success(); 102 } 103 104 // Add an index vector component to a base pointer. 105 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 106 LLVMTypeConverter &typeConverter, 107 MemRefType memRefType, Value llvmMemref, Value base, 108 Value index, uint64_t vLen) { 109 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 110 "unsupported memref type"); 111 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 112 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 113 return rewriter.create<LLVM::GEPOp>( 114 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 115 base, index); 116 } 117 118 // Casts a strided element pointer to a vector pointer. The vector pointer 119 // will be in the same address space as the incoming memref type. 120 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 121 Value ptr, MemRefType memRefType, Type vt, 122 LLVMTypeConverter &converter) { 123 if (converter.useOpaquePointers()) 124 return ptr; 125 126 unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); 127 auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); 128 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 129 } 130 131 namespace { 132 133 /// Trivial Vector to LLVM conversions 134 using VectorScaleOpConversion = 135 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 136 137 /// Conversion pattern for a vector.bitcast. 138 class VectorBitCastOpConversion 139 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 140 public: 141 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 142 143 LogicalResult 144 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 145 ConversionPatternRewriter &rewriter) const override { 146 // Only 0-D and 1-D vectors can be lowered to LLVM. 147 VectorType resultTy = bitCastOp.getResultVectorType(); 148 if (resultTy.getRank() > 1) 149 return failure(); 150 Type newResultTy = typeConverter->convertType(resultTy); 151 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 152 adaptor.getOperands()[0]); 153 return success(); 154 } 155 }; 156 157 /// Conversion pattern for a vector.matrix_multiply. 158 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 159 class VectorMatmulOpConversion 160 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 161 public: 162 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 163 164 LogicalResult 165 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 166 ConversionPatternRewriter &rewriter) const override { 167 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 168 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 169 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 170 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 171 return success(); 172 } 173 }; 174 175 /// Conversion pattern for a vector.flat_transpose. 176 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 177 class VectorFlatTransposeOpConversion 178 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 179 public: 180 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 181 182 LogicalResult 183 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 184 ConversionPatternRewriter &rewriter) const override { 185 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 186 transOp, typeConverter->convertType(transOp.getRes().getType()), 187 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 188 return success(); 189 } 190 }; 191 192 /// Overloaded utility that replaces a vector.load, vector.store, 193 /// vector.maskedload and vector.maskedstore with their respective LLVM 194 /// couterparts. 195 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 196 vector::LoadOpAdaptor adaptor, 197 VectorType vectorTy, Value ptr, unsigned align, 198 ConversionPatternRewriter &rewriter) { 199 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 200 } 201 202 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 203 vector::MaskedLoadOpAdaptor adaptor, 204 VectorType vectorTy, Value ptr, unsigned align, 205 ConversionPatternRewriter &rewriter) { 206 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 207 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 208 } 209 210 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 211 vector::StoreOpAdaptor adaptor, 212 VectorType vectorTy, Value ptr, unsigned align, 213 ConversionPatternRewriter &rewriter) { 214 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 215 ptr, align); 216 } 217 218 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 219 vector::MaskedStoreOpAdaptor adaptor, 220 VectorType vectorTy, Value ptr, unsigned align, 221 ConversionPatternRewriter &rewriter) { 222 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 223 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 224 } 225 226 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 227 /// vector.maskedstore. 228 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 229 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 230 public: 231 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 232 233 LogicalResult 234 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 235 typename LoadOrStoreOp::Adaptor adaptor, 236 ConversionPatternRewriter &rewriter) const override { 237 // Only 1-D vectors can be lowered to LLVM. 238 VectorType vectorTy = loadOrStoreOp.getVectorType(); 239 if (vectorTy.getRank() > 1) 240 return failure(); 241 242 auto loc = loadOrStoreOp->getLoc(); 243 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 244 245 // Resolve alignment. 246 unsigned align; 247 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 248 return failure(); 249 250 // Resolve address. 251 auto vtype = cast<VectorType>( 252 this->typeConverter->convertType(loadOrStoreOp.getVectorType())); 253 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 254 adaptor.getIndices(), rewriter); 255 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, 256 *this->getTypeConverter()); 257 258 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 259 return success(); 260 } 261 }; 262 263 /// Conversion pattern for a vector.gather. 264 class VectorGatherOpConversion 265 : public ConvertOpToLLVMPattern<vector::GatherOp> { 266 public: 267 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 268 269 LogicalResult 270 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 271 ConversionPatternRewriter &rewriter) const override { 272 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); 273 assert(memRefType && "The base should be bufferized"); 274 275 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 276 return failure(); 277 278 auto loc = gather->getLoc(); 279 280 // Resolve alignment. 281 unsigned align; 282 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 283 return failure(); 284 285 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 286 adaptor.getIndices(), rewriter); 287 Value base = adaptor.getBase(); 288 289 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 290 // Handle the simple case of 1-D vector. 291 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) { 292 auto vType = gather.getVectorType(); 293 // Resolve address. 294 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 295 memRefType, base, ptr, adaptor.getIndexVec(), 296 /*vLen=*/vType.getDimSize(0)); 297 // Replace with the gather intrinsic. 298 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 299 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 300 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 301 return success(); 302 } 303 304 LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 305 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 306 &typeConverter](Type llvm1DVectorTy, 307 ValueRange vectorOperands) { 308 // Resolve address. 309 Value ptrs = getIndexedPtrs( 310 rewriter, loc, typeConverter, memRefType, base, ptr, 311 /*index=*/vectorOperands[0], 312 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 313 // Create the gather intrinsic. 314 return rewriter.create<LLVM::masked_gather>( 315 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 316 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 317 }; 318 SmallVector<Value> vectorOperands = { 319 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 320 return LLVM::detail::handleMultidimensionalVectors( 321 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 322 } 323 }; 324 325 /// Conversion pattern for a vector.scatter. 326 class VectorScatterOpConversion 327 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 328 public: 329 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 330 331 LogicalResult 332 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 333 ConversionPatternRewriter &rewriter) const override { 334 auto loc = scatter->getLoc(); 335 MemRefType memRefType = scatter.getMemRefType(); 336 337 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 338 return failure(); 339 340 // Resolve alignment. 341 unsigned align; 342 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 343 return failure(); 344 345 // Resolve address. 346 VectorType vType = scatter.getVectorType(); 347 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 348 adaptor.getIndices(), rewriter); 349 Value ptrs = getIndexedPtrs( 350 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 351 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 352 353 // Replace with the scatter intrinsic. 354 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 355 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 356 rewriter.getI32IntegerAttr(align)); 357 return success(); 358 } 359 }; 360 361 /// Conversion pattern for a vector.expandload. 362 class VectorExpandLoadOpConversion 363 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 364 public: 365 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 366 367 LogicalResult 368 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 369 ConversionPatternRewriter &rewriter) const override { 370 auto loc = expand->getLoc(); 371 MemRefType memRefType = expand.getMemRefType(); 372 373 // Resolve address. 374 auto vtype = typeConverter->convertType(expand.getVectorType()); 375 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 376 adaptor.getIndices(), rewriter); 377 378 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 379 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 380 return success(); 381 } 382 }; 383 384 /// Conversion pattern for a vector.compressstore. 385 class VectorCompressStoreOpConversion 386 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 387 public: 388 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 389 390 LogicalResult 391 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const override { 393 auto loc = compress->getLoc(); 394 MemRefType memRefType = compress.getMemRefType(); 395 396 // Resolve address. 397 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 398 adaptor.getIndices(), rewriter); 399 400 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 401 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 402 return success(); 403 } 404 }; 405 406 /// Reduction neutral classes for overloading. 407 class ReductionNeutralZero {}; 408 class ReductionNeutralIntOne {}; 409 class ReductionNeutralFPOne {}; 410 class ReductionNeutralAllOnes {}; 411 class ReductionNeutralSIntMin {}; 412 class ReductionNeutralUIntMin {}; 413 class ReductionNeutralSIntMax {}; 414 class ReductionNeutralUIntMax {}; 415 class ReductionNeutralFPMin {}; 416 class ReductionNeutralFPMax {}; 417 418 /// Create the reduction neutral zero value. 419 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 420 ConversionPatternRewriter &rewriter, 421 Location loc, Type llvmType) { 422 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 423 rewriter.getZeroAttr(llvmType)); 424 } 425 426 /// Create the reduction neutral integer one value. 427 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 428 ConversionPatternRewriter &rewriter, 429 Location loc, Type llvmType) { 430 return rewriter.create<LLVM::ConstantOp>( 431 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 432 } 433 434 /// Create the reduction neutral fp one value. 435 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 436 ConversionPatternRewriter &rewriter, 437 Location loc, Type llvmType) { 438 return rewriter.create<LLVM::ConstantOp>( 439 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 440 } 441 442 /// Create the reduction neutral all-ones value. 443 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 444 ConversionPatternRewriter &rewriter, 445 Location loc, Type llvmType) { 446 return rewriter.create<LLVM::ConstantOp>( 447 loc, llvmType, 448 rewriter.getIntegerAttr( 449 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 450 } 451 452 /// Create the reduction neutral signed int minimum value. 453 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 454 ConversionPatternRewriter &rewriter, 455 Location loc, Type llvmType) { 456 return rewriter.create<LLVM::ConstantOp>( 457 loc, llvmType, 458 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 459 llvmType.getIntOrFloatBitWidth()))); 460 } 461 462 /// Create the reduction neutral unsigned int minimum value. 463 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 464 ConversionPatternRewriter &rewriter, 465 Location loc, Type llvmType) { 466 return rewriter.create<LLVM::ConstantOp>( 467 loc, llvmType, 468 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 469 llvmType.getIntOrFloatBitWidth()))); 470 } 471 472 /// Create the reduction neutral signed int maximum value. 473 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 474 ConversionPatternRewriter &rewriter, 475 Location loc, Type llvmType) { 476 return rewriter.create<LLVM::ConstantOp>( 477 loc, llvmType, 478 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 479 llvmType.getIntOrFloatBitWidth()))); 480 } 481 482 /// Create the reduction neutral unsigned int maximum value. 483 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 484 ConversionPatternRewriter &rewriter, 485 Location loc, Type llvmType) { 486 return rewriter.create<LLVM::ConstantOp>( 487 loc, llvmType, 488 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 489 llvmType.getIntOrFloatBitWidth()))); 490 } 491 492 /// Create the reduction neutral fp minimum value. 493 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 494 ConversionPatternRewriter &rewriter, 495 Location loc, Type llvmType) { 496 auto floatType = cast<FloatType>(llvmType); 497 return rewriter.create<LLVM::ConstantOp>( 498 loc, llvmType, 499 rewriter.getFloatAttr( 500 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 501 /*Negative=*/false))); 502 } 503 504 /// Create the reduction neutral fp maximum value. 505 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 506 ConversionPatternRewriter &rewriter, 507 Location loc, Type llvmType) { 508 auto floatType = cast<FloatType>(llvmType); 509 return rewriter.create<LLVM::ConstantOp>( 510 loc, llvmType, 511 rewriter.getFloatAttr( 512 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 513 /*Negative=*/true))); 514 } 515 516 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 517 /// returns a new accumulator value using `ReductionNeutral`. 518 template <class ReductionNeutral> 519 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 520 Location loc, Type llvmType, 521 Value accumulator) { 522 if (accumulator) 523 return accumulator; 524 525 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 526 llvmType); 527 } 528 529 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 530 /// This is used as effective vector length by some intrinsics supporting 531 /// dynamic vector lengths at runtime. 532 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 533 Location loc, Type llvmType) { 534 VectorType vType = cast<VectorType>(llvmType); 535 auto vShape = vType.getShape(); 536 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 537 538 return rewriter.create<LLVM::ConstantOp>( 539 loc, rewriter.getI32Type(), 540 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 541 } 542 543 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 544 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 545 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 546 /// non-null. 547 template <class LLVMRedIntrinOp, class ScalarOp> 548 static Value createIntegerReductionArithmeticOpLowering( 549 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 550 Value vectorOperand, Value accumulator) { 551 552 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 553 554 if (accumulator) 555 result = rewriter.create<ScalarOp>(loc, accumulator, result); 556 return result; 557 } 558 559 /// Helper method to lower a `vector.reduction` operation that performs 560 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 561 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 562 /// the accumulator value if non-null. 563 template <class LLVMRedIntrinOp> 564 static Value createIntegerReductionComparisonOpLowering( 565 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 566 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 567 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 568 if (accumulator) { 569 Value cmp = 570 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 571 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 572 } 573 return result; 574 } 575 576 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum 577 /// with vector types. 578 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, 579 Value rhs, bool isMin) { 580 auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType())); 581 Type i1Type = builder.getI1Type(); 582 if (auto vecType = dyn_cast<VectorType>(lhs.getType())) 583 i1Type = VectorType::get(vecType.getShape(), i1Type); 584 Value cmp = builder.create<LLVM::FCmpOp>( 585 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, 586 lhs, rhs); 587 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); 588 Value isNan = builder.create<LLVM::FCmpOp>( 589 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); 590 Value nan = builder.create<LLVM::ConstantOp>( 591 loc, lhs.getType(), 592 builder.getFloatAttr(floatType, 593 APFloat::getQNaN(floatType.getFloatSemantics()))); 594 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); 595 } 596 597 template <class LLVMRedIntrinOp> 598 static Value createFPReductionComparisonOpLowering( 599 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 600 Value vectorOperand, Value accumulator, bool isMin) { 601 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 602 603 if (accumulator) 604 result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin); 605 606 return result; 607 } 608 609 /// Overloaded methods to lower a reduction to an llvm instrinsic that requires 610 /// a start value. This start value format spans across fp reductions without 611 /// mask and all the masked reduction intrinsics. 612 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 613 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 614 Location loc, Type llvmType, 615 Value vectorOperand, 616 Value accumulator) { 617 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 618 llvmType, accumulator); 619 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 620 /*startValue=*/accumulator, 621 vectorOperand); 622 } 623 624 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 625 static Value 626 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 627 Type llvmType, Value vectorOperand, 628 Value accumulator, bool reassociateFPReds) { 629 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 630 llvmType, accumulator); 631 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 632 /*startValue=*/accumulator, 633 vectorOperand, reassociateFPReds); 634 } 635 636 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 637 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 638 Location loc, Type llvmType, 639 Value vectorOperand, 640 Value accumulator, Value mask) { 641 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 642 llvmType, accumulator); 643 Value vectorLength = 644 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 645 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 646 /*startValue=*/accumulator, 647 vectorOperand, mask, vectorLength); 648 } 649 650 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 651 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 652 Location loc, Type llvmType, 653 Value vectorOperand, 654 Value accumulator, Value mask, 655 bool reassociateFPReds) { 656 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 657 llvmType, accumulator); 658 Value vectorLength = 659 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 660 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 661 /*startValue=*/accumulator, 662 vectorOperand, mask, vectorLength, 663 reassociateFPReds); 664 } 665 666 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 667 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 668 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 669 Location loc, Type llvmType, 670 Value vectorOperand, 671 Value accumulator, Value mask) { 672 if (llvmType.isIntOrIndex()) 673 return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp, 674 IntReductionNeutral>( 675 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 676 677 // FP dispatch. 678 return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>( 679 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 680 } 681 682 /// Conversion pattern for all vector reductions. 683 class VectorReductionOpConversion 684 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 685 public: 686 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 687 bool reassociateFPRed) 688 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 689 reassociateFPReductions(reassociateFPRed) {} 690 691 LogicalResult 692 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 693 ConversionPatternRewriter &rewriter) const override { 694 auto kind = reductionOp.getKind(); 695 Type eltType = reductionOp.getDest().getType(); 696 Type llvmType = typeConverter->convertType(eltType); 697 Value operand = adaptor.getVector(); 698 Value acc = adaptor.getAcc(); 699 Location loc = reductionOp.getLoc(); 700 701 if (eltType.isIntOrIndex()) { 702 // Integer reductions: add/mul/min/max/and/or/xor. 703 Value result; 704 switch (kind) { 705 case vector::CombiningKind::ADD: 706 result = 707 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 708 LLVM::AddOp>( 709 rewriter, loc, llvmType, operand, acc); 710 break; 711 case vector::CombiningKind::MUL: 712 result = 713 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 714 LLVM::MulOp>( 715 rewriter, loc, llvmType, operand, acc); 716 break; 717 case vector::CombiningKind::MINUI: 718 result = createIntegerReductionComparisonOpLowering< 719 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 720 LLVM::ICmpPredicate::ule); 721 break; 722 case vector::CombiningKind::MINSI: 723 result = createIntegerReductionComparisonOpLowering< 724 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 725 LLVM::ICmpPredicate::sle); 726 break; 727 case vector::CombiningKind::MAXUI: 728 result = createIntegerReductionComparisonOpLowering< 729 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 730 LLVM::ICmpPredicate::uge); 731 break; 732 case vector::CombiningKind::MAXSI: 733 result = createIntegerReductionComparisonOpLowering< 734 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 735 LLVM::ICmpPredicate::sge); 736 break; 737 case vector::CombiningKind::AND: 738 result = 739 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 740 LLVM::AndOp>( 741 rewriter, loc, llvmType, operand, acc); 742 break; 743 case vector::CombiningKind::OR: 744 result = 745 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 746 LLVM::OrOp>( 747 rewriter, loc, llvmType, operand, acc); 748 break; 749 case vector::CombiningKind::XOR: 750 result = 751 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 752 LLVM::XOrOp>( 753 rewriter, loc, llvmType, operand, acc); 754 break; 755 default: 756 return failure(); 757 } 758 rewriter.replaceOp(reductionOp, result); 759 760 return success(); 761 } 762 763 if (!isa<FloatType>(eltType)) 764 return failure(); 765 766 // Floating-point reductions: add/mul/min/max 767 Value result; 768 if (kind == vector::CombiningKind::ADD) { 769 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 770 ReductionNeutralZero>( 771 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 772 } else if (kind == vector::CombiningKind::MUL) { 773 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 774 ReductionNeutralFPOne>( 775 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 776 } else if (kind == vector::CombiningKind::MINF) { 777 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 778 // NaNs/-0.0/+0.0 in the same way. 779 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 780 rewriter, loc, llvmType, operand, acc, 781 /*isMin=*/true); 782 } else if (kind == vector::CombiningKind::MAXF) { 783 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 784 // NaNs/-0.0/+0.0 in the same way. 785 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 786 rewriter, loc, llvmType, operand, acc, 787 /*isMin=*/false); 788 } else 789 return failure(); 790 791 rewriter.replaceOp(reductionOp, result); 792 return success(); 793 } 794 795 private: 796 const bool reassociateFPReductions; 797 }; 798 799 /// Base class to convert a `vector.mask` operation while matching traits 800 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 801 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 802 /// method performs a second match against the maskable operation `MaskedOp`. 803 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 804 /// implemented by the concrete conversion classes. This method can match 805 /// against specific traits of the `vector.mask` and the maskable operation. It 806 /// must replace the `vector.mask` operation. 807 template <class MaskedOp> 808 class VectorMaskOpConversionBase 809 : public ConvertOpToLLVMPattern<vector::MaskOp> { 810 public: 811 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 812 813 LogicalResult 814 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 815 ConversionPatternRewriter &rewriter) const final { 816 // Match against the maskable operation kind. 817 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 818 if (!maskedOp) 819 return failure(); 820 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 821 } 822 823 protected: 824 virtual LogicalResult 825 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 826 vector::MaskableOpInterface maskableOp, 827 ConversionPatternRewriter &rewriter) const = 0; 828 }; 829 830 class MaskedReductionOpConversion 831 : public VectorMaskOpConversionBase<vector::ReductionOp> { 832 833 public: 834 using VectorMaskOpConversionBase< 835 vector::ReductionOp>::VectorMaskOpConversionBase; 836 837 LogicalResult matchAndRewriteMaskableOp( 838 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 839 ConversionPatternRewriter &rewriter) const override { 840 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 841 auto kind = reductionOp.getKind(); 842 Type eltType = reductionOp.getDest().getType(); 843 Type llvmType = typeConverter->convertType(eltType); 844 Value operand = reductionOp.getVector(); 845 Value acc = reductionOp.getAcc(); 846 Location loc = reductionOp.getLoc(); 847 848 Value result; 849 switch (kind) { 850 case vector::CombiningKind::ADD: 851 result = lowerReductionWithStartValue< 852 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 853 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 854 maskOp.getMask()); 855 break; 856 case vector::CombiningKind::MUL: 857 result = lowerReductionWithStartValue< 858 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 859 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 860 maskOp.getMask()); 861 break; 862 case vector::CombiningKind::MINUI: 863 result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp, 864 ReductionNeutralUIntMax>( 865 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 866 break; 867 case vector::CombiningKind::MINSI: 868 result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp, 869 ReductionNeutralSIntMax>( 870 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 871 break; 872 case vector::CombiningKind::MAXUI: 873 result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp, 874 ReductionNeutralUIntMin>( 875 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 876 break; 877 case vector::CombiningKind::MAXSI: 878 result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp, 879 ReductionNeutralSIntMin>( 880 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 881 break; 882 case vector::CombiningKind::AND: 883 result = lowerReductionWithStartValue<LLVM::VPReduceAndOp, 884 ReductionNeutralAllOnes>( 885 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 886 break; 887 case vector::CombiningKind::OR: 888 result = lowerReductionWithStartValue<LLVM::VPReduceOrOp, 889 ReductionNeutralZero>( 890 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 891 break; 892 case vector::CombiningKind::XOR: 893 result = lowerReductionWithStartValue<LLVM::VPReduceXorOp, 894 ReductionNeutralZero>( 895 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 896 break; 897 case vector::CombiningKind::MINF: 898 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 899 // NaNs/-0.0/+0.0 in the same way. 900 result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp, 901 ReductionNeutralFPMax>( 902 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 903 break; 904 case vector::CombiningKind::MAXF: 905 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 906 // NaNs/-0.0/+0.0 in the same way. 907 result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp, 908 ReductionNeutralFPMin>( 909 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 910 break; 911 } 912 913 // Replace `vector.mask` operation altogether. 914 rewriter.replaceOp(maskOp, result); 915 return success(); 916 } 917 }; 918 919 class VectorShuffleOpConversion 920 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 921 public: 922 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 923 924 LogicalResult 925 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 926 ConversionPatternRewriter &rewriter) const override { 927 auto loc = shuffleOp->getLoc(); 928 auto v1Type = shuffleOp.getV1VectorType(); 929 auto v2Type = shuffleOp.getV2VectorType(); 930 auto vectorType = shuffleOp.getResultVectorType(); 931 Type llvmType = typeConverter->convertType(vectorType); 932 auto maskArrayAttr = shuffleOp.getMask(); 933 934 // Bail if result type cannot be lowered. 935 if (!llvmType) 936 return failure(); 937 938 // Get rank and dimension sizes. 939 int64_t rank = vectorType.getRank(); 940 #ifndef NDEBUG 941 bool wellFormed0DCase = 942 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 943 bool wellFormedNDCase = 944 v1Type.getRank() == rank && v2Type.getRank() == rank; 945 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 946 #endif 947 948 // For rank 0 and 1, where both operands have *exactly* the same vector 949 // type, there is direct shuffle support in LLVM. Use it! 950 if (rank <= 1 && v1Type == v2Type) { 951 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 952 loc, adaptor.getV1(), adaptor.getV2(), 953 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 954 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 955 return success(); 956 } 957 958 // For all other cases, insert the individual values individually. 959 int64_t v1Dim = v1Type.getDimSize(0); 960 Type eltType; 961 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType)) 962 eltType = arrayType.getElementType(); 963 else 964 eltType = cast<VectorType>(llvmType).getElementType(); 965 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 966 int64_t insPos = 0; 967 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 968 int64_t extPos = cast<IntegerAttr>(en.value()).getInt(); 969 Value value = adaptor.getV1(); 970 if (extPos >= v1Dim) { 971 extPos -= v1Dim; 972 value = adaptor.getV2(); 973 } 974 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 975 eltType, rank, extPos); 976 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 977 llvmType, rank, insPos++); 978 } 979 rewriter.replaceOp(shuffleOp, insert); 980 return success(); 981 } 982 }; 983 984 class VectorExtractElementOpConversion 985 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 986 public: 987 using ConvertOpToLLVMPattern< 988 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 989 990 LogicalResult 991 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 992 ConversionPatternRewriter &rewriter) const override { 993 auto vectorType = extractEltOp.getSourceVectorType(); 994 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 995 996 // Bail if result type cannot be lowered. 997 if (!llvmType) 998 return failure(); 999 1000 if (vectorType.getRank() == 0) { 1001 Location loc = extractEltOp.getLoc(); 1002 auto idxType = rewriter.getIndexType(); 1003 auto zero = rewriter.create<LLVM::ConstantOp>( 1004 loc, typeConverter->convertType(idxType), 1005 rewriter.getIntegerAttr(idxType, 0)); 1006 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1007 extractEltOp, llvmType, adaptor.getVector(), zero); 1008 return success(); 1009 } 1010 1011 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1012 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1013 return success(); 1014 } 1015 }; 1016 1017 class VectorExtractOpConversion 1018 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1019 public: 1020 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1021 1022 LogicalResult 1023 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1024 ConversionPatternRewriter &rewriter) const override { 1025 auto loc = extractOp->getLoc(); 1026 auto resultType = extractOp.getResult().getType(); 1027 auto llvmResultType = typeConverter->convertType(resultType); 1028 auto positionArrayAttr = extractOp.getPosition(); 1029 1030 // Bail if result type cannot be lowered. 1031 if (!llvmResultType) 1032 return failure(); 1033 1034 // Extract entire vector. Should be handled by folder, but just to be safe. 1035 if (positionArrayAttr.empty()) { 1036 rewriter.replaceOp(extractOp, adaptor.getVector()); 1037 return success(); 1038 } 1039 1040 // One-shot extraction of vector from array (only requires extractvalue). 1041 if (isa<VectorType>(resultType)) { 1042 SmallVector<int64_t> indices; 1043 for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>()) 1044 indices.push_back(idx.getInt()); 1045 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1046 loc, adaptor.getVector(), indices); 1047 rewriter.replaceOp(extractOp, extracted); 1048 return success(); 1049 } 1050 1051 // Potential extraction of 1-D vector from array. 1052 Value extracted = adaptor.getVector(); 1053 auto positionAttrs = positionArrayAttr.getValue(); 1054 if (positionAttrs.size() > 1) { 1055 SmallVector<int64_t> nMinusOnePosition; 1056 for (auto idx : positionAttrs.drop_back()) 1057 nMinusOnePosition.push_back(cast<IntegerAttr>(idx).getInt()); 1058 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 1059 nMinusOnePosition); 1060 } 1061 1062 // Remaining extraction of element from 1-D LLVM vector 1063 auto position = cast<IntegerAttr>(positionAttrs.back()); 1064 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1065 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 1066 extracted = 1067 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 1068 rewriter.replaceOp(extractOp, extracted); 1069 1070 return success(); 1071 } 1072 }; 1073 1074 /// Conversion pattern that turns a vector.fma on a 1-D vector 1075 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1076 /// This does not match vectors of n >= 2 rank. 1077 /// 1078 /// Example: 1079 /// ``` 1080 /// vector.fma %a, %a, %a : vector<8xf32> 1081 /// ``` 1082 /// is converted to: 1083 /// ``` 1084 /// llvm.intr.fmuladd %va, %va, %va: 1085 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1086 /// -> !llvm."<8 x f32>"> 1087 /// ``` 1088 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1089 public: 1090 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1091 1092 LogicalResult 1093 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1094 ConversionPatternRewriter &rewriter) const override { 1095 VectorType vType = fmaOp.getVectorType(); 1096 if (vType.getRank() > 1) 1097 return failure(); 1098 1099 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1100 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1101 return success(); 1102 } 1103 }; 1104 1105 class VectorInsertElementOpConversion 1106 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1107 public: 1108 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1109 1110 LogicalResult 1111 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1112 ConversionPatternRewriter &rewriter) const override { 1113 auto vectorType = insertEltOp.getDestVectorType(); 1114 auto llvmType = typeConverter->convertType(vectorType); 1115 1116 // Bail if result type cannot be lowered. 1117 if (!llvmType) 1118 return failure(); 1119 1120 if (vectorType.getRank() == 0) { 1121 Location loc = insertEltOp.getLoc(); 1122 auto idxType = rewriter.getIndexType(); 1123 auto zero = rewriter.create<LLVM::ConstantOp>( 1124 loc, typeConverter->convertType(idxType), 1125 rewriter.getIntegerAttr(idxType, 0)); 1126 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1127 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1128 return success(); 1129 } 1130 1131 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1132 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1133 adaptor.getPosition()); 1134 return success(); 1135 } 1136 }; 1137 1138 class VectorInsertOpConversion 1139 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1140 public: 1141 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1142 1143 LogicalResult 1144 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1145 ConversionPatternRewriter &rewriter) const override { 1146 auto loc = insertOp->getLoc(); 1147 auto sourceType = insertOp.getSourceType(); 1148 auto destVectorType = insertOp.getDestVectorType(); 1149 auto llvmResultType = typeConverter->convertType(destVectorType); 1150 auto positionArrayAttr = insertOp.getPosition(); 1151 1152 // Bail if result type cannot be lowered. 1153 if (!llvmResultType) 1154 return failure(); 1155 1156 // Overwrite entire vector with value. Should be handled by folder, but 1157 // just to be safe. 1158 if (positionArrayAttr.empty()) { 1159 rewriter.replaceOp(insertOp, adaptor.getSource()); 1160 return success(); 1161 } 1162 1163 // One-shot insertion of a vector into an array (only requires insertvalue). 1164 if (isa<VectorType>(sourceType)) { 1165 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1166 loc, adaptor.getDest(), adaptor.getSource(), 1167 LLVM::convertArrayToIndices(positionArrayAttr)); 1168 rewriter.replaceOp(insertOp, inserted); 1169 return success(); 1170 } 1171 1172 // Potential extraction of 1-D vector from array. 1173 Value extracted = adaptor.getDest(); 1174 auto positionAttrs = positionArrayAttr.getValue(); 1175 auto position = cast<IntegerAttr>(positionAttrs.back()); 1176 auto oneDVectorType = destVectorType; 1177 if (positionAttrs.size() > 1) { 1178 oneDVectorType = reducedVectorTypeBack(destVectorType); 1179 extracted = rewriter.create<LLVM::ExtractValueOp>( 1180 loc, extracted, 1181 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 1182 } 1183 1184 // Insertion of an element into a 1-D LLVM vector. 1185 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1186 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 1187 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1188 loc, typeConverter->convertType(oneDVectorType), extracted, 1189 adaptor.getSource(), constant); 1190 1191 // Potential insertion of resulting 1-D vector into array. 1192 if (positionAttrs.size() > 1) { 1193 inserted = rewriter.create<LLVM::InsertValueOp>( 1194 loc, adaptor.getDest(), inserted, 1195 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 1196 } 1197 1198 rewriter.replaceOp(insertOp, inserted); 1199 return success(); 1200 } 1201 }; 1202 1203 /// Lower vector.scalable.insert ops to LLVM vector.insert 1204 struct VectorScalableInsertOpLowering 1205 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1206 using ConvertOpToLLVMPattern< 1207 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1208 1209 LogicalResult 1210 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1211 ConversionPatternRewriter &rewriter) const override { 1212 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1213 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1214 return success(); 1215 } 1216 }; 1217 1218 /// Lower vector.scalable.extract ops to LLVM vector.extract 1219 struct VectorScalableExtractOpLowering 1220 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1221 using ConvertOpToLLVMPattern< 1222 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1223 1224 LogicalResult 1225 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1226 ConversionPatternRewriter &rewriter) const override { 1227 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1228 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1229 adaptor.getSource(), adaptor.getPos()); 1230 return success(); 1231 } 1232 }; 1233 1234 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1235 /// 1236 /// Example: 1237 /// ``` 1238 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1239 /// ``` 1240 /// is rewritten into: 1241 /// ``` 1242 /// %r = splat %f0: vector<2x4xf32> 1243 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1244 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1245 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1246 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1247 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1248 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1249 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1250 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1251 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1252 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1253 /// // %r3 holds the final value. 1254 /// ``` 1255 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1256 public: 1257 using OpRewritePattern<FMAOp>::OpRewritePattern; 1258 1259 void initialize() { 1260 // This pattern recursively unpacks one dimension at a time. The recursion 1261 // bounded as the rank is strictly decreasing. 1262 setHasBoundedRewriteRecursion(); 1263 } 1264 1265 LogicalResult matchAndRewrite(FMAOp op, 1266 PatternRewriter &rewriter) const override { 1267 auto vType = op.getVectorType(); 1268 if (vType.getRank() < 2) 1269 return failure(); 1270 1271 auto loc = op.getLoc(); 1272 auto elemType = vType.getElementType(); 1273 Value zero = rewriter.create<arith::ConstantOp>( 1274 loc, elemType, rewriter.getZeroAttr(elemType)); 1275 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1276 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1277 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1278 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1279 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1280 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1281 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1282 } 1283 rewriter.replaceOp(op, desc); 1284 return success(); 1285 } 1286 }; 1287 1288 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1289 /// static layout. 1290 static std::optional<SmallVector<int64_t, 4>> 1291 computeContiguousStrides(MemRefType memRefType) { 1292 int64_t offset; 1293 SmallVector<int64_t, 4> strides; 1294 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1295 return std::nullopt; 1296 if (!strides.empty() && strides.back() != 1) 1297 return std::nullopt; 1298 // If no layout or identity layout, this is contiguous by definition. 1299 if (memRefType.getLayout().isIdentity()) 1300 return strides; 1301 1302 // Otherwise, we must determine contiguity form shapes. This can only ever 1303 // work in static cases because MemRefType is underspecified to represent 1304 // contiguous dynamic shapes in other ways than with just empty/identity 1305 // layout. 1306 auto sizes = memRefType.getShape(); 1307 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1308 if (ShapedType::isDynamic(sizes[index + 1]) || 1309 ShapedType::isDynamic(strides[index]) || 1310 ShapedType::isDynamic(strides[index + 1])) 1311 return std::nullopt; 1312 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1313 return std::nullopt; 1314 } 1315 return strides; 1316 } 1317 1318 class VectorTypeCastOpConversion 1319 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1320 public: 1321 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1322 1323 LogicalResult 1324 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1325 ConversionPatternRewriter &rewriter) const override { 1326 auto loc = castOp->getLoc(); 1327 MemRefType sourceMemRefType = 1328 cast<MemRefType>(castOp.getOperand().getType()); 1329 MemRefType targetMemRefType = castOp.getType(); 1330 1331 // Only static shape casts supported atm. 1332 if (!sourceMemRefType.hasStaticShape() || 1333 !targetMemRefType.hasStaticShape()) 1334 return failure(); 1335 1336 auto llvmSourceDescriptorTy = 1337 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); 1338 if (!llvmSourceDescriptorTy) 1339 return failure(); 1340 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1341 1342 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1343 typeConverter->convertType(targetMemRefType)); 1344 if (!llvmTargetDescriptorTy) 1345 return failure(); 1346 1347 // Only contiguous source buffers supported atm. 1348 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1349 if (!sourceStrides) 1350 return failure(); 1351 auto targetStrides = computeContiguousStrides(targetMemRefType); 1352 if (!targetStrides) 1353 return failure(); 1354 // Only support static strides for now, regardless of contiguity. 1355 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1356 return failure(); 1357 1358 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1359 1360 // Create descriptor. 1361 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1362 Type llvmTargetElementTy = desc.getElementPtrType(); 1363 // Set allocated ptr. 1364 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1365 if (!getTypeConverter()->useOpaquePointers()) 1366 allocated = 1367 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1368 desc.setAllocatedPtr(rewriter, loc, allocated); 1369 1370 // Set aligned ptr. 1371 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1372 if (!getTypeConverter()->useOpaquePointers()) 1373 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1374 1375 desc.setAlignedPtr(rewriter, loc, ptr); 1376 // Fill offset 0. 1377 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1378 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1379 desc.setOffset(rewriter, loc, zero); 1380 1381 // Fill size and stride descriptors in memref. 1382 for (const auto &indexedSize : 1383 llvm::enumerate(targetMemRefType.getShape())) { 1384 int64_t index = indexedSize.index(); 1385 auto sizeAttr = 1386 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1387 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1388 desc.setSize(rewriter, loc, index, size); 1389 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1390 (*targetStrides)[index]); 1391 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1392 desc.setStride(rewriter, loc, index, stride); 1393 } 1394 1395 rewriter.replaceOp(castOp, {desc}); 1396 return success(); 1397 } 1398 }; 1399 1400 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1401 /// Non-scalable versions of this operation are handled in Vector Transforms. 1402 class VectorCreateMaskOpRewritePattern 1403 : public OpRewritePattern<vector::CreateMaskOp> { 1404 public: 1405 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1406 bool enableIndexOpt) 1407 : OpRewritePattern<vector::CreateMaskOp>(context), 1408 force32BitVectorIndices(enableIndexOpt) {} 1409 1410 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1411 PatternRewriter &rewriter) const override { 1412 auto dstType = op.getType(); 1413 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) 1414 return failure(); 1415 IntegerType idxType = 1416 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1417 auto loc = op->getLoc(); 1418 Value indices = rewriter.create<LLVM::StepVectorOp>( 1419 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1420 /*isScalable=*/true)); 1421 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1422 op.getOperand(0)); 1423 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1424 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1425 indices, bounds); 1426 rewriter.replaceOp(op, comp); 1427 return success(); 1428 } 1429 1430 private: 1431 const bool force32BitVectorIndices; 1432 }; 1433 1434 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1435 public: 1436 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1437 1438 // Proof-of-concept lowering implementation that relies on a small 1439 // runtime support library, which only needs to provide a few 1440 // printing methods (single value for all data types, opening/closing 1441 // bracket, comma, newline). The lowering fully unrolls a vector 1442 // in terms of these elementary printing operations. The advantage 1443 // of this approach is that the library can remain unaware of all 1444 // low-level implementation details of vectors while still supporting 1445 // output of any shaped and dimensioned vector. Due to full unrolling, 1446 // this approach is less suited for very large vectors though. 1447 // 1448 // TODO: rely solely on libc in future? something else? 1449 // 1450 LogicalResult 1451 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1452 ConversionPatternRewriter &rewriter) const override { 1453 Type printType = printOp.getPrintType(); 1454 1455 if (typeConverter->convertType(printType) == nullptr) 1456 return failure(); 1457 1458 // Make sure element type has runtime support. 1459 PrintConversion conversion = PrintConversion::None; 1460 VectorType vectorType = dyn_cast<VectorType>(printType); 1461 Type eltType = vectorType ? vectorType.getElementType() : printType; 1462 auto parent = printOp->getParentOfType<ModuleOp>(); 1463 Operation *printer; 1464 if (eltType.isF32()) { 1465 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1466 } else if (eltType.isF64()) { 1467 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1468 } else if (eltType.isF16()) { 1469 conversion = PrintConversion::Bitcast16; // bits! 1470 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1471 } else if (eltType.isBF16()) { 1472 conversion = PrintConversion::Bitcast16; // bits! 1473 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1474 } else if (eltType.isIndex()) { 1475 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1476 } else if (auto intTy = dyn_cast<IntegerType>(eltType)) { 1477 // Integers need a zero or sign extension on the operand 1478 // (depending on the source type) as well as a signed or 1479 // unsigned print method. Up to 64-bit is supported. 1480 unsigned width = intTy.getWidth(); 1481 if (intTy.isUnsigned()) { 1482 if (width <= 64) { 1483 if (width < 64) 1484 conversion = PrintConversion::ZeroExt64; 1485 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1486 } else { 1487 return failure(); 1488 } 1489 } else { 1490 assert(intTy.isSignless() || intTy.isSigned()); 1491 if (width <= 64) { 1492 // Note that we *always* zero extend booleans (1-bit integers), 1493 // so that true/false is printed as 1/0 rather than -1/0. 1494 if (width == 1) 1495 conversion = PrintConversion::ZeroExt64; 1496 else if (width < 64) 1497 conversion = PrintConversion::SignExt64; 1498 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1499 } else { 1500 return failure(); 1501 } 1502 } 1503 } else { 1504 return failure(); 1505 } 1506 1507 // Unroll vector into elementary print calls. 1508 int64_t rank = vectorType ? vectorType.getRank() : 0; 1509 Type type = vectorType ? vectorType : eltType; 1510 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1511 conversion); 1512 emitCall(rewriter, printOp->getLoc(), 1513 LLVM::lookupOrCreatePrintNewlineFn(parent)); 1514 rewriter.eraseOp(printOp); 1515 return success(); 1516 } 1517 1518 private: 1519 enum class PrintConversion { 1520 // clang-format off 1521 None, 1522 ZeroExt64, 1523 SignExt64, 1524 Bitcast16 1525 // clang-format on 1526 }; 1527 1528 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1529 Value value, Type type, Operation *printer, int64_t rank, 1530 PrintConversion conversion) const { 1531 VectorType vectorType = dyn_cast<VectorType>(type); 1532 Location loc = op->getLoc(); 1533 if (!vectorType) { 1534 assert(rank == 0 && "The scalar case expects rank == 0"); 1535 switch (conversion) { 1536 case PrintConversion::ZeroExt64: 1537 value = rewriter.create<arith::ExtUIOp>( 1538 loc, IntegerType::get(rewriter.getContext(), 64), value); 1539 break; 1540 case PrintConversion::SignExt64: 1541 value = rewriter.create<arith::ExtSIOp>( 1542 loc, IntegerType::get(rewriter.getContext(), 64), value); 1543 break; 1544 case PrintConversion::Bitcast16: 1545 value = rewriter.create<LLVM::BitcastOp>( 1546 loc, IntegerType::get(rewriter.getContext(), 16), value); 1547 break; 1548 case PrintConversion::None: 1549 break; 1550 } 1551 emitCall(rewriter, loc, printer, value); 1552 return; 1553 } 1554 1555 auto parent = op->getParentOfType<ModuleOp>(); 1556 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent)); 1557 Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent); 1558 1559 if (rank <= 1) { 1560 auto reducedType = vectorType.getElementType(); 1561 auto llvmType = typeConverter->convertType(reducedType); 1562 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1563 for (int64_t d = 0; d < dim; ++d) { 1564 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1565 llvmType, /*rank=*/0, /*pos=*/d); 1566 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1567 conversion); 1568 if (d != dim - 1) 1569 emitCall(rewriter, loc, printComma); 1570 } 1571 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); 1572 return; 1573 } 1574 1575 int64_t dim = vectorType.getDimSize(0); 1576 for (int64_t d = 0; d < dim; ++d) { 1577 auto reducedType = reducedVectorTypeFront(vectorType); 1578 auto llvmType = typeConverter->convertType(reducedType); 1579 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1580 llvmType, rank, d); 1581 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1582 conversion); 1583 if (d != dim - 1) 1584 emitCall(rewriter, loc, printComma); 1585 } 1586 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); 1587 } 1588 1589 // Helper to emit a call. 1590 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1591 Operation *ref, ValueRange params = ValueRange()) { 1592 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1593 params); 1594 } 1595 }; 1596 1597 /// The Splat operation is lowered to an insertelement + a shufflevector 1598 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1599 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1600 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1601 1602 LogicalResult 1603 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1604 ConversionPatternRewriter &rewriter) const override { 1605 VectorType resultType = cast<VectorType>(splatOp.getType()); 1606 if (resultType.getRank() > 1) 1607 return failure(); 1608 1609 // First insert it into an undef vector so we can shuffle it. 1610 auto vectorType = typeConverter->convertType(splatOp.getType()); 1611 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1612 auto zero = rewriter.create<LLVM::ConstantOp>( 1613 splatOp.getLoc(), 1614 typeConverter->convertType(rewriter.getIntegerType(32)), 1615 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1616 1617 // For 0-d vector, we simply do `insertelement`. 1618 if (resultType.getRank() == 0) { 1619 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1620 splatOp, vectorType, undef, adaptor.getInput(), zero); 1621 return success(); 1622 } 1623 1624 // For 1-d vector, we additionally do a `vectorshuffle`. 1625 auto v = rewriter.create<LLVM::InsertElementOp>( 1626 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1627 1628 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); 1629 SmallVector<int32_t> zeroValues(width, 0); 1630 1631 // Shuffle the value across the desired number of elements. 1632 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1633 zeroValues); 1634 return success(); 1635 } 1636 }; 1637 1638 /// The Splat operation is lowered to an insertelement + a shufflevector 1639 /// operation. Splat to only 2+-d vector result types are lowered by the 1640 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1641 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1642 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1643 1644 LogicalResult 1645 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1646 ConversionPatternRewriter &rewriter) const override { 1647 VectorType resultType = splatOp.getType(); 1648 if (resultType.getRank() <= 1) 1649 return failure(); 1650 1651 // First insert it into an undef vector so we can shuffle it. 1652 auto loc = splatOp.getLoc(); 1653 auto vectorTypeInfo = 1654 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1655 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1656 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1657 if (!llvmNDVectorTy || !llvm1DVectorTy) 1658 return failure(); 1659 1660 // Construct returned value. 1661 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1662 1663 // Construct a 1-D vector with the splatted value that we insert in all the 1664 // places within the returned descriptor. 1665 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1666 auto zero = rewriter.create<LLVM::ConstantOp>( 1667 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1668 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1669 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1670 adaptor.getInput(), zero); 1671 1672 // Shuffle the value across the desired number of elements. 1673 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1674 SmallVector<int32_t> zeroValues(width, 0); 1675 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1676 1677 // Iterate of linear index, convert to coords space and insert splatted 1-D 1678 // vector in each position. 1679 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1680 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1681 }); 1682 rewriter.replaceOp(splatOp, desc); 1683 return success(); 1684 } 1685 }; 1686 1687 } // namespace 1688 1689 /// Populate the given list with patterns that convert from Vector to LLVM. 1690 void mlir::populateVectorToLLVMConversionPatterns( 1691 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1692 bool reassociateFPReductions, bool force32BitVectorIndices) { 1693 MLIRContext *ctx = converter.getDialect()->getContext(); 1694 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1695 populateVectorInsertExtractStridedSliceTransforms(patterns); 1696 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1697 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1698 patterns 1699 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1700 VectorExtractElementOpConversion, VectorExtractOpConversion, 1701 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1702 VectorInsertOpConversion, VectorPrintOpConversion, 1703 VectorTypeCastOpConversion, VectorScaleOpConversion, 1704 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1705 VectorLoadStoreConversion<vector::MaskedLoadOp, 1706 vector::MaskedLoadOpAdaptor>, 1707 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1708 VectorLoadStoreConversion<vector::MaskedStoreOp, 1709 vector::MaskedStoreOpAdaptor>, 1710 VectorGatherOpConversion, VectorScatterOpConversion, 1711 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1712 VectorSplatOpLowering, VectorSplatNdOpLowering, 1713 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1714 MaskedReductionOpConversion>(converter); 1715 // Transfer ops with rank > 1 are handled by VectorToSCF. 1716 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1717 } 1718 1719 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1720 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1721 patterns.add<VectorMatmulOpConversion>(converter); 1722 patterns.add<VectorFlatTransposeOpConversion>(converter); 1723 } 1724