1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Arith/Utils/Utils.h" 14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include <optional> 23 24 using namespace mlir; 25 using namespace mlir::vector; 26 27 // Helper to reduce vector type by one rank at front. 28 static VectorType reducedVectorTypeFront(VectorType tp) { 29 assert((tp.getRank() > 1) && "unlowerable vector type"); 30 unsigned numScalableDims = tp.getNumScalableDims(); 31 if (tp.getShape().size() == numScalableDims) 32 --numScalableDims; 33 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 34 numScalableDims); 35 } 36 37 // Helper to reduce vector type by *all* but one rank at back. 38 static VectorType reducedVectorTypeBack(VectorType tp) { 39 assert((tp.getRank() > 1) && "unlowerable vector type"); 40 unsigned numScalableDims = tp.getNumScalableDims(); 41 if (numScalableDims > 0) 42 --numScalableDims; 43 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 44 numScalableDims); 45 } 46 47 // Helper that picks the proper sequence for inserting. 48 static Value insertOne(ConversionPatternRewriter &rewriter, 49 LLVMTypeConverter &typeConverter, Location loc, 50 Value val1, Value val2, Type llvmType, int64_t rank, 51 int64_t pos) { 52 assert(rank > 0 && "0-D vector corner case should have been handled already"); 53 if (rank == 1) { 54 auto idxType = rewriter.getIndexType(); 55 auto constant = rewriter.create<LLVM::ConstantOp>( 56 loc, typeConverter.convertType(idxType), 57 rewriter.getIntegerAttr(idxType, pos)); 58 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 59 constant); 60 } 61 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 62 } 63 64 // Helper that picks the proper sequence for extracting. 65 static Value extractOne(ConversionPatternRewriter &rewriter, 66 LLVMTypeConverter &typeConverter, Location loc, 67 Value val, Type llvmType, int64_t rank, int64_t pos) { 68 if (rank <= 1) { 69 auto idxType = rewriter.getIndexType(); 70 auto constant = rewriter.create<LLVM::ConstantOp>( 71 loc, typeConverter.convertType(idxType), 72 rewriter.getIntegerAttr(idxType, pos)); 73 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 74 constant); 75 } 76 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 77 } 78 79 // Helper that returns data layout alignment of a memref. 80 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 81 MemRefType memrefType, unsigned &align) { 82 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 83 if (!elementTy) 84 return failure(); 85 86 // TODO: this should use the MLIR data layout when it becomes available and 87 // stop depending on translation. 88 llvm::LLVMContext llvmContext; 89 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 90 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 91 return success(); 92 } 93 94 // Check if the last stride is non-unit or the memory space is not zero. 95 static LogicalResult isMemRefTypeSupported(MemRefType memRefType) { 96 int64_t offset; 97 SmallVector<int64_t, 4> strides; 98 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 99 if (failed(successStrides) || strides.back() != 1 || 100 memRefType.getMemorySpaceAsInt() != 0) 101 return failure(); 102 return success(); 103 } 104 105 // Add an index vector component to a base pointer. 106 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 107 MemRefType memRefType, Value llvmMemref, Value base, 108 Value index, uint64_t vLen) { 109 assert(succeeded(isMemRefTypeSupported(memRefType)) && 110 "unsupported memref type"); 111 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 112 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 113 return rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index); 114 } 115 116 // Casts a strided element pointer to a vector pointer. The vector pointer 117 // will be in the same address space as the incoming memref type. 118 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 119 Value ptr, MemRefType memRefType, Type vt) { 120 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); 121 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 122 } 123 124 namespace { 125 126 /// Trivial Vector to LLVM conversions 127 using VectorScaleOpConversion = 128 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 129 130 /// Conversion pattern for a vector.bitcast. 131 class VectorBitCastOpConversion 132 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 133 public: 134 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 135 136 LogicalResult 137 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 138 ConversionPatternRewriter &rewriter) const override { 139 // Only 0-D and 1-D vectors can be lowered to LLVM. 140 VectorType resultTy = bitCastOp.getResultVectorType(); 141 if (resultTy.getRank() > 1) 142 return failure(); 143 Type newResultTy = typeConverter->convertType(resultTy); 144 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 145 adaptor.getOperands()[0]); 146 return success(); 147 } 148 }; 149 150 /// Conversion pattern for a vector.matrix_multiply. 151 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 152 class VectorMatmulOpConversion 153 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 154 public: 155 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 156 157 LogicalResult 158 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 159 ConversionPatternRewriter &rewriter) const override { 160 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 161 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 162 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 163 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 164 return success(); 165 } 166 }; 167 168 /// Conversion pattern for a vector.flat_transpose. 169 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 170 class VectorFlatTransposeOpConversion 171 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 172 public: 173 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 174 175 LogicalResult 176 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 177 ConversionPatternRewriter &rewriter) const override { 178 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 179 transOp, typeConverter->convertType(transOp.getRes().getType()), 180 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 181 return success(); 182 } 183 }; 184 185 /// Overloaded utility that replaces a vector.load, vector.store, 186 /// vector.maskedload and vector.maskedstore with their respective LLVM 187 /// couterparts. 188 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 189 vector::LoadOpAdaptor adaptor, 190 VectorType vectorTy, Value ptr, unsigned align, 191 ConversionPatternRewriter &rewriter) { 192 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 193 } 194 195 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 196 vector::MaskedLoadOpAdaptor adaptor, 197 VectorType vectorTy, Value ptr, unsigned align, 198 ConversionPatternRewriter &rewriter) { 199 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 200 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 201 } 202 203 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 204 vector::StoreOpAdaptor adaptor, 205 VectorType vectorTy, Value ptr, unsigned align, 206 ConversionPatternRewriter &rewriter) { 207 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 208 ptr, align); 209 } 210 211 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 212 vector::MaskedStoreOpAdaptor adaptor, 213 VectorType vectorTy, Value ptr, unsigned align, 214 ConversionPatternRewriter &rewriter) { 215 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 216 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 217 } 218 219 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 220 /// vector.maskedstore. 221 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 222 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 223 public: 224 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 225 226 LogicalResult 227 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 228 typename LoadOrStoreOp::Adaptor adaptor, 229 ConversionPatternRewriter &rewriter) const override { 230 // Only 1-D vectors can be lowered to LLVM. 231 VectorType vectorTy = loadOrStoreOp.getVectorType(); 232 if (vectorTy.getRank() > 1) 233 return failure(); 234 235 auto loc = loadOrStoreOp->getLoc(); 236 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 237 238 // Resolve alignment. 239 unsigned align; 240 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 241 return failure(); 242 243 // Resolve address. 244 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 245 .template cast<VectorType>(); 246 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 247 adaptor.getIndices(), rewriter); 248 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 249 250 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 251 return success(); 252 } 253 }; 254 255 /// Conversion pattern for a vector.gather. 256 class VectorGatherOpConversion 257 : public ConvertOpToLLVMPattern<vector::GatherOp> { 258 public: 259 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 260 261 LogicalResult 262 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 263 ConversionPatternRewriter &rewriter) const override { 264 MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>(); 265 assert(memRefType && "The base should be bufferized"); 266 267 if (failed(isMemRefTypeSupported(memRefType))) 268 return failure(); 269 270 auto loc = gather->getLoc(); 271 272 // Resolve alignment. 273 unsigned align; 274 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 275 return failure(); 276 277 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 278 adaptor.getIndices(), rewriter); 279 Value base = adaptor.getBase(); 280 281 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 282 // Handle the simple case of 1-D vector. 283 if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) { 284 auto vType = gather.getVectorType(); 285 // Resolve address. 286 Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr, 287 adaptor.getIndexVec(), 288 /*vLen=*/vType.getDimSize(0)); 289 // Replace with the gather intrinsic. 290 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 291 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 292 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 293 return success(); 294 } 295 296 auto callback = [align, memRefType, base, ptr, loc, &rewriter]( 297 Type llvm1DVectorTy, ValueRange vectorOperands) { 298 // Resolve address. 299 Value ptrs = getIndexedPtrs( 300 rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0], 301 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 302 // Create the gather intrinsic. 303 return rewriter.create<LLVM::masked_gather>( 304 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 305 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 306 }; 307 SmallVector<Value> vectorOperands = { 308 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 309 return LLVM::detail::handleMultidimensionalVectors( 310 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 311 } 312 }; 313 314 /// Conversion pattern for a vector.scatter. 315 class VectorScatterOpConversion 316 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 317 public: 318 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 319 320 LogicalResult 321 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 322 ConversionPatternRewriter &rewriter) const override { 323 auto loc = scatter->getLoc(); 324 MemRefType memRefType = scatter.getMemRefType(); 325 326 if (failed(isMemRefTypeSupported(memRefType))) 327 return failure(); 328 329 // Resolve alignment. 330 unsigned align; 331 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 332 return failure(); 333 334 // Resolve address. 335 VectorType vType = scatter.getVectorType(); 336 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 337 adaptor.getIndices(), rewriter); 338 Value ptrs = 339 getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr, 340 adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 341 342 // Replace with the scatter intrinsic. 343 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 344 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 345 rewriter.getI32IntegerAttr(align)); 346 return success(); 347 } 348 }; 349 350 /// Conversion pattern for a vector.expandload. 351 class VectorExpandLoadOpConversion 352 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 353 public: 354 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 355 356 LogicalResult 357 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 358 ConversionPatternRewriter &rewriter) const override { 359 auto loc = expand->getLoc(); 360 MemRefType memRefType = expand.getMemRefType(); 361 362 // Resolve address. 363 auto vtype = typeConverter->convertType(expand.getVectorType()); 364 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 365 adaptor.getIndices(), rewriter); 366 367 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 368 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 369 return success(); 370 } 371 }; 372 373 /// Conversion pattern for a vector.compressstore. 374 class VectorCompressStoreOpConversion 375 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 376 public: 377 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 378 379 LogicalResult 380 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 381 ConversionPatternRewriter &rewriter) const override { 382 auto loc = compress->getLoc(); 383 MemRefType memRefType = compress.getMemRefType(); 384 385 // Resolve address. 386 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 387 adaptor.getIndices(), rewriter); 388 389 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 390 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 391 return success(); 392 } 393 }; 394 395 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 396 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 397 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 398 /// non-null. 399 template <class VectorOp, class ScalarOp> 400 static Value createIntegerReductionArithmeticOpLowering( 401 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 402 Value vectorOperand, Value accumulator) { 403 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); 404 if (accumulator) 405 result = rewriter.create<ScalarOp>(loc, accumulator, result); 406 return result; 407 } 408 409 /// Helper method to lower a `vector.reduction` operation that performs 410 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 411 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 412 /// the accumulator value if non-null. 413 template <class VectorOp> 414 static Value createIntegerReductionComparisonOpLowering( 415 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 416 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 417 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); 418 if (accumulator) { 419 Value cmp = 420 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 421 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 422 } 423 return result; 424 } 425 426 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum 427 /// with vector types. 428 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, 429 Value rhs, bool isMin) { 430 auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); 431 Type i1Type = builder.getI1Type(); 432 if (auto vecType = lhs.getType().dyn_cast<VectorType>()) 433 i1Type = VectorType::get(vecType.getShape(), i1Type); 434 Value cmp = builder.create<LLVM::FCmpOp>( 435 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, 436 lhs, rhs); 437 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); 438 Value isNan = builder.create<LLVM::FCmpOp>( 439 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); 440 Value nan = builder.create<LLVM::ConstantOp>( 441 loc, lhs.getType(), 442 builder.getFloatAttr(floatType, 443 APFloat::getQNaN(floatType.getFloatSemantics()))); 444 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); 445 } 446 447 /// Conversion pattern for all vector reductions. 448 class VectorReductionOpConversion 449 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 450 public: 451 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 452 bool reassociateFPRed) 453 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 454 reassociateFPReductions(reassociateFPRed) {} 455 456 LogicalResult 457 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 458 ConversionPatternRewriter &rewriter) const override { 459 auto kind = reductionOp.getKind(); 460 Type eltType = reductionOp.getDest().getType(); 461 Type llvmType = typeConverter->convertType(eltType); 462 Value operand = adaptor.getVector(); 463 Value acc = adaptor.getAcc(); 464 Location loc = reductionOp.getLoc(); 465 if (eltType.isIntOrIndex()) { 466 // Integer reductions: add/mul/min/max/and/or/xor. 467 Value result; 468 switch (kind) { 469 case vector::CombiningKind::ADD: 470 result = 471 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 472 LLVM::AddOp>( 473 rewriter, loc, llvmType, operand, acc); 474 break; 475 case vector::CombiningKind::MUL: 476 result = 477 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 478 LLVM::MulOp>( 479 rewriter, loc, llvmType, operand, acc); 480 break; 481 case vector::CombiningKind::MINUI: 482 result = createIntegerReductionComparisonOpLowering< 483 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 484 LLVM::ICmpPredicate::ule); 485 break; 486 case vector::CombiningKind::MINSI: 487 result = createIntegerReductionComparisonOpLowering< 488 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 489 LLVM::ICmpPredicate::sle); 490 break; 491 case vector::CombiningKind::MAXUI: 492 result = createIntegerReductionComparisonOpLowering< 493 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 494 LLVM::ICmpPredicate::uge); 495 break; 496 case vector::CombiningKind::MAXSI: 497 result = createIntegerReductionComparisonOpLowering< 498 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 499 LLVM::ICmpPredicate::sge); 500 break; 501 case vector::CombiningKind::AND: 502 result = 503 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 504 LLVM::AndOp>( 505 rewriter, loc, llvmType, operand, acc); 506 break; 507 case vector::CombiningKind::OR: 508 result = 509 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 510 LLVM::OrOp>( 511 rewriter, loc, llvmType, operand, acc); 512 break; 513 case vector::CombiningKind::XOR: 514 result = 515 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 516 LLVM::XOrOp>( 517 rewriter, loc, llvmType, operand, acc); 518 break; 519 default: 520 return failure(); 521 } 522 rewriter.replaceOp(reductionOp, result); 523 524 return success(); 525 } 526 527 if (!eltType.isa<FloatType>()) 528 return failure(); 529 530 // Floating-point reductions: add/mul/min/max 531 if (kind == vector::CombiningKind::ADD) { 532 // Optional accumulator (or zero). 533 Value acc = adaptor.getOperands().size() > 1 534 ? adaptor.getOperands()[1] 535 : rewriter.create<LLVM::ConstantOp>( 536 reductionOp->getLoc(), llvmType, 537 rewriter.getZeroAttr(eltType)); 538 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 539 reductionOp, llvmType, acc, operand, 540 rewriter.getBoolAttr(reassociateFPReductions)); 541 } else if (kind == vector::CombiningKind::MUL) { 542 // Optional accumulator (or one). 543 Value acc = adaptor.getOperands().size() > 1 544 ? adaptor.getOperands()[1] 545 : rewriter.create<LLVM::ConstantOp>( 546 reductionOp->getLoc(), llvmType, 547 rewriter.getFloatAttr(eltType, 1.0)); 548 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 549 reductionOp, llvmType, acc, operand, 550 rewriter.getBoolAttr(reassociateFPReductions)); 551 } else if (kind == vector::CombiningKind::MINF) { 552 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 553 // NaNs/-0.0/+0.0 in the same way. 554 Value result = 555 rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand); 556 if (acc) 557 result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true); 558 rewriter.replaceOp(reductionOp, result); 559 } else if (kind == vector::CombiningKind::MAXF) { 560 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 561 // NaNs/-0.0/+0.0 in the same way. 562 Value result = 563 rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand); 564 if (acc) 565 result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false); 566 rewriter.replaceOp(reductionOp, result); 567 } else 568 return failure(); 569 570 return success(); 571 } 572 573 private: 574 const bool reassociateFPReductions; 575 }; 576 577 class VectorShuffleOpConversion 578 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 579 public: 580 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 581 582 LogicalResult 583 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 584 ConversionPatternRewriter &rewriter) const override { 585 auto loc = shuffleOp->getLoc(); 586 auto v1Type = shuffleOp.getV1VectorType(); 587 auto v2Type = shuffleOp.getV2VectorType(); 588 auto vectorType = shuffleOp.getVectorType(); 589 Type llvmType = typeConverter->convertType(vectorType); 590 auto maskArrayAttr = shuffleOp.getMask(); 591 592 // Bail if result type cannot be lowered. 593 if (!llvmType) 594 return failure(); 595 596 // Get rank and dimension sizes. 597 int64_t rank = vectorType.getRank(); 598 #ifndef NDEBUG 599 bool wellFormed0DCase = 600 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 601 bool wellFormedNDCase = 602 v1Type.getRank() == rank && v2Type.getRank() == rank; 603 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 604 #endif 605 606 // For rank 0 and 1, where both operands have *exactly* the same vector 607 // type, there is direct shuffle support in LLVM. Use it! 608 if (rank <= 1 && v1Type == v2Type) { 609 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 610 loc, adaptor.getV1(), adaptor.getV2(), 611 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 612 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 613 return success(); 614 } 615 616 // For all other cases, insert the individual values individually. 617 int64_t v1Dim = v1Type.getDimSize(0); 618 Type eltType; 619 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 620 eltType = arrayType.getElementType(); 621 else 622 eltType = llvmType.cast<VectorType>().getElementType(); 623 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 624 int64_t insPos = 0; 625 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 626 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 627 Value value = adaptor.getV1(); 628 if (extPos >= v1Dim) { 629 extPos -= v1Dim; 630 value = adaptor.getV2(); 631 } 632 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 633 eltType, rank, extPos); 634 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 635 llvmType, rank, insPos++); 636 } 637 rewriter.replaceOp(shuffleOp, insert); 638 return success(); 639 } 640 }; 641 642 class VectorExtractElementOpConversion 643 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 644 public: 645 using ConvertOpToLLVMPattern< 646 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 647 648 LogicalResult 649 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 650 ConversionPatternRewriter &rewriter) const override { 651 auto vectorType = extractEltOp.getVectorType(); 652 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 653 654 // Bail if result type cannot be lowered. 655 if (!llvmType) 656 return failure(); 657 658 if (vectorType.getRank() == 0) { 659 Location loc = extractEltOp.getLoc(); 660 auto idxType = rewriter.getIndexType(); 661 auto zero = rewriter.create<LLVM::ConstantOp>( 662 loc, typeConverter->convertType(idxType), 663 rewriter.getIntegerAttr(idxType, 0)); 664 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 665 extractEltOp, llvmType, adaptor.getVector(), zero); 666 return success(); 667 } 668 669 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 670 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 671 return success(); 672 } 673 }; 674 675 class VectorExtractOpConversion 676 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 677 public: 678 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 679 680 LogicalResult 681 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 682 ConversionPatternRewriter &rewriter) const override { 683 auto loc = extractOp->getLoc(); 684 auto resultType = extractOp.getResult().getType(); 685 auto llvmResultType = typeConverter->convertType(resultType); 686 auto positionArrayAttr = extractOp.getPosition(); 687 688 // Bail if result type cannot be lowered. 689 if (!llvmResultType) 690 return failure(); 691 692 // Extract entire vector. Should be handled by folder, but just to be safe. 693 if (positionArrayAttr.empty()) { 694 rewriter.replaceOp(extractOp, adaptor.getVector()); 695 return success(); 696 } 697 698 // One-shot extraction of vector from array (only requires extractvalue). 699 if (resultType.isa<VectorType>()) { 700 SmallVector<int64_t> indices; 701 for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>()) 702 indices.push_back(idx.getInt()); 703 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 704 loc, adaptor.getVector(), indices); 705 rewriter.replaceOp(extractOp, extracted); 706 return success(); 707 } 708 709 // Potential extraction of 1-D vector from array. 710 Value extracted = adaptor.getVector(); 711 auto positionAttrs = positionArrayAttr.getValue(); 712 if (positionAttrs.size() > 1) { 713 SmallVector<int64_t> nMinusOnePosition; 714 for (auto idx : positionAttrs.drop_back()) 715 nMinusOnePosition.push_back(idx.cast<IntegerAttr>().getInt()); 716 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 717 nMinusOnePosition); 718 } 719 720 // Remaining extraction of element from 1-D LLVM vector 721 auto position = positionAttrs.back().cast<IntegerAttr>(); 722 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 723 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 724 extracted = 725 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 726 rewriter.replaceOp(extractOp, extracted); 727 728 return success(); 729 } 730 }; 731 732 /// Conversion pattern that turns a vector.fma on a 1-D vector 733 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 734 /// This does not match vectors of n >= 2 rank. 735 /// 736 /// Example: 737 /// ``` 738 /// vector.fma %a, %a, %a : vector<8xf32> 739 /// ``` 740 /// is converted to: 741 /// ``` 742 /// llvm.intr.fmuladd %va, %va, %va: 743 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 744 /// -> !llvm."<8 x f32>"> 745 /// ``` 746 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 747 public: 748 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 749 750 LogicalResult 751 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 752 ConversionPatternRewriter &rewriter) const override { 753 VectorType vType = fmaOp.getVectorType(); 754 if (vType.getRank() > 1) 755 return failure(); 756 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 757 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 758 return success(); 759 } 760 }; 761 762 class VectorInsertElementOpConversion 763 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 764 public: 765 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 766 767 LogicalResult 768 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 769 ConversionPatternRewriter &rewriter) const override { 770 auto vectorType = insertEltOp.getDestVectorType(); 771 auto llvmType = typeConverter->convertType(vectorType); 772 773 // Bail if result type cannot be lowered. 774 if (!llvmType) 775 return failure(); 776 777 if (vectorType.getRank() == 0) { 778 Location loc = insertEltOp.getLoc(); 779 auto idxType = rewriter.getIndexType(); 780 auto zero = rewriter.create<LLVM::ConstantOp>( 781 loc, typeConverter->convertType(idxType), 782 rewriter.getIntegerAttr(idxType, 0)); 783 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 784 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 785 return success(); 786 } 787 788 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 789 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 790 adaptor.getPosition()); 791 return success(); 792 } 793 }; 794 795 class VectorInsertOpConversion 796 : public ConvertOpToLLVMPattern<vector::InsertOp> { 797 public: 798 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 799 800 LogicalResult 801 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 802 ConversionPatternRewriter &rewriter) const override { 803 auto loc = insertOp->getLoc(); 804 auto sourceType = insertOp.getSourceType(); 805 auto destVectorType = insertOp.getDestVectorType(); 806 auto llvmResultType = typeConverter->convertType(destVectorType); 807 auto positionArrayAttr = insertOp.getPosition(); 808 809 // Bail if result type cannot be lowered. 810 if (!llvmResultType) 811 return failure(); 812 813 // Overwrite entire vector with value. Should be handled by folder, but 814 // just to be safe. 815 if (positionArrayAttr.empty()) { 816 rewriter.replaceOp(insertOp, adaptor.getSource()); 817 return success(); 818 } 819 820 // One-shot insertion of a vector into an array (only requires insertvalue). 821 if (sourceType.isa<VectorType>()) { 822 Value inserted = rewriter.create<LLVM::InsertValueOp>( 823 loc, adaptor.getDest(), adaptor.getSource(), 824 LLVM::convertArrayToIndices(positionArrayAttr)); 825 rewriter.replaceOp(insertOp, inserted); 826 return success(); 827 } 828 829 // Potential extraction of 1-D vector from array. 830 Value extracted = adaptor.getDest(); 831 auto positionAttrs = positionArrayAttr.getValue(); 832 auto position = positionAttrs.back().cast<IntegerAttr>(); 833 auto oneDVectorType = destVectorType; 834 if (positionAttrs.size() > 1) { 835 oneDVectorType = reducedVectorTypeBack(destVectorType); 836 extracted = rewriter.create<LLVM::ExtractValueOp>( 837 loc, extracted, 838 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 839 } 840 841 // Insertion of an element into a 1-D LLVM vector. 842 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 843 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 844 Value inserted = rewriter.create<LLVM::InsertElementOp>( 845 loc, typeConverter->convertType(oneDVectorType), extracted, 846 adaptor.getSource(), constant); 847 848 // Potential insertion of resulting 1-D vector into array. 849 if (positionAttrs.size() > 1) { 850 inserted = rewriter.create<LLVM::InsertValueOp>( 851 loc, adaptor.getDest(), inserted, 852 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 853 } 854 855 rewriter.replaceOp(insertOp, inserted); 856 return success(); 857 } 858 }; 859 860 /// Lower vector.scalable.insert ops to LLVM vector.insert 861 struct VectorScalableInsertOpLowering 862 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 863 using ConvertOpToLLVMPattern< 864 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 865 866 LogicalResult 867 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 868 ConversionPatternRewriter &rewriter) const override { 869 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 870 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 871 return success(); 872 } 873 }; 874 875 /// Lower vector.scalable.extract ops to LLVM vector.extract 876 struct VectorScalableExtractOpLowering 877 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 878 using ConvertOpToLLVMPattern< 879 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 880 881 LogicalResult 882 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 883 ConversionPatternRewriter &rewriter) const override { 884 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 885 extOp, typeConverter->convertType(extOp.getResultVectorType()), 886 adaptor.getSource(), adaptor.getPos()); 887 return success(); 888 } 889 }; 890 891 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 892 /// 893 /// Example: 894 /// ``` 895 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 896 /// ``` 897 /// is rewritten into: 898 /// ``` 899 /// %r = splat %f0: vector<2x4xf32> 900 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 901 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 902 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 903 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 904 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 905 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 906 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 907 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 908 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 909 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 910 /// // %r3 holds the final value. 911 /// ``` 912 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 913 public: 914 using OpRewritePattern<FMAOp>::OpRewritePattern; 915 916 void initialize() { 917 // This pattern recursively unpacks one dimension at a time. The recursion 918 // bounded as the rank is strictly decreasing. 919 setHasBoundedRewriteRecursion(); 920 } 921 922 LogicalResult matchAndRewrite(FMAOp op, 923 PatternRewriter &rewriter) const override { 924 auto vType = op.getVectorType(); 925 if (vType.getRank() < 2) 926 return failure(); 927 928 auto loc = op.getLoc(); 929 auto elemType = vType.getElementType(); 930 Value zero = rewriter.create<arith::ConstantOp>( 931 loc, elemType, rewriter.getZeroAttr(elemType)); 932 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 933 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 934 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 935 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 936 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 937 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 938 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 939 } 940 rewriter.replaceOp(op, desc); 941 return success(); 942 } 943 }; 944 945 /// Returns the strides if the memory underlying `memRefType` has a contiguous 946 /// static layout. 947 static std::optional<SmallVector<int64_t, 4>> 948 computeContiguousStrides(MemRefType memRefType) { 949 int64_t offset; 950 SmallVector<int64_t, 4> strides; 951 if (failed(getStridesAndOffset(memRefType, strides, offset))) 952 return std::nullopt; 953 if (!strides.empty() && strides.back() != 1) 954 return std::nullopt; 955 // If no layout or identity layout, this is contiguous by definition. 956 if (memRefType.getLayout().isIdentity()) 957 return strides; 958 959 // Otherwise, we must determine contiguity form shapes. This can only ever 960 // work in static cases because MemRefType is underspecified to represent 961 // contiguous dynamic shapes in other ways than with just empty/identity 962 // layout. 963 auto sizes = memRefType.getShape(); 964 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 965 if (ShapedType::isDynamic(sizes[index + 1]) || 966 ShapedType::isDynamic(strides[index]) || 967 ShapedType::isDynamic(strides[index + 1])) 968 return std::nullopt; 969 if (strides[index] != strides[index + 1] * sizes[index + 1]) 970 return std::nullopt; 971 } 972 return strides; 973 } 974 975 class VectorTypeCastOpConversion 976 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 977 public: 978 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 979 980 LogicalResult 981 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 982 ConversionPatternRewriter &rewriter) const override { 983 auto loc = castOp->getLoc(); 984 MemRefType sourceMemRefType = 985 castOp.getOperand().getType().cast<MemRefType>(); 986 MemRefType targetMemRefType = castOp.getType(); 987 988 // Only static shape casts supported atm. 989 if (!sourceMemRefType.hasStaticShape() || 990 !targetMemRefType.hasStaticShape()) 991 return failure(); 992 993 auto llvmSourceDescriptorTy = 994 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 995 if (!llvmSourceDescriptorTy) 996 return failure(); 997 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 998 999 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1000 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1001 if (!llvmTargetDescriptorTy) 1002 return failure(); 1003 1004 // Only contiguous source buffers supported atm. 1005 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1006 if (!sourceStrides) 1007 return failure(); 1008 auto targetStrides = computeContiguousStrides(targetMemRefType); 1009 if (!targetStrides) 1010 return failure(); 1011 // Only support static strides for now, regardless of contiguity. 1012 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1013 return failure(); 1014 1015 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1016 1017 // Create descriptor. 1018 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1019 Type llvmTargetElementTy = desc.getElementPtrType(); 1020 // Set allocated ptr. 1021 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1022 allocated = 1023 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1024 desc.setAllocatedPtr(rewriter, loc, allocated); 1025 // Set aligned ptr. 1026 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1027 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1028 desc.setAlignedPtr(rewriter, loc, ptr); 1029 // Fill offset 0. 1030 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1031 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1032 desc.setOffset(rewriter, loc, zero); 1033 1034 // Fill size and stride descriptors in memref. 1035 for (const auto &indexedSize : 1036 llvm::enumerate(targetMemRefType.getShape())) { 1037 int64_t index = indexedSize.index(); 1038 auto sizeAttr = 1039 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1040 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1041 desc.setSize(rewriter, loc, index, size); 1042 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1043 (*targetStrides)[index]); 1044 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1045 desc.setStride(rewriter, loc, index, stride); 1046 } 1047 1048 rewriter.replaceOp(castOp, {desc}); 1049 return success(); 1050 } 1051 }; 1052 1053 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1054 /// Non-scalable versions of this operation are handled in Vector Transforms. 1055 class VectorCreateMaskOpRewritePattern 1056 : public OpRewritePattern<vector::CreateMaskOp> { 1057 public: 1058 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1059 bool enableIndexOpt) 1060 : OpRewritePattern<vector::CreateMaskOp>(context), 1061 force32BitVectorIndices(enableIndexOpt) {} 1062 1063 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1064 PatternRewriter &rewriter) const override { 1065 auto dstType = op.getType(); 1066 if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable()) 1067 return failure(); 1068 IntegerType idxType = 1069 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1070 auto loc = op->getLoc(); 1071 Value indices = rewriter.create<LLVM::StepVectorOp>( 1072 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1073 /*isScalable=*/true)); 1074 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1075 op.getOperand(0)); 1076 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1077 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1078 indices, bounds); 1079 rewriter.replaceOp(op, comp); 1080 return success(); 1081 } 1082 1083 private: 1084 const bool force32BitVectorIndices; 1085 }; 1086 1087 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1088 public: 1089 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1090 1091 // Proof-of-concept lowering implementation that relies on a small 1092 // runtime support library, which only needs to provide a few 1093 // printing methods (single value for all data types, opening/closing 1094 // bracket, comma, newline). The lowering fully unrolls a vector 1095 // in terms of these elementary printing operations. The advantage 1096 // of this approach is that the library can remain unaware of all 1097 // low-level implementation details of vectors while still supporting 1098 // output of any shaped and dimensioned vector. Due to full unrolling, 1099 // this approach is less suited for very large vectors though. 1100 // 1101 // TODO: rely solely on libc in future? something else? 1102 // 1103 LogicalResult 1104 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1105 ConversionPatternRewriter &rewriter) const override { 1106 Type printType = printOp.getPrintType(); 1107 1108 if (typeConverter->convertType(printType) == nullptr) 1109 return failure(); 1110 1111 // Make sure element type has runtime support. 1112 PrintConversion conversion = PrintConversion::None; 1113 VectorType vectorType = printType.dyn_cast<VectorType>(); 1114 Type eltType = vectorType ? vectorType.getElementType() : printType; 1115 Operation *printer; 1116 if (eltType.isF32()) { 1117 printer = 1118 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 1119 } else if (eltType.isF64()) { 1120 printer = 1121 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 1122 } else if (eltType.isIndex()) { 1123 printer = 1124 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 1125 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1126 // Integers need a zero or sign extension on the operand 1127 // (depending on the source type) as well as a signed or 1128 // unsigned print method. Up to 64-bit is supported. 1129 unsigned width = intTy.getWidth(); 1130 if (intTy.isUnsigned()) { 1131 if (width <= 64) { 1132 if (width < 64) 1133 conversion = PrintConversion::ZeroExt64; 1134 printer = LLVM::lookupOrCreatePrintU64Fn( 1135 printOp->getParentOfType<ModuleOp>()); 1136 } else { 1137 return failure(); 1138 } 1139 } else { 1140 assert(intTy.isSignless() || intTy.isSigned()); 1141 if (width <= 64) { 1142 // Note that we *always* zero extend booleans (1-bit integers), 1143 // so that true/false is printed as 1/0 rather than -1/0. 1144 if (width == 1) 1145 conversion = PrintConversion::ZeroExt64; 1146 else if (width < 64) 1147 conversion = PrintConversion::SignExt64; 1148 printer = LLVM::lookupOrCreatePrintI64Fn( 1149 printOp->getParentOfType<ModuleOp>()); 1150 } else { 1151 return failure(); 1152 } 1153 } 1154 } else { 1155 return failure(); 1156 } 1157 1158 // Unroll vector into elementary print calls. 1159 int64_t rank = vectorType ? vectorType.getRank() : 0; 1160 Type type = vectorType ? vectorType : eltType; 1161 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1162 conversion); 1163 emitCall(rewriter, printOp->getLoc(), 1164 LLVM::lookupOrCreatePrintNewlineFn( 1165 printOp->getParentOfType<ModuleOp>())); 1166 rewriter.eraseOp(printOp); 1167 return success(); 1168 } 1169 1170 private: 1171 enum class PrintConversion { 1172 // clang-format off 1173 None, 1174 ZeroExt64, 1175 SignExt64 1176 // clang-format on 1177 }; 1178 1179 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1180 Value value, Type type, Operation *printer, int64_t rank, 1181 PrintConversion conversion) const { 1182 VectorType vectorType = type.dyn_cast<VectorType>(); 1183 Location loc = op->getLoc(); 1184 if (!vectorType) { 1185 assert(rank == 0 && "The scalar case expects rank == 0"); 1186 switch (conversion) { 1187 case PrintConversion::ZeroExt64: 1188 value = rewriter.create<arith::ExtUIOp>( 1189 loc, IntegerType::get(rewriter.getContext(), 64), value); 1190 break; 1191 case PrintConversion::SignExt64: 1192 value = rewriter.create<arith::ExtSIOp>( 1193 loc, IntegerType::get(rewriter.getContext(), 64), value); 1194 break; 1195 case PrintConversion::None: 1196 break; 1197 } 1198 emitCall(rewriter, loc, printer, value); 1199 return; 1200 } 1201 1202 emitCall(rewriter, loc, 1203 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1204 Operation *printComma = 1205 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1206 1207 if (rank <= 1) { 1208 auto reducedType = vectorType.getElementType(); 1209 auto llvmType = typeConverter->convertType(reducedType); 1210 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1211 for (int64_t d = 0; d < dim; ++d) { 1212 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1213 llvmType, /*rank=*/0, /*pos=*/d); 1214 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1215 conversion); 1216 if (d != dim - 1) 1217 emitCall(rewriter, loc, printComma); 1218 } 1219 emitCall( 1220 rewriter, loc, 1221 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1222 return; 1223 } 1224 1225 int64_t dim = vectorType.getDimSize(0); 1226 for (int64_t d = 0; d < dim; ++d) { 1227 auto reducedType = reducedVectorTypeFront(vectorType); 1228 auto llvmType = typeConverter->convertType(reducedType); 1229 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1230 llvmType, rank, d); 1231 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1232 conversion); 1233 if (d != dim - 1) 1234 emitCall(rewriter, loc, printComma); 1235 } 1236 emitCall(rewriter, loc, 1237 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1238 } 1239 1240 // Helper to emit a call. 1241 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1242 Operation *ref, ValueRange params = ValueRange()) { 1243 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1244 params); 1245 } 1246 }; 1247 1248 /// The Splat operation is lowered to an insertelement + a shufflevector 1249 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1250 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1251 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1252 1253 LogicalResult 1254 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1255 ConversionPatternRewriter &rewriter) const override { 1256 VectorType resultType = splatOp.getType().cast<VectorType>(); 1257 if (resultType.getRank() > 1) 1258 return failure(); 1259 1260 // First insert it into an undef vector so we can shuffle it. 1261 auto vectorType = typeConverter->convertType(splatOp.getType()); 1262 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1263 auto zero = rewriter.create<LLVM::ConstantOp>( 1264 splatOp.getLoc(), 1265 typeConverter->convertType(rewriter.getIntegerType(32)), 1266 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1267 1268 // For 0-d vector, we simply do `insertelement`. 1269 if (resultType.getRank() == 0) { 1270 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1271 splatOp, vectorType, undef, adaptor.getInput(), zero); 1272 return success(); 1273 } 1274 1275 // For 1-d vector, we additionally do a `vectorshuffle`. 1276 auto v = rewriter.create<LLVM::InsertElementOp>( 1277 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1278 1279 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); 1280 SmallVector<int32_t> zeroValues(width, 0); 1281 1282 // Shuffle the value across the desired number of elements. 1283 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1284 zeroValues); 1285 return success(); 1286 } 1287 }; 1288 1289 /// The Splat operation is lowered to an insertelement + a shufflevector 1290 /// operation. Splat to only 2+-d vector result types are lowered by the 1291 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1292 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1293 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1294 1295 LogicalResult 1296 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1297 ConversionPatternRewriter &rewriter) const override { 1298 VectorType resultType = splatOp.getType(); 1299 if (resultType.getRank() <= 1) 1300 return failure(); 1301 1302 // First insert it into an undef vector so we can shuffle it. 1303 auto loc = splatOp.getLoc(); 1304 auto vectorTypeInfo = 1305 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1306 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1307 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1308 if (!llvmNDVectorTy || !llvm1DVectorTy) 1309 return failure(); 1310 1311 // Construct returned value. 1312 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1313 1314 // Construct a 1-D vector with the splatted value that we insert in all the 1315 // places within the returned descriptor. 1316 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1317 auto zero = rewriter.create<LLVM::ConstantOp>( 1318 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1319 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1320 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1321 adaptor.getInput(), zero); 1322 1323 // Shuffle the value across the desired number of elements. 1324 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1325 SmallVector<int32_t> zeroValues(width, 0); 1326 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1327 1328 // Iterate of linear index, convert to coords space and insert splatted 1-D 1329 // vector in each position. 1330 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1331 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1332 }); 1333 rewriter.replaceOp(splatOp, desc); 1334 return success(); 1335 } 1336 }; 1337 1338 } // namespace 1339 1340 /// Populate the given list with patterns that convert from Vector to LLVM. 1341 void mlir::populateVectorToLLVMConversionPatterns( 1342 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1343 bool reassociateFPReductions, bool force32BitVectorIndices) { 1344 MLIRContext *ctx = converter.getDialect()->getContext(); 1345 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1346 populateVectorInsertExtractStridedSliceTransforms(patterns); 1347 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1348 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1349 patterns 1350 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1351 VectorExtractElementOpConversion, VectorExtractOpConversion, 1352 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1353 VectorInsertOpConversion, VectorPrintOpConversion, 1354 VectorTypeCastOpConversion, VectorScaleOpConversion, 1355 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1356 VectorLoadStoreConversion<vector::MaskedLoadOp, 1357 vector::MaskedLoadOpAdaptor>, 1358 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1359 VectorLoadStoreConversion<vector::MaskedStoreOp, 1360 vector::MaskedStoreOpAdaptor>, 1361 VectorGatherOpConversion, VectorScatterOpConversion, 1362 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1363 VectorSplatOpLowering, VectorSplatNdOpLowering, 1364 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>( 1365 converter); 1366 // Transfer ops with rank > 1 are handled by VectorToSCF. 1367 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1368 } 1369 1370 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1371 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1372 patterns.add<VectorMatmulOpConversion>(converter); 1373 patterns.add<VectorFlatTransposeOpConversion>(converter); 1374 } 1375