1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/StandardOps/Ops.h" 14 #include "mlir/Dialect/VectorOps/VectorOps.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/MLIRContext.h" 18 #include "mlir/IR/Module.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/IR/StandardTypes.h" 22 #include "mlir/IR/Types.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "mlir/Transforms/Passes.h" 27 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Module.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/Support/Allocator.h" 32 #include "llvm/Support/ErrorHandling.h" 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 template <typename T> 38 static LLVM::LLVMType getPtrToElementType(T containerType, 39 LLVMTypeConverter &lowering) { 40 return lowering.convertType(containerType.getElementType()) 41 .template cast<LLVM::LLVMType>() 42 .getPointerTo(); 43 } 44 45 // Helper to reduce vector type by one rank at front. 46 static VectorType reducedVectorTypeFront(VectorType tp) { 47 assert((tp.getRank() > 1) && "unlowerable vector type"); 48 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 49 } 50 51 // Helper to reduce vector type by *all* but one rank at back. 52 static VectorType reducedVectorTypeBack(VectorType tp) { 53 assert((tp.getRank() > 1) && "unlowerable vector type"); 54 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 55 } 56 57 // Helper that picks the proper sequence for inserting. 58 static Value insertOne(ConversionPatternRewriter &rewriter, 59 LLVMTypeConverter &lowering, Location loc, Value val1, 60 Value val2, Type llvmType, int64_t rank, int64_t pos) { 61 if (rank == 1) { 62 auto idxType = rewriter.getIndexType(); 63 auto constant = rewriter.create<LLVM::ConstantOp>( 64 loc, lowering.convertType(idxType), 65 rewriter.getIntegerAttr(idxType, pos)); 66 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 67 constant); 68 } 69 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 70 rewriter.getI64ArrayAttr(pos)); 71 } 72 73 // Helper that picks the proper sequence for extracting. 74 static Value extractOne(ConversionPatternRewriter &rewriter, 75 LLVMTypeConverter &lowering, Location loc, Value val, 76 Type llvmType, int64_t rank, int64_t pos) { 77 if (rank == 1) { 78 auto idxType = rewriter.getIndexType(); 79 auto constant = rewriter.create<LLVM::ConstantOp>( 80 loc, lowering.convertType(idxType), 81 rewriter.getIntegerAttr(idxType, pos)); 82 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 83 constant); 84 } 85 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 86 rewriter.getI64ArrayAttr(pos)); 87 } 88 89 class VectorBroadcastOpConversion : public LLVMOpLowering { 90 public: 91 explicit VectorBroadcastOpConversion(MLIRContext *context, 92 LLVMTypeConverter &typeConverter) 93 : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context, 94 typeConverter) {} 95 96 PatternMatchResult 97 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 98 ConversionPatternRewriter &rewriter) const override { 99 auto broadcastOp = cast<vector::BroadcastOp>(op); 100 VectorType dstVectorType = broadcastOp.getVectorType(); 101 if (lowering.convertType(dstVectorType) == nullptr) 102 return matchFailure(); 103 // Rewrite when the full vector type can be lowered (which 104 // implies all 'reduced' types can be lowered too). 105 auto adaptor = vector::BroadcastOpOperandAdaptor(operands); 106 VectorType srcVectorType = 107 broadcastOp.getSourceType().dyn_cast<VectorType>(); 108 rewriter.replaceOp( 109 op, expandRanks(adaptor.source(), // source value to be expanded 110 op->getLoc(), // location of original broadcast 111 srcVectorType, dstVectorType, rewriter)); 112 return matchSuccess(); 113 } 114 115 private: 116 // Expands the given source value over all the ranks, as defined 117 // by the source and destination type (a null source type denotes 118 // expansion from a scalar value into a vector). 119 // 120 // TODO(ajcbik): consider replacing this one-pattern lowering 121 // with a two-pattern lowering using other vector 122 // ops once all insert/extract/shuffle operations 123 // are available with lowering implemention. 124 // 125 Value expandRanks(Value value, Location loc, VectorType srcVectorType, 126 VectorType dstVectorType, 127 ConversionPatternRewriter &rewriter) const { 128 assert((dstVectorType != nullptr) && "invalid result type in broadcast"); 129 // Determine rank of source and destination. 130 int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; 131 int64_t dstRank = dstVectorType.getRank(); 132 int64_t curDim = dstVectorType.getDimSize(0); 133 if (srcRank < dstRank) 134 // Duplicate this rank. 135 return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 136 curDim, rewriter); 137 // If all trailing dimensions are the same, the broadcast consists of 138 // simply passing through the source value and we are done. Otherwise, 139 // any non-matching dimension forces a stretch along this rank. 140 assert((srcVectorType != nullptr) && (srcRank > 0) && 141 (srcRank == dstRank) && "invalid rank in broadcast"); 142 for (int64_t r = 0; r < dstRank; r++) { 143 if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) { 144 return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 145 curDim, rewriter); 146 } 147 } 148 return value; 149 } 150 151 // Picks the best way to duplicate a single rank. For the 1-D case, a 152 // single insert-elt/shuffle is the most efficient expansion. For higher 153 // dimensions, however, we need dim x insert-values on a new broadcast 154 // with one less leading dimension, which will be lowered "recursively" 155 // to matching LLVM IR. 156 // For example: 157 // v = broadcast s : f32 to vector<4x2xf32> 158 // becomes: 159 // x = broadcast s : f32 to vector<2xf32> 160 // v = [x,x,x,x] 161 // becomes: 162 // x = [s,s] 163 // v = [x,x,x,x] 164 Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType, 165 VectorType dstVectorType, int64_t rank, int64_t dim, 166 ConversionPatternRewriter &rewriter) const { 167 Type llvmType = lowering.convertType(dstVectorType); 168 assert((llvmType != nullptr) && "unlowerable vector type"); 169 if (rank == 1) { 170 Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); 171 Value expand = 172 insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0); 173 SmallVector<int32_t, 4> zeroValues(dim, 0); 174 return rewriter.create<LLVM::ShuffleVectorOp>( 175 loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); 176 } 177 Value expand = expandRanks(value, loc, srcVectorType, 178 reducedVectorTypeFront(dstVectorType), rewriter); 179 Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 180 for (int64_t d = 0; d < dim; ++d) { 181 result = 182 insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); 183 } 184 return result; 185 } 186 187 // Picks the best way to stretch a single rank. For the 1-D case, a 188 // single insert-elt/shuffle is the most efficient expansion when at 189 // a stretch. Otherwise, every dimension needs to be expanded 190 // individually and individually inserted in the resulting vector. 191 // For example: 192 // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32> 193 // becomes: 194 // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32> 195 // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32> 196 // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32> 197 // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32> 198 // v = [a,b,c,d] 199 // becomes: 200 // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32> 201 // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> 202 // a = [x, y] 203 // etc. 204 Value stretchOneRank(Value value, Location loc, VectorType srcVectorType, 205 VectorType dstVectorType, int64_t rank, int64_t dim, 206 ConversionPatternRewriter &rewriter) const { 207 Type llvmType = lowering.convertType(dstVectorType); 208 assert((llvmType != nullptr) && "unlowerable vector type"); 209 Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 210 bool atStretch = dim != srcVectorType.getDimSize(0); 211 if (rank == 1) { 212 assert(atStretch); 213 Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); 214 Value one = 215 extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0); 216 Value expand = 217 insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0); 218 SmallVector<int32_t, 4> zeroValues(dim, 0); 219 return rewriter.create<LLVM::ShuffleVectorOp>( 220 loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); 221 } 222 VectorType redSrcType = reducedVectorTypeFront(srcVectorType); 223 VectorType redDstType = reducedVectorTypeFront(dstVectorType); 224 Type redLlvmType = lowering.convertType(redSrcType); 225 for (int64_t d = 0; d < dim; ++d) { 226 int64_t pos = atStretch ? 0 : d; 227 Value one = 228 extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos); 229 Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); 230 result = 231 insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); 232 } 233 return result; 234 } 235 }; 236 237 class VectorShuffleOpConversion : public LLVMOpLowering { 238 public: 239 explicit VectorShuffleOpConversion(MLIRContext *context, 240 LLVMTypeConverter &typeConverter) 241 : LLVMOpLowering(vector::ShuffleOp::getOperationName(), context, 242 typeConverter) {} 243 244 PatternMatchResult 245 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 246 ConversionPatternRewriter &rewriter) const override { 247 auto loc = op->getLoc(); 248 auto adaptor = vector::ShuffleOpOperandAdaptor(operands); 249 auto shuffleOp = cast<vector::ShuffleOp>(op); 250 auto v1Type = shuffleOp.getV1VectorType(); 251 auto v2Type = shuffleOp.getV2VectorType(); 252 auto vectorType = shuffleOp.getVectorType(); 253 Type llvmType = lowering.convertType(vectorType); 254 auto maskArrayAttr = shuffleOp.mask(); 255 256 // Bail if result type cannot be lowered. 257 if (!llvmType) 258 return matchFailure(); 259 260 // Get rank and dimension sizes. 261 int64_t rank = vectorType.getRank(); 262 assert(v1Type.getRank() == rank); 263 assert(v2Type.getRank() == rank); 264 int64_t v1Dim = v1Type.getDimSize(0); 265 266 // For rank 1, where both operands have *exactly* the same vector type, 267 // there is direct shuffle support in LLVM. Use it! 268 if (rank == 1 && v1Type == v2Type) { 269 Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 270 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 271 rewriter.replaceOp(op, shuffle); 272 return matchSuccess(); 273 } 274 275 // For all other cases, insert the individual values individually. 276 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 277 int64_t insPos = 0; 278 for (auto en : llvm::enumerate(maskArrayAttr)) { 279 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 280 Value value = adaptor.v1(); 281 if (extPos >= v1Dim) { 282 extPos -= v1Dim; 283 value = adaptor.v2(); 284 } 285 Value extract = 286 extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos); 287 insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType, 288 rank, insPos++); 289 } 290 rewriter.replaceOp(op, insert); 291 return matchSuccess(); 292 } 293 }; 294 295 class VectorExtractElementOpConversion : public LLVMOpLowering { 296 public: 297 explicit VectorExtractElementOpConversion(MLIRContext *context, 298 LLVMTypeConverter &typeConverter) 299 : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, 300 typeConverter) {} 301 302 PatternMatchResult 303 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 304 ConversionPatternRewriter &rewriter) const override { 305 auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); 306 auto extractEltOp = cast<vector::ExtractElementOp>(op); 307 auto vectorType = extractEltOp.getVectorType(); 308 auto llvmType = lowering.convertType(vectorType.getElementType()); 309 310 // Bail if result type cannot be lowered. 311 if (!llvmType) 312 return matchFailure(); 313 314 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 315 op, llvmType, adaptor.vector(), adaptor.position()); 316 return matchSuccess(); 317 } 318 }; 319 320 class VectorExtractOpConversion : public LLVMOpLowering { 321 public: 322 explicit VectorExtractOpConversion(MLIRContext *context, 323 LLVMTypeConverter &typeConverter) 324 : LLVMOpLowering(vector::ExtractOp::getOperationName(), context, 325 typeConverter) {} 326 327 PatternMatchResult 328 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 329 ConversionPatternRewriter &rewriter) const override { 330 auto loc = op->getLoc(); 331 auto adaptor = vector::ExtractOpOperandAdaptor(operands); 332 auto extractOp = cast<vector::ExtractOp>(op); 333 auto vectorType = extractOp.getVectorType(); 334 auto resultType = extractOp.getResult()->getType(); 335 auto llvmResultType = lowering.convertType(resultType); 336 auto positionArrayAttr = extractOp.position(); 337 338 // Bail if result type cannot be lowered. 339 if (!llvmResultType) 340 return matchFailure(); 341 342 // One-shot extraction of vector from array (only requires extractvalue). 343 if (resultType.isa<VectorType>()) { 344 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 345 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 346 rewriter.replaceOp(op, extracted); 347 return matchSuccess(); 348 } 349 350 // Potential extraction of 1-D vector from array. 351 auto *context = op->getContext(); 352 Value extracted = adaptor.vector(); 353 auto positionAttrs = positionArrayAttr.getValue(); 354 if (positionAttrs.size() > 1) { 355 auto oneDVectorType = reducedVectorTypeBack(vectorType); 356 auto nMinusOnePositionAttrs = 357 ArrayAttr::get(positionAttrs.drop_back(), context); 358 extracted = rewriter.create<LLVM::ExtractValueOp>( 359 loc, lowering.convertType(oneDVectorType), extracted, 360 nMinusOnePositionAttrs); 361 } 362 363 // Remaining extraction of element from 1-D LLVM vector 364 auto position = positionAttrs.back().cast<IntegerAttr>(); 365 auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); 366 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 367 extracted = 368 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 369 rewriter.replaceOp(op, extracted); 370 371 return matchSuccess(); 372 } 373 }; 374 375 class VectorInsertElementOpConversion : public LLVMOpLowering { 376 public: 377 explicit VectorInsertElementOpConversion(MLIRContext *context, 378 LLVMTypeConverter &typeConverter) 379 : LLVMOpLowering(vector::InsertElementOp::getOperationName(), context, 380 typeConverter) {} 381 382 PatternMatchResult 383 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 384 ConversionPatternRewriter &rewriter) const override { 385 auto adaptor = vector::InsertElementOpOperandAdaptor(operands); 386 auto insertEltOp = cast<vector::InsertElementOp>(op); 387 auto vectorType = insertEltOp.getDestVectorType(); 388 auto llvmType = lowering.convertType(vectorType); 389 390 // Bail if result type cannot be lowered. 391 if (!llvmType) 392 return matchFailure(); 393 394 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 395 op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 396 return matchSuccess(); 397 } 398 }; 399 400 class VectorInsertOpConversion : public LLVMOpLowering { 401 public: 402 explicit VectorInsertOpConversion(MLIRContext *context, 403 LLVMTypeConverter &typeConverter) 404 : LLVMOpLowering(vector::InsertOp::getOperationName(), context, 405 typeConverter) {} 406 407 PatternMatchResult 408 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 409 ConversionPatternRewriter &rewriter) const override { 410 auto loc = op->getLoc(); 411 auto adaptor = vector::InsertOpOperandAdaptor(operands); 412 auto insertOp = cast<vector::InsertOp>(op); 413 auto sourceType = insertOp.getSourceType(); 414 auto destVectorType = insertOp.getDestVectorType(); 415 auto llvmResultType = lowering.convertType(destVectorType); 416 auto positionArrayAttr = insertOp.position(); 417 418 // Bail if result type cannot be lowered. 419 if (!llvmResultType) 420 return matchFailure(); 421 422 // One-shot insertion of a vector into an array (only requires insertvalue). 423 if (sourceType.isa<VectorType>()) { 424 Value inserted = rewriter.create<LLVM::InsertValueOp>( 425 loc, llvmResultType, adaptor.dest(), adaptor.source(), 426 positionArrayAttr); 427 rewriter.replaceOp(op, inserted); 428 return matchSuccess(); 429 } 430 431 // Potential extraction of 1-D vector from array. 432 auto *context = op->getContext(); 433 Value extracted = adaptor.dest(); 434 auto positionAttrs = positionArrayAttr.getValue(); 435 auto position = positionAttrs.back().cast<IntegerAttr>(); 436 auto oneDVectorType = destVectorType; 437 if (positionAttrs.size() > 1) { 438 oneDVectorType = reducedVectorTypeBack(destVectorType); 439 auto nMinusOnePositionAttrs = 440 ArrayAttr::get(positionAttrs.drop_back(), context); 441 extracted = rewriter.create<LLVM::ExtractValueOp>( 442 loc, lowering.convertType(oneDVectorType), extracted, 443 nMinusOnePositionAttrs); 444 } 445 446 // Insertion of an element into a 1-D LLVM vector. 447 auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); 448 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 449 Value inserted = rewriter.create<LLVM::InsertElementOp>( 450 loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), 451 constant); 452 453 // Potential insertion of resulting 1-D vector into array. 454 if (positionAttrs.size() > 1) { 455 auto nMinusOnePositionAttrs = 456 ArrayAttr::get(positionAttrs.drop_back(), context); 457 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 458 adaptor.dest(), inserted, 459 nMinusOnePositionAttrs); 460 } 461 462 rewriter.replaceOp(op, inserted); 463 return matchSuccess(); 464 } 465 }; 466 467 class VectorOuterProductOpConversion : public LLVMOpLowering { 468 public: 469 explicit VectorOuterProductOpConversion(MLIRContext *context, 470 LLVMTypeConverter &typeConverter) 471 : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, 472 typeConverter) {} 473 474 PatternMatchResult 475 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 476 ConversionPatternRewriter &rewriter) const override { 477 auto loc = op->getLoc(); 478 auto adaptor = vector::OuterProductOpOperandAdaptor(operands); 479 auto *ctx = op->getContext(); 480 auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); 481 auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); 482 auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); 483 auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); 484 auto llvmArrayOfVectType = lowering.convertType( 485 cast<vector::OuterProductOp>(op).getResult()->getType()); 486 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); 487 Value a = adaptor.lhs(), b = adaptor.rhs(); 488 Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); 489 SmallVector<Value, 8> lhs, accs; 490 lhs.reserve(rankLHS); 491 accs.reserve(rankLHS); 492 for (unsigned d = 0, e = rankLHS; d < e; ++d) { 493 // shufflevector explicitly requires i32. 494 auto attr = rewriter.getI32IntegerAttr(d); 495 SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); 496 auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); 497 Value aD = nullptr, accD = nullptr; 498 // 1. Broadcast the element a[d] into vector aD. 499 aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); 500 // 2. If acc is present, extract 1-d vector acc[d] into accD. 501 if (acc) 502 accD = rewriter.create<LLVM::ExtractValueOp>( 503 loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); 504 // 3. Compute aD outer b (plus accD, if relevant). 505 Value aOuterbD = 506 accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD) 507 .getResult() 508 : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); 509 // 4. Insert as value `d` in the descriptor. 510 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, 511 desc, aOuterbD, 512 rewriter.getI64ArrayAttr(d)); 513 } 514 rewriter.replaceOp(op, desc); 515 return matchSuccess(); 516 } 517 }; 518 519 class VectorTypeCastOpConversion : public LLVMOpLowering { 520 public: 521 explicit VectorTypeCastOpConversion(MLIRContext *context, 522 LLVMTypeConverter &typeConverter) 523 : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context, 524 typeConverter) {} 525 526 PatternMatchResult 527 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 528 ConversionPatternRewriter &rewriter) const override { 529 auto loc = op->getLoc(); 530 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 531 MemRefType sourceMemRefType = 532 castOp.getOperand()->getType().cast<MemRefType>(); 533 MemRefType targetMemRefType = 534 castOp.getResult()->getType().cast<MemRefType>(); 535 536 // Only static shape casts supported atm. 537 if (!sourceMemRefType.hasStaticShape() || 538 !targetMemRefType.hasStaticShape()) 539 return matchFailure(); 540 541 auto llvmSourceDescriptorTy = 542 operands[0]->getType().dyn_cast<LLVM::LLVMType>(); 543 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 544 return matchFailure(); 545 MemRefDescriptor sourceMemRef(operands[0]); 546 547 auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) 548 .dyn_cast_or_null<LLVM::LLVMType>(); 549 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 550 return matchFailure(); 551 552 int64_t offset; 553 SmallVector<int64_t, 4> strides; 554 auto successStrides = 555 getStridesAndOffset(sourceMemRefType, strides, offset); 556 bool isContiguous = (strides.back() == 1); 557 if (isContiguous) { 558 auto sizes = sourceMemRefType.getShape(); 559 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 560 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 561 isContiguous = false; 562 break; 563 } 564 } 565 } 566 // Only contiguous source tensors supported atm. 567 if (failed(successStrides) || !isContiguous) 568 return matchFailure(); 569 570 auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); 571 572 // Create descriptor. 573 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 574 Type llvmTargetElementTy = desc.getElementType(); 575 // Set allocated ptr. 576 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 577 allocated = 578 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 579 desc.setAllocatedPtr(rewriter, loc, allocated); 580 // Set aligned ptr. 581 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 582 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 583 desc.setAlignedPtr(rewriter, loc, ptr); 584 // Fill offset 0. 585 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 586 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 587 desc.setOffset(rewriter, loc, zero); 588 589 // Fill size and stride descriptors in memref. 590 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 591 int64_t index = indexedSize.index(); 592 auto sizeAttr = 593 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 594 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 595 desc.setSize(rewriter, loc, index, size); 596 auto strideAttr = 597 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 598 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 599 desc.setStride(rewriter, loc, index, stride); 600 } 601 602 rewriter.replaceOp(op, {desc}); 603 return matchSuccess(); 604 } 605 }; 606 607 class VectorPrintOpConversion : public LLVMOpLowering { 608 public: 609 explicit VectorPrintOpConversion(MLIRContext *context, 610 LLVMTypeConverter &typeConverter) 611 : LLVMOpLowering(vector::PrintOp::getOperationName(), context, 612 typeConverter) {} 613 614 // Proof-of-concept lowering implementation that relies on a small 615 // runtime support library, which only needs to provide a few 616 // printing methods (single value for all data types, opening/closing 617 // bracket, comma, newline). The lowering fully unrolls a vector 618 // in terms of these elementary printing operations. The advantage 619 // of this approach is that the library can remain unaware of all 620 // low-level implementation details of vectors while still supporting 621 // output of any shaped and dimensioned vector. Due to full unrolling, 622 // this approach is less suited for very large vectors though. 623 // 624 // TODO(ajcbik): rely solely on libc in future? something else? 625 // 626 PatternMatchResult 627 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 628 ConversionPatternRewriter &rewriter) const override { 629 auto printOp = cast<vector::PrintOp>(op); 630 auto adaptor = vector::PrintOpOperandAdaptor(operands); 631 Type printType = printOp.getPrintType(); 632 633 if (lowering.convertType(printType) == nullptr) 634 return matchFailure(); 635 636 // Make sure element type has runtime support (currently just Float/Double). 637 VectorType vectorType = printType.dyn_cast<VectorType>(); 638 Type eltType = vectorType ? vectorType.getElementType() : printType; 639 int64_t rank = vectorType ? vectorType.getRank() : 0; 640 Operation *printer; 641 if (eltType.isF32()) 642 printer = getPrintFloat(op); 643 else if (eltType.isF64()) 644 printer = getPrintDouble(op); 645 else 646 return matchFailure(); 647 648 // Unroll vector into elementary print calls. 649 emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); 650 emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 651 rewriter.eraseOp(op); 652 return matchSuccess(); 653 } 654 655 private: 656 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 657 Value value, VectorType vectorType, Operation *printer, 658 int64_t rank) const { 659 Location loc = op->getLoc(); 660 if (rank == 0) { 661 emitCall(rewriter, loc, printer, value); 662 return; 663 } 664 665 emitCall(rewriter, loc, getPrintOpen(op)); 666 Operation *printComma = getPrintComma(op); 667 int64_t dim = vectorType.getDimSize(0); 668 for (int64_t d = 0; d < dim; ++d) { 669 auto reducedType = 670 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 671 auto llvmType = lowering.convertType( 672 rank > 1 ? reducedType : vectorType.getElementType()); 673 Value nestedVal = 674 extractOne(rewriter, lowering, loc, value, llvmType, rank, d); 675 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); 676 if (d != dim - 1) 677 emitCall(rewriter, loc, printComma); 678 } 679 emitCall(rewriter, loc, getPrintClose(op)); 680 } 681 682 // Helper to emit a call. 683 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 684 Operation *ref, ValueRange params = ValueRange()) { 685 rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, 686 rewriter.getSymbolRefAttr(ref), params); 687 } 688 689 // Helper for printer method declaration (first hit) and lookup. 690 static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, 691 StringRef name, ArrayRef<LLVM::LLVMType> params) { 692 auto module = op->getParentOfType<ModuleOp>(); 693 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 694 if (func) 695 return func; 696 OpBuilder moduleBuilder(module.getBodyRegion()); 697 return moduleBuilder.create<LLVM::LLVMFuncOp>( 698 op->getLoc(), name, 699 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), 700 params, /*isVarArg=*/false)); 701 } 702 703 // Helpers for method names. 704 Operation *getPrintFloat(Operation *op) const { 705 LLVM::LLVMDialect *dialect = lowering.getDialect(); 706 return getPrint(op, dialect, "print_f32", 707 LLVM::LLVMType::getFloatTy(dialect)); 708 } 709 Operation *getPrintDouble(Operation *op) const { 710 LLVM::LLVMDialect *dialect = lowering.getDialect(); 711 return getPrint(op, dialect, "print_f64", 712 LLVM::LLVMType::getDoubleTy(dialect)); 713 } 714 Operation *getPrintOpen(Operation *op) const { 715 return getPrint(op, lowering.getDialect(), "print_open", {}); 716 } 717 Operation *getPrintClose(Operation *op) const { 718 return getPrint(op, lowering.getDialect(), "print_close", {}); 719 } 720 Operation *getPrintComma(Operation *op) const { 721 return getPrint(op, lowering.getDialect(), "print_comma", {}); 722 } 723 Operation *getPrintNewline(Operation *op) const { 724 return getPrint(op, lowering.getDialect(), "print_newline", {}); 725 } 726 }; 727 728 // TODO(rriddle): Better support for attribute subtype forwarding + slicing. 729 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 730 unsigned dropFront = 0, 731 unsigned dropBack = 0) { 732 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 733 auto range = arrayAttr.getAsRange<IntegerAttr>(); 734 SmallVector<int64_t, 4> res; 735 res.reserve(arrayAttr.size() - dropFront - dropBack); 736 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 737 it != eit; ++it) 738 res.push_back((*it).getValue().getSExtValue()); 739 return res; 740 } 741 742 /// Emit the proper `ExtractOp` or `ExtractElementOp` depending on the rank 743 /// of `vector`. 744 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 745 int64_t offset) { 746 auto vectorType = vector.getType().cast<VectorType>(); 747 if (vectorType.getRank() > 1) 748 return rewriter.create<ExtractOp>(loc, vector, offset); 749 return rewriter.create<vector::ExtractElementOp>( 750 loc, vectorType.getElementType(), vector, 751 rewriter.create<ConstantIndexOp>(loc, offset)); 752 } 753 754 /// Emit the proper `InsertOp` or `InsertElementOp` depending on the rank 755 /// of `vector`. 756 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 757 Value into, int64_t offset) { 758 auto vectorType = into.getType().cast<VectorType>(); 759 if (vectorType.getRank() > 1) 760 return rewriter.create<InsertOp>(loc, from, into, offset); 761 return rewriter.create<vector::InsertElementOp>( 762 loc, vectorType, from, into, 763 rewriter.create<ConstantIndexOp>(loc, offset)); 764 } 765 766 /// Progressive lowering of StridedSliceOp to either: 767 /// 1. extractelement + insertelement for the 1-D case 768 /// 2. extract + optional strided_slice + insert for the n-D case. 769 class VectorStridedSliceOpRewritePattern 770 : public OpRewritePattern<StridedSliceOp> { 771 public: 772 using OpRewritePattern<StridedSliceOp>::OpRewritePattern; 773 774 PatternMatchResult matchAndRewrite(StridedSliceOp op, 775 PatternRewriter &rewriter) const override { 776 auto dstType = op.getResult().getType().cast<VectorType>(); 777 778 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 779 780 int64_t offset = 781 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 782 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 783 int64_t stride = 784 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 785 786 auto loc = op.getLoc(); 787 auto elemType = dstType.getElementType(); 788 assert(elemType.isIntOrIndexOrFloat()); 789 Value zero = rewriter.create<ConstantOp>(loc, elemType, 790 rewriter.getZeroAttr(elemType)); 791 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 792 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 793 off += stride, ++idx) { 794 Value extracted = extractOne(rewriter, loc, op.vector(), off); 795 if (op.offsets().getValue().size() > 1) { 796 StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>( 797 loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), 798 getI64SubArray(op.sizes(), /* dropFront=*/1), 799 getI64SubArray(op.strides(), /* dropFront=*/1)); 800 // Call matchAndRewrite recursively from within the pattern. This 801 // circumvents the current limitation that a given pattern cannot 802 // be called multiple times by the PatternRewrite infrastructure (to 803 // avoid infinite recursion, but in this case, infinite recursion 804 // cannot happen because the rank is strictly decreasing). 805 // TODO(rriddle, nicolasvasilache) Implement something like a hook for 806 // a potential function that must decrease and allow the same pattern 807 // multiple times. 808 auto success = matchAndRewrite(stridedSliceOp, rewriter); 809 (void)success; 810 assert(success && "Unexpected failure"); 811 extracted = stridedSliceOp; 812 } 813 res = insertOne(rewriter, loc, extracted, res, idx); 814 } 815 rewriter.replaceOp(op, {res}); 816 return matchSuccess(); 817 } 818 }; 819 820 /// Populate the given list with patterns that convert from Vector to LLVM. 821 void mlir::populateVectorToLLVMConversionPatterns( 822 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 823 MLIRContext *ctx = converter.getDialect()->getContext(); 824 patterns.insert<VectorStridedSliceOpRewritePattern>(ctx); 825 patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, 826 VectorExtractElementOpConversion, VectorExtractOpConversion, 827 VectorInsertElementOpConversion, VectorInsertOpConversion, 828 VectorOuterProductOpConversion, VectorTypeCastOpConversion, 829 VectorPrintOpConversion>(ctx, converter); 830 } 831 832 namespace { 833 struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { 834 void runOnModule() override; 835 }; 836 } // namespace 837 838 void LowerVectorToLLVMPass::runOnModule() { 839 // Convert to the LLVM IR dialect using the converter defined above. 840 OwningRewritePatternList patterns; 841 LLVMTypeConverter converter(&getContext()); 842 populateVectorToLLVMConversionPatterns(converter, patterns); 843 populateStdToLLVMConversionPatterns(converter, patterns); 844 845 ConversionTarget target(getContext()); 846 target.addLegalDialect<LLVM::LLVMDialect>(); 847 target.addDynamicallyLegalOp<FuncOp>( 848 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 849 if (failed( 850 applyPartialConversion(getModule(), target, patterns, &converter))) { 851 signalPassFailure(); 852 } 853 } 854 855 OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() { 856 return new LowerVectorToLLVMPass(); 857 } 858 859 static PassRegistration<LowerVectorToLLVMPass> 860 pass("convert-vector-to-llvm", 861 "Lower the operations from the vector dialect into the LLVM dialect"); 862