1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/VectorOps/VectorOps.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/MLIRContext.h" 19 #include "mlir/IR/Module.h" 20 #include "mlir/IR/Operation.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/StandardTypes.h" 23 #include "mlir/IR/Types.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Pass/PassManager.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "mlir/Transforms/Passes.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Module.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/Support/Allocator.h" 32 #include "llvm/Support/ErrorHandling.h" 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 template <typename T> 38 static LLVM::LLVMType getPtrToElementType(T containerType, 39 LLVMTypeConverter &typeConverter) { 40 return typeConverter.convertType(containerType.getElementType()) 41 .template cast<LLVM::LLVMType>() 42 .getPointerTo(); 43 } 44 45 // Helper to reduce vector type by one rank at front. 46 static VectorType reducedVectorTypeFront(VectorType tp) { 47 assert((tp.getRank() > 1) && "unlowerable vector type"); 48 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 49 } 50 51 // Helper to reduce vector type by *all* but one rank at back. 52 static VectorType reducedVectorTypeBack(VectorType tp) { 53 assert((tp.getRank() > 1) && "unlowerable vector type"); 54 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 55 } 56 57 // Helper that picks the proper sequence for inserting. 58 static Value insertOne(ConversionPatternRewriter &rewriter, 59 LLVMTypeConverter &typeConverter, Location loc, 60 Value val1, Value val2, Type llvmType, int64_t rank, 61 int64_t pos) { 62 if (rank == 1) { 63 auto idxType = rewriter.getIndexType(); 64 auto constant = rewriter.create<LLVM::ConstantOp>( 65 loc, typeConverter.convertType(idxType), 66 rewriter.getIntegerAttr(idxType, pos)); 67 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 68 constant); 69 } 70 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 71 rewriter.getI64ArrayAttr(pos)); 72 } 73 74 // Helper that picks the proper sequence for inserting. 75 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 76 Value into, int64_t offset) { 77 auto vectorType = into.getType().cast<VectorType>(); 78 if (vectorType.getRank() > 1) 79 return rewriter.create<InsertOp>(loc, from, into, offset); 80 return rewriter.create<vector::InsertElementOp>( 81 loc, vectorType, from, into, 82 rewriter.create<ConstantIndexOp>(loc, offset)); 83 } 84 85 // Helper that picks the proper sequence for extracting. 86 static Value extractOne(ConversionPatternRewriter &rewriter, 87 LLVMTypeConverter &typeConverter, Location loc, 88 Value val, Type llvmType, int64_t rank, int64_t pos) { 89 if (rank == 1) { 90 auto idxType = rewriter.getIndexType(); 91 auto constant = rewriter.create<LLVM::ConstantOp>( 92 loc, typeConverter.convertType(idxType), 93 rewriter.getIntegerAttr(idxType, pos)); 94 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 95 constant); 96 } 97 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 98 rewriter.getI64ArrayAttr(pos)); 99 } 100 101 // Helper that picks the proper sequence for extracting. 102 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 103 int64_t offset) { 104 auto vectorType = vector.getType().cast<VectorType>(); 105 if (vectorType.getRank() > 1) 106 return rewriter.create<ExtractOp>(loc, vector, offset); 107 return rewriter.create<vector::ExtractElementOp>( 108 loc, vectorType.getElementType(), vector, 109 rewriter.create<ConstantIndexOp>(loc, offset)); 110 } 111 112 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 113 // TODO(rriddle): Better support for attribute subtype forwarding + slicing. 114 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 115 unsigned dropFront = 0, 116 unsigned dropBack = 0) { 117 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 118 auto range = arrayAttr.getAsRange<IntegerAttr>(); 119 SmallVector<int64_t, 4> res; 120 res.reserve(arrayAttr.size() - dropFront - dropBack); 121 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 122 it != eit; ++it) 123 res.push_back((*it).getValue().getSExtValue()); 124 return res; 125 } 126 127 namespace { 128 129 class VectorBroadcastOpConversion : public ConvertToLLVMPattern { 130 public: 131 explicit VectorBroadcastOpConversion(MLIRContext *context, 132 LLVMTypeConverter &typeConverter) 133 : ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context, 134 typeConverter) {} 135 136 PatternMatchResult 137 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 138 ConversionPatternRewriter &rewriter) const override { 139 auto broadcastOp = cast<vector::BroadcastOp>(op); 140 VectorType dstVectorType = broadcastOp.getVectorType(); 141 if (typeConverter.convertType(dstVectorType) == nullptr) 142 return matchFailure(); 143 // Rewrite when the full vector type can be lowered (which 144 // implies all 'reduced' types can be lowered too). 145 auto adaptor = vector::BroadcastOpOperandAdaptor(operands); 146 VectorType srcVectorType = 147 broadcastOp.getSourceType().dyn_cast<VectorType>(); 148 rewriter.replaceOp( 149 op, expandRanks(adaptor.source(), // source value to be expanded 150 op->getLoc(), // location of original broadcast 151 srcVectorType, dstVectorType, rewriter)); 152 return matchSuccess(); 153 } 154 155 private: 156 // Expands the given source value over all the ranks, as defined 157 // by the source and destination type (a null source type denotes 158 // expansion from a scalar value into a vector). 159 // 160 // TODO(ajcbik): consider replacing this one-pattern lowering 161 // with a two-pattern lowering using other vector 162 // ops once all insert/extract/shuffle operations 163 // are available with lowering implementation. 164 // 165 Value expandRanks(Value value, Location loc, VectorType srcVectorType, 166 VectorType dstVectorType, 167 ConversionPatternRewriter &rewriter) const { 168 assert((dstVectorType != nullptr) && "invalid result type in broadcast"); 169 // Determine rank of source and destination. 170 int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; 171 int64_t dstRank = dstVectorType.getRank(); 172 int64_t curDim = dstVectorType.getDimSize(0); 173 if (srcRank < dstRank) 174 // Duplicate this rank. 175 return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 176 curDim, rewriter); 177 // If all trailing dimensions are the same, the broadcast consists of 178 // simply passing through the source value and we are done. Otherwise, 179 // any non-matching dimension forces a stretch along this rank. 180 assert((srcVectorType != nullptr) && (srcRank > 0) && 181 (srcRank == dstRank) && "invalid rank in broadcast"); 182 for (int64_t r = 0; r < dstRank; r++) { 183 if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) { 184 return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 185 curDim, rewriter); 186 } 187 } 188 return value; 189 } 190 191 // Picks the best way to duplicate a single rank. For the 1-D case, a 192 // single insert-elt/shuffle is the most efficient expansion. For higher 193 // dimensions, however, we need dim x insert-values on a new broadcast 194 // with one less leading dimension, which will be lowered "recursively" 195 // to matching LLVM IR. 196 // For example: 197 // v = broadcast s : f32 to vector<4x2xf32> 198 // becomes: 199 // x = broadcast s : f32 to vector<2xf32> 200 // v = [x,x,x,x] 201 // becomes: 202 // x = [s,s] 203 // v = [x,x,x,x] 204 Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType, 205 VectorType dstVectorType, int64_t rank, int64_t dim, 206 ConversionPatternRewriter &rewriter) const { 207 Type llvmType = typeConverter.convertType(dstVectorType); 208 assert((llvmType != nullptr) && "unlowerable vector type"); 209 if (rank == 1) { 210 Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); 211 Value expand = insertOne(rewriter, typeConverter, loc, undef, value, 212 llvmType, rank, 0); 213 SmallVector<int32_t, 4> zeroValues(dim, 0); 214 return rewriter.create<LLVM::ShuffleVectorOp>( 215 loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); 216 } 217 Value expand = expandRanks(value, loc, srcVectorType, 218 reducedVectorTypeFront(dstVectorType), rewriter); 219 Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 220 for (int64_t d = 0; d < dim; ++d) { 221 result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType, 222 rank, d); 223 } 224 return result; 225 } 226 227 // Picks the best way to stretch a single rank. For the 1-D case, a 228 // single insert-elt/shuffle is the most efficient expansion when at 229 // a stretch. Otherwise, every dimension needs to be expanded 230 // individually and individually inserted in the resulting vector. 231 // For example: 232 // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32> 233 // becomes: 234 // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32> 235 // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32> 236 // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32> 237 // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32> 238 // v = [a,b,c,d] 239 // becomes: 240 // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32> 241 // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> 242 // a = [x, y] 243 // etc. 244 Value stretchOneRank(Value value, Location loc, VectorType srcVectorType, 245 VectorType dstVectorType, int64_t rank, int64_t dim, 246 ConversionPatternRewriter &rewriter) const { 247 Type llvmType = typeConverter.convertType(dstVectorType); 248 assert((llvmType != nullptr) && "unlowerable vector type"); 249 Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 250 bool atStretch = dim != srcVectorType.getDimSize(0); 251 if (rank == 1) { 252 assert(atStretch); 253 Type redLlvmType = 254 typeConverter.convertType(dstVectorType.getElementType()); 255 Value one = 256 extractOne(rewriter, typeConverter, loc, value, redLlvmType, rank, 0); 257 Value expand = insertOne(rewriter, typeConverter, loc, result, one, 258 llvmType, rank, 0); 259 SmallVector<int32_t, 4> zeroValues(dim, 0); 260 return rewriter.create<LLVM::ShuffleVectorOp>( 261 loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); 262 } 263 VectorType redSrcType = reducedVectorTypeFront(srcVectorType); 264 VectorType redDstType = reducedVectorTypeFront(dstVectorType); 265 Type redLlvmType = typeConverter.convertType(redSrcType); 266 for (int64_t d = 0; d < dim; ++d) { 267 int64_t pos = atStretch ? 0 : d; 268 Value one = extractOne(rewriter, typeConverter, loc, value, redLlvmType, 269 rank, pos); 270 Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); 271 result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType, 272 rank, d); 273 } 274 return result; 275 } 276 }; 277 278 /// Conversion pattern for a vector.matrix_multiply. 279 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 280 class VectorMatmulOpConversion : public ConvertToLLVMPattern { 281 public: 282 explicit VectorMatmulOpConversion(MLIRContext *context, 283 LLVMTypeConverter &typeConverter) 284 : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, 285 typeConverter) {} 286 287 PatternMatchResult 288 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 289 ConversionPatternRewriter &rewriter) const override { 290 auto matmulOp = cast<vector::MatmulOp>(op); 291 auto adaptor = vector::MatmulOpOperandAdaptor(operands); 292 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 293 op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), 294 adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), 295 matmulOp.rhs_columns()); 296 return matchSuccess(); 297 } 298 }; 299 300 class VectorReductionOpConversion : public ConvertToLLVMPattern { 301 public: 302 explicit VectorReductionOpConversion(MLIRContext *context, 303 LLVMTypeConverter &typeConverter) 304 : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, 305 typeConverter) {} 306 307 PatternMatchResult 308 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 309 ConversionPatternRewriter &rewriter) const override { 310 auto reductionOp = cast<vector::ReductionOp>(op); 311 auto kind = reductionOp.kind(); 312 Type eltType = reductionOp.dest().getType(); 313 Type llvmType = typeConverter.convertType(eltType); 314 if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) { 315 // Integer reductions: add/mul/min/max/and/or/xor. 316 if (kind == "add") 317 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>( 318 op, llvmType, operands[0]); 319 else if (kind == "mul") 320 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>( 321 op, llvmType, operands[0]); 322 else if (kind == "min") 323 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>( 324 op, llvmType, operands[0]); 325 else if (kind == "max") 326 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>( 327 op, llvmType, operands[0]); 328 else if (kind == "and") 329 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>( 330 op, llvmType, operands[0]); 331 else if (kind == "or") 332 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>( 333 op, llvmType, operands[0]); 334 else if (kind == "xor") 335 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>( 336 op, llvmType, operands[0]); 337 else 338 return matchFailure(); 339 return matchSuccess(); 340 341 } else if (eltType.isF32() || eltType.isF64()) { 342 // Floating-point reductions: add/mul/min/max 343 if (kind == "add") { 344 // Optional accumulator (or zero). 345 Value acc = operands.size() > 1 ? operands[1] 346 : rewriter.create<LLVM::ConstantOp>( 347 op->getLoc(), llvmType, 348 rewriter.getZeroAttr(eltType)); 349 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>( 350 op, llvmType, acc, operands[0]); 351 } else if (kind == "mul") { 352 // Optional accumulator (or one). 353 Value acc = operands.size() > 1 354 ? operands[1] 355 : rewriter.create<LLVM::ConstantOp>( 356 op->getLoc(), llvmType, 357 rewriter.getFloatAttr(eltType, 1.0)); 358 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>( 359 op, llvmType, acc, operands[0]); 360 } else if (kind == "min") 361 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>( 362 op, llvmType, operands[0]); 363 else if (kind == "max") 364 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>( 365 op, llvmType, operands[0]); 366 else 367 return matchFailure(); 368 return matchSuccess(); 369 } 370 return matchFailure(); 371 } 372 }; 373 374 class VectorShuffleOpConversion : public ConvertToLLVMPattern { 375 public: 376 explicit VectorShuffleOpConversion(MLIRContext *context, 377 LLVMTypeConverter &typeConverter) 378 : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, 379 typeConverter) {} 380 381 PatternMatchResult 382 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 383 ConversionPatternRewriter &rewriter) const override { 384 auto loc = op->getLoc(); 385 auto adaptor = vector::ShuffleOpOperandAdaptor(operands); 386 auto shuffleOp = cast<vector::ShuffleOp>(op); 387 auto v1Type = shuffleOp.getV1VectorType(); 388 auto v2Type = shuffleOp.getV2VectorType(); 389 auto vectorType = shuffleOp.getVectorType(); 390 Type llvmType = typeConverter.convertType(vectorType); 391 auto maskArrayAttr = shuffleOp.mask(); 392 393 // Bail if result type cannot be lowered. 394 if (!llvmType) 395 return matchFailure(); 396 397 // Get rank and dimension sizes. 398 int64_t rank = vectorType.getRank(); 399 assert(v1Type.getRank() == rank); 400 assert(v2Type.getRank() == rank); 401 int64_t v1Dim = v1Type.getDimSize(0); 402 403 // For rank 1, where both operands have *exactly* the same vector type, 404 // there is direct shuffle support in LLVM. Use it! 405 if (rank == 1 && v1Type == v2Type) { 406 Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 407 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 408 rewriter.replaceOp(op, shuffle); 409 return matchSuccess(); 410 } 411 412 // For all other cases, insert the individual values individually. 413 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 414 int64_t insPos = 0; 415 for (auto en : llvm::enumerate(maskArrayAttr)) { 416 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 417 Value value = adaptor.v1(); 418 if (extPos >= v1Dim) { 419 extPos -= v1Dim; 420 value = adaptor.v2(); 421 } 422 Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, 423 rank, extPos); 424 insert = insertOne(rewriter, typeConverter, loc, insert, extract, 425 llvmType, rank, insPos++); 426 } 427 rewriter.replaceOp(op, insert); 428 return matchSuccess(); 429 } 430 }; 431 432 class VectorExtractElementOpConversion : public ConvertToLLVMPattern { 433 public: 434 explicit VectorExtractElementOpConversion(MLIRContext *context, 435 LLVMTypeConverter &typeConverter) 436 : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), 437 context, typeConverter) {} 438 439 PatternMatchResult 440 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 441 ConversionPatternRewriter &rewriter) const override { 442 auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); 443 auto extractEltOp = cast<vector::ExtractElementOp>(op); 444 auto vectorType = extractEltOp.getVectorType(); 445 auto llvmType = typeConverter.convertType(vectorType.getElementType()); 446 447 // Bail if result type cannot be lowered. 448 if (!llvmType) 449 return matchFailure(); 450 451 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 452 op, llvmType, adaptor.vector(), adaptor.position()); 453 return matchSuccess(); 454 } 455 }; 456 457 class VectorExtractOpConversion : public ConvertToLLVMPattern { 458 public: 459 explicit VectorExtractOpConversion(MLIRContext *context, 460 LLVMTypeConverter &typeConverter) 461 : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, 462 typeConverter) {} 463 464 PatternMatchResult 465 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 466 ConversionPatternRewriter &rewriter) const override { 467 auto loc = op->getLoc(); 468 auto adaptor = vector::ExtractOpOperandAdaptor(operands); 469 auto extractOp = cast<vector::ExtractOp>(op); 470 auto vectorType = extractOp.getVectorType(); 471 auto resultType = extractOp.getResult().getType(); 472 auto llvmResultType = typeConverter.convertType(resultType); 473 auto positionArrayAttr = extractOp.position(); 474 475 // Bail if result type cannot be lowered. 476 if (!llvmResultType) 477 return matchFailure(); 478 479 // One-shot extraction of vector from array (only requires extractvalue). 480 if (resultType.isa<VectorType>()) { 481 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 482 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 483 rewriter.replaceOp(op, extracted); 484 return matchSuccess(); 485 } 486 487 // Potential extraction of 1-D vector from array. 488 auto *context = op->getContext(); 489 Value extracted = adaptor.vector(); 490 auto positionAttrs = positionArrayAttr.getValue(); 491 if (positionAttrs.size() > 1) { 492 auto oneDVectorType = reducedVectorTypeBack(vectorType); 493 auto nMinusOnePositionAttrs = 494 ArrayAttr::get(positionAttrs.drop_back(), context); 495 extracted = rewriter.create<LLVM::ExtractValueOp>( 496 loc, typeConverter.convertType(oneDVectorType), extracted, 497 nMinusOnePositionAttrs); 498 } 499 500 // Remaining extraction of element from 1-D LLVM vector 501 auto position = positionAttrs.back().cast<IntegerAttr>(); 502 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 503 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 504 extracted = 505 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 506 rewriter.replaceOp(op, extracted); 507 508 return matchSuccess(); 509 } 510 }; 511 512 /// Conversion pattern that turns a vector.fma on a 1-D vector 513 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 514 /// This does not match vectors of n >= 2 rank. 515 /// 516 /// Example: 517 /// ``` 518 /// vector.fma %a, %a, %a : vector<8xf32> 519 /// ``` 520 /// is converted to: 521 /// ``` 522 /// llvm.intr.fma %va, %va, %va: 523 /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 524 /// -> !llvm<"<8 x float>"> 525 /// ``` 526 class VectorFMAOp1DConversion : public ConvertToLLVMPattern { 527 public: 528 explicit VectorFMAOp1DConversion(MLIRContext *context, 529 LLVMTypeConverter &typeConverter) 530 : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, 531 typeConverter) {} 532 533 PatternMatchResult 534 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 535 ConversionPatternRewriter &rewriter) const override { 536 auto adaptor = vector::FMAOpOperandAdaptor(operands); 537 vector::FMAOp fmaOp = cast<vector::FMAOp>(op); 538 VectorType vType = fmaOp.getVectorType(); 539 if (vType.getRank() != 1) 540 return matchFailure(); 541 rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(), 542 adaptor.acc()); 543 return matchSuccess(); 544 } 545 }; 546 547 class VectorInsertElementOpConversion : public ConvertToLLVMPattern { 548 public: 549 explicit VectorInsertElementOpConversion(MLIRContext *context, 550 LLVMTypeConverter &typeConverter) 551 : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), 552 context, typeConverter) {} 553 554 PatternMatchResult 555 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 556 ConversionPatternRewriter &rewriter) const override { 557 auto adaptor = vector::InsertElementOpOperandAdaptor(operands); 558 auto insertEltOp = cast<vector::InsertElementOp>(op); 559 auto vectorType = insertEltOp.getDestVectorType(); 560 auto llvmType = typeConverter.convertType(vectorType); 561 562 // Bail if result type cannot be lowered. 563 if (!llvmType) 564 return matchFailure(); 565 566 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 567 op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 568 return matchSuccess(); 569 } 570 }; 571 572 class VectorInsertOpConversion : public ConvertToLLVMPattern { 573 public: 574 explicit VectorInsertOpConversion(MLIRContext *context, 575 LLVMTypeConverter &typeConverter) 576 : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, 577 typeConverter) {} 578 579 PatternMatchResult 580 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 581 ConversionPatternRewriter &rewriter) const override { 582 auto loc = op->getLoc(); 583 auto adaptor = vector::InsertOpOperandAdaptor(operands); 584 auto insertOp = cast<vector::InsertOp>(op); 585 auto sourceType = insertOp.getSourceType(); 586 auto destVectorType = insertOp.getDestVectorType(); 587 auto llvmResultType = typeConverter.convertType(destVectorType); 588 auto positionArrayAttr = insertOp.position(); 589 590 // Bail if result type cannot be lowered. 591 if (!llvmResultType) 592 return matchFailure(); 593 594 // One-shot insertion of a vector into an array (only requires insertvalue). 595 if (sourceType.isa<VectorType>()) { 596 Value inserted = rewriter.create<LLVM::InsertValueOp>( 597 loc, llvmResultType, adaptor.dest(), adaptor.source(), 598 positionArrayAttr); 599 rewriter.replaceOp(op, inserted); 600 return matchSuccess(); 601 } 602 603 // Potential extraction of 1-D vector from array. 604 auto *context = op->getContext(); 605 Value extracted = adaptor.dest(); 606 auto positionAttrs = positionArrayAttr.getValue(); 607 auto position = positionAttrs.back().cast<IntegerAttr>(); 608 auto oneDVectorType = destVectorType; 609 if (positionAttrs.size() > 1) { 610 oneDVectorType = reducedVectorTypeBack(destVectorType); 611 auto nMinusOnePositionAttrs = 612 ArrayAttr::get(positionAttrs.drop_back(), context); 613 extracted = rewriter.create<LLVM::ExtractValueOp>( 614 loc, typeConverter.convertType(oneDVectorType), extracted, 615 nMinusOnePositionAttrs); 616 } 617 618 // Insertion of an element into a 1-D LLVM vector. 619 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 620 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 621 Value inserted = rewriter.create<LLVM::InsertElementOp>( 622 loc, typeConverter.convertType(oneDVectorType), extracted, 623 adaptor.source(), constant); 624 625 // Potential insertion of resulting 1-D vector into array. 626 if (positionAttrs.size() > 1) { 627 auto nMinusOnePositionAttrs = 628 ArrayAttr::get(positionAttrs.drop_back(), context); 629 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 630 adaptor.dest(), inserted, 631 nMinusOnePositionAttrs); 632 } 633 634 rewriter.replaceOp(op, inserted); 635 return matchSuccess(); 636 } 637 }; 638 639 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 640 /// 641 /// Example: 642 /// ``` 643 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 644 /// ``` 645 /// is rewritten into: 646 /// ``` 647 /// %r = splat %f0: vector<2x4xf32> 648 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 649 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 650 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 651 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 652 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 653 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 654 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 655 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 656 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 657 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 658 /// // %r3 holds the final value. 659 /// ``` 660 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 661 public: 662 using OpRewritePattern<FMAOp>::OpRewritePattern; 663 664 PatternMatchResult matchAndRewrite(FMAOp op, 665 PatternRewriter &rewriter) const override { 666 auto vType = op.getVectorType(); 667 if (vType.getRank() < 2) 668 return matchFailure(); 669 670 auto loc = op.getLoc(); 671 auto elemType = vType.getElementType(); 672 Value zero = rewriter.create<ConstantOp>(loc, elemType, 673 rewriter.getZeroAttr(elemType)); 674 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 675 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 676 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 677 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 678 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 679 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 680 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 681 } 682 rewriter.replaceOp(op, desc); 683 return matchSuccess(); 684 } 685 }; 686 687 // When ranks are different, InsertStridedSlice needs to extract a properly 688 // ranked vector from the destination vector into which to insert. This pattern 689 // only takes care of this part and forwards the rest of the conversion to 690 // another pattern that converts InsertStridedSlice for operands of the same 691 // rank. 692 // 693 // RewritePattern for InsertStridedSliceOp where source and destination vectors 694 // have different ranks. In this case: 695 // 1. the proper subvector is extracted from the destination vector 696 // 2. a new InsertStridedSlice op is created to insert the source in the 697 // destination subvector 698 // 3. the destination subvector is inserted back in the proper place 699 // 4. the op is replaced by the result of step 3. 700 // The new InsertStridedSlice from step 2. will be picked up by a 701 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 702 class VectorInsertStridedSliceOpDifferentRankRewritePattern 703 : public OpRewritePattern<InsertStridedSliceOp> { 704 public: 705 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 706 707 PatternMatchResult matchAndRewrite(InsertStridedSliceOp op, 708 PatternRewriter &rewriter) const override { 709 auto srcType = op.getSourceVectorType(); 710 auto dstType = op.getDestVectorType(); 711 712 if (op.offsets().getValue().empty()) 713 return matchFailure(); 714 715 auto loc = op.getLoc(); 716 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 717 assert(rankDiff >= 0); 718 if (rankDiff == 0) 719 return matchFailure(); 720 721 int64_t rankRest = dstType.getRank() - rankDiff; 722 // Extract / insert the subvector of matching rank and InsertStridedSlice 723 // on it. 724 Value extracted = 725 rewriter.create<ExtractOp>(loc, op.dest(), 726 getI64SubArray(op.offsets(), /*dropFront=*/0, 727 /*dropFront=*/rankRest)); 728 // A different pattern will kick in for InsertStridedSlice with matching 729 // ranks. 730 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 731 loc, op.source(), extracted, 732 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 733 getI64SubArray(op.strides(), /*dropFront=*/0)); 734 rewriter.replaceOpWithNewOp<InsertOp>( 735 op, stridedSliceInnerOp.getResult(), op.dest(), 736 getI64SubArray(op.offsets(), /*dropFront=*/0, 737 /*dropFront=*/rankRest)); 738 return matchSuccess(); 739 } 740 }; 741 742 // RewritePattern for InsertStridedSliceOp where source and destination vectors 743 // have the same rank. In this case, we reduce 744 // 1. the proper subvector is extracted from the destination vector 745 // 2. a new InsertStridedSlice op is created to insert the source in the 746 // destination subvector 747 // 3. the destination subvector is inserted back in the proper place 748 // 4. the op is replaced by the result of step 3. 749 // The new InsertStridedSlice from step 2. will be picked up by a 750 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 751 class VectorInsertStridedSliceOpSameRankRewritePattern 752 : public OpRewritePattern<InsertStridedSliceOp> { 753 public: 754 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 755 756 PatternMatchResult matchAndRewrite(InsertStridedSliceOp op, 757 PatternRewriter &rewriter) const override { 758 auto srcType = op.getSourceVectorType(); 759 auto dstType = op.getDestVectorType(); 760 761 if (op.offsets().getValue().empty()) 762 return matchFailure(); 763 764 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 765 assert(rankDiff >= 0); 766 if (rankDiff != 0) 767 return matchFailure(); 768 769 if (srcType == dstType) { 770 rewriter.replaceOp(op, op.source()); 771 return matchSuccess(); 772 } 773 774 int64_t offset = 775 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 776 int64_t size = srcType.getShape().front(); 777 int64_t stride = 778 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 779 780 auto loc = op.getLoc(); 781 Value res = op.dest(); 782 // For each slice of the source vector along the most major dimension. 783 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 784 off += stride, ++idx) { 785 // 1. extract the proper subvector (or element) from source 786 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 787 if (extractedSource.getType().isa<VectorType>()) { 788 // 2. If we have a vector, extract the proper subvector from destination 789 // Otherwise we are at the element level and no need to recurse. 790 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 791 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 792 // smaller rank. 793 InsertStridedSliceOp insertStridedSliceOp = 794 rewriter.create<InsertStridedSliceOp>( 795 loc, extractedSource, extractedDest, 796 getI64SubArray(op.offsets(), /* dropFront=*/1), 797 getI64SubArray(op.strides(), /* dropFront=*/1)); 798 // Call matchAndRewrite recursively from within the pattern. This 799 // circumvents the current limitation that a given pattern cannot 800 // be called multiple times by the PatternRewrite infrastructure (to 801 // avoid infinite recursion, but in this case, infinite recursion 802 // cannot happen because the rank is strictly decreasing). 803 // TODO(rriddle, nicolasvasilache) Implement something like a hook for 804 // a potential function that must decrease and allow the same pattern 805 // multiple times. 806 auto success = matchAndRewrite(insertStridedSliceOp, rewriter); 807 (void)success; 808 assert(success && "Unexpected failure"); 809 extractedSource = insertStridedSliceOp; 810 } 811 // 4. Insert the extractedSource into the res vector. 812 res = insertOne(rewriter, loc, extractedSource, res, off); 813 } 814 815 rewriter.replaceOp(op, res); 816 return matchSuccess(); 817 } 818 }; 819 820 class VectorOuterProductOpConversion : public ConvertToLLVMPattern { 821 public: 822 explicit VectorOuterProductOpConversion(MLIRContext *context, 823 LLVMTypeConverter &typeConverter) 824 : ConvertToLLVMPattern(vector::OuterProductOp::getOperationName(), 825 context, typeConverter) {} 826 827 PatternMatchResult 828 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 829 ConversionPatternRewriter &rewriter) const override { 830 auto loc = op->getLoc(); 831 auto adaptor = vector::OuterProductOpOperandAdaptor(operands); 832 auto *ctx = op->getContext(); 833 auto vLHS = adaptor.lhs().getType().cast<LLVM::LLVMType>(); 834 auto vRHS = adaptor.rhs().getType().cast<LLVM::LLVMType>(); 835 auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); 836 auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); 837 auto llvmArrayOfVectType = typeConverter.convertType( 838 cast<vector::OuterProductOp>(op).getResult().getType()); 839 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); 840 Value a = adaptor.lhs(), b = adaptor.rhs(); 841 Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); 842 SmallVector<Value, 8> lhs, accs; 843 lhs.reserve(rankLHS); 844 accs.reserve(rankLHS); 845 for (unsigned d = 0, e = rankLHS; d < e; ++d) { 846 // shufflevector explicitly requires i32. 847 auto attr = rewriter.getI32IntegerAttr(d); 848 SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); 849 auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); 850 Value aD = nullptr, accD = nullptr; 851 // 1. Broadcast the element a[d] into vector aD. 852 aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); 853 // 2. If acc is present, extract 1-d vector acc[d] into accD. 854 if (acc) 855 accD = rewriter.create<LLVM::ExtractValueOp>( 856 loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); 857 // 3. Compute aD outer b (plus accD, if relevant). 858 Value aOuterbD = 859 accD 860 ? rewriter.create<LLVM::FMAOp>(loc, vRHS, aD, b, accD).getResult() 861 : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); 862 // 4. Insert as value `d` in the descriptor. 863 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, 864 desc, aOuterbD, 865 rewriter.getI64ArrayAttr(d)); 866 } 867 rewriter.replaceOp(op, desc); 868 return matchSuccess(); 869 } 870 }; 871 872 class VectorTypeCastOpConversion : public ConvertToLLVMPattern { 873 public: 874 explicit VectorTypeCastOpConversion(MLIRContext *context, 875 LLVMTypeConverter &typeConverter) 876 : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, 877 typeConverter) {} 878 879 PatternMatchResult 880 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 881 ConversionPatternRewriter &rewriter) const override { 882 auto loc = op->getLoc(); 883 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 884 MemRefType sourceMemRefType = 885 castOp.getOperand().getType().cast<MemRefType>(); 886 MemRefType targetMemRefType = 887 castOp.getResult().getType().cast<MemRefType>(); 888 889 // Only static shape casts supported atm. 890 if (!sourceMemRefType.hasStaticShape() || 891 !targetMemRefType.hasStaticShape()) 892 return matchFailure(); 893 894 auto llvmSourceDescriptorTy = 895 operands[0].getType().dyn_cast<LLVM::LLVMType>(); 896 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 897 return matchFailure(); 898 MemRefDescriptor sourceMemRef(operands[0]); 899 900 auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) 901 .dyn_cast_or_null<LLVM::LLVMType>(); 902 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 903 return matchFailure(); 904 905 int64_t offset; 906 SmallVector<int64_t, 4> strides; 907 auto successStrides = 908 getStridesAndOffset(sourceMemRefType, strides, offset); 909 bool isContiguous = (strides.back() == 1); 910 if (isContiguous) { 911 auto sizes = sourceMemRefType.getShape(); 912 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 913 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 914 isContiguous = false; 915 break; 916 } 917 } 918 } 919 // Only contiguous source tensors supported atm. 920 if (failed(successStrides) || !isContiguous) 921 return matchFailure(); 922 923 auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 924 925 // Create descriptor. 926 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 927 Type llvmTargetElementTy = desc.getElementType(); 928 // Set allocated ptr. 929 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 930 allocated = 931 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 932 desc.setAllocatedPtr(rewriter, loc, allocated); 933 // Set aligned ptr. 934 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 935 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 936 desc.setAlignedPtr(rewriter, loc, ptr); 937 // Fill offset 0. 938 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 939 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 940 desc.setOffset(rewriter, loc, zero); 941 942 // Fill size and stride descriptors in memref. 943 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 944 int64_t index = indexedSize.index(); 945 auto sizeAttr = 946 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 947 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 948 desc.setSize(rewriter, loc, index, size); 949 auto strideAttr = 950 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 951 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 952 desc.setStride(rewriter, loc, index, stride); 953 } 954 955 rewriter.replaceOp(op, {desc}); 956 return matchSuccess(); 957 } 958 }; 959 960 class VectorPrintOpConversion : public ConvertToLLVMPattern { 961 public: 962 explicit VectorPrintOpConversion(MLIRContext *context, 963 LLVMTypeConverter &typeConverter) 964 : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, 965 typeConverter) {} 966 967 // Proof-of-concept lowering implementation that relies on a small 968 // runtime support library, which only needs to provide a few 969 // printing methods (single value for all data types, opening/closing 970 // bracket, comma, newline). The lowering fully unrolls a vector 971 // in terms of these elementary printing operations. The advantage 972 // of this approach is that the library can remain unaware of all 973 // low-level implementation details of vectors while still supporting 974 // output of any shaped and dimensioned vector. Due to full unrolling, 975 // this approach is less suited for very large vectors though. 976 // 977 // TODO(ajcbik): rely solely on libc in future? something else? 978 // 979 PatternMatchResult 980 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 981 ConversionPatternRewriter &rewriter) const override { 982 auto printOp = cast<vector::PrintOp>(op); 983 auto adaptor = vector::PrintOpOperandAdaptor(operands); 984 Type printType = printOp.getPrintType(); 985 986 if (typeConverter.convertType(printType) == nullptr) 987 return matchFailure(); 988 989 // Make sure element type has runtime support (currently just Float/Double). 990 VectorType vectorType = printType.dyn_cast<VectorType>(); 991 Type eltType = vectorType ? vectorType.getElementType() : printType; 992 int64_t rank = vectorType ? vectorType.getRank() : 0; 993 Operation *printer; 994 if (eltType.isSignlessInteger(32)) 995 printer = getPrintI32(op); 996 else if (eltType.isSignlessInteger(64)) 997 printer = getPrintI64(op); 998 else if (eltType.isF32()) 999 printer = getPrintFloat(op); 1000 else if (eltType.isF64()) 1001 printer = getPrintDouble(op); 1002 else 1003 return matchFailure(); 1004 1005 // Unroll vector into elementary print calls. 1006 emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); 1007 emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 1008 rewriter.eraseOp(op); 1009 return matchSuccess(); 1010 } 1011 1012 private: 1013 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1014 Value value, VectorType vectorType, Operation *printer, 1015 int64_t rank) const { 1016 Location loc = op->getLoc(); 1017 if (rank == 0) { 1018 emitCall(rewriter, loc, printer, value); 1019 return; 1020 } 1021 1022 emitCall(rewriter, loc, getPrintOpen(op)); 1023 Operation *printComma = getPrintComma(op); 1024 int64_t dim = vectorType.getDimSize(0); 1025 for (int64_t d = 0; d < dim; ++d) { 1026 auto reducedType = 1027 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1028 auto llvmType = typeConverter.convertType( 1029 rank > 1 ? reducedType : vectorType.getElementType()); 1030 Value nestedVal = 1031 extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); 1032 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); 1033 if (d != dim - 1) 1034 emitCall(rewriter, loc, printComma); 1035 } 1036 emitCall(rewriter, loc, getPrintClose(op)); 1037 } 1038 1039 // Helper to emit a call. 1040 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1041 Operation *ref, ValueRange params = ValueRange()) { 1042 rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, 1043 rewriter.getSymbolRefAttr(ref), params); 1044 } 1045 1046 // Helper for printer method declaration (first hit) and lookup. 1047 static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, 1048 StringRef name, ArrayRef<LLVM::LLVMType> params) { 1049 auto module = op->getParentOfType<ModuleOp>(); 1050 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1051 if (func) 1052 return func; 1053 OpBuilder moduleBuilder(module.getBodyRegion()); 1054 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1055 op->getLoc(), name, 1056 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), 1057 params, /*isVarArg=*/false)); 1058 } 1059 1060 // Helpers for method names. 1061 Operation *getPrintI32(Operation *op) const { 1062 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1063 return getPrint(op, dialect, "print_i32", 1064 LLVM::LLVMType::getInt32Ty(dialect)); 1065 } 1066 Operation *getPrintI64(Operation *op) const { 1067 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1068 return getPrint(op, dialect, "print_i64", 1069 LLVM::LLVMType::getInt64Ty(dialect)); 1070 } 1071 Operation *getPrintFloat(Operation *op) const { 1072 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1073 return getPrint(op, dialect, "print_f32", 1074 LLVM::LLVMType::getFloatTy(dialect)); 1075 } 1076 Operation *getPrintDouble(Operation *op) const { 1077 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1078 return getPrint(op, dialect, "print_f64", 1079 LLVM::LLVMType::getDoubleTy(dialect)); 1080 } 1081 Operation *getPrintOpen(Operation *op) const { 1082 return getPrint(op, typeConverter.getDialect(), "print_open", {}); 1083 } 1084 Operation *getPrintClose(Operation *op) const { 1085 return getPrint(op, typeConverter.getDialect(), "print_close", {}); 1086 } 1087 Operation *getPrintComma(Operation *op) const { 1088 return getPrint(op, typeConverter.getDialect(), "print_comma", {}); 1089 } 1090 Operation *getPrintNewline(Operation *op) const { 1091 return getPrint(op, typeConverter.getDialect(), "print_newline", {}); 1092 } 1093 }; 1094 1095 /// Progressive lowering of StridedSliceOp to either: 1096 /// 1. extractelement + insertelement for the 1-D case 1097 /// 2. extract + optional strided_slice + insert for the n-D case. 1098 class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> { 1099 public: 1100 using OpRewritePattern<StridedSliceOp>::OpRewritePattern; 1101 1102 PatternMatchResult matchAndRewrite(StridedSliceOp op, 1103 PatternRewriter &rewriter) const override { 1104 auto dstType = op.getResult().getType().cast<VectorType>(); 1105 1106 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1107 1108 int64_t offset = 1109 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1110 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1111 int64_t stride = 1112 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1113 1114 auto loc = op.getLoc(); 1115 auto elemType = dstType.getElementType(); 1116 assert(elemType.isSignlessIntOrIndexOrFloat()); 1117 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1118 rewriter.getZeroAttr(elemType)); 1119 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1120 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1121 off += stride, ++idx) { 1122 Value extracted = extractOne(rewriter, loc, op.vector(), off); 1123 if (op.offsets().getValue().size() > 1) { 1124 StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>( 1125 loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), 1126 getI64SubArray(op.sizes(), /* dropFront=*/1), 1127 getI64SubArray(op.strides(), /* dropFront=*/1)); 1128 // Call matchAndRewrite recursively from within the pattern. This 1129 // circumvents the current limitation that a given pattern cannot 1130 // be called multiple times by the PatternRewrite infrastructure (to 1131 // avoid infinite recursion, but in this case, infinite recursion 1132 // cannot happen because the rank is strictly decreasing). 1133 // TODO(rriddle, nicolasvasilache) Implement something like a hook for 1134 // a potential function that must decrease and allow the same pattern 1135 // multiple times. 1136 auto success = matchAndRewrite(stridedSliceOp, rewriter); 1137 (void)success; 1138 assert(success && "Unexpected failure"); 1139 extracted = stridedSliceOp; 1140 } 1141 res = insertOne(rewriter, loc, extracted, res, idx); 1142 } 1143 rewriter.replaceOp(op, {res}); 1144 return matchSuccess(); 1145 } 1146 }; 1147 1148 } // namespace 1149 1150 /// Populate the given list with patterns that convert from Vector to LLVM. 1151 void mlir::populateVectorToLLVMConversionPatterns( 1152 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1153 MLIRContext *ctx = converter.getDialect()->getContext(); 1154 patterns.insert<VectorFMAOpNDRewritePattern, 1155 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1156 VectorInsertStridedSliceOpSameRankRewritePattern, 1157 VectorStridedSliceOpConversion>(ctx); 1158 patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion, 1159 VectorShuffleOpConversion, VectorExtractElementOpConversion, 1160 VectorExtractOpConversion, VectorFMAOp1DConversion, 1161 VectorInsertElementOpConversion, VectorInsertOpConversion, 1162 VectorOuterProductOpConversion, VectorTypeCastOpConversion, 1163 VectorPrintOpConversion>(ctx, converter); 1164 } 1165 1166 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1167 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1168 MLIRContext *ctx = converter.getDialect()->getContext(); 1169 patterns.insert<VectorMatmulOpConversion>(ctx, converter); 1170 } 1171 1172 namespace { 1173 struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { 1174 void runOnModule() override; 1175 }; 1176 } // namespace 1177 1178 void LowerVectorToLLVMPass::runOnModule() { 1179 // Perform progressive lowering of operations on "slices" and 1180 // all contraction operations. Also applies folding and DCE. 1181 { 1182 OwningRewritePatternList patterns; 1183 populateVectorSlicesLoweringPatterns(patterns, &getContext()); 1184 populateVectorContractLoweringPatterns(patterns, &getContext()); 1185 applyPatternsGreedily(getModule(), patterns); 1186 } 1187 1188 // Convert to the LLVM IR dialect. 1189 LLVMTypeConverter converter(&getContext()); 1190 OwningRewritePatternList patterns; 1191 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1192 populateVectorToLLVMConversionPatterns(converter, patterns); 1193 populateStdToLLVMConversionPatterns(converter, patterns); 1194 1195 LLVMConversionTarget target(getContext()); 1196 target.addDynamicallyLegalOp<FuncOp>( 1197 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1198 if (failed( 1199 applyPartialConversion(getModule(), target, patterns, &converter))) { 1200 signalPassFailure(); 1201 } 1202 } 1203 1204 OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() { 1205 return new LowerVectorToLLVMPass(); 1206 } 1207 1208 static PassRegistration<LowerVectorToLLVMPass> 1209 pass("convert-vector-to-llvm", 1210 "Lower the operations from the vector dialect into the LLVM dialect"); 1211