1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/VectorOps/VectorOps.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/MLIRContext.h" 19 #include "mlir/IR/Module.h" 20 #include "mlir/IR/Operation.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/StandardTypes.h" 23 #include "mlir/IR/Types.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Pass/PassManager.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "mlir/Transforms/Passes.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Module.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/Support/Allocator.h" 32 #include "llvm/Support/ErrorHandling.h" 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 template <typename T> 38 static LLVM::LLVMType getPtrToElementType(T containerType, 39 LLVMTypeConverter &typeConverter) { 40 return typeConverter.convertType(containerType.getElementType()) 41 .template cast<LLVM::LLVMType>() 42 .getPointerTo(); 43 } 44 45 // Helper to reduce vector type by one rank at front. 46 static VectorType reducedVectorTypeFront(VectorType tp) { 47 assert((tp.getRank() > 1) && "unlowerable vector type"); 48 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 49 } 50 51 // Helper to reduce vector type by *all* but one rank at back. 52 static VectorType reducedVectorTypeBack(VectorType tp) { 53 assert((tp.getRank() > 1) && "unlowerable vector type"); 54 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 55 } 56 57 // Helper that picks the proper sequence for inserting. 58 static Value insertOne(ConversionPatternRewriter &rewriter, 59 LLVMTypeConverter &typeConverter, Location loc, 60 Value val1, Value val2, Type llvmType, int64_t rank, 61 int64_t pos) { 62 if (rank == 1) { 63 auto idxType = rewriter.getIndexType(); 64 auto constant = rewriter.create<LLVM::ConstantOp>( 65 loc, typeConverter.convertType(idxType), 66 rewriter.getIntegerAttr(idxType, pos)); 67 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 68 constant); 69 } 70 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 71 rewriter.getI64ArrayAttr(pos)); 72 } 73 74 // Helper that picks the proper sequence for inserting. 75 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 76 Value into, int64_t offset) { 77 auto vectorType = into.getType().cast<VectorType>(); 78 if (vectorType.getRank() > 1) 79 return rewriter.create<InsertOp>(loc, from, into, offset); 80 return rewriter.create<vector::InsertElementOp>( 81 loc, vectorType, from, into, 82 rewriter.create<ConstantIndexOp>(loc, offset)); 83 } 84 85 // Helper that picks the proper sequence for extracting. 86 static Value extractOne(ConversionPatternRewriter &rewriter, 87 LLVMTypeConverter &typeConverter, Location loc, 88 Value val, Type llvmType, int64_t rank, int64_t pos) { 89 if (rank == 1) { 90 auto idxType = rewriter.getIndexType(); 91 auto constant = rewriter.create<LLVM::ConstantOp>( 92 loc, typeConverter.convertType(idxType), 93 rewriter.getIntegerAttr(idxType, pos)); 94 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 95 constant); 96 } 97 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 98 rewriter.getI64ArrayAttr(pos)); 99 } 100 101 // Helper that picks the proper sequence for extracting. 102 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 103 int64_t offset) { 104 auto vectorType = vector.getType().cast<VectorType>(); 105 if (vectorType.getRank() > 1) 106 return rewriter.create<ExtractOp>(loc, vector, offset); 107 return rewriter.create<vector::ExtractElementOp>( 108 loc, vectorType.getElementType(), vector, 109 rewriter.create<ConstantIndexOp>(loc, offset)); 110 } 111 112 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 113 // TODO(rriddle): Better support for attribute subtype forwarding + slicing. 114 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 115 unsigned dropFront = 0, 116 unsigned dropBack = 0) { 117 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 118 auto range = arrayAttr.getAsRange<IntegerAttr>(); 119 SmallVector<int64_t, 4> res; 120 res.reserve(arrayAttr.size() - dropFront - dropBack); 121 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 122 it != eit; ++it) 123 res.push_back((*it).getValue().getSExtValue()); 124 return res; 125 } 126 127 namespace { 128 129 class VectorBroadcastOpConversion : public ConvertToLLVMPattern { 130 public: 131 explicit VectorBroadcastOpConversion(MLIRContext *context, 132 LLVMTypeConverter &typeConverter) 133 : ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context, 134 typeConverter) {} 135 136 PatternMatchResult 137 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 138 ConversionPatternRewriter &rewriter) const override { 139 auto broadcastOp = cast<vector::BroadcastOp>(op); 140 VectorType dstVectorType = broadcastOp.getVectorType(); 141 if (typeConverter.convertType(dstVectorType) == nullptr) 142 return matchFailure(); 143 // Rewrite when the full vector type can be lowered (which 144 // implies all 'reduced' types can be lowered too). 145 auto adaptor = vector::BroadcastOpOperandAdaptor(operands); 146 VectorType srcVectorType = 147 broadcastOp.getSourceType().dyn_cast<VectorType>(); 148 rewriter.replaceOp( 149 op, expandRanks(adaptor.source(), // source value to be expanded 150 op->getLoc(), // location of original broadcast 151 srcVectorType, dstVectorType, rewriter)); 152 return matchSuccess(); 153 } 154 155 private: 156 // Expands the given source value over all the ranks, as defined 157 // by the source and destination type (a null source type denotes 158 // expansion from a scalar value into a vector). 159 // 160 // TODO(ajcbik): consider replacing this one-pattern lowering 161 // with a two-pattern lowering using other vector 162 // ops once all insert/extract/shuffle operations 163 // are available with lowering implementation. 164 // 165 Value expandRanks(Value value, Location loc, VectorType srcVectorType, 166 VectorType dstVectorType, 167 ConversionPatternRewriter &rewriter) const { 168 assert((dstVectorType != nullptr) && "invalid result type in broadcast"); 169 // Determine rank of source and destination. 170 int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; 171 int64_t dstRank = dstVectorType.getRank(); 172 int64_t curDim = dstVectorType.getDimSize(0); 173 if (srcRank < dstRank) 174 // Duplicate this rank. 175 return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 176 curDim, rewriter); 177 // If all trailing dimensions are the same, the broadcast consists of 178 // simply passing through the source value and we are done. Otherwise, 179 // any non-matching dimension forces a stretch along this rank. 180 assert((srcVectorType != nullptr) && (srcRank > 0) && 181 (srcRank == dstRank) && "invalid rank in broadcast"); 182 for (int64_t r = 0; r < dstRank; r++) { 183 if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) { 184 return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 185 curDim, rewriter); 186 } 187 } 188 return value; 189 } 190 191 // Picks the best way to duplicate a single rank. For the 1-D case, a 192 // single insert-elt/shuffle is the most efficient expansion. For higher 193 // dimensions, however, we need dim x insert-values on a new broadcast 194 // with one less leading dimension, which will be lowered "recursively" 195 // to matching LLVM IR. 196 // For example: 197 // v = broadcast s : f32 to vector<4x2xf32> 198 // becomes: 199 // x = broadcast s : f32 to vector<2xf32> 200 // v = [x,x,x,x] 201 // becomes: 202 // x = [s,s] 203 // v = [x,x,x,x] 204 Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType, 205 VectorType dstVectorType, int64_t rank, int64_t dim, 206 ConversionPatternRewriter &rewriter) const { 207 Type llvmType = typeConverter.convertType(dstVectorType); 208 assert((llvmType != nullptr) && "unlowerable vector type"); 209 if (rank == 1) { 210 Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); 211 Value expand = insertOne(rewriter, typeConverter, loc, undef, value, 212 llvmType, rank, 0); 213 SmallVector<int32_t, 4> zeroValues(dim, 0); 214 return rewriter.create<LLVM::ShuffleVectorOp>( 215 loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); 216 } 217 Value expand = expandRanks(value, loc, srcVectorType, 218 reducedVectorTypeFront(dstVectorType), rewriter); 219 Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 220 for (int64_t d = 0; d < dim; ++d) { 221 result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType, 222 rank, d); 223 } 224 return result; 225 } 226 227 // Picks the best way to stretch a single rank. For the 1-D case, a 228 // single insert-elt/shuffle is the most efficient expansion when at 229 // a stretch. Otherwise, every dimension needs to be expanded 230 // individually and individually inserted in the resulting vector. 231 // For example: 232 // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32> 233 // becomes: 234 // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32> 235 // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32> 236 // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32> 237 // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32> 238 // v = [a,b,c,d] 239 // becomes: 240 // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32> 241 // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> 242 // a = [x, y] 243 // etc. 244 Value stretchOneRank(Value value, Location loc, VectorType srcVectorType, 245 VectorType dstVectorType, int64_t rank, int64_t dim, 246 ConversionPatternRewriter &rewriter) const { 247 Type llvmType = typeConverter.convertType(dstVectorType); 248 assert((llvmType != nullptr) && "unlowerable vector type"); 249 Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 250 bool atStretch = dim != srcVectorType.getDimSize(0); 251 if (rank == 1) { 252 assert(atStretch); 253 Type redLlvmType = 254 typeConverter.convertType(dstVectorType.getElementType()); 255 Value one = 256 extractOne(rewriter, typeConverter, loc, value, redLlvmType, rank, 0); 257 Value expand = insertOne(rewriter, typeConverter, loc, result, one, 258 llvmType, rank, 0); 259 SmallVector<int32_t, 4> zeroValues(dim, 0); 260 return rewriter.create<LLVM::ShuffleVectorOp>( 261 loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); 262 } 263 VectorType redSrcType = reducedVectorTypeFront(srcVectorType); 264 VectorType redDstType = reducedVectorTypeFront(dstVectorType); 265 Type redLlvmType = typeConverter.convertType(redSrcType); 266 for (int64_t d = 0; d < dim; ++d) { 267 int64_t pos = atStretch ? 0 : d; 268 Value one = extractOne(rewriter, typeConverter, loc, value, redLlvmType, 269 rank, pos); 270 Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); 271 result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType, 272 rank, d); 273 } 274 return result; 275 } 276 }; 277 278 /// Conversion pattern for a vector.matrix_multiply. 279 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 280 class VectorMatmulOpConversion : public ConvertToLLVMPattern { 281 public: 282 explicit VectorMatmulOpConversion(MLIRContext *context, 283 LLVMTypeConverter &typeConverter) 284 : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, 285 typeConverter) {} 286 287 PatternMatchResult 288 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 289 ConversionPatternRewriter &rewriter) const override { 290 auto matmulOp = cast<vector::MatmulOp>(op); 291 auto adaptor = vector::MatmulOpOperandAdaptor(operands); 292 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 293 op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), 294 adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), 295 matmulOp.rhs_columns()); 296 return matchSuccess(); 297 } 298 }; 299 300 class VectorReductionOpConversion : public ConvertToLLVMPattern { 301 public: 302 explicit VectorReductionOpConversion(MLIRContext *context, 303 LLVMTypeConverter &typeConverter) 304 : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, 305 typeConverter) {} 306 307 PatternMatchResult 308 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 309 ConversionPatternRewriter &rewriter) const override { 310 auto reductionOp = cast<vector::ReductionOp>(op); 311 auto kind = reductionOp.kind(); 312 Type eltType = reductionOp.dest().getType(); 313 Type llvmType = typeConverter.convertType(eltType); 314 if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) { 315 // Integer reductions: add/mul/min/max/and/or/xor. 316 if (kind == "add") 317 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>( 318 op, llvmType, operands[0]); 319 else if (kind == "mul") 320 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>( 321 op, llvmType, operands[0]); 322 else if (kind == "min") 323 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>( 324 op, llvmType, operands[0]); 325 else if (kind == "max") 326 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>( 327 op, llvmType, operands[0]); 328 else if (kind == "and") 329 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>( 330 op, llvmType, operands[0]); 331 else if (kind == "or") 332 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>( 333 op, llvmType, operands[0]); 334 else if (kind == "xor") 335 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>( 336 op, llvmType, operands[0]); 337 else 338 return matchFailure(); 339 return matchSuccess(); 340 341 } else if (eltType.isF32() || eltType.isF64()) { 342 // Floating-point reductions: add/mul/min/max 343 if (kind == "add") { 344 // Optional accumulator (or zero). 345 Value acc = operands.size() > 1 ? operands[1] 346 : rewriter.create<LLVM::ConstantOp>( 347 op->getLoc(), llvmType, 348 rewriter.getZeroAttr(eltType)); 349 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>( 350 op, llvmType, acc, operands[0]); 351 } else if (kind == "mul") { 352 // Optional accumulator (or one). 353 Value acc = operands.size() > 1 354 ? operands[1] 355 : rewriter.create<LLVM::ConstantOp>( 356 op->getLoc(), llvmType, 357 rewriter.getFloatAttr(eltType, 1.0)); 358 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>( 359 op, llvmType, acc, operands[0]); 360 } else if (kind == "min") 361 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>( 362 op, llvmType, operands[0]); 363 else if (kind == "max") 364 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>( 365 op, llvmType, operands[0]); 366 else 367 return matchFailure(); 368 return matchSuccess(); 369 } 370 return matchFailure(); 371 } 372 }; 373 374 class VectorShuffleOpConversion : public ConvertToLLVMPattern { 375 public: 376 explicit VectorShuffleOpConversion(MLIRContext *context, 377 LLVMTypeConverter &typeConverter) 378 : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, 379 typeConverter) {} 380 381 PatternMatchResult 382 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 383 ConversionPatternRewriter &rewriter) const override { 384 auto loc = op->getLoc(); 385 auto adaptor = vector::ShuffleOpOperandAdaptor(operands); 386 auto shuffleOp = cast<vector::ShuffleOp>(op); 387 auto v1Type = shuffleOp.getV1VectorType(); 388 auto v2Type = shuffleOp.getV2VectorType(); 389 auto vectorType = shuffleOp.getVectorType(); 390 Type llvmType = typeConverter.convertType(vectorType); 391 auto maskArrayAttr = shuffleOp.mask(); 392 393 // Bail if result type cannot be lowered. 394 if (!llvmType) 395 return matchFailure(); 396 397 // Get rank and dimension sizes. 398 int64_t rank = vectorType.getRank(); 399 assert(v1Type.getRank() == rank); 400 assert(v2Type.getRank() == rank); 401 int64_t v1Dim = v1Type.getDimSize(0); 402 403 // For rank 1, where both operands have *exactly* the same vector type, 404 // there is direct shuffle support in LLVM. Use it! 405 if (rank == 1 && v1Type == v2Type) { 406 Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 407 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 408 rewriter.replaceOp(op, shuffle); 409 return matchSuccess(); 410 } 411 412 // For all other cases, insert the individual values individually. 413 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 414 int64_t insPos = 0; 415 for (auto en : llvm::enumerate(maskArrayAttr)) { 416 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 417 Value value = adaptor.v1(); 418 if (extPos >= v1Dim) { 419 extPos -= v1Dim; 420 value = adaptor.v2(); 421 } 422 Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, 423 rank, extPos); 424 insert = insertOne(rewriter, typeConverter, loc, insert, extract, 425 llvmType, rank, insPos++); 426 } 427 rewriter.replaceOp(op, insert); 428 return matchSuccess(); 429 } 430 }; 431 432 class VectorExtractElementOpConversion : public ConvertToLLVMPattern { 433 public: 434 explicit VectorExtractElementOpConversion(MLIRContext *context, 435 LLVMTypeConverter &typeConverter) 436 : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), 437 context, typeConverter) {} 438 439 PatternMatchResult 440 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 441 ConversionPatternRewriter &rewriter) const override { 442 auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); 443 auto extractEltOp = cast<vector::ExtractElementOp>(op); 444 auto vectorType = extractEltOp.getVectorType(); 445 auto llvmType = typeConverter.convertType(vectorType.getElementType()); 446 447 // Bail if result type cannot be lowered. 448 if (!llvmType) 449 return matchFailure(); 450 451 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 452 op, llvmType, adaptor.vector(), adaptor.position()); 453 return matchSuccess(); 454 } 455 }; 456 457 class VectorExtractOpConversion : public ConvertToLLVMPattern { 458 public: 459 explicit VectorExtractOpConversion(MLIRContext *context, 460 LLVMTypeConverter &typeConverter) 461 : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, 462 typeConverter) {} 463 464 PatternMatchResult 465 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 466 ConversionPatternRewriter &rewriter) const override { 467 auto loc = op->getLoc(); 468 auto adaptor = vector::ExtractOpOperandAdaptor(operands); 469 auto extractOp = cast<vector::ExtractOp>(op); 470 auto vectorType = extractOp.getVectorType(); 471 auto resultType = extractOp.getResult().getType(); 472 auto llvmResultType = typeConverter.convertType(resultType); 473 auto positionArrayAttr = extractOp.position(); 474 475 // Bail if result type cannot be lowered. 476 if (!llvmResultType) 477 return matchFailure(); 478 479 // One-shot extraction of vector from array (only requires extractvalue). 480 if (resultType.isa<VectorType>()) { 481 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 482 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 483 rewriter.replaceOp(op, extracted); 484 return matchSuccess(); 485 } 486 487 // Potential extraction of 1-D vector from array. 488 auto *context = op->getContext(); 489 Value extracted = adaptor.vector(); 490 auto positionAttrs = positionArrayAttr.getValue(); 491 if (positionAttrs.size() > 1) { 492 auto oneDVectorType = reducedVectorTypeBack(vectorType); 493 auto nMinusOnePositionAttrs = 494 ArrayAttr::get(positionAttrs.drop_back(), context); 495 extracted = rewriter.create<LLVM::ExtractValueOp>( 496 loc, typeConverter.convertType(oneDVectorType), extracted, 497 nMinusOnePositionAttrs); 498 } 499 500 // Remaining extraction of element from 1-D LLVM vector 501 auto position = positionAttrs.back().cast<IntegerAttr>(); 502 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 503 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 504 extracted = 505 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 506 rewriter.replaceOp(op, extracted); 507 508 return matchSuccess(); 509 } 510 }; 511 512 /// Conversion pattern that turns a vector.fma on a 1-D vector 513 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 514 /// This does not match vectors of n >= 2 rank. 515 /// 516 /// Example: 517 /// ``` 518 /// vector.fma %a, %a, %a : vector<8xf32> 519 /// ``` 520 /// is converted to: 521 /// ``` 522 /// llvm.intr.fma %va, %va, %va: 523 /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 524 /// -> !llvm<"<8 x float>"> 525 /// ``` 526 class VectorFMAOp1DConversion : public ConvertToLLVMPattern { 527 public: 528 explicit VectorFMAOp1DConversion(MLIRContext *context, 529 LLVMTypeConverter &typeConverter) 530 : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, 531 typeConverter) {} 532 533 PatternMatchResult 534 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 535 ConversionPatternRewriter &rewriter) const override { 536 auto adaptor = vector::FMAOpOperandAdaptor(operands); 537 vector::FMAOp fmaOp = cast<vector::FMAOp>(op); 538 VectorType vType = fmaOp.getVectorType(); 539 if (vType.getRank() != 1) 540 return matchFailure(); 541 rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(), 542 adaptor.acc()); 543 return matchSuccess(); 544 } 545 }; 546 547 class VectorInsertElementOpConversion : public ConvertToLLVMPattern { 548 public: 549 explicit VectorInsertElementOpConversion(MLIRContext *context, 550 LLVMTypeConverter &typeConverter) 551 : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), 552 context, typeConverter) {} 553 554 PatternMatchResult 555 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 556 ConversionPatternRewriter &rewriter) const override { 557 auto adaptor = vector::InsertElementOpOperandAdaptor(operands); 558 auto insertEltOp = cast<vector::InsertElementOp>(op); 559 auto vectorType = insertEltOp.getDestVectorType(); 560 auto llvmType = typeConverter.convertType(vectorType); 561 562 // Bail if result type cannot be lowered. 563 if (!llvmType) 564 return matchFailure(); 565 566 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 567 op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 568 return matchSuccess(); 569 } 570 }; 571 572 class VectorInsertOpConversion : public ConvertToLLVMPattern { 573 public: 574 explicit VectorInsertOpConversion(MLIRContext *context, 575 LLVMTypeConverter &typeConverter) 576 : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, 577 typeConverter) {} 578 579 PatternMatchResult 580 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 581 ConversionPatternRewriter &rewriter) const override { 582 auto loc = op->getLoc(); 583 auto adaptor = vector::InsertOpOperandAdaptor(operands); 584 auto insertOp = cast<vector::InsertOp>(op); 585 auto sourceType = insertOp.getSourceType(); 586 auto destVectorType = insertOp.getDestVectorType(); 587 auto llvmResultType = typeConverter.convertType(destVectorType); 588 auto positionArrayAttr = insertOp.position(); 589 590 // Bail if result type cannot be lowered. 591 if (!llvmResultType) 592 return matchFailure(); 593 594 // One-shot insertion of a vector into an array (only requires insertvalue). 595 if (sourceType.isa<VectorType>()) { 596 Value inserted = rewriter.create<LLVM::InsertValueOp>( 597 loc, llvmResultType, adaptor.dest(), adaptor.source(), 598 positionArrayAttr); 599 rewriter.replaceOp(op, inserted); 600 return matchSuccess(); 601 } 602 603 // Potential extraction of 1-D vector from array. 604 auto *context = op->getContext(); 605 Value extracted = adaptor.dest(); 606 auto positionAttrs = positionArrayAttr.getValue(); 607 auto position = positionAttrs.back().cast<IntegerAttr>(); 608 auto oneDVectorType = destVectorType; 609 if (positionAttrs.size() > 1) { 610 oneDVectorType = reducedVectorTypeBack(destVectorType); 611 auto nMinusOnePositionAttrs = 612 ArrayAttr::get(positionAttrs.drop_back(), context); 613 extracted = rewriter.create<LLVM::ExtractValueOp>( 614 loc, typeConverter.convertType(oneDVectorType), extracted, 615 nMinusOnePositionAttrs); 616 } 617 618 // Insertion of an element into a 1-D LLVM vector. 619 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 620 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 621 Value inserted = rewriter.create<LLVM::InsertElementOp>( 622 loc, typeConverter.convertType(oneDVectorType), extracted, 623 adaptor.source(), constant); 624 625 // Potential insertion of resulting 1-D vector into array. 626 if (positionAttrs.size() > 1) { 627 auto nMinusOnePositionAttrs = 628 ArrayAttr::get(positionAttrs.drop_back(), context); 629 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 630 adaptor.dest(), inserted, 631 nMinusOnePositionAttrs); 632 } 633 634 rewriter.replaceOp(op, inserted); 635 return matchSuccess(); 636 } 637 }; 638 639 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 640 /// 641 /// Example: 642 /// ``` 643 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 644 /// ``` 645 /// is rewritten into: 646 /// ``` 647 /// %r = splat %f0: vector<2x4xf32> 648 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 649 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 650 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 651 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 652 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 653 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 654 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 655 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 656 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 657 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 658 /// // %r3 holds the final value. 659 /// ``` 660 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 661 public: 662 using OpRewritePattern<FMAOp>::OpRewritePattern; 663 664 PatternMatchResult matchAndRewrite(FMAOp op, 665 PatternRewriter &rewriter) const override { 666 auto vType = op.getVectorType(); 667 if (vType.getRank() < 2) 668 return matchFailure(); 669 670 auto loc = op.getLoc(); 671 auto elemType = vType.getElementType(); 672 Value zero = rewriter.create<ConstantOp>(loc, elemType, 673 rewriter.getZeroAttr(elemType)); 674 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 675 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 676 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 677 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 678 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 679 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 680 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 681 } 682 rewriter.replaceOp(op, desc); 683 return matchSuccess(); 684 } 685 }; 686 687 // When ranks are different, InsertStridedSlice needs to extract a properly 688 // ranked vector from the destination vector into which to insert. This pattern 689 // only takes care of this part and forwards the rest of the conversion to 690 // another pattern that converts InsertStridedSlice for operands of the same 691 // rank. 692 // 693 // RewritePattern for InsertStridedSliceOp where source and destination vectors 694 // have different ranks. In this case: 695 // 1. the proper subvector is extracted from the destination vector 696 // 2. a new InsertStridedSlice op is created to insert the source in the 697 // destination subvector 698 // 3. the destination subvector is inserted back in the proper place 699 // 4. the op is replaced by the result of step 3. 700 // The new InsertStridedSlice from step 2. will be picked up by a 701 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 702 class VectorInsertStridedSliceOpDifferentRankRewritePattern 703 : public OpRewritePattern<InsertStridedSliceOp> { 704 public: 705 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 706 707 PatternMatchResult matchAndRewrite(InsertStridedSliceOp op, 708 PatternRewriter &rewriter) const override { 709 auto srcType = op.getSourceVectorType(); 710 auto dstType = op.getDestVectorType(); 711 712 if (op.offsets().getValue().empty()) 713 return matchFailure(); 714 715 auto loc = op.getLoc(); 716 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 717 assert(rankDiff >= 0); 718 if (rankDiff == 0) 719 return matchFailure(); 720 721 int64_t rankRest = dstType.getRank() - rankDiff; 722 // Extract / insert the subvector of matching rank and InsertStridedSlice 723 // on it. 724 Value extracted = 725 rewriter.create<ExtractOp>(loc, op.dest(), 726 getI64SubArray(op.offsets(), /*dropFront=*/0, 727 /*dropFront=*/rankRest)); 728 // A different pattern will kick in for InsertStridedSlice with matching 729 // ranks. 730 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 731 loc, op.source(), extracted, 732 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 733 getI64SubArray(op.strides(), /*dropFront=*/0)); 734 rewriter.replaceOpWithNewOp<InsertOp>( 735 op, stridedSliceInnerOp.getResult(), op.dest(), 736 getI64SubArray(op.offsets(), /*dropFront=*/0, 737 /*dropFront=*/rankRest)); 738 return matchSuccess(); 739 } 740 }; 741 742 // RewritePattern for InsertStridedSliceOp where source and destination vectors 743 // have the same rank. In this case, we reduce 744 // 1. the proper subvector is extracted from the destination vector 745 // 2. a new InsertStridedSlice op is created to insert the source in the 746 // destination subvector 747 // 3. the destination subvector is inserted back in the proper place 748 // 4. the op is replaced by the result of step 3. 749 // The new InsertStridedSlice from step 2. will be picked up by a 750 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 751 class VectorInsertStridedSliceOpSameRankRewritePattern 752 : public OpRewritePattern<InsertStridedSliceOp> { 753 public: 754 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 755 756 PatternMatchResult matchAndRewrite(InsertStridedSliceOp op, 757 PatternRewriter &rewriter) const override { 758 auto srcType = op.getSourceVectorType(); 759 auto dstType = op.getDestVectorType(); 760 761 if (op.offsets().getValue().empty()) 762 return matchFailure(); 763 764 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 765 assert(rankDiff >= 0); 766 if (rankDiff != 0) 767 return matchFailure(); 768 769 if (srcType == dstType) { 770 rewriter.replaceOp(op, op.source()); 771 return matchSuccess(); 772 } 773 774 int64_t offset = 775 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 776 int64_t size = srcType.getShape().front(); 777 int64_t stride = 778 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 779 780 auto loc = op.getLoc(); 781 Value res = op.dest(); 782 // For each slice of the source vector along the most major dimension. 783 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 784 off += stride, ++idx) { 785 // 1. extract the proper subvector (or element) from source 786 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 787 if (extractedSource.getType().isa<VectorType>()) { 788 // 2. If we have a vector, extract the proper subvector from destination 789 // Otherwise we are at the element level and no need to recurse. 790 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 791 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 792 // smaller rank. 793 InsertStridedSliceOp insertStridedSliceOp = 794 rewriter.create<InsertStridedSliceOp>( 795 loc, extractedSource, extractedDest, 796 getI64SubArray(op.offsets(), /* dropFront=*/1), 797 getI64SubArray(op.strides(), /* dropFront=*/1)); 798 // Call matchAndRewrite recursively from within the pattern. This 799 // circumvents the current limitation that a given pattern cannot 800 // be called multiple times by the PatternRewrite infrastructure (to 801 // avoid infinite recursion, but in this case, infinite recursion 802 // cannot happen because the rank is strictly decreasing). 803 // TODO(rriddle, nicolasvasilache) Implement something like a hook for 804 // a potential function that must decrease and allow the same pattern 805 // multiple times. 806 auto success = matchAndRewrite(insertStridedSliceOp, rewriter); 807 (void)success; 808 assert(succeeded(success) && "Unexpected failure"); 809 extractedSource = insertStridedSliceOp; 810 } 811 // 4. Insert the extractedSource into the res vector. 812 res = insertOne(rewriter, loc, extractedSource, res, off); 813 } 814 815 rewriter.replaceOp(op, res); 816 return matchSuccess(); 817 } 818 }; 819 820 class VectorTypeCastOpConversion : public ConvertToLLVMPattern { 821 public: 822 explicit VectorTypeCastOpConversion(MLIRContext *context, 823 LLVMTypeConverter &typeConverter) 824 : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, 825 typeConverter) {} 826 827 PatternMatchResult 828 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 829 ConversionPatternRewriter &rewriter) const override { 830 auto loc = op->getLoc(); 831 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 832 MemRefType sourceMemRefType = 833 castOp.getOperand().getType().cast<MemRefType>(); 834 MemRefType targetMemRefType = 835 castOp.getResult().getType().cast<MemRefType>(); 836 837 // Only static shape casts supported atm. 838 if (!sourceMemRefType.hasStaticShape() || 839 !targetMemRefType.hasStaticShape()) 840 return matchFailure(); 841 842 auto llvmSourceDescriptorTy = 843 operands[0].getType().dyn_cast<LLVM::LLVMType>(); 844 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 845 return matchFailure(); 846 MemRefDescriptor sourceMemRef(operands[0]); 847 848 auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) 849 .dyn_cast_or_null<LLVM::LLVMType>(); 850 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 851 return matchFailure(); 852 853 int64_t offset; 854 SmallVector<int64_t, 4> strides; 855 auto successStrides = 856 getStridesAndOffset(sourceMemRefType, strides, offset); 857 bool isContiguous = (strides.back() == 1); 858 if (isContiguous) { 859 auto sizes = sourceMemRefType.getShape(); 860 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 861 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 862 isContiguous = false; 863 break; 864 } 865 } 866 } 867 // Only contiguous source tensors supported atm. 868 if (failed(successStrides) || !isContiguous) 869 return matchFailure(); 870 871 auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 872 873 // Create descriptor. 874 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 875 Type llvmTargetElementTy = desc.getElementType(); 876 // Set allocated ptr. 877 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 878 allocated = 879 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 880 desc.setAllocatedPtr(rewriter, loc, allocated); 881 // Set aligned ptr. 882 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 883 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 884 desc.setAlignedPtr(rewriter, loc, ptr); 885 // Fill offset 0. 886 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 887 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 888 desc.setOffset(rewriter, loc, zero); 889 890 // Fill size and stride descriptors in memref. 891 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 892 int64_t index = indexedSize.index(); 893 auto sizeAttr = 894 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 895 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 896 desc.setSize(rewriter, loc, index, size); 897 auto strideAttr = 898 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 899 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 900 desc.setStride(rewriter, loc, index, stride); 901 } 902 903 rewriter.replaceOp(op, {desc}); 904 return matchSuccess(); 905 } 906 }; 907 908 class VectorPrintOpConversion : public ConvertToLLVMPattern { 909 public: 910 explicit VectorPrintOpConversion(MLIRContext *context, 911 LLVMTypeConverter &typeConverter) 912 : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, 913 typeConverter) {} 914 915 // Proof-of-concept lowering implementation that relies on a small 916 // runtime support library, which only needs to provide a few 917 // printing methods (single value for all data types, opening/closing 918 // bracket, comma, newline). The lowering fully unrolls a vector 919 // in terms of these elementary printing operations. The advantage 920 // of this approach is that the library can remain unaware of all 921 // low-level implementation details of vectors while still supporting 922 // output of any shaped and dimensioned vector. Due to full unrolling, 923 // this approach is less suited for very large vectors though. 924 // 925 // TODO(ajcbik): rely solely on libc in future? something else? 926 // 927 PatternMatchResult 928 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 929 ConversionPatternRewriter &rewriter) const override { 930 auto printOp = cast<vector::PrintOp>(op); 931 auto adaptor = vector::PrintOpOperandAdaptor(operands); 932 Type printType = printOp.getPrintType(); 933 934 if (typeConverter.convertType(printType) == nullptr) 935 return matchFailure(); 936 937 // Make sure element type has runtime support (currently just Float/Double). 938 VectorType vectorType = printType.dyn_cast<VectorType>(); 939 Type eltType = vectorType ? vectorType.getElementType() : printType; 940 int64_t rank = vectorType ? vectorType.getRank() : 0; 941 Operation *printer; 942 if (eltType.isSignlessInteger(32)) 943 printer = getPrintI32(op); 944 else if (eltType.isSignlessInteger(64)) 945 printer = getPrintI64(op); 946 else if (eltType.isF32()) 947 printer = getPrintFloat(op); 948 else if (eltType.isF64()) 949 printer = getPrintDouble(op); 950 else 951 return matchFailure(); 952 953 // Unroll vector into elementary print calls. 954 emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); 955 emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 956 rewriter.eraseOp(op); 957 return matchSuccess(); 958 } 959 960 private: 961 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 962 Value value, VectorType vectorType, Operation *printer, 963 int64_t rank) const { 964 Location loc = op->getLoc(); 965 if (rank == 0) { 966 emitCall(rewriter, loc, printer, value); 967 return; 968 } 969 970 emitCall(rewriter, loc, getPrintOpen(op)); 971 Operation *printComma = getPrintComma(op); 972 int64_t dim = vectorType.getDimSize(0); 973 for (int64_t d = 0; d < dim; ++d) { 974 auto reducedType = 975 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 976 auto llvmType = typeConverter.convertType( 977 rank > 1 ? reducedType : vectorType.getElementType()); 978 Value nestedVal = 979 extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); 980 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); 981 if (d != dim - 1) 982 emitCall(rewriter, loc, printComma); 983 } 984 emitCall(rewriter, loc, getPrintClose(op)); 985 } 986 987 // Helper to emit a call. 988 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 989 Operation *ref, ValueRange params = ValueRange()) { 990 rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, 991 rewriter.getSymbolRefAttr(ref), params); 992 } 993 994 // Helper for printer method declaration (first hit) and lookup. 995 static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, 996 StringRef name, ArrayRef<LLVM::LLVMType> params) { 997 auto module = op->getParentOfType<ModuleOp>(); 998 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 999 if (func) 1000 return func; 1001 OpBuilder moduleBuilder(module.getBodyRegion()); 1002 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1003 op->getLoc(), name, 1004 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), 1005 params, /*isVarArg=*/false)); 1006 } 1007 1008 // Helpers for method names. 1009 Operation *getPrintI32(Operation *op) const { 1010 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1011 return getPrint(op, dialect, "print_i32", 1012 LLVM::LLVMType::getInt32Ty(dialect)); 1013 } 1014 Operation *getPrintI64(Operation *op) const { 1015 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1016 return getPrint(op, dialect, "print_i64", 1017 LLVM::LLVMType::getInt64Ty(dialect)); 1018 } 1019 Operation *getPrintFloat(Operation *op) const { 1020 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1021 return getPrint(op, dialect, "print_f32", 1022 LLVM::LLVMType::getFloatTy(dialect)); 1023 } 1024 Operation *getPrintDouble(Operation *op) const { 1025 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1026 return getPrint(op, dialect, "print_f64", 1027 LLVM::LLVMType::getDoubleTy(dialect)); 1028 } 1029 Operation *getPrintOpen(Operation *op) const { 1030 return getPrint(op, typeConverter.getDialect(), "print_open", {}); 1031 } 1032 Operation *getPrintClose(Operation *op) const { 1033 return getPrint(op, typeConverter.getDialect(), "print_close", {}); 1034 } 1035 Operation *getPrintComma(Operation *op) const { 1036 return getPrint(op, typeConverter.getDialect(), "print_comma", {}); 1037 } 1038 Operation *getPrintNewline(Operation *op) const { 1039 return getPrint(op, typeConverter.getDialect(), "print_newline", {}); 1040 } 1041 }; 1042 1043 /// Progressive lowering of StridedSliceOp to either: 1044 /// 1. extractelement + insertelement for the 1-D case 1045 /// 2. extract + optional strided_slice + insert for the n-D case. 1046 class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> { 1047 public: 1048 using OpRewritePattern<StridedSliceOp>::OpRewritePattern; 1049 1050 PatternMatchResult matchAndRewrite(StridedSliceOp op, 1051 PatternRewriter &rewriter) const override { 1052 auto dstType = op.getResult().getType().cast<VectorType>(); 1053 1054 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1055 1056 int64_t offset = 1057 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1058 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1059 int64_t stride = 1060 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1061 1062 auto loc = op.getLoc(); 1063 auto elemType = dstType.getElementType(); 1064 assert(elemType.isSignlessIntOrIndexOrFloat()); 1065 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1066 rewriter.getZeroAttr(elemType)); 1067 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1068 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1069 off += stride, ++idx) { 1070 Value extracted = extractOne(rewriter, loc, op.vector(), off); 1071 if (op.offsets().getValue().size() > 1) { 1072 StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>( 1073 loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), 1074 getI64SubArray(op.sizes(), /* dropFront=*/1), 1075 getI64SubArray(op.strides(), /* dropFront=*/1)); 1076 // Call matchAndRewrite recursively from within the pattern. This 1077 // circumvents the current limitation that a given pattern cannot 1078 // be called multiple times by the PatternRewrite infrastructure (to 1079 // avoid infinite recursion, but in this case, infinite recursion 1080 // cannot happen because the rank is strictly decreasing). 1081 // TODO(rriddle, nicolasvasilache) Implement something like a hook for 1082 // a potential function that must decrease and allow the same pattern 1083 // multiple times. 1084 auto success = matchAndRewrite(stridedSliceOp, rewriter); 1085 (void)success; 1086 assert(succeeded(success) && "Unexpected failure"); 1087 extracted = stridedSliceOp; 1088 } 1089 res = insertOne(rewriter, loc, extracted, res, idx); 1090 } 1091 rewriter.replaceOp(op, {res}); 1092 return matchSuccess(); 1093 } 1094 }; 1095 1096 } // namespace 1097 1098 /// Populate the given list with patterns that convert from Vector to LLVM. 1099 void mlir::populateVectorToLLVMConversionPatterns( 1100 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1101 MLIRContext *ctx = converter.getDialect()->getContext(); 1102 patterns.insert<VectorFMAOpNDRewritePattern, 1103 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1104 VectorInsertStridedSliceOpSameRankRewritePattern, 1105 VectorStridedSliceOpConversion>(ctx); 1106 patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion, 1107 VectorShuffleOpConversion, VectorExtractElementOpConversion, 1108 VectorExtractOpConversion, VectorFMAOp1DConversion, 1109 VectorInsertElementOpConversion, VectorInsertOpConversion, 1110 VectorTypeCastOpConversion, VectorPrintOpConversion>( 1111 ctx, converter); 1112 } 1113 1114 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1115 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1116 MLIRContext *ctx = converter.getDialect()->getContext(); 1117 patterns.insert<VectorMatmulOpConversion>(ctx, converter); 1118 } 1119 1120 namespace { 1121 struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { 1122 void runOnModule() override; 1123 }; 1124 } // namespace 1125 1126 void LowerVectorToLLVMPass::runOnModule() { 1127 // Perform progressive lowering of operations on slices and 1128 // all contraction operations. Also applies folding and DCE. 1129 { 1130 OwningRewritePatternList patterns; 1131 populateVectorSlicesLoweringPatterns(patterns, &getContext()); 1132 populateVectorContractLoweringPatterns(patterns, &getContext()); 1133 applyPatternsGreedily(getModule(), patterns); 1134 } 1135 1136 // Convert to the LLVM IR dialect. 1137 LLVMTypeConverter converter(&getContext()); 1138 OwningRewritePatternList patterns; 1139 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1140 populateVectorToLLVMConversionPatterns(converter, patterns); 1141 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1142 populateStdToLLVMConversionPatterns(converter, patterns); 1143 1144 LLVMConversionTarget target(getContext()); 1145 target.addDynamicallyLegalOp<FuncOp>( 1146 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1147 if (failed( 1148 applyPartialConversion(getModule(), target, patterns, &converter))) { 1149 signalPassFailure(); 1150 } 1151 } 1152 1153 OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() { 1154 return new LowerVectorToLLVMPass(); 1155 } 1156 1157 static PassRegistration<LowerVectorToLLVMPass> 1158 pass("convert-vector-to-llvm", 1159 "Lower the operations from the vector dialect into the LLVM dialect"); 1160