1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 20 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/VectorOps/VectorOps.h" 23 #include "mlir/IR/Attributes.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/MLIRContext.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/Operation.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/IR/StandardTypes.h" 30 #include "mlir/IR/Types.h" 31 #include "mlir/Pass/Pass.h" 32 #include "mlir/Pass/PassManager.h" 33 #include "mlir/Transforms/DialectConversion.h" 34 #include "mlir/Transforms/Passes.h" 35 36 #include "llvm/IR/DerivedTypes.h" 37 #include "llvm/IR/Module.h" 38 #include "llvm/IR/Type.h" 39 #include "llvm/Support/Allocator.h" 40 #include "llvm/Support/ErrorHandling.h" 41 42 using namespace mlir; 43 44 template <typename T> 45 static LLVM::LLVMType getPtrToElementType(T containerType, 46 LLVMTypeConverter &lowering) { 47 return lowering.convertType(containerType.getElementType()) 48 .template cast<LLVM::LLVMType>() 49 .getPointerTo(); 50 } 51 52 // Helper to reduce vector type by one rank at front. 53 static VectorType reducedVectorTypeFront(VectorType tp) { 54 assert((tp.getRank() > 1) && "unlowerable vector type"); 55 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 56 } 57 58 // Helper to reduce vector type by *all* but one rank at back. 59 static VectorType reducedVectorTypeBack(VectorType tp) { 60 assert((tp.getRank() > 1) && "unlowerable vector type"); 61 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 62 } 63 64 // Helper that picks the proper sequence for inserting. 65 static Value *insertOne(ConversionPatternRewriter &rewriter, 66 LLVMTypeConverter &lowering, Location loc, Value *val1, 67 Value *val2, Type llvmType, int64_t rank, int64_t pos) { 68 if (rank == 1) { 69 auto idxType = rewriter.getIndexType(); 70 auto constant = rewriter.create<LLVM::ConstantOp>( 71 loc, lowering.convertType(idxType), 72 rewriter.getIntegerAttr(idxType, pos)); 73 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 74 constant); 75 } 76 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 77 rewriter.getI64ArrayAttr(pos)); 78 } 79 80 // Helper that picks the proper sequence for extracting. 81 static Value *extractOne(ConversionPatternRewriter &rewriter, 82 LLVMTypeConverter &lowering, Location loc, Value *val, 83 Type llvmType, int64_t rank, int64_t pos) { 84 if (rank == 1) { 85 auto idxType = rewriter.getIndexType(); 86 auto constant = rewriter.create<LLVM::ConstantOp>( 87 loc, lowering.convertType(idxType), 88 rewriter.getIntegerAttr(idxType, pos)); 89 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 90 constant); 91 } 92 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 93 rewriter.getI64ArrayAttr(pos)); 94 } 95 96 class VectorBroadcastOpConversion : public LLVMOpLowering { 97 public: 98 explicit VectorBroadcastOpConversion(MLIRContext *context, 99 LLVMTypeConverter &typeConverter) 100 : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context, 101 typeConverter) {} 102 103 PatternMatchResult 104 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 105 ConversionPatternRewriter &rewriter) const override { 106 auto broadcastOp = cast<vector::BroadcastOp>(op); 107 VectorType dstVectorType = broadcastOp.getVectorType(); 108 if (lowering.convertType(dstVectorType) == nullptr) 109 return matchFailure(); 110 // Rewrite when the full vector type can be lowered (which 111 // implies all 'reduced' types can be lowered too). 112 auto adaptor = vector::BroadcastOpOperandAdaptor(operands); 113 VectorType srcVectorType = 114 broadcastOp.getSourceType().dyn_cast<VectorType>(); 115 rewriter.replaceOp( 116 op, expandRanks(adaptor.source(), // source value to be expanded 117 op->getLoc(), // location of original broadcast 118 srcVectorType, dstVectorType, rewriter)); 119 return matchSuccess(); 120 } 121 122 private: 123 // Expands the given source value over all the ranks, as defined 124 // by the source and destination type (a null source type denotes 125 // expansion from a scalar value into a vector). 126 // 127 // TODO(ajcbik): consider replacing this one-pattern lowering 128 // with a two-pattern lowering using other vector 129 // ops once all insert/extract/shuffle operations 130 // are available with lowering implemention. 131 // 132 Value *expandRanks(Value *value, Location loc, VectorType srcVectorType, 133 VectorType dstVectorType, 134 ConversionPatternRewriter &rewriter) const { 135 assert((dstVectorType != nullptr) && "invalid result type in broadcast"); 136 // Determine rank of source and destination. 137 int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; 138 int64_t dstRank = dstVectorType.getRank(); 139 int64_t curDim = dstVectorType.getDimSize(0); 140 if (srcRank < dstRank) 141 // Duplicate this rank. 142 return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 143 curDim, rewriter); 144 // If all trailing dimensions are the same, the broadcast consists of 145 // simply passing through the source value and we are done. Otherwise, 146 // any non-matching dimension forces a stretch along this rank. 147 assert((srcVectorType != nullptr) && (srcRank > 0) && 148 (srcRank == dstRank) && "invalid rank in broadcast"); 149 for (int64_t r = 0; r < dstRank; r++) { 150 if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) { 151 return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 152 curDim, rewriter); 153 } 154 } 155 return value; 156 } 157 158 // Picks the best way to duplicate a single rank. For the 1-D case, a 159 // single insert-elt/shuffle is the most efficient expansion. For higher 160 // dimensions, however, we need dim x insert-values on a new broadcast 161 // with one less leading dimension, which will be lowered "recursively" 162 // to matching LLVM IR. 163 // For example: 164 // v = broadcast s : f32 to vector<4x2xf32> 165 // becomes: 166 // x = broadcast s : f32 to vector<2xf32> 167 // v = [x,x,x,x] 168 // becomes: 169 // x = [s,s] 170 // v = [x,x,x,x] 171 Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType, 172 VectorType dstVectorType, int64_t rank, int64_t dim, 173 ConversionPatternRewriter &rewriter) const { 174 Type llvmType = lowering.convertType(dstVectorType); 175 assert((llvmType != nullptr) && "unlowerable vector type"); 176 if (rank == 1) { 177 Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); 178 Value *expand = 179 insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0); 180 SmallVector<int32_t, 4> zeroValues(dim, 0); 181 return rewriter.create<LLVM::ShuffleVectorOp>( 182 loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); 183 } 184 Value *expand = 185 expandRanks(value, loc, srcVectorType, 186 reducedVectorTypeFront(dstVectorType), rewriter); 187 Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 188 for (int64_t d = 0; d < dim; ++d) { 189 result = 190 insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); 191 } 192 return result; 193 } 194 195 // Picks the best way to stretch a single rank. For the 1-D case, a 196 // single insert-elt/shuffle is the most efficient expansion when at 197 // a stretch. Otherwise, every dimension needs to be expanded 198 // individually and individually inserted in the resulting vector. 199 // For example: 200 // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32> 201 // becomes: 202 // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32> 203 // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32> 204 // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32> 205 // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32> 206 // v = [a,b,c,d] 207 // becomes: 208 // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32> 209 // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> 210 // a = [x, y] 211 // etc. 212 Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType, 213 VectorType dstVectorType, int64_t rank, int64_t dim, 214 ConversionPatternRewriter &rewriter) const { 215 Type llvmType = lowering.convertType(dstVectorType); 216 assert((llvmType != nullptr) && "unlowerable vector type"); 217 Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 218 bool atStretch = dim != srcVectorType.getDimSize(0); 219 if (rank == 1) { 220 assert(atStretch); 221 Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); 222 Value *one = 223 extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0); 224 Value *expand = 225 insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0); 226 SmallVector<int32_t, 4> zeroValues(dim, 0); 227 return rewriter.create<LLVM::ShuffleVectorOp>( 228 loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); 229 } 230 VectorType redSrcType = reducedVectorTypeFront(srcVectorType); 231 VectorType redDstType = reducedVectorTypeFront(dstVectorType); 232 Type redLlvmType = lowering.convertType(redSrcType); 233 for (int64_t d = 0; d < dim; ++d) { 234 int64_t pos = atStretch ? 0 : d; 235 Value *one = 236 extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos); 237 Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); 238 result = 239 insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); 240 } 241 return result; 242 } 243 }; 244 245 class VectorShuffleOpConversion : public LLVMOpLowering { 246 public: 247 explicit VectorShuffleOpConversion(MLIRContext *context, 248 LLVMTypeConverter &typeConverter) 249 : LLVMOpLowering(vector::ShuffleOp::getOperationName(), context, 250 typeConverter) {} 251 252 PatternMatchResult 253 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 254 ConversionPatternRewriter &rewriter) const override { 255 auto loc = op->getLoc(); 256 auto adaptor = vector::ShuffleOpOperandAdaptor(operands); 257 auto shuffleOp = cast<vector::ShuffleOp>(op); 258 auto v1Type = shuffleOp.getV1VectorType(); 259 auto v2Type = shuffleOp.getV2VectorType(); 260 auto vectorType = shuffleOp.getVectorType(); 261 Type llvmType = lowering.convertType(vectorType); 262 auto maskArrayAttr = shuffleOp.mask(); 263 264 // Bail if result type cannot be lowered. 265 if (!llvmType) 266 return matchFailure(); 267 268 // Get rank and dimension sizes. 269 int64_t rank = vectorType.getRank(); 270 assert(v1Type.getRank() == rank); 271 assert(v2Type.getRank() == rank); 272 int64_t v1Dim = v1Type.getDimSize(0); 273 274 // For rank 1, where both operands have *exactly* the same vector type, 275 // there is direct shuffle support in LLVM. Use it! 276 if (rank == 1 && v1Type == v2Type) { 277 Value *shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 278 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 279 rewriter.replaceOp(op, shuffle); 280 return matchSuccess(); 281 } 282 283 // For all other cases, insert the individual values individually. 284 Value *insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 285 int64_t insPos = 0; 286 for (auto en : llvm::enumerate(maskArrayAttr)) { 287 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 288 Value *value = adaptor.v1(); 289 if (extPos >= v1Dim) { 290 extPos -= v1Dim; 291 value = adaptor.v2(); 292 } 293 Value *extract = 294 extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos); 295 insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType, 296 rank, insPos++); 297 } 298 rewriter.replaceOp(op, insert); 299 return matchSuccess(); 300 } 301 }; 302 303 class VectorExtractElementOpConversion : public LLVMOpLowering { 304 public: 305 explicit VectorExtractElementOpConversion(MLIRContext *context, 306 LLVMTypeConverter &typeConverter) 307 : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, 308 typeConverter) {} 309 310 PatternMatchResult 311 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 312 ConversionPatternRewriter &rewriter) const override { 313 auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); 314 auto extractEltOp = cast<vector::ExtractElementOp>(op); 315 auto vectorType = extractEltOp.getVectorType(); 316 auto llvmType = lowering.convertType(vectorType.getElementType()); 317 318 // Bail if result type cannot be lowered. 319 if (!llvmType) 320 return matchFailure(); 321 322 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 323 op, llvmType, adaptor.vector(), adaptor.position()); 324 return matchSuccess(); 325 } 326 }; 327 328 class VectorExtractOpConversion : public LLVMOpLowering { 329 public: 330 explicit VectorExtractOpConversion(MLIRContext *context, 331 LLVMTypeConverter &typeConverter) 332 : LLVMOpLowering(vector::ExtractOp::getOperationName(), context, 333 typeConverter) {} 334 335 PatternMatchResult 336 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 337 ConversionPatternRewriter &rewriter) const override { 338 auto loc = op->getLoc(); 339 auto adaptor = vector::ExtractOpOperandAdaptor(operands); 340 auto extractOp = cast<vector::ExtractOp>(op); 341 auto vectorType = extractOp.getVectorType(); 342 auto resultType = extractOp.getResult()->getType(); 343 auto llvmResultType = lowering.convertType(resultType); 344 auto positionArrayAttr = extractOp.position(); 345 346 // Bail if result type cannot be lowered. 347 if (!llvmResultType) 348 return matchFailure(); 349 350 // One-shot extraction of vector from array (only requires extractvalue). 351 if (resultType.isa<VectorType>()) { 352 Value *extracted = rewriter.create<LLVM::ExtractValueOp>( 353 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 354 rewriter.replaceOp(op, extracted); 355 return matchSuccess(); 356 } 357 358 // Potential extraction of 1-D vector from array. 359 auto *context = op->getContext(); 360 Value *extracted = adaptor.vector(); 361 auto positionAttrs = positionArrayAttr.getValue(); 362 if (positionAttrs.size() > 1) { 363 auto oneDVectorType = reducedVectorTypeBack(vectorType); 364 auto nMinusOnePositionAttrs = 365 ArrayAttr::get(positionAttrs.drop_back(), context); 366 extracted = rewriter.create<LLVM::ExtractValueOp>( 367 loc, lowering.convertType(oneDVectorType), extracted, 368 nMinusOnePositionAttrs); 369 } 370 371 // Remaining extraction of element from 1-D LLVM vector 372 auto position = positionAttrs.back().cast<IntegerAttr>(); 373 auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); 374 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 375 extracted = 376 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 377 rewriter.replaceOp(op, extracted); 378 379 return matchSuccess(); 380 } 381 }; 382 383 class VectorInsertElementOpConversion : public LLVMOpLowering { 384 public: 385 explicit VectorInsertElementOpConversion(MLIRContext *context, 386 LLVMTypeConverter &typeConverter) 387 : LLVMOpLowering(vector::InsertElementOp::getOperationName(), context, 388 typeConverter) {} 389 390 PatternMatchResult 391 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 392 ConversionPatternRewriter &rewriter) const override { 393 auto adaptor = vector::InsertElementOpOperandAdaptor(operands); 394 auto insertEltOp = cast<vector::InsertElementOp>(op); 395 auto vectorType = insertEltOp.getDestVectorType(); 396 auto llvmType = lowering.convertType(vectorType); 397 398 // Bail if result type cannot be lowered. 399 if (!llvmType) 400 return matchFailure(); 401 402 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 403 op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 404 return matchSuccess(); 405 } 406 }; 407 408 class VectorInsertOpConversion : public LLVMOpLowering { 409 public: 410 explicit VectorInsertOpConversion(MLIRContext *context, 411 LLVMTypeConverter &typeConverter) 412 : LLVMOpLowering(vector::InsertOp::getOperationName(), context, 413 typeConverter) {} 414 415 PatternMatchResult 416 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 417 ConversionPatternRewriter &rewriter) const override { 418 auto loc = op->getLoc(); 419 auto adaptor = vector::InsertOpOperandAdaptor(operands); 420 auto insertOp = cast<vector::InsertOp>(op); 421 auto sourceType = insertOp.getSourceType(); 422 auto destVectorType = insertOp.getDestVectorType(); 423 auto llvmResultType = lowering.convertType(destVectorType); 424 auto positionArrayAttr = insertOp.position(); 425 426 // Bail if result type cannot be lowered. 427 if (!llvmResultType) 428 return matchFailure(); 429 430 // One-shot insertion of a vector into an array (only requires insertvalue). 431 if (sourceType.isa<VectorType>()) { 432 Value *inserted = rewriter.create<LLVM::InsertValueOp>( 433 loc, llvmResultType, adaptor.dest(), adaptor.source(), 434 positionArrayAttr); 435 rewriter.replaceOp(op, inserted); 436 return matchSuccess(); 437 } 438 439 // Potential extraction of 1-D vector from array. 440 auto *context = op->getContext(); 441 Value *extracted = adaptor.dest(); 442 auto positionAttrs = positionArrayAttr.getValue(); 443 auto position = positionAttrs.back().cast<IntegerAttr>(); 444 auto oneDVectorType = destVectorType; 445 if (positionAttrs.size() > 1) { 446 oneDVectorType = reducedVectorTypeBack(destVectorType); 447 auto nMinusOnePositionAttrs = 448 ArrayAttr::get(positionAttrs.drop_back(), context); 449 extracted = rewriter.create<LLVM::ExtractValueOp>( 450 loc, lowering.convertType(oneDVectorType), extracted, 451 nMinusOnePositionAttrs); 452 } 453 454 // Insertion of an element into a 1-D LLVM vector. 455 auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); 456 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 457 Value *inserted = rewriter.create<LLVM::InsertElementOp>( 458 loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), 459 constant); 460 461 // Potential insertion of resulting 1-D vector into array. 462 if (positionAttrs.size() > 1) { 463 auto nMinusOnePositionAttrs = 464 ArrayAttr::get(positionAttrs.drop_back(), context); 465 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 466 adaptor.dest(), inserted, 467 nMinusOnePositionAttrs); 468 } 469 470 rewriter.replaceOp(op, inserted); 471 return matchSuccess(); 472 } 473 }; 474 475 class VectorOuterProductOpConversion : public LLVMOpLowering { 476 public: 477 explicit VectorOuterProductOpConversion(MLIRContext *context, 478 LLVMTypeConverter &typeConverter) 479 : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, 480 typeConverter) {} 481 482 PatternMatchResult 483 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 484 ConversionPatternRewriter &rewriter) const override { 485 auto loc = op->getLoc(); 486 auto adaptor = vector::OuterProductOpOperandAdaptor(operands); 487 auto *ctx = op->getContext(); 488 auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); 489 auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); 490 auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); 491 auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); 492 auto llvmArrayOfVectType = lowering.convertType( 493 cast<vector::OuterProductOp>(op).getResult()->getType()); 494 Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); 495 Value *a = adaptor.lhs(), *b = adaptor.rhs(); 496 Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); 497 SmallVector<Value *, 8> lhs, accs; 498 lhs.reserve(rankLHS); 499 accs.reserve(rankLHS); 500 for (unsigned d = 0, e = rankLHS; d < e; ++d) { 501 // shufflevector explicitly requires i32. 502 auto attr = rewriter.getI32IntegerAttr(d); 503 SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); 504 auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); 505 Value *aD = nullptr, *accD = nullptr; 506 // 1. Broadcast the element a[d] into vector aD. 507 aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); 508 // 2. If acc is present, extract 1-d vector acc[d] into accD. 509 if (acc) 510 accD = rewriter.create<LLVM::ExtractValueOp>( 511 loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); 512 // 3. Compute aD outer b (plus accD, if relevant). 513 Value *aOuterbD = 514 accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD) 515 .getResult() 516 : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); 517 // 4. Insert as value `d` in the descriptor. 518 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, 519 desc, aOuterbD, 520 rewriter.getI64ArrayAttr(d)); 521 } 522 rewriter.replaceOp(op, desc); 523 return matchSuccess(); 524 } 525 }; 526 527 class VectorTypeCastOpConversion : public LLVMOpLowering { 528 public: 529 explicit VectorTypeCastOpConversion(MLIRContext *context, 530 LLVMTypeConverter &typeConverter) 531 : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context, 532 typeConverter) {} 533 534 PatternMatchResult 535 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 536 ConversionPatternRewriter &rewriter) const override { 537 auto loc = op->getLoc(); 538 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 539 MemRefType sourceMemRefType = 540 castOp.getOperand()->getType().cast<MemRefType>(); 541 MemRefType targetMemRefType = 542 castOp.getResult()->getType().cast<MemRefType>(); 543 544 // Only static shape casts supported atm. 545 if (!sourceMemRefType.hasStaticShape() || 546 !targetMemRefType.hasStaticShape()) 547 return matchFailure(); 548 549 auto llvmSourceDescriptorTy = 550 operands[0]->getType().dyn_cast<LLVM::LLVMType>(); 551 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 552 return matchFailure(); 553 MemRefDescriptor sourceMemRef(operands[0]); 554 555 auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) 556 .dyn_cast_or_null<LLVM::LLVMType>(); 557 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 558 return matchFailure(); 559 560 int64_t offset; 561 SmallVector<int64_t, 4> strides; 562 auto successStrides = 563 getStridesAndOffset(sourceMemRefType, strides, offset); 564 bool isContiguous = (strides.back() == 1); 565 if (isContiguous) { 566 auto sizes = sourceMemRefType.getShape(); 567 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 568 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 569 isContiguous = false; 570 break; 571 } 572 } 573 } 574 // Only contiguous source tensors supported atm. 575 if (failed(successStrides) || !isContiguous) 576 return matchFailure(); 577 578 auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); 579 580 // Create descriptor. 581 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 582 Type llvmTargetElementTy = desc.getElementType(); 583 // Set allocated ptr. 584 Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc); 585 allocated = 586 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 587 desc.setAllocatedPtr(rewriter, loc, allocated); 588 // Set aligned ptr. 589 Value *ptr = sourceMemRef.alignedPtr(rewriter, loc); 590 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 591 desc.setAlignedPtr(rewriter, loc, ptr); 592 // Fill offset 0. 593 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 594 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 595 desc.setOffset(rewriter, loc, zero); 596 597 // Fill size and stride descriptors in memref. 598 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 599 int64_t index = indexedSize.index(); 600 auto sizeAttr = 601 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 602 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 603 desc.setSize(rewriter, loc, index, size); 604 auto strideAttr = 605 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 606 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 607 desc.setStride(rewriter, loc, index, stride); 608 } 609 610 rewriter.replaceOp(op, {desc}); 611 return matchSuccess(); 612 } 613 }; 614 615 class VectorPrintOpConversion : public LLVMOpLowering { 616 public: 617 explicit VectorPrintOpConversion(MLIRContext *context, 618 LLVMTypeConverter &typeConverter) 619 : LLVMOpLowering(vector::PrintOp::getOperationName(), context, 620 typeConverter) {} 621 622 // Proof-of-concept lowering implementation that relies on a small 623 // runtime support library, which only needs to provide a few 624 // printing methods (single value for all data types, opening/closing 625 // bracket, comma, newline). The lowering fully unrolls a vector 626 // in terms of these elementary printing operations. The advantage 627 // of this approach is that the library can remain unaware of all 628 // low-level implementation details of vectors while still supporting 629 // output of any shaped and dimensioned vector. Due to full unrolling, 630 // this approach is less suited for very large vectors though. 631 // 632 // TODO(ajcbik): rely solely on libc in future? something else? 633 // 634 PatternMatchResult 635 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 636 ConversionPatternRewriter &rewriter) const override { 637 auto printOp = cast<vector::PrintOp>(op); 638 auto adaptor = vector::PrintOpOperandAdaptor(operands); 639 Type printType = printOp.getPrintType(); 640 641 if (lowering.convertType(printType) == nullptr) 642 return matchFailure(); 643 644 // Make sure element type has runtime support (currently just Float/Double). 645 VectorType vectorType = printType.dyn_cast<VectorType>(); 646 Type eltType = vectorType ? vectorType.getElementType() : printType; 647 int64_t rank = vectorType ? vectorType.getRank() : 0; 648 Operation *printer; 649 if (eltType.isF32()) 650 printer = getPrintFloat(op); 651 else if (eltType.isF64()) 652 printer = getPrintDouble(op); 653 else 654 return matchFailure(); 655 656 // Unroll vector into elementary print calls. 657 emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); 658 emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 659 rewriter.eraseOp(op); 660 return matchSuccess(); 661 } 662 663 private: 664 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 665 Value *value, VectorType vectorType, Operation *printer, 666 int64_t rank) const { 667 Location loc = op->getLoc(); 668 if (rank == 0) { 669 emitCall(rewriter, loc, printer, value); 670 return; 671 } 672 673 emitCall(rewriter, loc, getPrintOpen(op)); 674 Operation *printComma = getPrintComma(op); 675 int64_t dim = vectorType.getDimSize(0); 676 for (int64_t d = 0; d < dim; ++d) { 677 auto reducedType = 678 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 679 auto llvmType = lowering.convertType( 680 rank > 1 ? reducedType : vectorType.getElementType()); 681 Value *nestedVal = 682 extractOne(rewriter, lowering, loc, value, llvmType, rank, d); 683 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); 684 if (d != dim - 1) 685 emitCall(rewriter, loc, printComma); 686 } 687 emitCall(rewriter, loc, getPrintClose(op)); 688 } 689 690 // Helper to emit a call. 691 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 692 Operation *ref, ValueRange params = ValueRange()) { 693 rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, 694 rewriter.getSymbolRefAttr(ref), params); 695 } 696 697 // Helper for printer method declaration (first hit) and lookup. 698 static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, 699 StringRef name, ArrayRef<LLVM::LLVMType> params) { 700 auto module = op->getParentOfType<ModuleOp>(); 701 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 702 if (func) 703 return func; 704 OpBuilder moduleBuilder(module.getBodyRegion()); 705 return moduleBuilder.create<LLVM::LLVMFuncOp>( 706 op->getLoc(), name, 707 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), 708 params, /*isVarArg=*/false)); 709 } 710 711 // Helpers for method names. 712 Operation *getPrintFloat(Operation *op) const { 713 LLVM::LLVMDialect *dialect = lowering.getDialect(); 714 return getPrint(op, dialect, "print_f32", 715 LLVM::LLVMType::getFloatTy(dialect)); 716 } 717 Operation *getPrintDouble(Operation *op) const { 718 LLVM::LLVMDialect *dialect = lowering.getDialect(); 719 return getPrint(op, dialect, "print_f64", 720 LLVM::LLVMType::getDoubleTy(dialect)); 721 } 722 Operation *getPrintOpen(Operation *op) const { 723 return getPrint(op, lowering.getDialect(), "print_open", {}); 724 } 725 Operation *getPrintClose(Operation *op) const { 726 return getPrint(op, lowering.getDialect(), "print_close", {}); 727 } 728 Operation *getPrintComma(Operation *op) const { 729 return getPrint(op, lowering.getDialect(), "print_comma", {}); 730 } 731 Operation *getPrintNewline(Operation *op) const { 732 return getPrint(op, lowering.getDialect(), "print_newline", {}); 733 } 734 }; 735 736 /// Populate the given list with patterns that convert from Vector to LLVM. 737 void mlir::populateVectorToLLVMConversionPatterns( 738 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 739 patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, 740 VectorExtractElementOpConversion, VectorExtractOpConversion, 741 VectorInsertElementOpConversion, VectorInsertOpConversion, 742 VectorOuterProductOpConversion, VectorTypeCastOpConversion, 743 VectorPrintOpConversion>(converter.getDialect()->getContext(), 744 converter); 745 } 746 747 namespace { 748 struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { 749 void runOnModule() override; 750 }; 751 } // namespace 752 753 void LowerVectorToLLVMPass::runOnModule() { 754 // Convert to the LLVM IR dialect using the converter defined above. 755 OwningRewritePatternList patterns; 756 LLVMTypeConverter converter(&getContext()); 757 populateVectorToLLVMConversionPatterns(converter, patterns); 758 populateStdToLLVMConversionPatterns(converter, patterns); 759 760 ConversionTarget target(getContext()); 761 target.addLegalDialect<LLVM::LLVMDialect>(); 762 target.addDynamicallyLegalOp<FuncOp>( 763 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 764 if (failed( 765 applyPartialConversion(getModule(), target, patterns, &converter))) { 766 signalPassFailure(); 767 } 768 } 769 770 OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() { 771 return new LowerVectorToLLVMPass(); 772 } 773 774 static PassRegistration<LowerVectorToLLVMPass> 775 pass("convert-vector-to-llvm", 776 "Lower the operations from the vector dialect into the LLVM dialect"); 777