1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/MLIRContext.h" 21 #include "mlir/IR/Module.h" 22 #include "mlir/IR/Operation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "mlir/IR/Types.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "mlir/Transforms/Passes.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Module.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/Support/Allocator.h" 32 #include "llvm/Support/ErrorHandling.h" 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 template <typename T> 38 static LLVM::LLVMType getPtrToElementType(T containerType, 39 LLVMTypeConverter &typeConverter) { 40 return typeConverter.convertType(containerType.getElementType()) 41 .template cast<LLVM::LLVMType>() 42 .getPointerTo(); 43 } 44 45 // Helper to reduce vector type by one rank at front. 46 static VectorType reducedVectorTypeFront(VectorType tp) { 47 assert((tp.getRank() > 1) && "unlowerable vector type"); 48 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 49 } 50 51 // Helper to reduce vector type by *all* but one rank at back. 52 static VectorType reducedVectorTypeBack(VectorType tp) { 53 assert((tp.getRank() > 1) && "unlowerable vector type"); 54 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 55 } 56 57 // Helper that picks the proper sequence for inserting. 58 static Value insertOne(ConversionPatternRewriter &rewriter, 59 LLVMTypeConverter &typeConverter, Location loc, 60 Value val1, Value val2, Type llvmType, int64_t rank, 61 int64_t pos) { 62 if (rank == 1) { 63 auto idxType = rewriter.getIndexType(); 64 auto constant = rewriter.create<LLVM::ConstantOp>( 65 loc, typeConverter.convertType(idxType), 66 rewriter.getIntegerAttr(idxType, pos)); 67 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 68 constant); 69 } 70 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 71 rewriter.getI64ArrayAttr(pos)); 72 } 73 74 // Helper that picks the proper sequence for inserting. 75 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 76 Value into, int64_t offset) { 77 auto vectorType = into.getType().cast<VectorType>(); 78 if (vectorType.getRank() > 1) 79 return rewriter.create<InsertOp>(loc, from, into, offset); 80 return rewriter.create<vector::InsertElementOp>( 81 loc, vectorType, from, into, 82 rewriter.create<ConstantIndexOp>(loc, offset)); 83 } 84 85 // Helper that picks the proper sequence for extracting. 86 static Value extractOne(ConversionPatternRewriter &rewriter, 87 LLVMTypeConverter &typeConverter, Location loc, 88 Value val, Type llvmType, int64_t rank, int64_t pos) { 89 if (rank == 1) { 90 auto idxType = rewriter.getIndexType(); 91 auto constant = rewriter.create<LLVM::ConstantOp>( 92 loc, typeConverter.convertType(idxType), 93 rewriter.getIntegerAttr(idxType, pos)); 94 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 95 constant); 96 } 97 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 98 rewriter.getI64ArrayAttr(pos)); 99 } 100 101 // Helper that picks the proper sequence for extracting. 102 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 103 int64_t offset) { 104 auto vectorType = vector.getType().cast<VectorType>(); 105 if (vectorType.getRank() > 1) 106 return rewriter.create<ExtractOp>(loc, vector, offset); 107 return rewriter.create<vector::ExtractElementOp>( 108 loc, vectorType.getElementType(), vector, 109 rewriter.create<ConstantIndexOp>(loc, offset)); 110 } 111 112 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 113 // TODO(rriddle): Better support for attribute subtype forwarding + slicing. 114 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 115 unsigned dropFront = 0, 116 unsigned dropBack = 0) { 117 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 118 auto range = arrayAttr.getAsRange<IntegerAttr>(); 119 SmallVector<int64_t, 4> res; 120 res.reserve(arrayAttr.size() - dropFront - dropBack); 121 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 122 it != eit; ++it) 123 res.push_back((*it).getValue().getSExtValue()); 124 return res; 125 } 126 127 namespace { 128 129 /// Conversion pattern for a vector.matrix_multiply. 130 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 131 class VectorMatmulOpConversion : public ConvertToLLVMPattern { 132 public: 133 explicit VectorMatmulOpConversion(MLIRContext *context, 134 LLVMTypeConverter &typeConverter) 135 : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, 136 typeConverter) {} 137 138 LogicalResult 139 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 140 ConversionPatternRewriter &rewriter) const override { 141 auto matmulOp = cast<vector::MatmulOp>(op); 142 auto adaptor = vector::MatmulOpOperandAdaptor(operands); 143 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 144 op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), 145 adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), 146 matmulOp.rhs_columns()); 147 return success(); 148 } 149 }; 150 151 class VectorReductionOpConversion : public ConvertToLLVMPattern { 152 public: 153 explicit VectorReductionOpConversion(MLIRContext *context, 154 LLVMTypeConverter &typeConverter) 155 : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, 156 typeConverter) {} 157 158 LogicalResult 159 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 160 ConversionPatternRewriter &rewriter) const override { 161 auto reductionOp = cast<vector::ReductionOp>(op); 162 auto kind = reductionOp.kind(); 163 Type eltType = reductionOp.dest().getType(); 164 Type llvmType = typeConverter.convertType(eltType); 165 if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) { 166 // Integer reductions: add/mul/min/max/and/or/xor. 167 if (kind == "add") 168 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>( 169 op, llvmType, operands[0]); 170 else if (kind == "mul") 171 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>( 172 op, llvmType, operands[0]); 173 else if (kind == "min") 174 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>( 175 op, llvmType, operands[0]); 176 else if (kind == "max") 177 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>( 178 op, llvmType, operands[0]); 179 else if (kind == "and") 180 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>( 181 op, llvmType, operands[0]); 182 else if (kind == "or") 183 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>( 184 op, llvmType, operands[0]); 185 else if (kind == "xor") 186 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>( 187 op, llvmType, operands[0]); 188 else 189 return failure(); 190 return success(); 191 192 } else if (eltType.isF32() || eltType.isF64()) { 193 // Floating-point reductions: add/mul/min/max 194 if (kind == "add") { 195 // Optional accumulator (or zero). 196 Value acc = operands.size() > 1 ? operands[1] 197 : rewriter.create<LLVM::ConstantOp>( 198 op->getLoc(), llvmType, 199 rewriter.getZeroAttr(eltType)); 200 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>( 201 op, llvmType, acc, operands[0]); 202 } else if (kind == "mul") { 203 // Optional accumulator (or one). 204 Value acc = operands.size() > 1 205 ? operands[1] 206 : rewriter.create<LLVM::ConstantOp>( 207 op->getLoc(), llvmType, 208 rewriter.getFloatAttr(eltType, 1.0)); 209 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>( 210 op, llvmType, acc, operands[0]); 211 } else if (kind == "min") 212 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>( 213 op, llvmType, operands[0]); 214 else if (kind == "max") 215 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>( 216 op, llvmType, operands[0]); 217 else 218 return failure(); 219 return success(); 220 } 221 return failure(); 222 } 223 }; 224 225 class VectorShuffleOpConversion : public ConvertToLLVMPattern { 226 public: 227 explicit VectorShuffleOpConversion(MLIRContext *context, 228 LLVMTypeConverter &typeConverter) 229 : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, 230 typeConverter) {} 231 232 LogicalResult 233 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 234 ConversionPatternRewriter &rewriter) const override { 235 auto loc = op->getLoc(); 236 auto adaptor = vector::ShuffleOpOperandAdaptor(operands); 237 auto shuffleOp = cast<vector::ShuffleOp>(op); 238 auto v1Type = shuffleOp.getV1VectorType(); 239 auto v2Type = shuffleOp.getV2VectorType(); 240 auto vectorType = shuffleOp.getVectorType(); 241 Type llvmType = typeConverter.convertType(vectorType); 242 auto maskArrayAttr = shuffleOp.mask(); 243 244 // Bail if result type cannot be lowered. 245 if (!llvmType) 246 return failure(); 247 248 // Get rank and dimension sizes. 249 int64_t rank = vectorType.getRank(); 250 assert(v1Type.getRank() == rank); 251 assert(v2Type.getRank() == rank); 252 int64_t v1Dim = v1Type.getDimSize(0); 253 254 // For rank 1, where both operands have *exactly* the same vector type, 255 // there is direct shuffle support in LLVM. Use it! 256 if (rank == 1 && v1Type == v2Type) { 257 Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 258 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 259 rewriter.replaceOp(op, shuffle); 260 return success(); 261 } 262 263 // For all other cases, insert the individual values individually. 264 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 265 int64_t insPos = 0; 266 for (auto en : llvm::enumerate(maskArrayAttr)) { 267 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 268 Value value = adaptor.v1(); 269 if (extPos >= v1Dim) { 270 extPos -= v1Dim; 271 value = adaptor.v2(); 272 } 273 Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, 274 rank, extPos); 275 insert = insertOne(rewriter, typeConverter, loc, insert, extract, 276 llvmType, rank, insPos++); 277 } 278 rewriter.replaceOp(op, insert); 279 return success(); 280 } 281 }; 282 283 class VectorExtractElementOpConversion : public ConvertToLLVMPattern { 284 public: 285 explicit VectorExtractElementOpConversion(MLIRContext *context, 286 LLVMTypeConverter &typeConverter) 287 : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), 288 context, typeConverter) {} 289 290 LogicalResult 291 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 292 ConversionPatternRewriter &rewriter) const override { 293 auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); 294 auto extractEltOp = cast<vector::ExtractElementOp>(op); 295 auto vectorType = extractEltOp.getVectorType(); 296 auto llvmType = typeConverter.convertType(vectorType.getElementType()); 297 298 // Bail if result type cannot be lowered. 299 if (!llvmType) 300 return failure(); 301 302 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 303 op, llvmType, adaptor.vector(), adaptor.position()); 304 return success(); 305 } 306 }; 307 308 class VectorExtractOpConversion : public ConvertToLLVMPattern { 309 public: 310 explicit VectorExtractOpConversion(MLIRContext *context, 311 LLVMTypeConverter &typeConverter) 312 : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, 313 typeConverter) {} 314 315 LogicalResult 316 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 317 ConversionPatternRewriter &rewriter) const override { 318 auto loc = op->getLoc(); 319 auto adaptor = vector::ExtractOpOperandAdaptor(operands); 320 auto extractOp = cast<vector::ExtractOp>(op); 321 auto vectorType = extractOp.getVectorType(); 322 auto resultType = extractOp.getResult().getType(); 323 auto llvmResultType = typeConverter.convertType(resultType); 324 auto positionArrayAttr = extractOp.position(); 325 326 // Bail if result type cannot be lowered. 327 if (!llvmResultType) 328 return failure(); 329 330 // One-shot extraction of vector from array (only requires extractvalue). 331 if (resultType.isa<VectorType>()) { 332 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 333 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 334 rewriter.replaceOp(op, extracted); 335 return success(); 336 } 337 338 // Potential extraction of 1-D vector from array. 339 auto *context = op->getContext(); 340 Value extracted = adaptor.vector(); 341 auto positionAttrs = positionArrayAttr.getValue(); 342 if (positionAttrs.size() > 1) { 343 auto oneDVectorType = reducedVectorTypeBack(vectorType); 344 auto nMinusOnePositionAttrs = 345 ArrayAttr::get(positionAttrs.drop_back(), context); 346 extracted = rewriter.create<LLVM::ExtractValueOp>( 347 loc, typeConverter.convertType(oneDVectorType), extracted, 348 nMinusOnePositionAttrs); 349 } 350 351 // Remaining extraction of element from 1-D LLVM vector 352 auto position = positionAttrs.back().cast<IntegerAttr>(); 353 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 354 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 355 extracted = 356 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 357 rewriter.replaceOp(op, extracted); 358 359 return success(); 360 } 361 }; 362 363 /// Conversion pattern that turns a vector.fma on a 1-D vector 364 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 365 /// This does not match vectors of n >= 2 rank. 366 /// 367 /// Example: 368 /// ``` 369 /// vector.fma %a, %a, %a : vector<8xf32> 370 /// ``` 371 /// is converted to: 372 /// ``` 373 /// llvm.intr.fma %va, %va, %va: 374 /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 375 /// -> !llvm<"<8 x float>"> 376 /// ``` 377 class VectorFMAOp1DConversion : public ConvertToLLVMPattern { 378 public: 379 explicit VectorFMAOp1DConversion(MLIRContext *context, 380 LLVMTypeConverter &typeConverter) 381 : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, 382 typeConverter) {} 383 384 LogicalResult 385 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 386 ConversionPatternRewriter &rewriter) const override { 387 auto adaptor = vector::FMAOpOperandAdaptor(operands); 388 vector::FMAOp fmaOp = cast<vector::FMAOp>(op); 389 VectorType vType = fmaOp.getVectorType(); 390 if (vType.getRank() != 1) 391 return failure(); 392 rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(), 393 adaptor.acc()); 394 return success(); 395 } 396 }; 397 398 class VectorInsertElementOpConversion : public ConvertToLLVMPattern { 399 public: 400 explicit VectorInsertElementOpConversion(MLIRContext *context, 401 LLVMTypeConverter &typeConverter) 402 : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), 403 context, typeConverter) {} 404 405 LogicalResult 406 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 407 ConversionPatternRewriter &rewriter) const override { 408 auto adaptor = vector::InsertElementOpOperandAdaptor(operands); 409 auto insertEltOp = cast<vector::InsertElementOp>(op); 410 auto vectorType = insertEltOp.getDestVectorType(); 411 auto llvmType = typeConverter.convertType(vectorType); 412 413 // Bail if result type cannot be lowered. 414 if (!llvmType) 415 return failure(); 416 417 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 418 op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 419 return success(); 420 } 421 }; 422 423 class VectorInsertOpConversion : public ConvertToLLVMPattern { 424 public: 425 explicit VectorInsertOpConversion(MLIRContext *context, 426 LLVMTypeConverter &typeConverter) 427 : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, 428 typeConverter) {} 429 430 LogicalResult 431 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 432 ConversionPatternRewriter &rewriter) const override { 433 auto loc = op->getLoc(); 434 auto adaptor = vector::InsertOpOperandAdaptor(operands); 435 auto insertOp = cast<vector::InsertOp>(op); 436 auto sourceType = insertOp.getSourceType(); 437 auto destVectorType = insertOp.getDestVectorType(); 438 auto llvmResultType = typeConverter.convertType(destVectorType); 439 auto positionArrayAttr = insertOp.position(); 440 441 // Bail if result type cannot be lowered. 442 if (!llvmResultType) 443 return failure(); 444 445 // One-shot insertion of a vector into an array (only requires insertvalue). 446 if (sourceType.isa<VectorType>()) { 447 Value inserted = rewriter.create<LLVM::InsertValueOp>( 448 loc, llvmResultType, adaptor.dest(), adaptor.source(), 449 positionArrayAttr); 450 rewriter.replaceOp(op, inserted); 451 return success(); 452 } 453 454 // Potential extraction of 1-D vector from array. 455 auto *context = op->getContext(); 456 Value extracted = adaptor.dest(); 457 auto positionAttrs = positionArrayAttr.getValue(); 458 auto position = positionAttrs.back().cast<IntegerAttr>(); 459 auto oneDVectorType = destVectorType; 460 if (positionAttrs.size() > 1) { 461 oneDVectorType = reducedVectorTypeBack(destVectorType); 462 auto nMinusOnePositionAttrs = 463 ArrayAttr::get(positionAttrs.drop_back(), context); 464 extracted = rewriter.create<LLVM::ExtractValueOp>( 465 loc, typeConverter.convertType(oneDVectorType), extracted, 466 nMinusOnePositionAttrs); 467 } 468 469 // Insertion of an element into a 1-D LLVM vector. 470 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 471 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 472 Value inserted = rewriter.create<LLVM::InsertElementOp>( 473 loc, typeConverter.convertType(oneDVectorType), extracted, 474 adaptor.source(), constant); 475 476 // Potential insertion of resulting 1-D vector into array. 477 if (positionAttrs.size() > 1) { 478 auto nMinusOnePositionAttrs = 479 ArrayAttr::get(positionAttrs.drop_back(), context); 480 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 481 adaptor.dest(), inserted, 482 nMinusOnePositionAttrs); 483 } 484 485 rewriter.replaceOp(op, inserted); 486 return success(); 487 } 488 }; 489 490 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 491 /// 492 /// Example: 493 /// ``` 494 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 495 /// ``` 496 /// is rewritten into: 497 /// ``` 498 /// %r = splat %f0: vector<2x4xf32> 499 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 500 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 501 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 502 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 503 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 504 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 505 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 506 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 507 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 508 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 509 /// // %r3 holds the final value. 510 /// ``` 511 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 512 public: 513 using OpRewritePattern<FMAOp>::OpRewritePattern; 514 515 LogicalResult matchAndRewrite(FMAOp op, 516 PatternRewriter &rewriter) const override { 517 auto vType = op.getVectorType(); 518 if (vType.getRank() < 2) 519 return failure(); 520 521 auto loc = op.getLoc(); 522 auto elemType = vType.getElementType(); 523 Value zero = rewriter.create<ConstantOp>(loc, elemType, 524 rewriter.getZeroAttr(elemType)); 525 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 526 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 527 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 528 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 529 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 530 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 531 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 532 } 533 rewriter.replaceOp(op, desc); 534 return success(); 535 } 536 }; 537 538 // When ranks are different, InsertStridedSlice needs to extract a properly 539 // ranked vector from the destination vector into which to insert. This pattern 540 // only takes care of this part and forwards the rest of the conversion to 541 // another pattern that converts InsertStridedSlice for operands of the same 542 // rank. 543 // 544 // RewritePattern for InsertStridedSliceOp where source and destination vectors 545 // have different ranks. In this case: 546 // 1. the proper subvector is extracted from the destination vector 547 // 2. a new InsertStridedSlice op is created to insert the source in the 548 // destination subvector 549 // 3. the destination subvector is inserted back in the proper place 550 // 4. the op is replaced by the result of step 3. 551 // The new InsertStridedSlice from step 2. will be picked up by a 552 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 553 class VectorInsertStridedSliceOpDifferentRankRewritePattern 554 : public OpRewritePattern<InsertStridedSliceOp> { 555 public: 556 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 557 558 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 559 PatternRewriter &rewriter) const override { 560 auto srcType = op.getSourceVectorType(); 561 auto dstType = op.getDestVectorType(); 562 563 if (op.offsets().getValue().empty()) 564 return failure(); 565 566 auto loc = op.getLoc(); 567 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 568 assert(rankDiff >= 0); 569 if (rankDiff == 0) 570 return failure(); 571 572 int64_t rankRest = dstType.getRank() - rankDiff; 573 // Extract / insert the subvector of matching rank and InsertStridedSlice 574 // on it. 575 Value extracted = 576 rewriter.create<ExtractOp>(loc, op.dest(), 577 getI64SubArray(op.offsets(), /*dropFront=*/0, 578 /*dropFront=*/rankRest)); 579 // A different pattern will kick in for InsertStridedSlice with matching 580 // ranks. 581 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 582 loc, op.source(), extracted, 583 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 584 getI64SubArray(op.strides(), /*dropFront=*/0)); 585 rewriter.replaceOpWithNewOp<InsertOp>( 586 op, stridedSliceInnerOp.getResult(), op.dest(), 587 getI64SubArray(op.offsets(), /*dropFront=*/0, 588 /*dropFront=*/rankRest)); 589 return success(); 590 } 591 }; 592 593 // RewritePattern for InsertStridedSliceOp where source and destination vectors 594 // have the same rank. In this case, we reduce 595 // 1. the proper subvector is extracted from the destination vector 596 // 2. a new InsertStridedSlice op is created to insert the source in the 597 // destination subvector 598 // 3. the destination subvector is inserted back in the proper place 599 // 4. the op is replaced by the result of step 3. 600 // The new InsertStridedSlice from step 2. will be picked up by a 601 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 602 class VectorInsertStridedSliceOpSameRankRewritePattern 603 : public OpRewritePattern<InsertStridedSliceOp> { 604 public: 605 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 606 607 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 608 PatternRewriter &rewriter) const override { 609 auto srcType = op.getSourceVectorType(); 610 auto dstType = op.getDestVectorType(); 611 612 if (op.offsets().getValue().empty()) 613 return failure(); 614 615 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 616 assert(rankDiff >= 0); 617 if (rankDiff != 0) 618 return failure(); 619 620 if (srcType == dstType) { 621 rewriter.replaceOp(op, op.source()); 622 return success(); 623 } 624 625 int64_t offset = 626 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 627 int64_t size = srcType.getShape().front(); 628 int64_t stride = 629 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 630 631 auto loc = op.getLoc(); 632 Value res = op.dest(); 633 // For each slice of the source vector along the most major dimension. 634 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 635 off += stride, ++idx) { 636 // 1. extract the proper subvector (or element) from source 637 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 638 if (extractedSource.getType().isa<VectorType>()) { 639 // 2. If we have a vector, extract the proper subvector from destination 640 // Otherwise we are at the element level and no need to recurse. 641 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 642 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 643 // smaller rank. 644 extractedSource = rewriter.create<InsertStridedSliceOp>( 645 loc, extractedSource, extractedDest, 646 getI64SubArray(op.offsets(), /* dropFront=*/1), 647 getI64SubArray(op.strides(), /* dropFront=*/1)); 648 } 649 // 4. Insert the extractedSource into the res vector. 650 res = insertOne(rewriter, loc, extractedSource, res, off); 651 } 652 653 rewriter.replaceOp(op, res); 654 return success(); 655 } 656 /// This pattern creates recursive InsertStridedSliceOp, but the recursion is 657 /// bounded as the rank is strictly decreasing. 658 bool hasBoundedRewriteRecursion() const final { return true; } 659 }; 660 661 class VectorTypeCastOpConversion : public ConvertToLLVMPattern { 662 public: 663 explicit VectorTypeCastOpConversion(MLIRContext *context, 664 LLVMTypeConverter &typeConverter) 665 : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, 666 typeConverter) {} 667 668 LogicalResult 669 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 670 ConversionPatternRewriter &rewriter) const override { 671 auto loc = op->getLoc(); 672 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 673 MemRefType sourceMemRefType = 674 castOp.getOperand().getType().cast<MemRefType>(); 675 MemRefType targetMemRefType = 676 castOp.getResult().getType().cast<MemRefType>(); 677 678 // Only static shape casts supported atm. 679 if (!sourceMemRefType.hasStaticShape() || 680 !targetMemRefType.hasStaticShape()) 681 return failure(); 682 683 auto llvmSourceDescriptorTy = 684 operands[0].getType().dyn_cast<LLVM::LLVMType>(); 685 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 686 return failure(); 687 MemRefDescriptor sourceMemRef(operands[0]); 688 689 auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) 690 .dyn_cast_or_null<LLVM::LLVMType>(); 691 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 692 return failure(); 693 694 int64_t offset; 695 SmallVector<int64_t, 4> strides; 696 auto successStrides = 697 getStridesAndOffset(sourceMemRefType, strides, offset); 698 bool isContiguous = (strides.back() == 1); 699 if (isContiguous) { 700 auto sizes = sourceMemRefType.getShape(); 701 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 702 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 703 isContiguous = false; 704 break; 705 } 706 } 707 } 708 // Only contiguous source tensors supported atm. 709 if (failed(successStrides) || !isContiguous) 710 return failure(); 711 712 auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 713 714 // Create descriptor. 715 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 716 Type llvmTargetElementTy = desc.getElementType(); 717 // Set allocated ptr. 718 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 719 allocated = 720 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 721 desc.setAllocatedPtr(rewriter, loc, allocated); 722 // Set aligned ptr. 723 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 724 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 725 desc.setAlignedPtr(rewriter, loc, ptr); 726 // Fill offset 0. 727 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 728 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 729 desc.setOffset(rewriter, loc, zero); 730 731 // Fill size and stride descriptors in memref. 732 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 733 int64_t index = indexedSize.index(); 734 auto sizeAttr = 735 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 736 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 737 desc.setSize(rewriter, loc, index, size); 738 auto strideAttr = 739 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 740 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 741 desc.setStride(rewriter, loc, index, stride); 742 } 743 744 rewriter.replaceOp(op, {desc}); 745 return success(); 746 } 747 }; 748 749 template <typename ConcreteOp> 750 LogicalResult replaceTransferOp(ConversionPatternRewriter &rewriter, 751 LLVMTypeConverter &typeConverter, Location loc, 752 Operation *op, ArrayRef<Value> operands, 753 Value dataPtr, Value mask); 754 755 LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter, 756 Type type, LLVM::LLVMType &llvmType, 757 unsigned &align) { 758 auto convertedType = typeConverter.convertType(type); 759 if (!convertedType) 760 return failure(); 761 762 llvmType = convertedType.template cast<LLVM::LLVMType>(); 763 auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); 764 align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType()); 765 return success(); 766 } 767 768 template <> 769 LogicalResult replaceTransferOp<TransferReadOp>( 770 ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, 771 Location loc, Operation *op, ArrayRef<Value> operands, Value dataPtr, 772 Value mask) { 773 auto xferOp = cast<TransferReadOp>(op); 774 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 775 VectorType fillType = xferOp.getVectorType(); 776 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 777 fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 778 779 LLVM::LLVMType vecTy; 780 unsigned align; 781 if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), 782 vecTy, align))) 783 return failure(); 784 785 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 786 op, vecTy, dataPtr, mask, ValueRange{fill}, 787 rewriter.getI32IntegerAttr(align)); 788 return success(); 789 } 790 791 template <> 792 LogicalResult replaceTransferOp<TransferWriteOp>( 793 ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, 794 Location loc, Operation *op, ArrayRef<Value> operands, Value dataPtr, 795 Value mask) { 796 auto adaptor = TransferWriteOpOperandAdaptor(operands); 797 798 auto xferOp = cast<TransferWriteOp>(op); 799 LLVM::LLVMType vecTy; 800 unsigned align; 801 if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), 802 vecTy, align))) 803 return failure(); 804 805 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 806 op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align)); 807 return success(); 808 } 809 810 static TransferReadOpOperandAdaptor 811 getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) { 812 return TransferReadOpOperandAdaptor(operands); 813 } 814 815 static TransferWriteOpOperandAdaptor 816 getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) { 817 return TransferWriteOpOperandAdaptor(operands); 818 } 819 820 bool isMinorIdentity(AffineMap map, unsigned rank) { 821 if (map.getNumResults() < rank) 822 return false; 823 unsigned startDim = map.getNumDims() - rank; 824 for (unsigned i = 0; i < rank; ++i) 825 if (map.getResult(i) != getAffineDimExpr(startDim + i, map.getContext())) 826 return false; 827 return true; 828 } 829 830 /// Conversion pattern that converts a 1-D vector transfer read/write op in a 831 /// sequence of: 832 /// 1. Bitcast or addrspacecast to vector form. 833 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 834 /// 3. Create a mask where offsetVector is compared against memref upper bound. 835 /// 4. Rewrite op as a masked read or write. 836 template <typename ConcreteOp> 837 class VectorTransferConversion : public ConvertToLLVMPattern { 838 public: 839 explicit VectorTransferConversion(MLIRContext *context, 840 LLVMTypeConverter &typeConv) 841 : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, 842 typeConv) {} 843 844 LogicalResult 845 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 846 ConversionPatternRewriter &rewriter) const override { 847 auto xferOp = cast<ConcreteOp>(op); 848 auto adaptor = getTransferOpAdapter(xferOp, operands); 849 850 if (xferOp.getVectorType().getRank() > 1 || 851 llvm::size(xferOp.indices()) == 0) 852 return failure(); 853 if (!isMinorIdentity(xferOp.permutation_map(), 854 xferOp.getVectorType().getRank())) 855 return failure(); 856 857 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 858 859 Location loc = op->getLoc(); 860 Type i64Type = rewriter.getIntegerType(64); 861 MemRefType memRefType = xferOp.getMemRefType(); 862 863 // 1. Get the source/dst address as an LLVM vector pointer. 864 // The vector pointer would always be on address space 0, therefore 865 // addrspacecast shall be used when source/dst memrefs are not on 866 // address space 0. 867 // TODO: support alignment when possible. 868 Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), 869 adaptor.indices(), rewriter, getModule()); 870 auto vecTy = 871 toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); 872 Value vectorDataPtr; 873 if (memRefType.getMemorySpace() == 0) 874 vectorDataPtr = 875 rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr); 876 else 877 vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 878 loc, vecTy.getPointerTo(), dataPtr); 879 880 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 881 unsigned vecWidth = vecTy.getVectorNumElements(); 882 VectorType vectorCmpType = VectorType::get(vecWidth, i64Type); 883 SmallVector<int64_t, 8> indices; 884 indices.reserve(vecWidth); 885 for (unsigned i = 0; i < vecWidth; ++i) 886 indices.push_back(i); 887 Value linearIndices = rewriter.create<ConstantOp>( 888 loc, vectorCmpType, 889 DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices))); 890 linearIndices = rewriter.create<LLVM::DialectCastOp>( 891 loc, toLLVMTy(vectorCmpType), linearIndices); 892 893 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 894 // TODO(ntv, ajcbik): when the leaf transfer rank is k > 1 we need the last 895 // `k` dimensions here. 896 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 897 Value offsetIndex = *(xferOp.indices().begin() + lastIndex); 898 offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex); 899 Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex); 900 Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices); 901 902 // 4. Let dim the memref dimension, compute the vector comparison mask: 903 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 904 Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex); 905 dim = rewriter.create<IndexCastOp>(loc, i64Type, dim); 906 dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim); 907 Value mask = 908 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim); 909 mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()), 910 mask); 911 912 // 5. Rewrite as a masked read / write. 913 return replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op, 914 operands, vectorDataPtr, mask); 915 } 916 }; 917 918 class VectorPrintOpConversion : public ConvertToLLVMPattern { 919 public: 920 explicit VectorPrintOpConversion(MLIRContext *context, 921 LLVMTypeConverter &typeConverter) 922 : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, 923 typeConverter) {} 924 925 // Proof-of-concept lowering implementation that relies on a small 926 // runtime support library, which only needs to provide a few 927 // printing methods (single value for all data types, opening/closing 928 // bracket, comma, newline). The lowering fully unrolls a vector 929 // in terms of these elementary printing operations. The advantage 930 // of this approach is that the library can remain unaware of all 931 // low-level implementation details of vectors while still supporting 932 // output of any shaped and dimensioned vector. Due to full unrolling, 933 // this approach is less suited for very large vectors though. 934 // 935 // TODO(ajcbik): rely solely on libc in future? something else? 936 // 937 LogicalResult 938 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 939 ConversionPatternRewriter &rewriter) const override { 940 auto printOp = cast<vector::PrintOp>(op); 941 auto adaptor = vector::PrintOpOperandAdaptor(operands); 942 Type printType = printOp.getPrintType(); 943 944 if (typeConverter.convertType(printType) == nullptr) 945 return failure(); 946 947 // Make sure element type has runtime support (currently just Float/Double). 948 VectorType vectorType = printType.dyn_cast<VectorType>(); 949 Type eltType = vectorType ? vectorType.getElementType() : printType; 950 int64_t rank = vectorType ? vectorType.getRank() : 0; 951 Operation *printer; 952 if (eltType.isSignlessInteger(1)) 953 printer = getPrintI1(op); 954 else if (eltType.isSignlessInteger(32)) 955 printer = getPrintI32(op); 956 else if (eltType.isSignlessInteger(64)) 957 printer = getPrintI64(op); 958 else if (eltType.isF32()) 959 printer = getPrintFloat(op); 960 else if (eltType.isF64()) 961 printer = getPrintDouble(op); 962 else 963 return failure(); 964 965 // Unroll vector into elementary print calls. 966 emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); 967 emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 968 rewriter.eraseOp(op); 969 return success(); 970 } 971 972 private: 973 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 974 Value value, VectorType vectorType, Operation *printer, 975 int64_t rank) const { 976 Location loc = op->getLoc(); 977 if (rank == 0) { 978 emitCall(rewriter, loc, printer, value); 979 return; 980 } 981 982 emitCall(rewriter, loc, getPrintOpen(op)); 983 Operation *printComma = getPrintComma(op); 984 int64_t dim = vectorType.getDimSize(0); 985 for (int64_t d = 0; d < dim; ++d) { 986 auto reducedType = 987 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 988 auto llvmType = typeConverter.convertType( 989 rank > 1 ? reducedType : vectorType.getElementType()); 990 Value nestedVal = 991 extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); 992 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); 993 if (d != dim - 1) 994 emitCall(rewriter, loc, printComma); 995 } 996 emitCall(rewriter, loc, getPrintClose(op)); 997 } 998 999 // Helper to emit a call. 1000 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1001 Operation *ref, ValueRange params = ValueRange()) { 1002 rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, 1003 rewriter.getSymbolRefAttr(ref), params); 1004 } 1005 1006 // Helper for printer method declaration (first hit) and lookup. 1007 static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, 1008 StringRef name, ArrayRef<LLVM::LLVMType> params) { 1009 auto module = op->getParentOfType<ModuleOp>(); 1010 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1011 if (func) 1012 return func; 1013 OpBuilder moduleBuilder(module.getBodyRegion()); 1014 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1015 op->getLoc(), name, 1016 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), 1017 params, /*isVarArg=*/false)); 1018 } 1019 1020 // Helpers for method names. 1021 Operation *getPrintI1(Operation *op) const { 1022 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1023 return getPrint(op, dialect, "print_i1", 1024 LLVM::LLVMType::getInt1Ty(dialect)); 1025 } 1026 Operation *getPrintI32(Operation *op) const { 1027 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1028 return getPrint(op, dialect, "print_i32", 1029 LLVM::LLVMType::getInt32Ty(dialect)); 1030 } 1031 Operation *getPrintI64(Operation *op) const { 1032 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1033 return getPrint(op, dialect, "print_i64", 1034 LLVM::LLVMType::getInt64Ty(dialect)); 1035 } 1036 Operation *getPrintFloat(Operation *op) const { 1037 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1038 return getPrint(op, dialect, "print_f32", 1039 LLVM::LLVMType::getFloatTy(dialect)); 1040 } 1041 Operation *getPrintDouble(Operation *op) const { 1042 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1043 return getPrint(op, dialect, "print_f64", 1044 LLVM::LLVMType::getDoubleTy(dialect)); 1045 } 1046 Operation *getPrintOpen(Operation *op) const { 1047 return getPrint(op, typeConverter.getDialect(), "print_open", {}); 1048 } 1049 Operation *getPrintClose(Operation *op) const { 1050 return getPrint(op, typeConverter.getDialect(), "print_close", {}); 1051 } 1052 Operation *getPrintComma(Operation *op) const { 1053 return getPrint(op, typeConverter.getDialect(), "print_comma", {}); 1054 } 1055 Operation *getPrintNewline(Operation *op) const { 1056 return getPrint(op, typeConverter.getDialect(), "print_newline", {}); 1057 } 1058 }; 1059 1060 /// Progressive lowering of StridedSliceOp to either: 1061 /// 1. extractelement + insertelement for the 1-D case 1062 /// 2. extract + optional strided_slice + insert for the n-D case. 1063 class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> { 1064 public: 1065 using OpRewritePattern<StridedSliceOp>::OpRewritePattern; 1066 1067 LogicalResult matchAndRewrite(StridedSliceOp op, 1068 PatternRewriter &rewriter) const override { 1069 auto dstType = op.getResult().getType().cast<VectorType>(); 1070 1071 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1072 1073 int64_t offset = 1074 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1075 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1076 int64_t stride = 1077 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1078 1079 auto loc = op.getLoc(); 1080 auto elemType = dstType.getElementType(); 1081 assert(elemType.isSignlessIntOrIndexOrFloat()); 1082 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1083 rewriter.getZeroAttr(elemType)); 1084 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1085 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1086 off += stride, ++idx) { 1087 Value extracted = extractOne(rewriter, loc, op.vector(), off); 1088 if (op.offsets().getValue().size() > 1) { 1089 extracted = rewriter.create<StridedSliceOp>( 1090 loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), 1091 getI64SubArray(op.sizes(), /* dropFront=*/1), 1092 getI64SubArray(op.strides(), /* dropFront=*/1)); 1093 } 1094 res = insertOne(rewriter, loc, extracted, res, idx); 1095 } 1096 rewriter.replaceOp(op, {res}); 1097 return success(); 1098 } 1099 /// This pattern creates recursive StridedSliceOp, but the recursion is 1100 /// bounded as the rank is strictly decreasing. 1101 bool hasBoundedRewriteRecursion() const final { return true; } 1102 }; 1103 1104 } // namespace 1105 1106 /// Populate the given list with patterns that convert from Vector to LLVM. 1107 void mlir::populateVectorToLLVMConversionPatterns( 1108 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1109 MLIRContext *ctx = converter.getDialect()->getContext(); 1110 // clang-format off 1111 patterns.insert<VectorFMAOpNDRewritePattern, 1112 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1113 VectorInsertStridedSliceOpSameRankRewritePattern, 1114 VectorStridedSliceOpConversion>(ctx); 1115 patterns 1116 .insert<VectorReductionOpConversion, 1117 VectorShuffleOpConversion, 1118 VectorExtractElementOpConversion, 1119 VectorExtractOpConversion, 1120 VectorFMAOp1DConversion, 1121 VectorInsertElementOpConversion, 1122 VectorInsertOpConversion, 1123 VectorPrintOpConversion, 1124 VectorTransferConversion<TransferReadOp>, 1125 VectorTransferConversion<TransferWriteOp>, 1126 VectorTypeCastOpConversion>(ctx, converter); 1127 // clang-format on 1128 } 1129 1130 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1131 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1132 MLIRContext *ctx = converter.getDialect()->getContext(); 1133 patterns.insert<VectorMatmulOpConversion>(ctx, converter); 1134 } 1135 1136 namespace { 1137 struct LowerVectorToLLVMPass 1138 : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> { 1139 void runOnOperation() override; 1140 }; 1141 } // namespace 1142 1143 void LowerVectorToLLVMPass::runOnOperation() { 1144 // Perform progressive lowering of operations on slices and 1145 // all contraction operations. Also applies folding and DCE. 1146 { 1147 OwningRewritePatternList patterns; 1148 populateVectorSlicesLoweringPatterns(patterns, &getContext()); 1149 populateVectorContractLoweringPatterns(patterns, &getContext()); 1150 applyPatternsAndFoldGreedily(getOperation(), patterns); 1151 } 1152 1153 // Convert to the LLVM IR dialect. 1154 LLVMTypeConverter converter(&getContext()); 1155 OwningRewritePatternList patterns; 1156 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1157 populateVectorToLLVMConversionPatterns(converter, patterns); 1158 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1159 populateStdToLLVMConversionPatterns(converter, patterns); 1160 1161 LLVMConversionTarget target(getContext()); 1162 target.addDynamicallyLegalOp<FuncOp>( 1163 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1164 if (failed(applyPartialConversion(getOperation(), target, patterns, 1165 &converter))) { 1166 signalPassFailure(); 1167 } 1168 } 1169 1170 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertVectorToLLVMPass() { 1171 return std::make_unique<LowerVectorToLLVMPass>(); 1172 } 1173