1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Support/MathExtras.h" 20 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 26 // Helper to reduce vector type by one rank at front. 27 static VectorType reducedVectorTypeFront(VectorType tp) { 28 assert((tp.getRank() > 1) && "unlowerable vector type"); 29 unsigned numScalableDims = tp.getNumScalableDims(); 30 if (tp.getShape().size() == numScalableDims) 31 --numScalableDims; 32 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 33 numScalableDims); 34 } 35 36 // Helper to reduce vector type by *all* but one rank at back. 37 static VectorType reducedVectorTypeBack(VectorType tp) { 38 assert((tp.getRank() > 1) && "unlowerable vector type"); 39 unsigned numScalableDims = tp.getNumScalableDims(); 40 if (numScalableDims > 0) 41 --numScalableDims; 42 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 43 numScalableDims); 44 } 45 46 // Helper that picks the proper sequence for inserting. 47 static Value insertOne(ConversionPatternRewriter &rewriter, 48 LLVMTypeConverter &typeConverter, Location loc, 49 Value val1, Value val2, Type llvmType, int64_t rank, 50 int64_t pos) { 51 assert(rank > 0 && "0-D vector corner case should have been handled already"); 52 if (rank == 1) { 53 auto idxType = rewriter.getIndexType(); 54 auto constant = rewriter.create<LLVM::ConstantOp>( 55 loc, typeConverter.convertType(idxType), 56 rewriter.getIntegerAttr(idxType, pos)); 57 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 58 constant); 59 } 60 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 61 rewriter.getI64ArrayAttr(pos)); 62 } 63 64 // Helper that picks the proper sequence for extracting. 65 static Value extractOne(ConversionPatternRewriter &rewriter, 66 LLVMTypeConverter &typeConverter, Location loc, 67 Value val, Type llvmType, int64_t rank, int64_t pos) { 68 if (rank <= 1) { 69 auto idxType = rewriter.getIndexType(); 70 auto constant = rewriter.create<LLVM::ConstantOp>( 71 loc, typeConverter.convertType(idxType), 72 rewriter.getIntegerAttr(idxType, pos)); 73 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 74 constant); 75 } 76 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 77 rewriter.getI64ArrayAttr(pos)); 78 } 79 80 // Helper that returns data layout alignment of a memref. 81 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 82 MemRefType memrefType, unsigned &align) { 83 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 84 if (!elementTy) 85 return failure(); 86 87 // TODO: this should use the MLIR data layout when it becomes available and 88 // stop depending on translation. 89 llvm::LLVMContext llvmContext; 90 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 91 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 92 return success(); 93 } 94 95 // Add an index vector component to a base pointer. This almost always succeeds 96 // unless the last stride is non-unit or the memory space is not zero. 97 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 98 Location loc, Value memref, Value base, 99 Value index, MemRefType memRefType, 100 VectorType vType, Value &ptrs) { 101 int64_t offset; 102 SmallVector<int64_t, 4> strides; 103 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 104 if (failed(successStrides) || strides.back() != 1 || 105 memRefType.getMemorySpaceAsInt() != 0) 106 return failure(); 107 auto pType = MemRefDescriptor(memref).getElementPtrType(); 108 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 109 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index); 110 return success(); 111 } 112 113 // Casts a strided element pointer to a vector pointer. The vector pointer 114 // will be in the same address space as the incoming memref type. 115 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 116 Value ptr, MemRefType memRefType, Type vt) { 117 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); 118 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 119 } 120 121 namespace { 122 123 /// Trivial Vector to LLVM conversions 124 using VectorScaleOpConversion = 125 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 126 127 /// Conversion pattern for a vector.bitcast. 128 class VectorBitCastOpConversion 129 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 130 public: 131 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 132 133 LogicalResult 134 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 135 ConversionPatternRewriter &rewriter) const override { 136 // Only 0-D and 1-D vectors can be lowered to LLVM. 137 VectorType resultTy = bitCastOp.getResultVectorType(); 138 if (resultTy.getRank() > 1) 139 return failure(); 140 Type newResultTy = typeConverter->convertType(resultTy); 141 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 142 adaptor.getOperands()[0]); 143 return success(); 144 } 145 }; 146 147 /// Conversion pattern for a vector.matrix_multiply. 148 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 149 class VectorMatmulOpConversion 150 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 151 public: 152 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 153 154 LogicalResult 155 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 156 ConversionPatternRewriter &rewriter) const override { 157 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 158 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 159 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 160 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 161 return success(); 162 } 163 }; 164 165 /// Conversion pattern for a vector.flat_transpose. 166 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 167 class VectorFlatTransposeOpConversion 168 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 169 public: 170 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 171 172 LogicalResult 173 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 174 ConversionPatternRewriter &rewriter) const override { 175 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 176 transOp, typeConverter->convertType(transOp.getRes().getType()), 177 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 178 return success(); 179 } 180 }; 181 182 /// Overloaded utility that replaces a vector.load, vector.store, 183 /// vector.maskedload and vector.maskedstore with their respective LLVM 184 /// couterparts. 185 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 186 vector::LoadOpAdaptor adaptor, 187 VectorType vectorTy, Value ptr, unsigned align, 188 ConversionPatternRewriter &rewriter) { 189 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 190 } 191 192 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 193 vector::MaskedLoadOpAdaptor adaptor, 194 VectorType vectorTy, Value ptr, unsigned align, 195 ConversionPatternRewriter &rewriter) { 196 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 197 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 198 } 199 200 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 201 vector::StoreOpAdaptor adaptor, 202 VectorType vectorTy, Value ptr, unsigned align, 203 ConversionPatternRewriter &rewriter) { 204 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 205 ptr, align); 206 } 207 208 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 209 vector::MaskedStoreOpAdaptor adaptor, 210 VectorType vectorTy, Value ptr, unsigned align, 211 ConversionPatternRewriter &rewriter) { 212 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 213 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 214 } 215 216 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 217 /// vector.maskedstore. 218 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 219 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 220 public: 221 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 222 223 LogicalResult 224 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 225 typename LoadOrStoreOp::Adaptor adaptor, 226 ConversionPatternRewriter &rewriter) const override { 227 // Only 1-D vectors can be lowered to LLVM. 228 VectorType vectorTy = loadOrStoreOp.getVectorType(); 229 if (vectorTy.getRank() > 1) 230 return failure(); 231 232 auto loc = loadOrStoreOp->getLoc(); 233 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 234 235 // Resolve alignment. 236 unsigned align; 237 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 238 return failure(); 239 240 // Resolve address. 241 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 242 .template cast<VectorType>(); 243 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 244 adaptor.getIndices(), rewriter); 245 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 246 247 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 248 return success(); 249 } 250 }; 251 252 /// Conversion pattern for a vector.gather. 253 class VectorGatherOpConversion 254 : public ConvertOpToLLVMPattern<vector::GatherOp> { 255 public: 256 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 257 258 LogicalResult 259 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 260 ConversionPatternRewriter &rewriter) const override { 261 auto loc = gather->getLoc(); 262 MemRefType memRefType = gather.getMemRefType(); 263 264 // Resolve alignment. 265 unsigned align; 266 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 267 return failure(); 268 269 // Resolve address. 270 Value ptrs; 271 VectorType vType = gather.getVectorType(); 272 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 273 adaptor.getIndices(), rewriter); 274 if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, 275 adaptor.getIndexVec(), memRefType, vType, ptrs))) 276 return failure(); 277 278 // Replace with the gather intrinsic. 279 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 280 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 281 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 282 return success(); 283 } 284 }; 285 286 /// Conversion pattern for a vector.scatter. 287 class VectorScatterOpConversion 288 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 289 public: 290 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 291 292 LogicalResult 293 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 294 ConversionPatternRewriter &rewriter) const override { 295 auto loc = scatter->getLoc(); 296 MemRefType memRefType = scatter.getMemRefType(); 297 298 // Resolve alignment. 299 unsigned align; 300 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 301 return failure(); 302 303 // Resolve address. 304 Value ptrs; 305 VectorType vType = scatter.getVectorType(); 306 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 307 adaptor.getIndices(), rewriter); 308 if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, 309 adaptor.getIndexVec(), memRefType, vType, ptrs))) 310 return failure(); 311 312 // Replace with the scatter intrinsic. 313 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 314 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 315 rewriter.getI32IntegerAttr(align)); 316 return success(); 317 } 318 }; 319 320 /// Conversion pattern for a vector.expandload. 321 class VectorExpandLoadOpConversion 322 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 323 public: 324 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 325 326 LogicalResult 327 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 328 ConversionPatternRewriter &rewriter) const override { 329 auto loc = expand->getLoc(); 330 MemRefType memRefType = expand.getMemRefType(); 331 332 // Resolve address. 333 auto vtype = typeConverter->convertType(expand.getVectorType()); 334 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 335 adaptor.getIndices(), rewriter); 336 337 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 338 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 339 return success(); 340 } 341 }; 342 343 /// Conversion pattern for a vector.compressstore. 344 class VectorCompressStoreOpConversion 345 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 346 public: 347 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 348 349 LogicalResult 350 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 351 ConversionPatternRewriter &rewriter) const override { 352 auto loc = compress->getLoc(); 353 MemRefType memRefType = compress.getMemRefType(); 354 355 // Resolve address. 356 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 357 adaptor.getIndices(), rewriter); 358 359 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 360 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 361 return success(); 362 } 363 }; 364 365 /// Conversion pattern for all vector reductions. 366 class VectorReductionOpConversion 367 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 368 public: 369 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 370 bool reassociateFPRed) 371 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 372 reassociateFPReductions(reassociateFPRed) {} 373 374 LogicalResult 375 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 376 ConversionPatternRewriter &rewriter) const override { 377 auto kind = reductionOp.getKind(); 378 Type eltType = reductionOp.getDest().getType(); 379 Type llvmType = typeConverter->convertType(eltType); 380 Value operand = adaptor.getOperands()[0]; 381 if (eltType.isIntOrIndex()) { 382 // Integer reductions: add/mul/min/max/and/or/xor. 383 if (kind == vector::CombiningKind::ADD) 384 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp, 385 llvmType, operand); 386 else if (kind == vector::CombiningKind::MUL) 387 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp, 388 llvmType, operand); 389 else if (kind == vector::CombiningKind::MINUI) 390 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 391 reductionOp, llvmType, operand); 392 else if (kind == vector::CombiningKind::MINSI) 393 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 394 reductionOp, llvmType, operand); 395 else if (kind == vector::CombiningKind::MAXUI) 396 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 397 reductionOp, llvmType, operand); 398 else if (kind == vector::CombiningKind::MAXSI) 399 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 400 reductionOp, llvmType, operand); 401 else if (kind == vector::CombiningKind::AND) 402 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp, 403 llvmType, operand); 404 else if (kind == vector::CombiningKind::OR) 405 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp, 406 llvmType, operand); 407 else if (kind == vector::CombiningKind::XOR) 408 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp, 409 llvmType, operand); 410 else 411 return failure(); 412 return success(); 413 } 414 415 if (!eltType.isa<FloatType>()) 416 return failure(); 417 418 // Floating-point reductions: add/mul/min/max 419 if (kind == vector::CombiningKind::ADD) { 420 // Optional accumulator (or zero). 421 Value acc = adaptor.getOperands().size() > 1 422 ? adaptor.getOperands()[1] 423 : rewriter.create<LLVM::ConstantOp>( 424 reductionOp->getLoc(), llvmType, 425 rewriter.getZeroAttr(eltType)); 426 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 427 reductionOp, llvmType, acc, operand, 428 rewriter.getBoolAttr(reassociateFPReductions)); 429 } else if (kind == vector::CombiningKind::MUL) { 430 // Optional accumulator (or one). 431 Value acc = adaptor.getOperands().size() > 1 432 ? adaptor.getOperands()[1] 433 : rewriter.create<LLVM::ConstantOp>( 434 reductionOp->getLoc(), llvmType, 435 rewriter.getFloatAttr(eltType, 1.0)); 436 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 437 reductionOp, llvmType, acc, operand, 438 rewriter.getBoolAttr(reassociateFPReductions)); 439 } else if (kind == vector::CombiningKind::MINF) 440 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 441 // NaNs/-0.0/+0.0 in the same way. 442 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp, 443 llvmType, operand); 444 else if (kind == vector::CombiningKind::MAXF) 445 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 446 // NaNs/-0.0/+0.0 in the same way. 447 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp, 448 llvmType, operand); 449 else 450 return failure(); 451 return success(); 452 } 453 454 private: 455 const bool reassociateFPReductions; 456 }; 457 458 class VectorShuffleOpConversion 459 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 460 public: 461 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 462 463 LogicalResult 464 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 465 ConversionPatternRewriter &rewriter) const override { 466 auto loc = shuffleOp->getLoc(); 467 auto v1Type = shuffleOp.getV1VectorType(); 468 auto v2Type = shuffleOp.getV2VectorType(); 469 auto vectorType = shuffleOp.getVectorType(); 470 Type llvmType = typeConverter->convertType(vectorType); 471 auto maskArrayAttr = shuffleOp.getMask(); 472 473 // Bail if result type cannot be lowered. 474 if (!llvmType) 475 return failure(); 476 477 // Get rank and dimension sizes. 478 int64_t rank = vectorType.getRank(); 479 assert(v1Type.getRank() == rank); 480 assert(v2Type.getRank() == rank); 481 int64_t v1Dim = v1Type.getDimSize(0); 482 483 // For rank 1, where both operands have *exactly* the same vector type, 484 // there is direct shuffle support in LLVM. Use it! 485 if (rank == 1 && v1Type == v2Type) { 486 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 487 loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr); 488 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 489 return success(); 490 } 491 492 // For all other cases, insert the individual values individually. 493 Type eltType; 494 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 495 eltType = arrayType.getElementType(); 496 else 497 eltType = llvmType.cast<VectorType>().getElementType(); 498 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 499 int64_t insPos = 0; 500 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 501 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 502 Value value = adaptor.getV1(); 503 if (extPos >= v1Dim) { 504 extPos -= v1Dim; 505 value = adaptor.getV2(); 506 } 507 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 508 eltType, rank, extPos); 509 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 510 llvmType, rank, insPos++); 511 } 512 rewriter.replaceOp(shuffleOp, insert); 513 return success(); 514 } 515 }; 516 517 class VectorExtractElementOpConversion 518 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 519 public: 520 using ConvertOpToLLVMPattern< 521 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 522 523 LogicalResult 524 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 525 ConversionPatternRewriter &rewriter) const override { 526 auto vectorType = extractEltOp.getVectorType(); 527 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 528 529 // Bail if result type cannot be lowered. 530 if (!llvmType) 531 return failure(); 532 533 if (vectorType.getRank() == 0) { 534 Location loc = extractEltOp.getLoc(); 535 auto idxType = rewriter.getIndexType(); 536 auto zero = rewriter.create<LLVM::ConstantOp>( 537 loc, typeConverter->convertType(idxType), 538 rewriter.getIntegerAttr(idxType, 0)); 539 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 540 extractEltOp, llvmType, adaptor.getVector(), zero); 541 return success(); 542 } 543 544 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 545 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 546 return success(); 547 } 548 }; 549 550 class VectorExtractOpConversion 551 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 552 public: 553 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 554 555 LogicalResult 556 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 557 ConversionPatternRewriter &rewriter) const override { 558 auto loc = extractOp->getLoc(); 559 auto vectorType = extractOp.getVectorType(); 560 auto resultType = extractOp.getResult().getType(); 561 auto llvmResultType = typeConverter->convertType(resultType); 562 auto positionArrayAttr = extractOp.getPosition(); 563 564 // Bail if result type cannot be lowered. 565 if (!llvmResultType) 566 return failure(); 567 568 // Extract entire vector. Should be handled by folder, but just to be safe. 569 if (positionArrayAttr.empty()) { 570 rewriter.replaceOp(extractOp, adaptor.getVector()); 571 return success(); 572 } 573 574 // One-shot extraction of vector from array (only requires extractvalue). 575 if (resultType.isa<VectorType>()) { 576 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 577 loc, llvmResultType, adaptor.getVector(), positionArrayAttr); 578 rewriter.replaceOp(extractOp, extracted); 579 return success(); 580 } 581 582 // Potential extraction of 1-D vector from array. 583 auto *context = extractOp->getContext(); 584 Value extracted = adaptor.getVector(); 585 auto positionAttrs = positionArrayAttr.getValue(); 586 if (positionAttrs.size() > 1) { 587 auto oneDVectorType = reducedVectorTypeBack(vectorType); 588 auto nMinusOnePositionAttrs = 589 ArrayAttr::get(context, positionAttrs.drop_back()); 590 extracted = rewriter.create<LLVM::ExtractValueOp>( 591 loc, typeConverter->convertType(oneDVectorType), extracted, 592 nMinusOnePositionAttrs); 593 } 594 595 // Remaining extraction of element from 1-D LLVM vector 596 auto position = positionAttrs.back().cast<IntegerAttr>(); 597 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 598 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 599 extracted = 600 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 601 rewriter.replaceOp(extractOp, extracted); 602 603 return success(); 604 } 605 }; 606 607 /// Conversion pattern that turns a vector.fma on a 1-D vector 608 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 609 /// This does not match vectors of n >= 2 rank. 610 /// 611 /// Example: 612 /// ``` 613 /// vector.fma %a, %a, %a : vector<8xf32> 614 /// ``` 615 /// is converted to: 616 /// ``` 617 /// llvm.intr.fmuladd %va, %va, %va: 618 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 619 /// -> !llvm."<8 x f32>"> 620 /// ``` 621 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 622 public: 623 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 624 625 LogicalResult 626 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 627 ConversionPatternRewriter &rewriter) const override { 628 VectorType vType = fmaOp.getVectorType(); 629 if (vType.getRank() != 1) 630 return failure(); 631 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 632 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 633 return success(); 634 } 635 }; 636 637 class VectorInsertElementOpConversion 638 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 639 public: 640 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 641 642 LogicalResult 643 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 644 ConversionPatternRewriter &rewriter) const override { 645 auto vectorType = insertEltOp.getDestVectorType(); 646 auto llvmType = typeConverter->convertType(vectorType); 647 648 // Bail if result type cannot be lowered. 649 if (!llvmType) 650 return failure(); 651 652 if (vectorType.getRank() == 0) { 653 Location loc = insertEltOp.getLoc(); 654 auto idxType = rewriter.getIndexType(); 655 auto zero = rewriter.create<LLVM::ConstantOp>( 656 loc, typeConverter->convertType(idxType), 657 rewriter.getIntegerAttr(idxType, 0)); 658 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 659 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 660 return success(); 661 } 662 663 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 664 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 665 adaptor.getPosition()); 666 return success(); 667 } 668 }; 669 670 class VectorInsertOpConversion 671 : public ConvertOpToLLVMPattern<vector::InsertOp> { 672 public: 673 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 674 675 LogicalResult 676 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 677 ConversionPatternRewriter &rewriter) const override { 678 auto loc = insertOp->getLoc(); 679 auto sourceType = insertOp.getSourceType(); 680 auto destVectorType = insertOp.getDestVectorType(); 681 auto llvmResultType = typeConverter->convertType(destVectorType); 682 auto positionArrayAttr = insertOp.getPosition(); 683 684 // Bail if result type cannot be lowered. 685 if (!llvmResultType) 686 return failure(); 687 688 // Overwrite entire vector with value. Should be handled by folder, but 689 // just to be safe. 690 if (positionArrayAttr.empty()) { 691 rewriter.replaceOp(insertOp, adaptor.getSource()); 692 return success(); 693 } 694 695 // One-shot insertion of a vector into an array (only requires insertvalue). 696 if (sourceType.isa<VectorType>()) { 697 Value inserted = rewriter.create<LLVM::InsertValueOp>( 698 loc, llvmResultType, adaptor.getDest(), adaptor.getSource(), 699 positionArrayAttr); 700 rewriter.replaceOp(insertOp, inserted); 701 return success(); 702 } 703 704 // Potential extraction of 1-D vector from array. 705 auto *context = insertOp->getContext(); 706 Value extracted = adaptor.getDest(); 707 auto positionAttrs = positionArrayAttr.getValue(); 708 auto position = positionAttrs.back().cast<IntegerAttr>(); 709 auto oneDVectorType = destVectorType; 710 if (positionAttrs.size() > 1) { 711 oneDVectorType = reducedVectorTypeBack(destVectorType); 712 auto nMinusOnePositionAttrs = 713 ArrayAttr::get(context, positionAttrs.drop_back()); 714 extracted = rewriter.create<LLVM::ExtractValueOp>( 715 loc, typeConverter->convertType(oneDVectorType), extracted, 716 nMinusOnePositionAttrs); 717 } 718 719 // Insertion of an element into a 1-D LLVM vector. 720 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 721 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 722 Value inserted = rewriter.create<LLVM::InsertElementOp>( 723 loc, typeConverter->convertType(oneDVectorType), extracted, 724 adaptor.getSource(), constant); 725 726 // Potential insertion of resulting 1-D vector into array. 727 if (positionAttrs.size() > 1) { 728 auto nMinusOnePositionAttrs = 729 ArrayAttr::get(context, positionAttrs.drop_back()); 730 inserted = rewriter.create<LLVM::InsertValueOp>( 731 loc, llvmResultType, adaptor.getDest(), inserted, 732 nMinusOnePositionAttrs); 733 } 734 735 rewriter.replaceOp(insertOp, inserted); 736 return success(); 737 } 738 }; 739 740 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 741 /// 742 /// Example: 743 /// ``` 744 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 745 /// ``` 746 /// is rewritten into: 747 /// ``` 748 /// %r = splat %f0: vector<2x4xf32> 749 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 750 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 751 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 752 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 753 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 754 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 755 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 756 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 757 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 758 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 759 /// // %r3 holds the final value. 760 /// ``` 761 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 762 public: 763 using OpRewritePattern<FMAOp>::OpRewritePattern; 764 765 void initialize() { 766 // This pattern recursively unpacks one dimension at a time. The recursion 767 // bounded as the rank is strictly decreasing. 768 setHasBoundedRewriteRecursion(); 769 } 770 771 LogicalResult matchAndRewrite(FMAOp op, 772 PatternRewriter &rewriter) const override { 773 auto vType = op.getVectorType(); 774 if (vType.getRank() < 2) 775 return failure(); 776 777 auto loc = op.getLoc(); 778 auto elemType = vType.getElementType(); 779 Value zero = rewriter.create<arith::ConstantOp>( 780 loc, elemType, rewriter.getZeroAttr(elemType)); 781 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 782 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 783 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 784 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 785 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 786 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 787 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 788 } 789 rewriter.replaceOp(op, desc); 790 return success(); 791 } 792 }; 793 794 /// Returns the strides if the memory underlying `memRefType` has a contiguous 795 /// static layout. 796 static llvm::Optional<SmallVector<int64_t, 4>> 797 computeContiguousStrides(MemRefType memRefType) { 798 int64_t offset; 799 SmallVector<int64_t, 4> strides; 800 if (failed(getStridesAndOffset(memRefType, strides, offset))) 801 return None; 802 if (!strides.empty() && strides.back() != 1) 803 return None; 804 // If no layout or identity layout, this is contiguous by definition. 805 if (memRefType.getLayout().isIdentity()) 806 return strides; 807 808 // Otherwise, we must determine contiguity form shapes. This can only ever 809 // work in static cases because MemRefType is underspecified to represent 810 // contiguous dynamic shapes in other ways than with just empty/identity 811 // layout. 812 auto sizes = memRefType.getShape(); 813 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 814 if (ShapedType::isDynamic(sizes[index + 1]) || 815 ShapedType::isDynamicStrideOrOffset(strides[index]) || 816 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 817 return None; 818 if (strides[index] != strides[index + 1] * sizes[index + 1]) 819 return None; 820 } 821 return strides; 822 } 823 824 class VectorTypeCastOpConversion 825 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 826 public: 827 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 828 829 LogicalResult 830 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 831 ConversionPatternRewriter &rewriter) const override { 832 auto loc = castOp->getLoc(); 833 MemRefType sourceMemRefType = 834 castOp.getOperand().getType().cast<MemRefType>(); 835 MemRefType targetMemRefType = castOp.getType(); 836 837 // Only static shape casts supported atm. 838 if (!sourceMemRefType.hasStaticShape() || 839 !targetMemRefType.hasStaticShape()) 840 return failure(); 841 842 auto llvmSourceDescriptorTy = 843 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 844 if (!llvmSourceDescriptorTy) 845 return failure(); 846 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 847 848 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 849 .dyn_cast_or_null<LLVM::LLVMStructType>(); 850 if (!llvmTargetDescriptorTy) 851 return failure(); 852 853 // Only contiguous source buffers supported atm. 854 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 855 if (!sourceStrides) 856 return failure(); 857 auto targetStrides = computeContiguousStrides(targetMemRefType); 858 if (!targetStrides) 859 return failure(); 860 // Only support static strides for now, regardless of contiguity. 861 if (llvm::any_of(*targetStrides, [](int64_t stride) { 862 return ShapedType::isDynamicStrideOrOffset(stride); 863 })) 864 return failure(); 865 866 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 867 868 // Create descriptor. 869 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 870 Type llvmTargetElementTy = desc.getElementPtrType(); 871 // Set allocated ptr. 872 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 873 allocated = 874 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 875 desc.setAllocatedPtr(rewriter, loc, allocated); 876 // Set aligned ptr. 877 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 878 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 879 desc.setAlignedPtr(rewriter, loc, ptr); 880 // Fill offset 0. 881 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 882 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 883 desc.setOffset(rewriter, loc, zero); 884 885 // Fill size and stride descriptors in memref. 886 for (const auto &indexedSize : 887 llvm::enumerate(targetMemRefType.getShape())) { 888 int64_t index = indexedSize.index(); 889 auto sizeAttr = 890 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 891 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 892 desc.setSize(rewriter, loc, index, size); 893 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 894 (*targetStrides)[index]); 895 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 896 desc.setStride(rewriter, loc, index, stride); 897 } 898 899 rewriter.replaceOp(castOp, {desc}); 900 return success(); 901 } 902 }; 903 904 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 905 /// Non-scalable versions of this operation are handled in Vector Transforms. 906 class VectorCreateMaskOpRewritePattern 907 : public OpRewritePattern<vector::CreateMaskOp> { 908 public: 909 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 910 bool enableIndexOpt) 911 : OpRewritePattern<vector::CreateMaskOp>(context), 912 indexOptimizations(enableIndexOpt) {} 913 914 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 915 PatternRewriter &rewriter) const override { 916 auto dstType = op.getType(); 917 if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable()) 918 return failure(); 919 IntegerType idxType = 920 indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); 921 auto loc = op->getLoc(); 922 Value indices = rewriter.create<LLVM::StepVectorOp>( 923 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 924 /*isScalable=*/true)); 925 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 926 op.getOperand(0)); 927 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 928 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 929 indices, bounds); 930 rewriter.replaceOp(op, comp); 931 return success(); 932 } 933 934 private: 935 const bool indexOptimizations; 936 }; 937 938 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 939 public: 940 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 941 942 // Proof-of-concept lowering implementation that relies on a small 943 // runtime support library, which only needs to provide a few 944 // printing methods (single value for all data types, opening/closing 945 // bracket, comma, newline). The lowering fully unrolls a vector 946 // in terms of these elementary printing operations. The advantage 947 // of this approach is that the library can remain unaware of all 948 // low-level implementation details of vectors while still supporting 949 // output of any shaped and dimensioned vector. Due to full unrolling, 950 // this approach is less suited for very large vectors though. 951 // 952 // TODO: rely solely on libc in future? something else? 953 // 954 LogicalResult 955 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 956 ConversionPatternRewriter &rewriter) const override { 957 Type printType = printOp.getPrintType(); 958 959 if (typeConverter->convertType(printType) == nullptr) 960 return failure(); 961 962 // Make sure element type has runtime support. 963 PrintConversion conversion = PrintConversion::None; 964 VectorType vectorType = printType.dyn_cast<VectorType>(); 965 Type eltType = vectorType ? vectorType.getElementType() : printType; 966 Operation *printer; 967 if (eltType.isF32()) { 968 printer = 969 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 970 } else if (eltType.isF64()) { 971 printer = 972 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 973 } else if (eltType.isIndex()) { 974 printer = 975 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 976 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 977 // Integers need a zero or sign extension on the operand 978 // (depending on the source type) as well as a signed or 979 // unsigned print method. Up to 64-bit is supported. 980 unsigned width = intTy.getWidth(); 981 if (intTy.isUnsigned()) { 982 if (width <= 64) { 983 if (width < 64) 984 conversion = PrintConversion::ZeroExt64; 985 printer = LLVM::lookupOrCreatePrintU64Fn( 986 printOp->getParentOfType<ModuleOp>()); 987 } else { 988 return failure(); 989 } 990 } else { 991 assert(intTy.isSignless() || intTy.isSigned()); 992 if (width <= 64) { 993 // Note that we *always* zero extend booleans (1-bit integers), 994 // so that true/false is printed as 1/0 rather than -1/0. 995 if (width == 1) 996 conversion = PrintConversion::ZeroExt64; 997 else if (width < 64) 998 conversion = PrintConversion::SignExt64; 999 printer = LLVM::lookupOrCreatePrintI64Fn( 1000 printOp->getParentOfType<ModuleOp>()); 1001 } else { 1002 return failure(); 1003 } 1004 } 1005 } else { 1006 return failure(); 1007 } 1008 1009 // Unroll vector into elementary print calls. 1010 int64_t rank = vectorType ? vectorType.getRank() : 0; 1011 Type type = vectorType ? vectorType : eltType; 1012 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1013 conversion); 1014 emitCall(rewriter, printOp->getLoc(), 1015 LLVM::lookupOrCreatePrintNewlineFn( 1016 printOp->getParentOfType<ModuleOp>())); 1017 rewriter.eraseOp(printOp); 1018 return success(); 1019 } 1020 1021 private: 1022 enum class PrintConversion { 1023 // clang-format off 1024 None, 1025 ZeroExt64, 1026 SignExt64 1027 // clang-format on 1028 }; 1029 1030 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1031 Value value, Type type, Operation *printer, int64_t rank, 1032 PrintConversion conversion) const { 1033 VectorType vectorType = type.dyn_cast<VectorType>(); 1034 Location loc = op->getLoc(); 1035 if (!vectorType) { 1036 assert(rank == 0 && "The scalar case expects rank == 0"); 1037 switch (conversion) { 1038 case PrintConversion::ZeroExt64: 1039 value = rewriter.create<arith::ExtUIOp>( 1040 loc, IntegerType::get(rewriter.getContext(), 64), value); 1041 break; 1042 case PrintConversion::SignExt64: 1043 value = rewriter.create<arith::ExtSIOp>( 1044 loc, IntegerType::get(rewriter.getContext(), 64), value); 1045 break; 1046 case PrintConversion::None: 1047 break; 1048 } 1049 emitCall(rewriter, loc, printer, value); 1050 return; 1051 } 1052 1053 emitCall(rewriter, loc, 1054 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1055 Operation *printComma = 1056 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1057 1058 if (rank <= 1) { 1059 auto reducedType = vectorType.getElementType(); 1060 auto llvmType = typeConverter->convertType(reducedType); 1061 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1062 for (int64_t d = 0; d < dim; ++d) { 1063 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1064 llvmType, /*rank=*/0, /*pos=*/d); 1065 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1066 conversion); 1067 if (d != dim - 1) 1068 emitCall(rewriter, loc, printComma); 1069 } 1070 emitCall( 1071 rewriter, loc, 1072 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1073 return; 1074 } 1075 1076 int64_t dim = vectorType.getDimSize(0); 1077 for (int64_t d = 0; d < dim; ++d) { 1078 auto reducedType = reducedVectorTypeFront(vectorType); 1079 auto llvmType = typeConverter->convertType(reducedType); 1080 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1081 llvmType, rank, d); 1082 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1083 conversion); 1084 if (d != dim - 1) 1085 emitCall(rewriter, loc, printComma); 1086 } 1087 emitCall(rewriter, loc, 1088 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1089 } 1090 1091 // Helper to emit a call. 1092 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1093 Operation *ref, ValueRange params = ValueRange()) { 1094 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1095 params); 1096 } 1097 }; 1098 1099 /// The Splat operation is lowered to an insertelement + a shufflevector 1100 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1101 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1102 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1103 1104 LogicalResult 1105 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1106 ConversionPatternRewriter &rewriter) const override { 1107 VectorType resultType = splatOp.getType().cast<VectorType>(); 1108 if (resultType.getRank() > 1) 1109 return failure(); 1110 1111 // First insert it into an undef vector so we can shuffle it. 1112 auto vectorType = typeConverter->convertType(splatOp.getType()); 1113 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1114 auto zero = rewriter.create<LLVM::ConstantOp>( 1115 splatOp.getLoc(), 1116 typeConverter->convertType(rewriter.getIntegerType(32)), 1117 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1118 1119 // For 0-d vector, we simply do `insertelement`. 1120 if (resultType.getRank() == 0) { 1121 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1122 splatOp, vectorType, undef, adaptor.getInput(), zero); 1123 return success(); 1124 } 1125 1126 // For 1-d vector, we additionally do a `vectorshuffle`. 1127 auto v = rewriter.create<LLVM::InsertElementOp>( 1128 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1129 1130 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); 1131 SmallVector<int32_t, 4> zeroValues(width, 0); 1132 1133 // Shuffle the value across the desired number of elements. 1134 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); 1135 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1136 zeroAttrs); 1137 return success(); 1138 } 1139 }; 1140 1141 /// The Splat operation is lowered to an insertelement + a shufflevector 1142 /// operation. Splat to only 2+-d vector result types are lowered by the 1143 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1144 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1145 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1146 1147 LogicalResult 1148 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1149 ConversionPatternRewriter &rewriter) const override { 1150 VectorType resultType = splatOp.getType(); 1151 if (resultType.getRank() <= 1) 1152 return failure(); 1153 1154 // First insert it into an undef vector so we can shuffle it. 1155 auto loc = splatOp.getLoc(); 1156 auto vectorTypeInfo = 1157 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1158 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1159 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1160 if (!llvmNDVectorTy || !llvm1DVectorTy) 1161 return failure(); 1162 1163 // Construct returned value. 1164 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1165 1166 // Construct a 1-D vector with the splatted value that we insert in all the 1167 // places within the returned descriptor. 1168 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1169 auto zero = rewriter.create<LLVM::ConstantOp>( 1170 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1171 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1172 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1173 adaptor.getInput(), zero); 1174 1175 // Shuffle the value across the desired number of elements. 1176 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1177 SmallVector<int32_t, 4> zeroValues(width, 0); 1178 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); 1179 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs); 1180 1181 // Iterate of linear index, convert to coords space and insert splatted 1-D 1182 // vector in each position. 1183 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { 1184 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v, 1185 position); 1186 }); 1187 rewriter.replaceOp(splatOp, desc); 1188 return success(); 1189 } 1190 }; 1191 1192 } // namespace 1193 1194 /// Populate the given list with patterns that convert from Vector to LLVM. 1195 void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, 1196 RewritePatternSet &patterns, 1197 bool reassociateFPReductions, 1198 bool indexOptimizations) { 1199 MLIRContext *ctx = converter.getDialect()->getContext(); 1200 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1201 populateVectorInsertExtractStridedSliceTransforms(patterns); 1202 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1203 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, indexOptimizations); 1204 patterns 1205 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1206 VectorExtractElementOpConversion, VectorExtractOpConversion, 1207 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1208 VectorInsertOpConversion, VectorPrintOpConversion, 1209 VectorTypeCastOpConversion, VectorScaleOpConversion, 1210 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1211 VectorLoadStoreConversion<vector::MaskedLoadOp, 1212 vector::MaskedLoadOpAdaptor>, 1213 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1214 VectorLoadStoreConversion<vector::MaskedStoreOp, 1215 vector::MaskedStoreOpAdaptor>, 1216 VectorGatherOpConversion, VectorScatterOpConversion, 1217 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1218 VectorSplatOpLowering, VectorSplatNdOpLowering>(converter); 1219 // Transfer ops with rank > 1 are handled by VectorToSCF. 1220 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1221 } 1222 1223 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1224 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1225 patterns.add<VectorMatmulOpConversion>(converter); 1226 patterns.add<VectorFlatTransposeOpConversion>(converter); 1227 } 1228