1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Arith/Utils/Utils.h" 15 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/Support/Casting.h" 26 #include <optional> 27 28 using namespace mlir; 29 using namespace mlir::vector; 30 31 // Helper to reduce vector type by one rank at front. 32 static VectorType reducedVectorTypeFront(VectorType tp) { 33 assert((tp.getRank() > 1) && "unlowerable vector type"); 34 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 35 tp.getScalableDims().drop_front()); 36 } 37 38 // Helper to reduce vector type by *all* but one rank at back. 39 static VectorType reducedVectorTypeBack(VectorType tp) { 40 assert((tp.getRank() > 1) && "unlowerable vector type"); 41 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 42 tp.getScalableDims().take_back()); 43 } 44 45 // Helper that picks the proper sequence for inserting. 46 static Value insertOne(ConversionPatternRewriter &rewriter, 47 LLVMTypeConverter &typeConverter, Location loc, 48 Value val1, Value val2, Type llvmType, int64_t rank, 49 int64_t pos) { 50 assert(rank > 0 && "0-D vector corner case should have been handled already"); 51 if (rank == 1) { 52 auto idxType = rewriter.getIndexType(); 53 auto constant = rewriter.create<LLVM::ConstantOp>( 54 loc, typeConverter.convertType(idxType), 55 rewriter.getIntegerAttr(idxType, pos)); 56 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 57 constant); 58 } 59 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 60 } 61 62 // Helper that picks the proper sequence for extracting. 63 static Value extractOne(ConversionPatternRewriter &rewriter, 64 LLVMTypeConverter &typeConverter, Location loc, 65 Value val, Type llvmType, int64_t rank, int64_t pos) { 66 if (rank <= 1) { 67 auto idxType = rewriter.getIndexType(); 68 auto constant = rewriter.create<LLVM::ConstantOp>( 69 loc, typeConverter.convertType(idxType), 70 rewriter.getIntegerAttr(idxType, pos)); 71 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 72 constant); 73 } 74 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 75 } 76 77 // Helper that returns data layout alignment of a memref. 78 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 79 MemRefType memrefType, unsigned &align) { 80 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 81 if (!elementTy) 82 return failure(); 83 84 // TODO: this should use the MLIR data layout when it becomes available and 85 // stop depending on translation. 86 llvm::LLVMContext llvmContext; 87 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 88 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 89 return success(); 90 } 91 92 // Check if the last stride is non-unit or the memory space is not zero. 93 static LogicalResult isMemRefTypeSupported(MemRefType memRefType, 94 LLVMTypeConverter &converter) { 95 int64_t offset; 96 SmallVector<int64_t, 4> strides; 97 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 98 FailureOr<unsigned> addressSpace = 99 converter.getMemRefAddressSpace(memRefType); 100 if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) || 101 *addressSpace != 0) 102 return failure(); 103 return success(); 104 } 105 106 // Add an index vector component to a base pointer. 107 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 108 LLVMTypeConverter &typeConverter, 109 MemRefType memRefType, Value llvmMemref, Value base, 110 Value index, uint64_t vLen) { 111 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && 112 "unsupported memref type"); 113 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 114 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 115 return rewriter.create<LLVM::GEPOp>( 116 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), 117 base, index); 118 } 119 120 // Casts a strided element pointer to a vector pointer. The vector pointer 121 // will be in the same address space as the incoming memref type. 122 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 123 Value ptr, MemRefType memRefType, Type vt, 124 LLVMTypeConverter &converter) { 125 if (converter.useOpaquePointers()) 126 return ptr; 127 128 unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType); 129 auto pType = LLVM::LLVMPointerType::get(vt, addressSpace); 130 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 131 } 132 133 namespace { 134 135 /// Trivial Vector to LLVM conversions 136 using VectorScaleOpConversion = 137 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 138 139 /// Conversion pattern for a vector.bitcast. 140 class VectorBitCastOpConversion 141 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 142 public: 143 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 144 145 LogicalResult 146 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 147 ConversionPatternRewriter &rewriter) const override { 148 // Only 0-D and 1-D vectors can be lowered to LLVM. 149 VectorType resultTy = bitCastOp.getResultVectorType(); 150 if (resultTy.getRank() > 1) 151 return failure(); 152 Type newResultTy = typeConverter->convertType(resultTy); 153 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 154 adaptor.getOperands()[0]); 155 return success(); 156 } 157 }; 158 159 /// Conversion pattern for a vector.matrix_multiply. 160 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 161 class VectorMatmulOpConversion 162 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 163 public: 164 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 165 166 LogicalResult 167 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 168 ConversionPatternRewriter &rewriter) const override { 169 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 170 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 171 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 172 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 173 return success(); 174 } 175 }; 176 177 /// Conversion pattern for a vector.flat_transpose. 178 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 179 class VectorFlatTransposeOpConversion 180 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 181 public: 182 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 183 184 LogicalResult 185 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 186 ConversionPatternRewriter &rewriter) const override { 187 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 188 transOp, typeConverter->convertType(transOp.getRes().getType()), 189 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 190 return success(); 191 } 192 }; 193 194 /// Overloaded utility that replaces a vector.load, vector.store, 195 /// vector.maskedload and vector.maskedstore with their respective LLVM 196 /// couterparts. 197 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 198 vector::LoadOpAdaptor adaptor, 199 VectorType vectorTy, Value ptr, unsigned align, 200 ConversionPatternRewriter &rewriter) { 201 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align); 202 } 203 204 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 205 vector::MaskedLoadOpAdaptor adaptor, 206 VectorType vectorTy, Value ptr, unsigned align, 207 ConversionPatternRewriter &rewriter) { 208 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 209 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 210 } 211 212 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 213 vector::StoreOpAdaptor adaptor, 214 VectorType vectorTy, Value ptr, unsigned align, 215 ConversionPatternRewriter &rewriter) { 216 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 217 ptr, align); 218 } 219 220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 221 vector::MaskedStoreOpAdaptor adaptor, 222 VectorType vectorTy, Value ptr, unsigned align, 223 ConversionPatternRewriter &rewriter) { 224 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 225 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 226 } 227 228 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 229 /// vector.maskedstore. 230 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 231 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 232 public: 233 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 234 235 LogicalResult 236 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 237 typename LoadOrStoreOp::Adaptor adaptor, 238 ConversionPatternRewriter &rewriter) const override { 239 // Only 1-D vectors can be lowered to LLVM. 240 VectorType vectorTy = loadOrStoreOp.getVectorType(); 241 if (vectorTy.getRank() > 1) 242 return failure(); 243 244 auto loc = loadOrStoreOp->getLoc(); 245 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 246 247 // Resolve alignment. 248 unsigned align; 249 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 250 return failure(); 251 252 // Resolve address. 253 auto vtype = cast<VectorType>( 254 this->typeConverter->convertType(loadOrStoreOp.getVectorType())); 255 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 256 adaptor.getIndices(), rewriter); 257 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, 258 *this->getTypeConverter()); 259 260 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 261 return success(); 262 } 263 }; 264 265 /// Conversion pattern for a vector.gather. 266 class VectorGatherOpConversion 267 : public ConvertOpToLLVMPattern<vector::GatherOp> { 268 public: 269 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 270 271 LogicalResult 272 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 273 ConversionPatternRewriter &rewriter) const override { 274 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); 275 assert(memRefType && "The base should be bufferized"); 276 277 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 278 return failure(); 279 280 auto loc = gather->getLoc(); 281 282 // Resolve alignment. 283 unsigned align; 284 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 285 return failure(); 286 287 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 288 adaptor.getIndices(), rewriter); 289 Value base = adaptor.getBase(); 290 291 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 292 // Handle the simple case of 1-D vector. 293 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) { 294 auto vType = gather.getVectorType(); 295 // Resolve address. 296 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), 297 memRefType, base, ptr, adaptor.getIndexVec(), 298 /*vLen=*/vType.getDimSize(0)); 299 // Replace with the gather intrinsic. 300 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 301 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 302 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 303 return success(); 304 } 305 306 LLVMTypeConverter &typeConverter = *this->getTypeConverter(); 307 auto callback = [align, memRefType, base, ptr, loc, &rewriter, 308 &typeConverter](Type llvm1DVectorTy, 309 ValueRange vectorOperands) { 310 // Resolve address. 311 Value ptrs = getIndexedPtrs( 312 rewriter, loc, typeConverter, memRefType, base, ptr, 313 /*index=*/vectorOperands[0], 314 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 315 // Create the gather intrinsic. 316 return rewriter.create<LLVM::masked_gather>( 317 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 318 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 319 }; 320 SmallVector<Value> vectorOperands = { 321 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 322 return LLVM::detail::handleMultidimensionalVectors( 323 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 324 } 325 }; 326 327 /// Conversion pattern for a vector.scatter. 328 class VectorScatterOpConversion 329 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 330 public: 331 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 332 333 LogicalResult 334 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 335 ConversionPatternRewriter &rewriter) const override { 336 auto loc = scatter->getLoc(); 337 MemRefType memRefType = scatter.getMemRefType(); 338 339 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) 340 return failure(); 341 342 // Resolve alignment. 343 unsigned align; 344 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 345 return failure(); 346 347 // Resolve address. 348 VectorType vType = scatter.getVectorType(); 349 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 350 adaptor.getIndices(), rewriter); 351 Value ptrs = getIndexedPtrs( 352 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), 353 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 354 355 // Replace with the scatter intrinsic. 356 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 357 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 358 rewriter.getI32IntegerAttr(align)); 359 return success(); 360 } 361 }; 362 363 /// Conversion pattern for a vector.expandload. 364 class VectorExpandLoadOpConversion 365 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 366 public: 367 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 368 369 LogicalResult 370 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 371 ConversionPatternRewriter &rewriter) const override { 372 auto loc = expand->getLoc(); 373 MemRefType memRefType = expand.getMemRefType(); 374 375 // Resolve address. 376 auto vtype = typeConverter->convertType(expand.getVectorType()); 377 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 378 adaptor.getIndices(), rewriter); 379 380 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 381 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 382 return success(); 383 } 384 }; 385 386 /// Conversion pattern for a vector.compressstore. 387 class VectorCompressStoreOpConversion 388 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 389 public: 390 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 391 392 LogicalResult 393 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 394 ConversionPatternRewriter &rewriter) const override { 395 auto loc = compress->getLoc(); 396 MemRefType memRefType = compress.getMemRefType(); 397 398 // Resolve address. 399 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 400 adaptor.getIndices(), rewriter); 401 402 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 403 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 404 return success(); 405 } 406 }; 407 408 /// Reduction neutral classes for overloading. 409 class ReductionNeutralZero {}; 410 class ReductionNeutralIntOne {}; 411 class ReductionNeutralFPOne {}; 412 class ReductionNeutralAllOnes {}; 413 class ReductionNeutralSIntMin {}; 414 class ReductionNeutralUIntMin {}; 415 class ReductionNeutralSIntMax {}; 416 class ReductionNeutralUIntMax {}; 417 class ReductionNeutralFPMin {}; 418 class ReductionNeutralFPMax {}; 419 420 /// Create the reduction neutral zero value. 421 static Value createReductionNeutralValue(ReductionNeutralZero neutral, 422 ConversionPatternRewriter &rewriter, 423 Location loc, Type llvmType) { 424 return rewriter.create<LLVM::ConstantOp>(loc, llvmType, 425 rewriter.getZeroAttr(llvmType)); 426 } 427 428 /// Create the reduction neutral integer one value. 429 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, 430 ConversionPatternRewriter &rewriter, 431 Location loc, Type llvmType) { 432 return rewriter.create<LLVM::ConstantOp>( 433 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); 434 } 435 436 /// Create the reduction neutral fp one value. 437 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, 438 ConversionPatternRewriter &rewriter, 439 Location loc, Type llvmType) { 440 return rewriter.create<LLVM::ConstantOp>( 441 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); 442 } 443 444 /// Create the reduction neutral all-ones value. 445 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, 446 ConversionPatternRewriter &rewriter, 447 Location loc, Type llvmType) { 448 return rewriter.create<LLVM::ConstantOp>( 449 loc, llvmType, 450 rewriter.getIntegerAttr( 451 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); 452 } 453 454 /// Create the reduction neutral signed int minimum value. 455 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, 456 ConversionPatternRewriter &rewriter, 457 Location loc, Type llvmType) { 458 return rewriter.create<LLVM::ConstantOp>( 459 loc, llvmType, 460 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( 461 llvmType.getIntOrFloatBitWidth()))); 462 } 463 464 /// Create the reduction neutral unsigned int minimum value. 465 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, 466 ConversionPatternRewriter &rewriter, 467 Location loc, Type llvmType) { 468 return rewriter.create<LLVM::ConstantOp>( 469 loc, llvmType, 470 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( 471 llvmType.getIntOrFloatBitWidth()))); 472 } 473 474 /// Create the reduction neutral signed int maximum value. 475 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, 476 ConversionPatternRewriter &rewriter, 477 Location loc, Type llvmType) { 478 return rewriter.create<LLVM::ConstantOp>( 479 loc, llvmType, 480 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( 481 llvmType.getIntOrFloatBitWidth()))); 482 } 483 484 /// Create the reduction neutral unsigned int maximum value. 485 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, 486 ConversionPatternRewriter &rewriter, 487 Location loc, Type llvmType) { 488 return rewriter.create<LLVM::ConstantOp>( 489 loc, llvmType, 490 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( 491 llvmType.getIntOrFloatBitWidth()))); 492 } 493 494 /// Create the reduction neutral fp minimum value. 495 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, 496 ConversionPatternRewriter &rewriter, 497 Location loc, Type llvmType) { 498 auto floatType = cast<FloatType>(llvmType); 499 return rewriter.create<LLVM::ConstantOp>( 500 loc, llvmType, 501 rewriter.getFloatAttr( 502 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 503 /*Negative=*/false))); 504 } 505 506 /// Create the reduction neutral fp maximum value. 507 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, 508 ConversionPatternRewriter &rewriter, 509 Location loc, Type llvmType) { 510 auto floatType = cast<FloatType>(llvmType); 511 return rewriter.create<LLVM::ConstantOp>( 512 loc, llvmType, 513 rewriter.getFloatAttr( 514 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), 515 /*Negative=*/true))); 516 } 517 518 /// Returns `accumulator` if it has a valid value. Otherwise, creates and 519 /// returns a new accumulator value using `ReductionNeutral`. 520 template <class ReductionNeutral> 521 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, 522 Location loc, Type llvmType, 523 Value accumulator) { 524 if (accumulator) 525 return accumulator; 526 527 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, 528 llvmType); 529 } 530 531 /// Creates a constant value with the 1-D vector shape provided in `llvmType`. 532 /// This is used as effective vector length by some intrinsics supporting 533 /// dynamic vector lengths at runtime. 534 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, 535 Location loc, Type llvmType) { 536 VectorType vType = cast<VectorType>(llvmType); 537 auto vShape = vType.getShape(); 538 assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); 539 540 return rewriter.create<LLVM::ConstantOp>( 541 loc, rewriter.getI32Type(), 542 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); 543 } 544 545 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 546 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 547 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 548 /// non-null. 549 template <class LLVMRedIntrinOp, class ScalarOp> 550 static Value createIntegerReductionArithmeticOpLowering( 551 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 552 Value vectorOperand, Value accumulator) { 553 554 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 555 556 if (accumulator) 557 result = rewriter.create<ScalarOp>(loc, accumulator, result); 558 return result; 559 } 560 561 /// Helper method to lower a `vector.reduction` operation that performs 562 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 563 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 564 /// the accumulator value if non-null. 565 template <class LLVMRedIntrinOp> 566 static Value createIntegerReductionComparisonOpLowering( 567 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 568 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 569 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 570 if (accumulator) { 571 Value cmp = 572 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 573 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 574 } 575 return result; 576 } 577 578 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum 579 /// with vector types. 580 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, 581 Value rhs, bool isMin) { 582 auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType())); 583 Type i1Type = builder.getI1Type(); 584 if (auto vecType = dyn_cast<VectorType>(lhs.getType())) 585 i1Type = VectorType::get(vecType.getShape(), i1Type); 586 Value cmp = builder.create<LLVM::FCmpOp>( 587 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, 588 lhs, rhs); 589 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); 590 Value isNan = builder.create<LLVM::FCmpOp>( 591 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); 592 Value nan = builder.create<LLVM::ConstantOp>( 593 loc, lhs.getType(), 594 builder.getFloatAttr(floatType, 595 APFloat::getQNaN(floatType.getFloatSemantics()))); 596 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); 597 } 598 599 template <class LLVMRedIntrinOp> 600 static Value createFPReductionComparisonOpLowering( 601 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 602 Value vectorOperand, Value accumulator, bool isMin) { 603 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); 604 605 if (accumulator) 606 result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin); 607 608 return result; 609 } 610 611 /// Overloaded methods to lower a reduction to an llvm instrinsic that requires 612 /// a start value. This start value format spans across fp reductions without 613 /// mask and all the masked reduction intrinsics. 614 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 615 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 616 Location loc, Type llvmType, 617 Value vectorOperand, 618 Value accumulator) { 619 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 620 llvmType, accumulator); 621 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 622 /*startValue=*/accumulator, 623 vectorOperand); 624 } 625 626 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 627 static Value 628 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, 629 Type llvmType, Value vectorOperand, 630 Value accumulator, bool reassociateFPReds) { 631 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 632 llvmType, accumulator); 633 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 634 /*startValue=*/accumulator, 635 vectorOperand, reassociateFPReds); 636 } 637 638 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 639 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 640 Location loc, Type llvmType, 641 Value vectorOperand, 642 Value accumulator, Value mask) { 643 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 644 llvmType, accumulator); 645 Value vectorLength = 646 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 647 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 648 /*startValue=*/accumulator, 649 vectorOperand, mask, vectorLength); 650 } 651 652 template <class LLVMVPRedIntrinOp, class ReductionNeutral> 653 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 654 Location loc, Type llvmType, 655 Value vectorOperand, 656 Value accumulator, Value mask, 657 bool reassociateFPReds) { 658 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, 659 llvmType, accumulator); 660 Value vectorLength = 661 createVectorLengthValue(rewriter, loc, vectorOperand.getType()); 662 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, 663 /*startValue=*/accumulator, 664 vectorOperand, mask, vectorLength, 665 reassociateFPReds); 666 } 667 668 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, 669 class LLVMFPVPRedIntrinOp, class FPReductionNeutral> 670 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, 671 Location loc, Type llvmType, 672 Value vectorOperand, 673 Value accumulator, Value mask) { 674 if (llvmType.isIntOrIndex()) 675 return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp, 676 IntReductionNeutral>( 677 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 678 679 // FP dispatch. 680 return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>( 681 rewriter, loc, llvmType, vectorOperand, accumulator, mask); 682 } 683 684 /// Conversion pattern for all vector reductions. 685 class VectorReductionOpConversion 686 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 687 public: 688 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 689 bool reassociateFPRed) 690 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 691 reassociateFPReductions(reassociateFPRed) {} 692 693 LogicalResult 694 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 695 ConversionPatternRewriter &rewriter) const override { 696 auto kind = reductionOp.getKind(); 697 Type eltType = reductionOp.getDest().getType(); 698 Type llvmType = typeConverter->convertType(eltType); 699 Value operand = adaptor.getVector(); 700 Value acc = adaptor.getAcc(); 701 Location loc = reductionOp.getLoc(); 702 703 if (eltType.isIntOrIndex()) { 704 // Integer reductions: add/mul/min/max/and/or/xor. 705 Value result; 706 switch (kind) { 707 case vector::CombiningKind::ADD: 708 result = 709 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 710 LLVM::AddOp>( 711 rewriter, loc, llvmType, operand, acc); 712 break; 713 case vector::CombiningKind::MUL: 714 result = 715 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 716 LLVM::MulOp>( 717 rewriter, loc, llvmType, operand, acc); 718 break; 719 case vector::CombiningKind::MINUI: 720 result = createIntegerReductionComparisonOpLowering< 721 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 722 LLVM::ICmpPredicate::ule); 723 break; 724 case vector::CombiningKind::MINSI: 725 result = createIntegerReductionComparisonOpLowering< 726 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 727 LLVM::ICmpPredicate::sle); 728 break; 729 case vector::CombiningKind::MAXUI: 730 result = createIntegerReductionComparisonOpLowering< 731 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 732 LLVM::ICmpPredicate::uge); 733 break; 734 case vector::CombiningKind::MAXSI: 735 result = createIntegerReductionComparisonOpLowering< 736 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 737 LLVM::ICmpPredicate::sge); 738 break; 739 case vector::CombiningKind::AND: 740 result = 741 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 742 LLVM::AndOp>( 743 rewriter, loc, llvmType, operand, acc); 744 break; 745 case vector::CombiningKind::OR: 746 result = 747 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 748 LLVM::OrOp>( 749 rewriter, loc, llvmType, operand, acc); 750 break; 751 case vector::CombiningKind::XOR: 752 result = 753 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 754 LLVM::XOrOp>( 755 rewriter, loc, llvmType, operand, acc); 756 break; 757 default: 758 return failure(); 759 } 760 rewriter.replaceOp(reductionOp, result); 761 762 return success(); 763 } 764 765 if (!isa<FloatType>(eltType)) 766 return failure(); 767 768 // Floating-point reductions: add/mul/min/max 769 Value result; 770 if (kind == vector::CombiningKind::ADD) { 771 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd, 772 ReductionNeutralZero>( 773 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 774 } else if (kind == vector::CombiningKind::MUL) { 775 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul, 776 ReductionNeutralFPOne>( 777 rewriter, loc, llvmType, operand, acc, reassociateFPReductions); 778 } else if (kind == vector::CombiningKind::MINF) { 779 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 780 // NaNs/-0.0/+0.0 in the same way. 781 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>( 782 rewriter, loc, llvmType, operand, acc, 783 /*isMin=*/true); 784 } else if (kind == vector::CombiningKind::MAXF) { 785 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 786 // NaNs/-0.0/+0.0 in the same way. 787 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>( 788 rewriter, loc, llvmType, operand, acc, 789 /*isMin=*/false); 790 } else 791 return failure(); 792 793 rewriter.replaceOp(reductionOp, result); 794 return success(); 795 } 796 797 private: 798 const bool reassociateFPReductions; 799 }; 800 801 /// Base class to convert a `vector.mask` operation while matching traits 802 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase` 803 /// instance matches against a `vector.mask` operation. The `matchAndRewrite` 804 /// method performs a second match against the maskable operation `MaskedOp`. 805 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be 806 /// implemented by the concrete conversion classes. This method can match 807 /// against specific traits of the `vector.mask` and the maskable operation. It 808 /// must replace the `vector.mask` operation. 809 template <class MaskedOp> 810 class VectorMaskOpConversionBase 811 : public ConvertOpToLLVMPattern<vector::MaskOp> { 812 public: 813 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern; 814 815 LogicalResult 816 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, 817 ConversionPatternRewriter &rewriter) const final { 818 // Match against the maskable operation kind. 819 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp()); 820 if (!maskedOp) 821 return failure(); 822 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter); 823 } 824 825 protected: 826 virtual LogicalResult 827 matchAndRewriteMaskableOp(vector::MaskOp maskOp, 828 vector::MaskableOpInterface maskableOp, 829 ConversionPatternRewriter &rewriter) const = 0; 830 }; 831 832 class MaskedReductionOpConversion 833 : public VectorMaskOpConversionBase<vector::ReductionOp> { 834 835 public: 836 using VectorMaskOpConversionBase< 837 vector::ReductionOp>::VectorMaskOpConversionBase; 838 839 LogicalResult matchAndRewriteMaskableOp( 840 vector::MaskOp maskOp, MaskableOpInterface maskableOp, 841 ConversionPatternRewriter &rewriter) const override { 842 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation()); 843 auto kind = reductionOp.getKind(); 844 Type eltType = reductionOp.getDest().getType(); 845 Type llvmType = typeConverter->convertType(eltType); 846 Value operand = reductionOp.getVector(); 847 Value acc = reductionOp.getAcc(); 848 Location loc = reductionOp.getLoc(); 849 850 Value result; 851 switch (kind) { 852 case vector::CombiningKind::ADD: 853 result = lowerReductionWithStartValue< 854 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, 855 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, 856 maskOp.getMask()); 857 break; 858 case vector::CombiningKind::MUL: 859 result = lowerReductionWithStartValue< 860 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, 861 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, 862 maskOp.getMask()); 863 break; 864 case vector::CombiningKind::MINUI: 865 result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp, 866 ReductionNeutralUIntMax>( 867 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 868 break; 869 case vector::CombiningKind::MINSI: 870 result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp, 871 ReductionNeutralSIntMax>( 872 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 873 break; 874 case vector::CombiningKind::MAXUI: 875 result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp, 876 ReductionNeutralUIntMin>( 877 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 878 break; 879 case vector::CombiningKind::MAXSI: 880 result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp, 881 ReductionNeutralSIntMin>( 882 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 883 break; 884 case vector::CombiningKind::AND: 885 result = lowerReductionWithStartValue<LLVM::VPReduceAndOp, 886 ReductionNeutralAllOnes>( 887 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 888 break; 889 case vector::CombiningKind::OR: 890 result = lowerReductionWithStartValue<LLVM::VPReduceOrOp, 891 ReductionNeutralZero>( 892 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 893 break; 894 case vector::CombiningKind::XOR: 895 result = lowerReductionWithStartValue<LLVM::VPReduceXorOp, 896 ReductionNeutralZero>( 897 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 898 break; 899 case vector::CombiningKind::MINF: 900 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 901 // NaNs/-0.0/+0.0 in the same way. 902 result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp, 903 ReductionNeutralFPMax>( 904 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 905 break; 906 case vector::CombiningKind::MAXF: 907 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 908 // NaNs/-0.0/+0.0 in the same way. 909 result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp, 910 ReductionNeutralFPMin>( 911 rewriter, loc, llvmType, operand, acc, maskOp.getMask()); 912 break; 913 } 914 915 // Replace `vector.mask` operation altogether. 916 rewriter.replaceOp(maskOp, result); 917 return success(); 918 } 919 }; 920 921 class VectorShuffleOpConversion 922 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 923 public: 924 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 925 926 LogicalResult 927 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 928 ConversionPatternRewriter &rewriter) const override { 929 auto loc = shuffleOp->getLoc(); 930 auto v1Type = shuffleOp.getV1VectorType(); 931 auto v2Type = shuffleOp.getV2VectorType(); 932 auto vectorType = shuffleOp.getResultVectorType(); 933 Type llvmType = typeConverter->convertType(vectorType); 934 auto maskArrayAttr = shuffleOp.getMask(); 935 936 // Bail if result type cannot be lowered. 937 if (!llvmType) 938 return failure(); 939 940 // Get rank and dimension sizes. 941 int64_t rank = vectorType.getRank(); 942 #ifndef NDEBUG 943 bool wellFormed0DCase = 944 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 945 bool wellFormedNDCase = 946 v1Type.getRank() == rank && v2Type.getRank() == rank; 947 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 948 #endif 949 950 // For rank 0 and 1, where both operands have *exactly* the same vector 951 // type, there is direct shuffle support in LLVM. Use it! 952 if (rank <= 1 && v1Type == v2Type) { 953 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 954 loc, adaptor.getV1(), adaptor.getV2(), 955 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 956 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 957 return success(); 958 } 959 960 // For all other cases, insert the individual values individually. 961 int64_t v1Dim = v1Type.getDimSize(0); 962 Type eltType; 963 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType)) 964 eltType = arrayType.getElementType(); 965 else 966 eltType = cast<VectorType>(llvmType).getElementType(); 967 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 968 int64_t insPos = 0; 969 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 970 int64_t extPos = cast<IntegerAttr>(en.value()).getInt(); 971 Value value = adaptor.getV1(); 972 if (extPos >= v1Dim) { 973 extPos -= v1Dim; 974 value = adaptor.getV2(); 975 } 976 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 977 eltType, rank, extPos); 978 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 979 llvmType, rank, insPos++); 980 } 981 rewriter.replaceOp(shuffleOp, insert); 982 return success(); 983 } 984 }; 985 986 class VectorExtractElementOpConversion 987 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 988 public: 989 using ConvertOpToLLVMPattern< 990 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 991 992 LogicalResult 993 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 994 ConversionPatternRewriter &rewriter) const override { 995 auto vectorType = extractEltOp.getSourceVectorType(); 996 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 997 998 // Bail if result type cannot be lowered. 999 if (!llvmType) 1000 return failure(); 1001 1002 if (vectorType.getRank() == 0) { 1003 Location loc = extractEltOp.getLoc(); 1004 auto idxType = rewriter.getIndexType(); 1005 auto zero = rewriter.create<LLVM::ConstantOp>( 1006 loc, typeConverter->convertType(idxType), 1007 rewriter.getIntegerAttr(idxType, 0)); 1008 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1009 extractEltOp, llvmType, adaptor.getVector(), zero); 1010 return success(); 1011 } 1012 1013 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 1014 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 1015 return success(); 1016 } 1017 }; 1018 1019 class VectorExtractOpConversion 1020 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 1021 public: 1022 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 1023 1024 LogicalResult 1025 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 1026 ConversionPatternRewriter &rewriter) const override { 1027 auto loc = extractOp->getLoc(); 1028 auto resultType = extractOp.getResult().getType(); 1029 auto llvmResultType = typeConverter->convertType(resultType); 1030 auto positionArrayAttr = extractOp.getPosition(); 1031 1032 // Bail if result type cannot be lowered. 1033 if (!llvmResultType) 1034 return failure(); 1035 1036 // Extract entire vector. Should be handled by folder, but just to be safe. 1037 if (positionArrayAttr.empty()) { 1038 rewriter.replaceOp(extractOp, adaptor.getVector()); 1039 return success(); 1040 } 1041 1042 // One-shot extraction of vector from array (only requires extractvalue). 1043 if (isa<VectorType>(resultType)) { 1044 SmallVector<int64_t> indices; 1045 for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>()) 1046 indices.push_back(idx.getInt()); 1047 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 1048 loc, adaptor.getVector(), indices); 1049 rewriter.replaceOp(extractOp, extracted); 1050 return success(); 1051 } 1052 1053 // Potential extraction of 1-D vector from array. 1054 Value extracted = adaptor.getVector(); 1055 auto positionAttrs = positionArrayAttr.getValue(); 1056 if (positionAttrs.size() > 1) { 1057 SmallVector<int64_t> nMinusOnePosition; 1058 for (auto idx : positionAttrs.drop_back()) 1059 nMinusOnePosition.push_back(cast<IntegerAttr>(idx).getInt()); 1060 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 1061 nMinusOnePosition); 1062 } 1063 1064 // Remaining extraction of element from 1-D LLVM vector 1065 auto position = cast<IntegerAttr>(positionAttrs.back()); 1066 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1067 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 1068 extracted = 1069 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 1070 rewriter.replaceOp(extractOp, extracted); 1071 1072 return success(); 1073 } 1074 }; 1075 1076 /// Conversion pattern that turns a vector.fma on a 1-D vector 1077 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 1078 /// This does not match vectors of n >= 2 rank. 1079 /// 1080 /// Example: 1081 /// ``` 1082 /// vector.fma %a, %a, %a : vector<8xf32> 1083 /// ``` 1084 /// is converted to: 1085 /// ``` 1086 /// llvm.intr.fmuladd %va, %va, %va: 1087 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 1088 /// -> !llvm."<8 x f32>"> 1089 /// ``` 1090 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 1091 public: 1092 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 1093 1094 LogicalResult 1095 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 1096 ConversionPatternRewriter &rewriter) const override { 1097 VectorType vType = fmaOp.getVectorType(); 1098 if (vType.getRank() > 1) 1099 return failure(); 1100 1101 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 1102 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 1103 return success(); 1104 } 1105 }; 1106 1107 class VectorInsertElementOpConversion 1108 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 1109 public: 1110 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 1111 1112 LogicalResult 1113 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 1114 ConversionPatternRewriter &rewriter) const override { 1115 auto vectorType = insertEltOp.getDestVectorType(); 1116 auto llvmType = typeConverter->convertType(vectorType); 1117 1118 // Bail if result type cannot be lowered. 1119 if (!llvmType) 1120 return failure(); 1121 1122 if (vectorType.getRank() == 0) { 1123 Location loc = insertEltOp.getLoc(); 1124 auto idxType = rewriter.getIndexType(); 1125 auto zero = rewriter.create<LLVM::ConstantOp>( 1126 loc, typeConverter->convertType(idxType), 1127 rewriter.getIntegerAttr(idxType, 0)); 1128 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1129 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 1130 return success(); 1131 } 1132 1133 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1134 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 1135 adaptor.getPosition()); 1136 return success(); 1137 } 1138 }; 1139 1140 class VectorInsertOpConversion 1141 : public ConvertOpToLLVMPattern<vector::InsertOp> { 1142 public: 1143 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 1144 1145 LogicalResult 1146 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 1147 ConversionPatternRewriter &rewriter) const override { 1148 auto loc = insertOp->getLoc(); 1149 auto sourceType = insertOp.getSourceType(); 1150 auto destVectorType = insertOp.getDestVectorType(); 1151 auto llvmResultType = typeConverter->convertType(destVectorType); 1152 auto positionArrayAttr = insertOp.getPosition(); 1153 1154 // Bail if result type cannot be lowered. 1155 if (!llvmResultType) 1156 return failure(); 1157 1158 // Overwrite entire vector with value. Should be handled by folder, but 1159 // just to be safe. 1160 if (positionArrayAttr.empty()) { 1161 rewriter.replaceOp(insertOp, adaptor.getSource()); 1162 return success(); 1163 } 1164 1165 // One-shot insertion of a vector into an array (only requires insertvalue). 1166 if (isa<VectorType>(sourceType)) { 1167 Value inserted = rewriter.create<LLVM::InsertValueOp>( 1168 loc, adaptor.getDest(), adaptor.getSource(), 1169 LLVM::convertArrayToIndices(positionArrayAttr)); 1170 rewriter.replaceOp(insertOp, inserted); 1171 return success(); 1172 } 1173 1174 // Potential extraction of 1-D vector from array. 1175 Value extracted = adaptor.getDest(); 1176 auto positionAttrs = positionArrayAttr.getValue(); 1177 auto position = cast<IntegerAttr>(positionAttrs.back()); 1178 auto oneDVectorType = destVectorType; 1179 if (positionAttrs.size() > 1) { 1180 oneDVectorType = reducedVectorTypeBack(destVectorType); 1181 extracted = rewriter.create<LLVM::ExtractValueOp>( 1182 loc, extracted, 1183 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 1184 } 1185 1186 // Insertion of an element into a 1-D LLVM vector. 1187 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 1188 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 1189 Value inserted = rewriter.create<LLVM::InsertElementOp>( 1190 loc, typeConverter->convertType(oneDVectorType), extracted, 1191 adaptor.getSource(), constant); 1192 1193 // Potential insertion of resulting 1-D vector into array. 1194 if (positionAttrs.size() > 1) { 1195 inserted = rewriter.create<LLVM::InsertValueOp>( 1196 loc, adaptor.getDest(), inserted, 1197 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 1198 } 1199 1200 rewriter.replaceOp(insertOp, inserted); 1201 return success(); 1202 } 1203 }; 1204 1205 /// Lower vector.scalable.insert ops to LLVM vector.insert 1206 struct VectorScalableInsertOpLowering 1207 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 1208 using ConvertOpToLLVMPattern< 1209 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 1210 1211 LogicalResult 1212 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 1213 ConversionPatternRewriter &rewriter) const override { 1214 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 1215 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 1216 return success(); 1217 } 1218 }; 1219 1220 /// Lower vector.scalable.extract ops to LLVM vector.extract 1221 struct VectorScalableExtractOpLowering 1222 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 1223 using ConvertOpToLLVMPattern< 1224 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 1225 1226 LogicalResult 1227 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 1228 ConversionPatternRewriter &rewriter) const override { 1229 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 1230 extOp, typeConverter->convertType(extOp.getResultVectorType()), 1231 adaptor.getSource(), adaptor.getPos()); 1232 return success(); 1233 } 1234 }; 1235 1236 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 1237 /// 1238 /// Example: 1239 /// ``` 1240 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 1241 /// ``` 1242 /// is rewritten into: 1243 /// ``` 1244 /// %r = splat %f0: vector<2x4xf32> 1245 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 1246 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 1247 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 1248 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 1249 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 1250 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 1251 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 1252 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 1253 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 1254 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 1255 /// // %r3 holds the final value. 1256 /// ``` 1257 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 1258 public: 1259 using OpRewritePattern<FMAOp>::OpRewritePattern; 1260 1261 void initialize() { 1262 // This pattern recursively unpacks one dimension at a time. The recursion 1263 // bounded as the rank is strictly decreasing. 1264 setHasBoundedRewriteRecursion(); 1265 } 1266 1267 LogicalResult matchAndRewrite(FMAOp op, 1268 PatternRewriter &rewriter) const override { 1269 auto vType = op.getVectorType(); 1270 if (vType.getRank() < 2) 1271 return failure(); 1272 1273 auto loc = op.getLoc(); 1274 auto elemType = vType.getElementType(); 1275 Value zero = rewriter.create<arith::ConstantOp>( 1276 loc, elemType, rewriter.getZeroAttr(elemType)); 1277 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 1278 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 1279 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 1280 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 1281 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 1282 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 1283 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 1284 } 1285 rewriter.replaceOp(op, desc); 1286 return success(); 1287 } 1288 }; 1289 1290 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1291 /// static layout. 1292 static std::optional<SmallVector<int64_t, 4>> 1293 computeContiguousStrides(MemRefType memRefType) { 1294 int64_t offset; 1295 SmallVector<int64_t, 4> strides; 1296 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1297 return std::nullopt; 1298 if (!strides.empty() && strides.back() != 1) 1299 return std::nullopt; 1300 // If no layout or identity layout, this is contiguous by definition. 1301 if (memRefType.getLayout().isIdentity()) 1302 return strides; 1303 1304 // Otherwise, we must determine contiguity form shapes. This can only ever 1305 // work in static cases because MemRefType is underspecified to represent 1306 // contiguous dynamic shapes in other ways than with just empty/identity 1307 // layout. 1308 auto sizes = memRefType.getShape(); 1309 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 1310 if (ShapedType::isDynamic(sizes[index + 1]) || 1311 ShapedType::isDynamic(strides[index]) || 1312 ShapedType::isDynamic(strides[index + 1])) 1313 return std::nullopt; 1314 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1315 return std::nullopt; 1316 } 1317 return strides; 1318 } 1319 1320 class VectorTypeCastOpConversion 1321 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1322 public: 1323 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1324 1325 LogicalResult 1326 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 1327 ConversionPatternRewriter &rewriter) const override { 1328 auto loc = castOp->getLoc(); 1329 MemRefType sourceMemRefType = 1330 cast<MemRefType>(castOp.getOperand().getType()); 1331 MemRefType targetMemRefType = castOp.getType(); 1332 1333 // Only static shape casts supported atm. 1334 if (!sourceMemRefType.hasStaticShape() || 1335 !targetMemRefType.hasStaticShape()) 1336 return failure(); 1337 1338 auto llvmSourceDescriptorTy = 1339 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); 1340 if (!llvmSourceDescriptorTy) 1341 return failure(); 1342 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 1343 1344 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1345 typeConverter->convertType(targetMemRefType)); 1346 if (!llvmTargetDescriptorTy) 1347 return failure(); 1348 1349 // Only contiguous source buffers supported atm. 1350 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1351 if (!sourceStrides) 1352 return failure(); 1353 auto targetStrides = computeContiguousStrides(targetMemRefType); 1354 if (!targetStrides) 1355 return failure(); 1356 // Only support static strides for now, regardless of contiguity. 1357 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1358 return failure(); 1359 1360 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1361 1362 // Create descriptor. 1363 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1364 Type llvmTargetElementTy = desc.getElementPtrType(); 1365 // Set allocated ptr. 1366 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1367 if (!getTypeConverter()->useOpaquePointers()) 1368 allocated = 1369 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1370 desc.setAllocatedPtr(rewriter, loc, allocated); 1371 1372 // Set aligned ptr. 1373 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1374 if (!getTypeConverter()->useOpaquePointers()) 1375 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1376 1377 desc.setAlignedPtr(rewriter, loc, ptr); 1378 // Fill offset 0. 1379 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1380 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1381 desc.setOffset(rewriter, loc, zero); 1382 1383 // Fill size and stride descriptors in memref. 1384 for (const auto &indexedSize : 1385 llvm::enumerate(targetMemRefType.getShape())) { 1386 int64_t index = indexedSize.index(); 1387 auto sizeAttr = 1388 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1389 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1390 desc.setSize(rewriter, loc, index, size); 1391 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1392 (*targetStrides)[index]); 1393 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1394 desc.setStride(rewriter, loc, index, stride); 1395 } 1396 1397 rewriter.replaceOp(castOp, {desc}); 1398 return success(); 1399 } 1400 }; 1401 1402 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1403 /// Non-scalable versions of this operation are handled in Vector Transforms. 1404 class VectorCreateMaskOpRewritePattern 1405 : public OpRewritePattern<vector::CreateMaskOp> { 1406 public: 1407 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1408 bool enableIndexOpt) 1409 : OpRewritePattern<vector::CreateMaskOp>(context), 1410 force32BitVectorIndices(enableIndexOpt) {} 1411 1412 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1413 PatternRewriter &rewriter) const override { 1414 auto dstType = op.getType(); 1415 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable()) 1416 return failure(); 1417 IntegerType idxType = 1418 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1419 auto loc = op->getLoc(); 1420 Value indices = rewriter.create<LLVM::StepVectorOp>( 1421 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1422 /*isScalable=*/true)); 1423 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1424 op.getOperand(0)); 1425 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1426 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1427 indices, bounds); 1428 rewriter.replaceOp(op, comp); 1429 return success(); 1430 } 1431 1432 private: 1433 const bool force32BitVectorIndices; 1434 }; 1435 1436 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1437 public: 1438 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1439 1440 // Proof-of-concept lowering implementation that relies on a small 1441 // runtime support library, which only needs to provide a few 1442 // printing methods (single value for all data types, opening/closing 1443 // bracket, comma, newline). The lowering fully unrolls a vector 1444 // in terms of these elementary printing operations. The advantage 1445 // of this approach is that the library can remain unaware of all 1446 // low-level implementation details of vectors while still supporting 1447 // output of any shaped and dimensioned vector. Due to full unrolling, 1448 // this approach is less suited for very large vectors though. 1449 // 1450 // TODO: rely solely on libc in future? something else? 1451 // 1452 LogicalResult 1453 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1454 ConversionPatternRewriter &rewriter) const override { 1455 Type printType = printOp.getPrintType(); 1456 1457 if (typeConverter->convertType(printType) == nullptr) 1458 return failure(); 1459 1460 // Make sure element type has runtime support. 1461 PrintConversion conversion = PrintConversion::None; 1462 VectorType vectorType = dyn_cast<VectorType>(printType); 1463 Type eltType = vectorType ? vectorType.getElementType() : printType; 1464 auto parent = printOp->getParentOfType<ModuleOp>(); 1465 Operation *printer; 1466 if (eltType.isF32()) { 1467 printer = LLVM::lookupOrCreatePrintF32Fn(parent); 1468 } else if (eltType.isF64()) { 1469 printer = LLVM::lookupOrCreatePrintF64Fn(parent); 1470 } else if (eltType.isF16()) { 1471 conversion = PrintConversion::Bitcast16; // bits! 1472 printer = LLVM::lookupOrCreatePrintF16Fn(parent); 1473 } else if (eltType.isBF16()) { 1474 conversion = PrintConversion::Bitcast16; // bits! 1475 printer = LLVM::lookupOrCreatePrintBF16Fn(parent); 1476 } else if (eltType.isIndex()) { 1477 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1478 } else if (auto intTy = dyn_cast<IntegerType>(eltType)) { 1479 // Integers need a zero or sign extension on the operand 1480 // (depending on the source type) as well as a signed or 1481 // unsigned print method. Up to 64-bit is supported. 1482 unsigned width = intTy.getWidth(); 1483 if (intTy.isUnsigned()) { 1484 if (width <= 64) { 1485 if (width < 64) 1486 conversion = PrintConversion::ZeroExt64; 1487 printer = LLVM::lookupOrCreatePrintU64Fn(parent); 1488 } else { 1489 return failure(); 1490 } 1491 } else { 1492 assert(intTy.isSignless() || intTy.isSigned()); 1493 if (width <= 64) { 1494 // Note that we *always* zero extend booleans (1-bit integers), 1495 // so that true/false is printed as 1/0 rather than -1/0. 1496 if (width == 1) 1497 conversion = PrintConversion::ZeroExt64; 1498 else if (width < 64) 1499 conversion = PrintConversion::SignExt64; 1500 printer = LLVM::lookupOrCreatePrintI64Fn(parent); 1501 } else { 1502 return failure(); 1503 } 1504 } 1505 } else { 1506 return failure(); 1507 } 1508 1509 // Unroll vector into elementary print calls. 1510 int64_t rank = vectorType ? vectorType.getRank() : 0; 1511 Type type = vectorType ? vectorType : eltType; 1512 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1513 conversion); 1514 emitCall(rewriter, printOp->getLoc(), 1515 LLVM::lookupOrCreatePrintNewlineFn(parent)); 1516 rewriter.eraseOp(printOp); 1517 return success(); 1518 } 1519 1520 private: 1521 enum class PrintConversion { 1522 // clang-format off 1523 None, 1524 ZeroExt64, 1525 SignExt64, 1526 Bitcast16 1527 // clang-format on 1528 }; 1529 1530 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1531 Value value, Type type, Operation *printer, int64_t rank, 1532 PrintConversion conversion) const { 1533 VectorType vectorType = dyn_cast<VectorType>(type); 1534 Location loc = op->getLoc(); 1535 if (!vectorType) { 1536 assert(rank == 0 && "The scalar case expects rank == 0"); 1537 switch (conversion) { 1538 case PrintConversion::ZeroExt64: 1539 value = rewriter.create<arith::ExtUIOp>( 1540 loc, IntegerType::get(rewriter.getContext(), 64), value); 1541 break; 1542 case PrintConversion::SignExt64: 1543 value = rewriter.create<arith::ExtSIOp>( 1544 loc, IntegerType::get(rewriter.getContext(), 64), value); 1545 break; 1546 case PrintConversion::Bitcast16: 1547 value = rewriter.create<LLVM::BitcastOp>( 1548 loc, IntegerType::get(rewriter.getContext(), 16), value); 1549 break; 1550 case PrintConversion::None: 1551 break; 1552 } 1553 emitCall(rewriter, loc, printer, value); 1554 return; 1555 } 1556 1557 auto parent = op->getParentOfType<ModuleOp>(); 1558 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent)); 1559 Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent); 1560 1561 if (rank <= 1) { 1562 auto reducedType = vectorType.getElementType(); 1563 auto llvmType = typeConverter->convertType(reducedType); 1564 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1565 for (int64_t d = 0; d < dim; ++d) { 1566 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1567 llvmType, /*rank=*/0, /*pos=*/d); 1568 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1569 conversion); 1570 if (d != dim - 1) 1571 emitCall(rewriter, loc, printComma); 1572 } 1573 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); 1574 return; 1575 } 1576 1577 int64_t dim = vectorType.getDimSize(0); 1578 for (int64_t d = 0; d < dim; ++d) { 1579 auto reducedType = reducedVectorTypeFront(vectorType); 1580 auto llvmType = typeConverter->convertType(reducedType); 1581 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1582 llvmType, rank, d); 1583 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1584 conversion); 1585 if (d != dim - 1) 1586 emitCall(rewriter, loc, printComma); 1587 } 1588 emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); 1589 } 1590 1591 // Helper to emit a call. 1592 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1593 Operation *ref, ValueRange params = ValueRange()) { 1594 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1595 params); 1596 } 1597 }; 1598 1599 /// The Splat operation is lowered to an insertelement + a shufflevector 1600 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1601 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1602 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1603 1604 LogicalResult 1605 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1606 ConversionPatternRewriter &rewriter) const override { 1607 VectorType resultType = cast<VectorType>(splatOp.getType()); 1608 if (resultType.getRank() > 1) 1609 return failure(); 1610 1611 // First insert it into an undef vector so we can shuffle it. 1612 auto vectorType = typeConverter->convertType(splatOp.getType()); 1613 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1614 auto zero = rewriter.create<LLVM::ConstantOp>( 1615 splatOp.getLoc(), 1616 typeConverter->convertType(rewriter.getIntegerType(32)), 1617 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1618 1619 // For 0-d vector, we simply do `insertelement`. 1620 if (resultType.getRank() == 0) { 1621 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1622 splatOp, vectorType, undef, adaptor.getInput(), zero); 1623 return success(); 1624 } 1625 1626 // For 1-d vector, we additionally do a `vectorshuffle`. 1627 auto v = rewriter.create<LLVM::InsertElementOp>( 1628 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1629 1630 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); 1631 SmallVector<int32_t> zeroValues(width, 0); 1632 1633 // Shuffle the value across the desired number of elements. 1634 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1635 zeroValues); 1636 return success(); 1637 } 1638 }; 1639 1640 /// The Splat operation is lowered to an insertelement + a shufflevector 1641 /// operation. Splat to only 2+-d vector result types are lowered by the 1642 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1643 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1644 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1645 1646 LogicalResult 1647 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1648 ConversionPatternRewriter &rewriter) const override { 1649 VectorType resultType = splatOp.getType(); 1650 if (resultType.getRank() <= 1) 1651 return failure(); 1652 1653 // First insert it into an undef vector so we can shuffle it. 1654 auto loc = splatOp.getLoc(); 1655 auto vectorTypeInfo = 1656 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1657 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1658 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1659 if (!llvmNDVectorTy || !llvm1DVectorTy) 1660 return failure(); 1661 1662 // Construct returned value. 1663 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1664 1665 // Construct a 1-D vector with the splatted value that we insert in all the 1666 // places within the returned descriptor. 1667 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1668 auto zero = rewriter.create<LLVM::ConstantOp>( 1669 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1670 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1671 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1672 adaptor.getInput(), zero); 1673 1674 // Shuffle the value across the desired number of elements. 1675 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1676 SmallVector<int32_t> zeroValues(width, 0); 1677 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1678 1679 // Iterate of linear index, convert to coords space and insert splatted 1-D 1680 // vector in each position. 1681 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1682 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1683 }); 1684 rewriter.replaceOp(splatOp, desc); 1685 return success(); 1686 } 1687 }; 1688 1689 } // namespace 1690 1691 /// Populate the given list with patterns that convert from Vector to LLVM. 1692 void mlir::populateVectorToLLVMConversionPatterns( 1693 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1694 bool reassociateFPReductions, bool force32BitVectorIndices) { 1695 MLIRContext *ctx = converter.getDialect()->getContext(); 1696 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1697 populateVectorInsertExtractStridedSliceTransforms(patterns); 1698 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1699 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1700 patterns 1701 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1702 VectorExtractElementOpConversion, VectorExtractOpConversion, 1703 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1704 VectorInsertOpConversion, VectorPrintOpConversion, 1705 VectorTypeCastOpConversion, VectorScaleOpConversion, 1706 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1707 VectorLoadStoreConversion<vector::MaskedLoadOp, 1708 vector::MaskedLoadOpAdaptor>, 1709 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1710 VectorLoadStoreConversion<vector::MaskedStoreOp, 1711 vector::MaskedStoreOpAdaptor>, 1712 VectorGatherOpConversion, VectorScatterOpConversion, 1713 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1714 VectorSplatOpLowering, VectorSplatNdOpLowering, 1715 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, 1716 MaskedReductionOpConversion>(converter); 1717 // Transfer ops with rank > 1 are handled by VectorToSCF. 1718 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1719 } 1720 1721 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1722 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1723 patterns.add<VectorMatmulOpConversion>(converter); 1724 patterns.add<VectorFlatTransposeOpConversion>(converter); 1725 } 1726