1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/MLIRContext.h" 21 #include "mlir/IR/Module.h" 22 #include "mlir/IR/Operation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "mlir/IR/Types.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "mlir/Transforms/Passes.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Module.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/Support/Allocator.h" 32 #include "llvm/Support/ErrorHandling.h" 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 template <typename T> 38 static LLVM::LLVMType getPtrToElementType(T containerType, 39 LLVMTypeConverter &typeConverter) { 40 return typeConverter.convertType(containerType.getElementType()) 41 .template cast<LLVM::LLVMType>() 42 .getPointerTo(); 43 } 44 45 // Helper to reduce vector type by one rank at front. 46 static VectorType reducedVectorTypeFront(VectorType tp) { 47 assert((tp.getRank() > 1) && "unlowerable vector type"); 48 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 49 } 50 51 // Helper to reduce vector type by *all* but one rank at back. 52 static VectorType reducedVectorTypeBack(VectorType tp) { 53 assert((tp.getRank() > 1) && "unlowerable vector type"); 54 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 55 } 56 57 // Helper that picks the proper sequence for inserting. 58 static Value insertOne(ConversionPatternRewriter &rewriter, 59 LLVMTypeConverter &typeConverter, Location loc, 60 Value val1, Value val2, Type llvmType, int64_t rank, 61 int64_t pos) { 62 if (rank == 1) { 63 auto idxType = rewriter.getIndexType(); 64 auto constant = rewriter.create<LLVM::ConstantOp>( 65 loc, typeConverter.convertType(idxType), 66 rewriter.getIntegerAttr(idxType, pos)); 67 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 68 constant); 69 } 70 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 71 rewriter.getI64ArrayAttr(pos)); 72 } 73 74 // Helper that picks the proper sequence for inserting. 75 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 76 Value into, int64_t offset) { 77 auto vectorType = into.getType().cast<VectorType>(); 78 if (vectorType.getRank() > 1) 79 return rewriter.create<InsertOp>(loc, from, into, offset); 80 return rewriter.create<vector::InsertElementOp>( 81 loc, vectorType, from, into, 82 rewriter.create<ConstantIndexOp>(loc, offset)); 83 } 84 85 // Helper that picks the proper sequence for extracting. 86 static Value extractOne(ConversionPatternRewriter &rewriter, 87 LLVMTypeConverter &typeConverter, Location loc, 88 Value val, Type llvmType, int64_t rank, int64_t pos) { 89 if (rank == 1) { 90 auto idxType = rewriter.getIndexType(); 91 auto constant = rewriter.create<LLVM::ConstantOp>( 92 loc, typeConverter.convertType(idxType), 93 rewriter.getIntegerAttr(idxType, pos)); 94 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 95 constant); 96 } 97 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 98 rewriter.getI64ArrayAttr(pos)); 99 } 100 101 // Helper that picks the proper sequence for extracting. 102 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 103 int64_t offset) { 104 auto vectorType = vector.getType().cast<VectorType>(); 105 if (vectorType.getRank() > 1) 106 return rewriter.create<ExtractOp>(loc, vector, offset); 107 return rewriter.create<vector::ExtractElementOp>( 108 loc, vectorType.getElementType(), vector, 109 rewriter.create<ConstantIndexOp>(loc, offset)); 110 } 111 112 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 113 // TODO: Better support for attribute subtype forwarding + slicing. 114 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 115 unsigned dropFront = 0, 116 unsigned dropBack = 0) { 117 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 118 auto range = arrayAttr.getAsRange<IntegerAttr>(); 119 SmallVector<int64_t, 4> res; 120 res.reserve(arrayAttr.size() - dropFront - dropBack); 121 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 122 it != eit; ++it) 123 res.push_back((*it).getValue().getSExtValue()); 124 return res; 125 } 126 127 template <typename TransferOp> 128 LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter, 129 TransferOp xferOp, unsigned &align) { 130 Type elementTy = 131 typeConverter.convertType(xferOp.getMemRefType().getElementType()); 132 if (!elementTy) 133 return failure(); 134 135 auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); 136 align = dataLayout.getPrefTypeAlignment( 137 elementTy.cast<LLVM::LLVMType>().getUnderlyingType()); 138 return success(); 139 } 140 141 static LogicalResult 142 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 143 LLVMTypeConverter &typeConverter, Location loc, 144 TransferReadOp xferOp, 145 ArrayRef<Value> operands, Value dataPtr) { 146 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr); 147 return success(); 148 } 149 150 static LogicalResult 151 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 152 LLVMTypeConverter &typeConverter, Location loc, 153 TransferReadOp xferOp, ArrayRef<Value> operands, 154 Value dataPtr, Value mask) { 155 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 156 VectorType fillType = xferOp.getVectorType(); 157 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 158 fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 159 160 Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 161 if (!vecTy) 162 return failure(); 163 164 unsigned align; 165 if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) 166 return failure(); 167 168 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 169 xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 170 rewriter.getI32IntegerAttr(align)); 171 return success(); 172 } 173 174 static LogicalResult 175 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 176 LLVMTypeConverter &typeConverter, Location loc, 177 TransferWriteOp xferOp, 178 ArrayRef<Value> operands, Value dataPtr) { 179 auto adaptor = TransferWriteOpAdaptor(operands); 180 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr); 181 return success(); 182 } 183 184 static LogicalResult 185 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 186 LLVMTypeConverter &typeConverter, Location loc, 187 TransferWriteOp xferOp, ArrayRef<Value> operands, 188 Value dataPtr, Value mask) { 189 unsigned align; 190 if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) 191 return failure(); 192 193 auto adaptor = TransferWriteOpAdaptor(operands); 194 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 195 xferOp, adaptor.vector(), dataPtr, mask, 196 rewriter.getI32IntegerAttr(align)); 197 return success(); 198 } 199 200 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 201 ArrayRef<Value> operands) { 202 return TransferReadOpAdaptor(operands); 203 } 204 205 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 206 ArrayRef<Value> operands) { 207 return TransferWriteOpAdaptor(operands); 208 } 209 210 namespace { 211 212 /// Conversion pattern for a vector.matrix_multiply. 213 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 214 class VectorMatmulOpConversion : public ConvertToLLVMPattern { 215 public: 216 explicit VectorMatmulOpConversion(MLIRContext *context, 217 LLVMTypeConverter &typeConverter) 218 : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, 219 typeConverter) {} 220 221 LogicalResult 222 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 223 ConversionPatternRewriter &rewriter) const override { 224 auto matmulOp = cast<vector::MatmulOp>(op); 225 auto adaptor = vector::MatmulOpAdaptor(operands); 226 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 227 op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), 228 adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), 229 matmulOp.rhs_columns()); 230 return success(); 231 } 232 }; 233 234 /// Conversion pattern for a vector.flat_transpose. 235 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 236 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { 237 public: 238 explicit VectorFlatTransposeOpConversion(MLIRContext *context, 239 LLVMTypeConverter &typeConverter) 240 : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(), 241 context, typeConverter) {} 242 243 LogicalResult 244 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 245 ConversionPatternRewriter &rewriter) const override { 246 auto transOp = cast<vector::FlatTransposeOp>(op); 247 auto adaptor = vector::FlatTransposeOpAdaptor(operands); 248 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 249 transOp, typeConverter.convertType(transOp.res().getType()), 250 adaptor.matrix(), transOp.rows(), transOp.columns()); 251 return success(); 252 } 253 }; 254 255 class VectorReductionOpConversion : public ConvertToLLVMPattern { 256 public: 257 explicit VectorReductionOpConversion(MLIRContext *context, 258 LLVMTypeConverter &typeConverter, 259 bool reassociateFP) 260 : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, 261 typeConverter), 262 reassociateFPReductions(reassociateFP) {} 263 264 LogicalResult 265 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 266 ConversionPatternRewriter &rewriter) const override { 267 auto reductionOp = cast<vector::ReductionOp>(op); 268 auto kind = reductionOp.kind(); 269 Type eltType = reductionOp.dest().getType(); 270 Type llvmType = typeConverter.convertType(eltType); 271 if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) { 272 // Integer reductions: add/mul/min/max/and/or/xor. 273 if (kind == "add") 274 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>( 275 op, llvmType, operands[0]); 276 else if (kind == "mul") 277 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>( 278 op, llvmType, operands[0]); 279 else if (kind == "min") 280 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>( 281 op, llvmType, operands[0]); 282 else if (kind == "max") 283 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>( 284 op, llvmType, operands[0]); 285 else if (kind == "and") 286 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>( 287 op, llvmType, operands[0]); 288 else if (kind == "or") 289 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>( 290 op, llvmType, operands[0]); 291 else if (kind == "xor") 292 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>( 293 op, llvmType, operands[0]); 294 else 295 return failure(); 296 return success(); 297 298 } else if (eltType.isF32() || eltType.isF64()) { 299 // Floating-point reductions: add/mul/min/max 300 if (kind == "add") { 301 // Optional accumulator (or zero). 302 Value acc = operands.size() > 1 ? operands[1] 303 : rewriter.create<LLVM::ConstantOp>( 304 op->getLoc(), llvmType, 305 rewriter.getZeroAttr(eltType)); 306 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>( 307 op, llvmType, acc, operands[0], 308 rewriter.getBoolAttr(reassociateFPReductions)); 309 } else if (kind == "mul") { 310 // Optional accumulator (or one). 311 Value acc = operands.size() > 1 312 ? operands[1] 313 : rewriter.create<LLVM::ConstantOp>( 314 op->getLoc(), llvmType, 315 rewriter.getFloatAttr(eltType, 1.0)); 316 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>( 317 op, llvmType, acc, operands[0], 318 rewriter.getBoolAttr(reassociateFPReductions)); 319 } else if (kind == "min") 320 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>( 321 op, llvmType, operands[0]); 322 else if (kind == "max") 323 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>( 324 op, llvmType, operands[0]); 325 else 326 return failure(); 327 return success(); 328 } 329 return failure(); 330 } 331 332 private: 333 const bool reassociateFPReductions; 334 }; 335 336 class VectorShuffleOpConversion : public ConvertToLLVMPattern { 337 public: 338 explicit VectorShuffleOpConversion(MLIRContext *context, 339 LLVMTypeConverter &typeConverter) 340 : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, 341 typeConverter) {} 342 343 LogicalResult 344 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 345 ConversionPatternRewriter &rewriter) const override { 346 auto loc = op->getLoc(); 347 auto adaptor = vector::ShuffleOpAdaptor(operands); 348 auto shuffleOp = cast<vector::ShuffleOp>(op); 349 auto v1Type = shuffleOp.getV1VectorType(); 350 auto v2Type = shuffleOp.getV2VectorType(); 351 auto vectorType = shuffleOp.getVectorType(); 352 Type llvmType = typeConverter.convertType(vectorType); 353 auto maskArrayAttr = shuffleOp.mask(); 354 355 // Bail if result type cannot be lowered. 356 if (!llvmType) 357 return failure(); 358 359 // Get rank and dimension sizes. 360 int64_t rank = vectorType.getRank(); 361 assert(v1Type.getRank() == rank); 362 assert(v2Type.getRank() == rank); 363 int64_t v1Dim = v1Type.getDimSize(0); 364 365 // For rank 1, where both operands have *exactly* the same vector type, 366 // there is direct shuffle support in LLVM. Use it! 367 if (rank == 1 && v1Type == v2Type) { 368 Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 369 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 370 rewriter.replaceOp(op, shuffle); 371 return success(); 372 } 373 374 // For all other cases, insert the individual values individually. 375 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 376 int64_t insPos = 0; 377 for (auto en : llvm::enumerate(maskArrayAttr)) { 378 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 379 Value value = adaptor.v1(); 380 if (extPos >= v1Dim) { 381 extPos -= v1Dim; 382 value = adaptor.v2(); 383 } 384 Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, 385 rank, extPos); 386 insert = insertOne(rewriter, typeConverter, loc, insert, extract, 387 llvmType, rank, insPos++); 388 } 389 rewriter.replaceOp(op, insert); 390 return success(); 391 } 392 }; 393 394 class VectorExtractElementOpConversion : public ConvertToLLVMPattern { 395 public: 396 explicit VectorExtractElementOpConversion(MLIRContext *context, 397 LLVMTypeConverter &typeConverter) 398 : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), 399 context, typeConverter) {} 400 401 LogicalResult 402 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 403 ConversionPatternRewriter &rewriter) const override { 404 auto adaptor = vector::ExtractElementOpAdaptor(operands); 405 auto extractEltOp = cast<vector::ExtractElementOp>(op); 406 auto vectorType = extractEltOp.getVectorType(); 407 auto llvmType = typeConverter.convertType(vectorType.getElementType()); 408 409 // Bail if result type cannot be lowered. 410 if (!llvmType) 411 return failure(); 412 413 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 414 op, llvmType, adaptor.vector(), adaptor.position()); 415 return success(); 416 } 417 }; 418 419 class VectorExtractOpConversion : public ConvertToLLVMPattern { 420 public: 421 explicit VectorExtractOpConversion(MLIRContext *context, 422 LLVMTypeConverter &typeConverter) 423 : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, 424 typeConverter) {} 425 426 LogicalResult 427 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 428 ConversionPatternRewriter &rewriter) const override { 429 auto loc = op->getLoc(); 430 auto adaptor = vector::ExtractOpAdaptor(operands); 431 auto extractOp = cast<vector::ExtractOp>(op); 432 auto vectorType = extractOp.getVectorType(); 433 auto resultType = extractOp.getResult().getType(); 434 auto llvmResultType = typeConverter.convertType(resultType); 435 auto positionArrayAttr = extractOp.position(); 436 437 // Bail if result type cannot be lowered. 438 if (!llvmResultType) 439 return failure(); 440 441 // One-shot extraction of vector from array (only requires extractvalue). 442 if (resultType.isa<VectorType>()) { 443 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 444 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 445 rewriter.replaceOp(op, extracted); 446 return success(); 447 } 448 449 // Potential extraction of 1-D vector from array. 450 auto *context = op->getContext(); 451 Value extracted = adaptor.vector(); 452 auto positionAttrs = positionArrayAttr.getValue(); 453 if (positionAttrs.size() > 1) { 454 auto oneDVectorType = reducedVectorTypeBack(vectorType); 455 auto nMinusOnePositionAttrs = 456 ArrayAttr::get(positionAttrs.drop_back(), context); 457 extracted = rewriter.create<LLVM::ExtractValueOp>( 458 loc, typeConverter.convertType(oneDVectorType), extracted, 459 nMinusOnePositionAttrs); 460 } 461 462 // Remaining extraction of element from 1-D LLVM vector 463 auto position = positionAttrs.back().cast<IntegerAttr>(); 464 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 465 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 466 extracted = 467 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 468 rewriter.replaceOp(op, extracted); 469 470 return success(); 471 } 472 }; 473 474 /// Conversion pattern that turns a vector.fma on a 1-D vector 475 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 476 /// This does not match vectors of n >= 2 rank. 477 /// 478 /// Example: 479 /// ``` 480 /// vector.fma %a, %a, %a : vector<8xf32> 481 /// ``` 482 /// is converted to: 483 /// ``` 484 /// llvm.intr.fma %va, %va, %va: 485 /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 486 /// -> !llvm<"<8 x float>"> 487 /// ``` 488 class VectorFMAOp1DConversion : public ConvertToLLVMPattern { 489 public: 490 explicit VectorFMAOp1DConversion(MLIRContext *context, 491 LLVMTypeConverter &typeConverter) 492 : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, 493 typeConverter) {} 494 495 LogicalResult 496 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 497 ConversionPatternRewriter &rewriter) const override { 498 auto adaptor = vector::FMAOpAdaptor(operands); 499 vector::FMAOp fmaOp = cast<vector::FMAOp>(op); 500 VectorType vType = fmaOp.getVectorType(); 501 if (vType.getRank() != 1) 502 return failure(); 503 rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(), 504 adaptor.acc()); 505 return success(); 506 } 507 }; 508 509 class VectorInsertElementOpConversion : public ConvertToLLVMPattern { 510 public: 511 explicit VectorInsertElementOpConversion(MLIRContext *context, 512 LLVMTypeConverter &typeConverter) 513 : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), 514 context, typeConverter) {} 515 516 LogicalResult 517 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 518 ConversionPatternRewriter &rewriter) const override { 519 auto adaptor = vector::InsertElementOpAdaptor(operands); 520 auto insertEltOp = cast<vector::InsertElementOp>(op); 521 auto vectorType = insertEltOp.getDestVectorType(); 522 auto llvmType = typeConverter.convertType(vectorType); 523 524 // Bail if result type cannot be lowered. 525 if (!llvmType) 526 return failure(); 527 528 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 529 op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 530 return success(); 531 } 532 }; 533 534 class VectorInsertOpConversion : public ConvertToLLVMPattern { 535 public: 536 explicit VectorInsertOpConversion(MLIRContext *context, 537 LLVMTypeConverter &typeConverter) 538 : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, 539 typeConverter) {} 540 541 LogicalResult 542 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 543 ConversionPatternRewriter &rewriter) const override { 544 auto loc = op->getLoc(); 545 auto adaptor = vector::InsertOpAdaptor(operands); 546 auto insertOp = cast<vector::InsertOp>(op); 547 auto sourceType = insertOp.getSourceType(); 548 auto destVectorType = insertOp.getDestVectorType(); 549 auto llvmResultType = typeConverter.convertType(destVectorType); 550 auto positionArrayAttr = insertOp.position(); 551 552 // Bail if result type cannot be lowered. 553 if (!llvmResultType) 554 return failure(); 555 556 // One-shot insertion of a vector into an array (only requires insertvalue). 557 if (sourceType.isa<VectorType>()) { 558 Value inserted = rewriter.create<LLVM::InsertValueOp>( 559 loc, llvmResultType, adaptor.dest(), adaptor.source(), 560 positionArrayAttr); 561 rewriter.replaceOp(op, inserted); 562 return success(); 563 } 564 565 // Potential extraction of 1-D vector from array. 566 auto *context = op->getContext(); 567 Value extracted = adaptor.dest(); 568 auto positionAttrs = positionArrayAttr.getValue(); 569 auto position = positionAttrs.back().cast<IntegerAttr>(); 570 auto oneDVectorType = destVectorType; 571 if (positionAttrs.size() > 1) { 572 oneDVectorType = reducedVectorTypeBack(destVectorType); 573 auto nMinusOnePositionAttrs = 574 ArrayAttr::get(positionAttrs.drop_back(), context); 575 extracted = rewriter.create<LLVM::ExtractValueOp>( 576 loc, typeConverter.convertType(oneDVectorType), extracted, 577 nMinusOnePositionAttrs); 578 } 579 580 // Insertion of an element into a 1-D LLVM vector. 581 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 582 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 583 Value inserted = rewriter.create<LLVM::InsertElementOp>( 584 loc, typeConverter.convertType(oneDVectorType), extracted, 585 adaptor.source(), constant); 586 587 // Potential insertion of resulting 1-D vector into array. 588 if (positionAttrs.size() > 1) { 589 auto nMinusOnePositionAttrs = 590 ArrayAttr::get(positionAttrs.drop_back(), context); 591 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 592 adaptor.dest(), inserted, 593 nMinusOnePositionAttrs); 594 } 595 596 rewriter.replaceOp(op, inserted); 597 return success(); 598 } 599 }; 600 601 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 602 /// 603 /// Example: 604 /// ``` 605 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 606 /// ``` 607 /// is rewritten into: 608 /// ``` 609 /// %r = splat %f0: vector<2x4xf32> 610 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 611 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 612 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 613 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 614 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 615 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 616 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 617 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 618 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 619 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 620 /// // %r3 holds the final value. 621 /// ``` 622 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 623 public: 624 using OpRewritePattern<FMAOp>::OpRewritePattern; 625 626 LogicalResult matchAndRewrite(FMAOp op, 627 PatternRewriter &rewriter) const override { 628 auto vType = op.getVectorType(); 629 if (vType.getRank() < 2) 630 return failure(); 631 632 auto loc = op.getLoc(); 633 auto elemType = vType.getElementType(); 634 Value zero = rewriter.create<ConstantOp>(loc, elemType, 635 rewriter.getZeroAttr(elemType)); 636 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 637 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 638 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 639 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 640 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 641 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 642 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 643 } 644 rewriter.replaceOp(op, desc); 645 return success(); 646 } 647 }; 648 649 // When ranks are different, InsertStridedSlice needs to extract a properly 650 // ranked vector from the destination vector into which to insert. This pattern 651 // only takes care of this part and forwards the rest of the conversion to 652 // another pattern that converts InsertStridedSlice for operands of the same 653 // rank. 654 // 655 // RewritePattern for InsertStridedSliceOp where source and destination vectors 656 // have different ranks. In this case: 657 // 1. the proper subvector is extracted from the destination vector 658 // 2. a new InsertStridedSlice op is created to insert the source in the 659 // destination subvector 660 // 3. the destination subvector is inserted back in the proper place 661 // 4. the op is replaced by the result of step 3. 662 // The new InsertStridedSlice from step 2. will be picked up by a 663 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 664 class VectorInsertStridedSliceOpDifferentRankRewritePattern 665 : public OpRewritePattern<InsertStridedSliceOp> { 666 public: 667 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 668 669 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 670 PatternRewriter &rewriter) const override { 671 auto srcType = op.getSourceVectorType(); 672 auto dstType = op.getDestVectorType(); 673 674 if (op.offsets().getValue().empty()) 675 return failure(); 676 677 auto loc = op.getLoc(); 678 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 679 assert(rankDiff >= 0); 680 if (rankDiff == 0) 681 return failure(); 682 683 int64_t rankRest = dstType.getRank() - rankDiff; 684 // Extract / insert the subvector of matching rank and InsertStridedSlice 685 // on it. 686 Value extracted = 687 rewriter.create<ExtractOp>(loc, op.dest(), 688 getI64SubArray(op.offsets(), /*dropFront=*/0, 689 /*dropFront=*/rankRest)); 690 // A different pattern will kick in for InsertStridedSlice with matching 691 // ranks. 692 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 693 loc, op.source(), extracted, 694 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 695 getI64SubArray(op.strides(), /*dropFront=*/0)); 696 rewriter.replaceOpWithNewOp<InsertOp>( 697 op, stridedSliceInnerOp.getResult(), op.dest(), 698 getI64SubArray(op.offsets(), /*dropFront=*/0, 699 /*dropFront=*/rankRest)); 700 return success(); 701 } 702 }; 703 704 // RewritePattern for InsertStridedSliceOp where source and destination vectors 705 // have the same rank. In this case, we reduce 706 // 1. the proper subvector is extracted from the destination vector 707 // 2. a new InsertStridedSlice op is created to insert the source in the 708 // destination subvector 709 // 3. the destination subvector is inserted back in the proper place 710 // 4. the op is replaced by the result of step 3. 711 // The new InsertStridedSlice from step 2. will be picked up by a 712 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 713 class VectorInsertStridedSliceOpSameRankRewritePattern 714 : public OpRewritePattern<InsertStridedSliceOp> { 715 public: 716 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 717 718 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 719 PatternRewriter &rewriter) const override { 720 auto srcType = op.getSourceVectorType(); 721 auto dstType = op.getDestVectorType(); 722 723 if (op.offsets().getValue().empty()) 724 return failure(); 725 726 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 727 assert(rankDiff >= 0); 728 if (rankDiff != 0) 729 return failure(); 730 731 if (srcType == dstType) { 732 rewriter.replaceOp(op, op.source()); 733 return success(); 734 } 735 736 int64_t offset = 737 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 738 int64_t size = srcType.getShape().front(); 739 int64_t stride = 740 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 741 742 auto loc = op.getLoc(); 743 Value res = op.dest(); 744 // For each slice of the source vector along the most major dimension. 745 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 746 off += stride, ++idx) { 747 // 1. extract the proper subvector (or element) from source 748 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 749 if (extractedSource.getType().isa<VectorType>()) { 750 // 2. If we have a vector, extract the proper subvector from destination 751 // Otherwise we are at the element level and no need to recurse. 752 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 753 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 754 // smaller rank. 755 extractedSource = rewriter.create<InsertStridedSliceOp>( 756 loc, extractedSource, extractedDest, 757 getI64SubArray(op.offsets(), /* dropFront=*/1), 758 getI64SubArray(op.strides(), /* dropFront=*/1)); 759 } 760 // 4. Insert the extractedSource into the res vector. 761 res = insertOne(rewriter, loc, extractedSource, res, off); 762 } 763 764 rewriter.replaceOp(op, res); 765 return success(); 766 } 767 /// This pattern creates recursive InsertStridedSliceOp, but the recursion is 768 /// bounded as the rank is strictly decreasing. 769 bool hasBoundedRewriteRecursion() const final { return true; } 770 }; 771 772 class VectorTypeCastOpConversion : public ConvertToLLVMPattern { 773 public: 774 explicit VectorTypeCastOpConversion(MLIRContext *context, 775 LLVMTypeConverter &typeConverter) 776 : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, 777 typeConverter) {} 778 779 LogicalResult 780 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 781 ConversionPatternRewriter &rewriter) const override { 782 auto loc = op->getLoc(); 783 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 784 MemRefType sourceMemRefType = 785 castOp.getOperand().getType().cast<MemRefType>(); 786 MemRefType targetMemRefType = 787 castOp.getResult().getType().cast<MemRefType>(); 788 789 // Only static shape casts supported atm. 790 if (!sourceMemRefType.hasStaticShape() || 791 !targetMemRefType.hasStaticShape()) 792 return failure(); 793 794 auto llvmSourceDescriptorTy = 795 operands[0].getType().dyn_cast<LLVM::LLVMType>(); 796 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 797 return failure(); 798 MemRefDescriptor sourceMemRef(operands[0]); 799 800 auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) 801 .dyn_cast_or_null<LLVM::LLVMType>(); 802 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 803 return failure(); 804 805 int64_t offset; 806 SmallVector<int64_t, 4> strides; 807 auto successStrides = 808 getStridesAndOffset(sourceMemRefType, strides, offset); 809 bool isContiguous = (strides.back() == 1); 810 if (isContiguous) { 811 auto sizes = sourceMemRefType.getShape(); 812 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 813 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 814 isContiguous = false; 815 break; 816 } 817 } 818 } 819 // Only contiguous source tensors supported atm. 820 if (failed(successStrides) || !isContiguous) 821 return failure(); 822 823 auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 824 825 // Create descriptor. 826 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 827 Type llvmTargetElementTy = desc.getElementType(); 828 // Set allocated ptr. 829 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 830 allocated = 831 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 832 desc.setAllocatedPtr(rewriter, loc, allocated); 833 // Set aligned ptr. 834 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 835 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 836 desc.setAlignedPtr(rewriter, loc, ptr); 837 // Fill offset 0. 838 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 839 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 840 desc.setOffset(rewriter, loc, zero); 841 842 // Fill size and stride descriptors in memref. 843 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 844 int64_t index = indexedSize.index(); 845 auto sizeAttr = 846 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 847 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 848 desc.setSize(rewriter, loc, index, size); 849 auto strideAttr = 850 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 851 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 852 desc.setStride(rewriter, loc, index, stride); 853 } 854 855 rewriter.replaceOp(op, {desc}); 856 return success(); 857 } 858 }; 859 860 /// Conversion pattern that converts a 1-D vector transfer read/write op in a 861 /// sequence of: 862 /// 1. Bitcast or addrspacecast to vector form. 863 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 864 /// 3. Create a mask where offsetVector is compared against memref upper bound. 865 /// 4. Rewrite op as a masked read or write. 866 template <typename ConcreteOp> 867 class VectorTransferConversion : public ConvertToLLVMPattern { 868 public: 869 explicit VectorTransferConversion(MLIRContext *context, 870 LLVMTypeConverter &typeConv) 871 : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, 872 typeConv) {} 873 874 LogicalResult 875 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 876 ConversionPatternRewriter &rewriter) const override { 877 auto xferOp = cast<ConcreteOp>(op); 878 auto adaptor = getTransferOpAdapter(xferOp, operands); 879 880 if (xferOp.getVectorType().getRank() > 1 || 881 llvm::size(xferOp.indices()) == 0) 882 return failure(); 883 if (xferOp.permutation_map() != 884 AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 885 xferOp.getVectorType().getRank(), 886 op->getContext())) 887 return failure(); 888 889 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 890 891 Location loc = op->getLoc(); 892 Type i64Type = rewriter.getIntegerType(64); 893 MemRefType memRefType = xferOp.getMemRefType(); 894 895 // 1. Get the source/dst address as an LLVM vector pointer. 896 // The vector pointer would always be on address space 0, therefore 897 // addrspacecast shall be used when source/dst memrefs are not on 898 // address space 0. 899 // TODO: support alignment when possible. 900 Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), 901 adaptor.indices(), rewriter, getModule()); 902 auto vecTy = 903 toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); 904 Value vectorDataPtr; 905 if (memRefType.getMemorySpace() == 0) 906 vectorDataPtr = 907 rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr); 908 else 909 vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 910 loc, vecTy.getPointerTo(), dataPtr); 911 912 if (!xferOp.isMaskedDim(0)) 913 return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, 914 xferOp, operands, vectorDataPtr); 915 916 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 917 unsigned vecWidth = vecTy.getVectorNumElements(); 918 VectorType vectorCmpType = VectorType::get(vecWidth, i64Type); 919 SmallVector<int64_t, 8> indices; 920 indices.reserve(vecWidth); 921 for (unsigned i = 0; i < vecWidth; ++i) 922 indices.push_back(i); 923 Value linearIndices = rewriter.create<ConstantOp>( 924 loc, vectorCmpType, 925 DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices))); 926 linearIndices = rewriter.create<LLVM::DialectCastOp>( 927 loc, toLLVMTy(vectorCmpType), linearIndices); 928 929 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 930 // TODO: when the leaf transfer rank is k > 1 we need the last 931 // `k` dimensions here. 932 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 933 Value offsetIndex = *(xferOp.indices().begin() + lastIndex); 934 offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex); 935 Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex); 936 Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices); 937 938 // 4. Let dim the memref dimension, compute the vector comparison mask: 939 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 940 Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex); 941 dim = rewriter.create<IndexCastOp>(loc, i64Type, dim); 942 dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim); 943 Value mask = 944 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim); 945 mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()), 946 mask); 947 948 // 5. Rewrite as a masked read / write. 949 return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, 950 operands, vectorDataPtr, mask); 951 } 952 }; 953 954 class VectorPrintOpConversion : public ConvertToLLVMPattern { 955 public: 956 explicit VectorPrintOpConversion(MLIRContext *context, 957 LLVMTypeConverter &typeConverter) 958 : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, 959 typeConverter) {} 960 961 // Proof-of-concept lowering implementation that relies on a small 962 // runtime support library, which only needs to provide a few 963 // printing methods (single value for all data types, opening/closing 964 // bracket, comma, newline). The lowering fully unrolls a vector 965 // in terms of these elementary printing operations. The advantage 966 // of this approach is that the library can remain unaware of all 967 // low-level implementation details of vectors while still supporting 968 // output of any shaped and dimensioned vector. Due to full unrolling, 969 // this approach is less suited for very large vectors though. 970 // 971 // TODO: rely solely on libc in future? something else? 972 // 973 LogicalResult 974 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 975 ConversionPatternRewriter &rewriter) const override { 976 auto printOp = cast<vector::PrintOp>(op); 977 auto adaptor = vector::PrintOpAdaptor(operands); 978 Type printType = printOp.getPrintType(); 979 980 if (typeConverter.convertType(printType) == nullptr) 981 return failure(); 982 983 // Make sure element type has runtime support (currently just Float/Double). 984 VectorType vectorType = printType.dyn_cast<VectorType>(); 985 Type eltType = vectorType ? vectorType.getElementType() : printType; 986 int64_t rank = vectorType ? vectorType.getRank() : 0; 987 Operation *printer; 988 if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32)) 989 printer = getPrintI32(op); 990 else if (eltType.isSignlessInteger(64)) 991 printer = getPrintI64(op); 992 else if (eltType.isF32()) 993 printer = getPrintFloat(op); 994 else if (eltType.isF64()) 995 printer = getPrintDouble(op); 996 else 997 return failure(); 998 999 // Unroll vector into elementary print calls. 1000 emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); 1001 emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 1002 rewriter.eraseOp(op); 1003 return success(); 1004 } 1005 1006 private: 1007 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1008 Value value, VectorType vectorType, Operation *printer, 1009 int64_t rank) const { 1010 Location loc = op->getLoc(); 1011 if (rank == 0) { 1012 if (value.getType() == 1013 LLVM::LLVMType::getInt1Ty(typeConverter.getDialect())) { 1014 // Convert i1 (bool) to i32 so we can use the print_i32 method. 1015 // This avoids the need for a print_i1 method with an unclear ABI. 1016 auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect()); 1017 auto trueVal = rewriter.create<ConstantOp>( 1018 loc, i32Type, rewriter.getI32IntegerAttr(1)); 1019 auto falseVal = rewriter.create<ConstantOp>( 1020 loc, i32Type, rewriter.getI32IntegerAttr(0)); 1021 value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal); 1022 } 1023 emitCall(rewriter, loc, printer, value); 1024 return; 1025 } 1026 1027 emitCall(rewriter, loc, getPrintOpen(op)); 1028 Operation *printComma = getPrintComma(op); 1029 int64_t dim = vectorType.getDimSize(0); 1030 for (int64_t d = 0; d < dim; ++d) { 1031 auto reducedType = 1032 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1033 auto llvmType = typeConverter.convertType( 1034 rank > 1 ? reducedType : vectorType.getElementType()); 1035 Value nestedVal = 1036 extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); 1037 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); 1038 if (d != dim - 1) 1039 emitCall(rewriter, loc, printComma); 1040 } 1041 emitCall(rewriter, loc, getPrintClose(op)); 1042 } 1043 1044 // Helper to emit a call. 1045 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1046 Operation *ref, ValueRange params = ValueRange()) { 1047 rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, 1048 rewriter.getSymbolRefAttr(ref), params); 1049 } 1050 1051 // Helper for printer method declaration (first hit) and lookup. 1052 static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, 1053 StringRef name, ArrayRef<LLVM::LLVMType> params) { 1054 auto module = op->getParentOfType<ModuleOp>(); 1055 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1056 if (func) 1057 return func; 1058 OpBuilder moduleBuilder(module.getBodyRegion()); 1059 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1060 op->getLoc(), name, 1061 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), 1062 params, /*isVarArg=*/false)); 1063 } 1064 1065 // Helpers for method names. 1066 Operation *getPrintI32(Operation *op) const { 1067 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1068 return getPrint(op, dialect, "print_i32", 1069 LLVM::LLVMType::getInt32Ty(dialect)); 1070 } 1071 Operation *getPrintI64(Operation *op) const { 1072 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1073 return getPrint(op, dialect, "print_i64", 1074 LLVM::LLVMType::getInt64Ty(dialect)); 1075 } 1076 Operation *getPrintFloat(Operation *op) const { 1077 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1078 return getPrint(op, dialect, "print_f32", 1079 LLVM::LLVMType::getFloatTy(dialect)); 1080 } 1081 Operation *getPrintDouble(Operation *op) const { 1082 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1083 return getPrint(op, dialect, "print_f64", 1084 LLVM::LLVMType::getDoubleTy(dialect)); 1085 } 1086 Operation *getPrintOpen(Operation *op) const { 1087 return getPrint(op, typeConverter.getDialect(), "print_open", {}); 1088 } 1089 Operation *getPrintClose(Operation *op) const { 1090 return getPrint(op, typeConverter.getDialect(), "print_close", {}); 1091 } 1092 Operation *getPrintComma(Operation *op) const { 1093 return getPrint(op, typeConverter.getDialect(), "print_comma", {}); 1094 } 1095 Operation *getPrintNewline(Operation *op) const { 1096 return getPrint(op, typeConverter.getDialect(), "print_newline", {}); 1097 } 1098 }; 1099 1100 /// Progressive lowering of ExtractStridedSliceOp to either: 1101 /// 1. extractelement + insertelement for the 1-D case 1102 /// 2. extract + optional strided_slice + insert for the n-D case. 1103 class VectorStridedSliceOpConversion 1104 : public OpRewritePattern<ExtractStridedSliceOp> { 1105 public: 1106 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 1107 1108 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 1109 PatternRewriter &rewriter) const override { 1110 auto dstType = op.getResult().getType().cast<VectorType>(); 1111 1112 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1113 1114 int64_t offset = 1115 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1116 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1117 int64_t stride = 1118 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1119 1120 auto loc = op.getLoc(); 1121 auto elemType = dstType.getElementType(); 1122 assert(elemType.isSignlessIntOrIndexOrFloat()); 1123 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1124 rewriter.getZeroAttr(elemType)); 1125 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1126 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1127 off += stride, ++idx) { 1128 Value extracted = extractOne(rewriter, loc, op.vector(), off); 1129 if (op.offsets().getValue().size() > 1) { 1130 extracted = rewriter.create<ExtractStridedSliceOp>( 1131 loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), 1132 getI64SubArray(op.sizes(), /* dropFront=*/1), 1133 getI64SubArray(op.strides(), /* dropFront=*/1)); 1134 } 1135 res = insertOne(rewriter, loc, extracted, res, idx); 1136 } 1137 rewriter.replaceOp(op, {res}); 1138 return success(); 1139 } 1140 /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is 1141 /// bounded as the rank is strictly decreasing. 1142 bool hasBoundedRewriteRecursion() const final { return true; } 1143 }; 1144 1145 } // namespace 1146 1147 /// Populate the given list with patterns that convert from Vector to LLVM. 1148 void mlir::populateVectorToLLVMConversionPatterns( 1149 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1150 bool reassociateFPReductions) { 1151 MLIRContext *ctx = converter.getDialect()->getContext(); 1152 // clang-format off 1153 patterns.insert<VectorFMAOpNDRewritePattern, 1154 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1155 VectorInsertStridedSliceOpSameRankRewritePattern, 1156 VectorStridedSliceOpConversion>(ctx); 1157 patterns.insert<VectorReductionOpConversion>( 1158 ctx, converter, reassociateFPReductions); 1159 patterns 1160 .insert<VectorShuffleOpConversion, 1161 VectorExtractElementOpConversion, 1162 VectorExtractOpConversion, 1163 VectorFMAOp1DConversion, 1164 VectorInsertElementOpConversion, 1165 VectorInsertOpConversion, 1166 VectorPrintOpConversion, 1167 VectorTransferConversion<TransferReadOp>, 1168 VectorTransferConversion<TransferWriteOp>, 1169 VectorTypeCastOpConversion>(ctx, converter); 1170 // clang-format on 1171 } 1172 1173 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1174 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1175 MLIRContext *ctx = converter.getDialect()->getContext(); 1176 patterns.insert<VectorMatmulOpConversion>(ctx, converter); 1177 patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter); 1178 } 1179 1180 namespace { 1181 struct LowerVectorToLLVMPass 1182 : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> { 1183 void runOnOperation() override; 1184 }; 1185 } // namespace 1186 1187 void LowerVectorToLLVMPass::runOnOperation() { 1188 // Perform progressive lowering of operations on slices and 1189 // all contraction operations. Also applies folding and DCE. 1190 { 1191 OwningRewritePatternList patterns; 1192 populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); 1193 populateVectorSlicesLoweringPatterns(patterns, &getContext()); 1194 populateVectorContractLoweringPatterns(patterns, &getContext()); 1195 applyPatternsAndFoldGreedily(getOperation(), patterns); 1196 } 1197 1198 // Convert to the LLVM IR dialect. 1199 LLVMTypeConverter converter(&getContext()); 1200 OwningRewritePatternList patterns; 1201 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1202 populateVectorToLLVMConversionPatterns(converter, patterns, 1203 reassociateFPReductions); 1204 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1205 populateStdToLLVMConversionPatterns(converter, patterns); 1206 1207 LLVMConversionTarget target(getContext()); 1208 if (failed(applyPartialConversion(getOperation(), target, patterns))) { 1209 signalPassFailure(); 1210 } 1211 } 1212 1213 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertVectorToLLVMPass() { 1214 return std::make_unique<LowerVectorToLLVMPass>(); 1215 } 1216