1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements convenience types for working with super-vectorization 10 // operations, in particular super-vector loads and stores. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 16 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" 17 #include "mlir/Dialect/Arith/IR/Arith.h" 18 #include "mlir/Dialect/Arith/Utils/Utils.h" 19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/UB/IR/UBOps.h" 23 #include "mlir/Dialect/Utils/IndexingUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/IR/Builders.h" 28 #include "mlir/IR/BuiltinAttributes.h" 29 #include "mlir/IR/BuiltinOps.h" 30 #include "mlir/IR/BuiltinTypes.h" 31 #include "mlir/IR/DialectImplementation.h" 32 #include "mlir/IR/IRMapping.h" 33 #include "mlir/IR/OpImplementation.h" 34 #include "mlir/IR/PatternMatch.h" 35 #include "mlir/IR/TypeUtilities.h" 36 #include "mlir/Interfaces/SubsetOpInterface.h" 37 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 38 #include "mlir/Support/LLVM.h" 39 #include "mlir/Transforms/InliningUtils.h" 40 #include "llvm/ADT/ArrayRef.h" 41 #include "llvm/ADT/STLExtras.h" 42 #include "llvm/ADT/SmallVector.h" 43 #include "llvm/ADT/StringSet.h" 44 #include "llvm/ADT/TypeSwitch.h" 45 #include "llvm/ADT/bit.h" 46 47 #include <cassert> 48 #include <cstdint> 49 #include <numeric> 50 51 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc" 52 // Pull in all enum type and utility function definitions. 53 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc" 54 55 using namespace mlir; 56 using namespace mlir::vector; 57 58 /// Helper enum to classify mask value. 59 enum class MaskFormat { 60 AllTrue = 0, 61 AllFalse = 1, 62 Unknown = 2, 63 }; 64 65 /// Helper method to classify a mask value. Currently, the method 66 /// looks "under the hood" of a constant value with dense attributes 67 /// and a constant mask operation (since the client may be called at 68 /// various stages during progressive lowering). 69 static MaskFormat getMaskFormat(Value mask) { 70 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) { 71 // Inspect constant dense values. We count up for bits that 72 // are set, count down for bits that are cleared, and bail 73 // when a mix is detected. 74 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) { 75 int64_t val = 0; 76 for (bool b : denseElts.getValues<bool>()) 77 if (b && val >= 0) 78 val++; 79 else if (!b && val <= 0) 80 val--; 81 else 82 return MaskFormat::Unknown; 83 if (val > 0) 84 return MaskFormat::AllTrue; 85 if (val < 0) 86 return MaskFormat::AllFalse; 87 } 88 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) { 89 // Inspect constant mask index. If the index exceeds the 90 // dimension size, all bits are set. If the index is zero 91 // or less, no bits are set. 92 ArrayRef<int64_t> masks = m.getMaskDimSizes(); 93 auto shape = m.getType().getShape(); 94 bool allTrue = true; 95 bool allFalse = true; 96 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) { 97 if (maskIdx < dimSize) 98 allTrue = false; 99 if (maskIdx > 0) 100 allFalse = false; 101 } 102 if (allTrue) 103 return MaskFormat::AllTrue; 104 if (allFalse) 105 return MaskFormat::AllFalse; 106 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) { 107 // Finds all-false create_masks. An all-true create_mask requires all 108 // dims to be constants, so that'll be folded to a constant_mask, then 109 // detected in the constant_mask case. 110 auto maskOperands = m.getOperands(); 111 for (Value operand : maskOperands) { 112 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) { 113 int64_t dimSize = 114 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt(); 115 if (dimSize <= 0) 116 return MaskFormat::AllFalse; 117 } 118 } 119 return MaskFormat::Unknown; 120 } 121 return MaskFormat::Unknown; 122 } 123 124 /// Default callback to build a region with a 'vector.yield' terminator with no 125 /// arguments. 126 void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) { 127 builder.create<vector::YieldOp>(loc); 128 } 129 130 // Helper for verifying combining kinds in contractions and reductions. 131 static bool isSupportedCombiningKind(CombiningKind combiningKind, 132 Type elementType) { 133 switch (combiningKind) { 134 case CombiningKind::ADD: 135 case CombiningKind::MUL: 136 return elementType.isIntOrIndexOrFloat(); 137 case CombiningKind::MINUI: 138 case CombiningKind::MINSI: 139 case CombiningKind::MAXUI: 140 case CombiningKind::MAXSI: 141 case CombiningKind::AND: 142 case CombiningKind::OR: 143 case CombiningKind::XOR: 144 return elementType.isIntOrIndex(); 145 case CombiningKind::MINNUMF: 146 case CombiningKind::MAXNUMF: 147 case CombiningKind::MINIMUMF: 148 case CombiningKind::MAXIMUMF: 149 return llvm::isa<FloatType>(elementType); 150 } 151 return false; 152 } 153 154 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, 155 VectorType vectorType) { 156 int64_t elementVectorRank = 0; 157 VectorType elementVectorType = 158 llvm::dyn_cast<VectorType>(shapedType.getElementType()); 159 if (elementVectorType) 160 elementVectorRank += elementVectorType.getRank(); 161 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>. 162 // TODO: replace once we have 0-d vectors. 163 if (shapedType.getRank() == 0 && 164 vectorType.getShape() == ArrayRef<int64_t>{1}) 165 return AffineMap::get( 166 /*numDims=*/0, /*numSymbols=*/0, 167 getAffineConstantExpr(0, shapedType.getContext())); 168 return AffineMap::getMinorIdentityMap( 169 shapedType.getRank(), vectorType.getRank() - elementVectorRank, 170 shapedType.getContext()); 171 } 172 173 /// Check if `write` is of a constant splat and the masked `read` is padded with 174 /// the same splat value -- meaning it could be the same value as the initial 175 /// constant splat. 176 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, 177 vector::TransferReadOp read) { 178 auto readMask = read.getMask(); 179 auto writeMask = write.getMask(); 180 // Check if the masks are consistent. The splat value could be the same if the 181 // read is masked (and padded with the splat value), and the write is unmasked 182 // or has the same mask. Note this does not allow the case where the write is 183 // masked and the read is unmasked, as then the read could be of more elements 184 // than the write (which may not be the same value). 185 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask); 186 if (!couldBeSameSplat) 187 return false; 188 // Check for constant splat (as the source of the write). 189 DenseElementsAttr splatAttr; 190 if (!matchPattern(write.getVector(), 191 m_Constant<DenseElementsAttr>(&splatAttr)) || 192 !splatAttr.isSplat()) { 193 return false; 194 } 195 // The padding of the read and the constant splat value must be the same. 196 Attribute padAttr; 197 if (!matchPattern(read.getPadding(), m_Constant(&padAttr))) 198 return false; 199 return padAttr == splatAttr.getSplatValue<Attribute>(); 200 } 201 202 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite, 203 vector::TransferReadOp read) { 204 return !defWrite.hasOutOfBoundsDim() && 205 defWrite.getIndices() == read.getIndices() && 206 defWrite.getVectorType() == read.getVectorType() && 207 defWrite.getPermutationMap() == read.getPermutationMap() && 208 ((!defWrite.getMask() && !read.getMask()) || 209 isSplatWriteConsistentWithMaskedRead(defWrite, read)); 210 } 211 212 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write, 213 vector::TransferWriteOp priorWrite) { 214 return priorWrite.getIndices() == write.getIndices() && 215 priorWrite.getMask() == write.getMask() && 216 priorWrite.getVectorType() == write.getVectorType() && 217 priorWrite.getPermutationMap() == write.getPermutationMap(); 218 } 219 220 bool mlir::vector::isDisjointTransferIndices( 221 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, 222 bool testDynamicValueUsingBounds) { 223 // For simplicity only look at transfer of same type. 224 if (transferA.getVectorType() != transferB.getVectorType()) 225 return false; 226 unsigned rankOffset = transferA.getLeadingShapedRank(); 227 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) { 228 Value indexA = transferA.getIndices()[i]; 229 Value indexB = transferB.getIndices()[i]; 230 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA); 231 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB); 232 233 if (i < rankOffset) { 234 // For leading dimensions, if we can prove that index are different we 235 // know we are accessing disjoint slices. 236 if (cstIndexA.has_value() && cstIndexB.has_value()) { 237 if (*cstIndexA != *cstIndexB) 238 return true; 239 continue; 240 } 241 if (testDynamicValueUsingBounds) { 242 // First try to see if we can fully compose and simplify the affine 243 // expression as a fast track. 244 FailureOr<uint64_t> delta = 245 affine::fullyComposeAndComputeConstantDelta(indexA, indexB); 246 if (succeeded(delta) && *delta != 0) 247 return true; 248 249 FailureOr<bool> testEqual = 250 ValueBoundsConstraintSet::areEqual(indexA, indexB); 251 if (succeeded(testEqual) && !testEqual.value()) 252 return true; 253 } 254 } else { 255 // For this dimension, we slice a part of the memref we need to make sure 256 // the intervals accessed don't overlap. 257 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset); 258 if (cstIndexA.has_value() && cstIndexB.has_value()) { 259 int64_t distance = std::abs(*cstIndexA - *cstIndexB); 260 if (distance >= vectorDim) 261 return true; 262 continue; 263 } 264 if (testDynamicValueUsingBounds) { 265 // First try to see if we can fully compose and simplify the affine 266 // expression as a fast track. 267 FailureOr<int64_t> delta = 268 affine::fullyComposeAndComputeConstantDelta(indexA, indexB); 269 if (succeeded(delta) && std::abs(*delta) >= vectorDim) 270 return true; 271 272 FailureOr<int64_t> computeDelta = 273 ValueBoundsConstraintSet::computeConstantDelta(indexA, indexB); 274 if (succeeded(computeDelta)) { 275 if (std::abs(computeDelta.value()) >= vectorDim) 276 return true; 277 } 278 } 279 } 280 } 281 return false; 282 } 283 284 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA, 285 VectorTransferOpInterface transferB, 286 bool testDynamicValueUsingBounds) { 287 if (transferA.getSource() != transferB.getSource()) 288 return false; 289 return isDisjointTransferIndices(transferA, transferB, 290 testDynamicValueUsingBounds); 291 } 292 293 // Helper to iterate over n-D vector slice elements. Calculate the next 294 // `position` in the n-D vector of size `shape`, applying an offset `offsets`. 295 // Modifies the `position` in place. Returns a failure when `position` becomes 296 // the end position. 297 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position, 298 ArrayRef<int64_t> shape, 299 ArrayRef<int64_t> offsets) { 300 for (auto [posInDim, dimSize, offsetInDim] : 301 llvm::reverse(llvm::zip_equal(position, shape, offsets))) { 302 ++posInDim; 303 if (posInDim < dimSize + offsetInDim) 304 return success(); 305 306 // Carry the overflow to the next loop iteration. 307 posInDim = offsetInDim; 308 } 309 310 return failure(); 311 } 312 313 /// Returns the integer numbers in `values`. `values` are expected to be 314 /// constant operations. 315 SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) { 316 SmallVector<int64_t> ints; 317 llvm::transform(values, std::back_inserter(ints), [](Value value) { 318 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>(); 319 assert(constOp && "Unexpected non-constant index"); 320 return constOp.value(); 321 }); 322 return ints; 323 } 324 325 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to 326 /// be constant operations. 327 SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) { 328 SmallVector<int64_t> ints; 329 llvm::transform( 330 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) { 331 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index"); 332 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt(); 333 }); 334 return ints; 335 } 336 337 /// Convert `foldResults` into Values. Integer attributes are converted to 338 /// constant op. 339 SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc, 340 ArrayRef<OpFoldResult> foldResults) { 341 SmallVector<Value> values; 342 llvm::transform(foldResults, std::back_inserter(values), 343 [&](OpFoldResult foldResult) { 344 if (auto attr = foldResult.dyn_cast<Attribute>()) 345 return builder 346 .create<arith::ConstantIndexOp>( 347 loc, cast<IntegerAttr>(attr).getInt()) 348 .getResult(); 349 350 return cast<Value>(foldResult); 351 }); 352 return values; 353 } 354 355 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) { 356 if (value.getDefiningOp<vector::VectorScaleOp>()) 357 return 1; 358 auto mul = value.getDefiningOp<arith::MulIOp>(); 359 if (!mul) 360 return {}; 361 auto lhs = mul.getLhs(); 362 auto rhs = mul.getRhs(); 363 if (lhs.getDefiningOp<vector::VectorScaleOp>()) 364 return getConstantIntValue(rhs); 365 if (rhs.getDefiningOp<vector::VectorScaleOp>()) 366 return getConstantIntValue(lhs); 367 return {}; 368 } 369 370 //===----------------------------------------------------------------------===// 371 // CombiningKindAttr 372 //===----------------------------------------------------------------------===// 373 374 namespace mlir { 375 namespace vector { 376 namespace detail { 377 struct BitmaskEnumStorage : public AttributeStorage { 378 using KeyTy = uint64_t; 379 380 BitmaskEnumStorage(KeyTy val) : value(val) {} 381 382 bool operator==(const KeyTy &key) const { return value == key; } 383 384 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator, 385 const KeyTy &key) { 386 return new (allocator.allocate<BitmaskEnumStorage>()) 387 BitmaskEnumStorage(key); 388 } 389 390 KeyTy value = 0; 391 }; 392 } // namespace detail 393 } // namespace vector 394 } // namespace mlir 395 396 //===----------------------------------------------------------------------===// 397 // VectorDialect 398 //===----------------------------------------------------------------------===// 399 400 namespace { 401 /// This class defines the interface for handling inlining with vector dialect 402 /// operations. 403 struct VectorInlinerInterface : public DialectInlinerInterface { 404 using DialectInlinerInterface::DialectInlinerInterface; 405 406 /// All vector dialect ops can be inlined. 407 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 408 return true; 409 } 410 }; 411 } // namespace 412 413 void VectorDialect::initialize() { 414 addAttributes< 415 #define GET_ATTRDEF_LIST 416 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc" 417 >(); 418 419 addOperations< 420 #define GET_OP_LIST 421 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" 422 >(); 423 424 addInterfaces<VectorInlinerInterface>(); 425 426 declarePromisedInterfaces<bufferization::BufferizableOpInterface, 427 TransferReadOp, TransferWriteOp, GatherOp, MaskOp, 428 YieldOp>(); 429 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp, 430 TransferWriteOp>(); 431 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>(); 432 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>(); 433 } 434 435 /// Materialize a single constant operation from a given attribute value with 436 /// the desired resultant type. 437 Operation *VectorDialect::materializeConstant(OpBuilder &builder, 438 Attribute value, Type type, 439 Location loc) { 440 return arith::ConstantOp::materialize(builder, value, type, loc); 441 } 442 443 IntegerType vector::getVectorSubscriptType(Builder &builder) { 444 return builder.getIntegerType(64); 445 } 446 447 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, 448 ArrayRef<int64_t> values) { 449 return builder.getI64ArrayAttr(values); 450 } 451 452 //===----------------------------------------------------------------------===// 453 // MultiDimReductionOp 454 //===----------------------------------------------------------------------===// 455 456 void vector::MultiDimReductionOp::build(OpBuilder &builder, 457 OperationState &result, Value source, 458 Value acc, ArrayRef<bool> reductionMask, 459 CombiningKind kind) { 460 SmallVector<int64_t> reductionDims; 461 for (const auto &en : llvm::enumerate(reductionMask)) 462 if (en.value()) 463 reductionDims.push_back(en.index()); 464 build(builder, result, kind, source, acc, reductionDims); 465 } 466 467 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) { 468 // Single parallel dim, this is a noop. 469 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0)) 470 return getSource(); 471 return {}; 472 } 473 474 std::optional<SmallVector<int64_t, 4>> 475 MultiDimReductionOp::getShapeForUnroll() { 476 return llvm::to_vector<4>(getSourceVectorType().getShape()); 477 } 478 479 LogicalResult MultiDimReductionOp::verify() { 480 SmallVector<int64_t> targetShape; 481 SmallVector<bool> scalableDims; 482 Type inferredReturnType; 483 auto sourceScalableDims = getSourceVectorType().getScalableDims(); 484 for (auto [dimIdx, dimSize] : 485 llvm::enumerate(getSourceVectorType().getShape())) 486 if (!llvm::any_of(getReductionDims(), 487 [dimIdx = dimIdx](int64_t reductionDimIdx) { 488 return reductionDimIdx == static_cast<int64_t>(dimIdx); 489 })) { 490 targetShape.push_back(dimSize); 491 scalableDims.push_back(sourceScalableDims[dimIdx]); 492 } 493 // TODO: update to also allow 0-d vectors when available. 494 if (targetShape.empty()) 495 inferredReturnType = getSourceVectorType().getElementType(); 496 else 497 inferredReturnType = VectorType::get( 498 targetShape, getSourceVectorType().getElementType(), scalableDims); 499 if (getType() != inferredReturnType) 500 return emitOpError() << "destination type " << getType() 501 << " is incompatible with source type " 502 << getSourceVectorType(); 503 504 return success(); 505 } 506 507 /// Returns the mask type expected by this operation. 508 Type MultiDimReductionOp::getExpectedMaskType() { 509 auto vecType = getSourceVectorType(); 510 return VectorType::get(vecType.getShape(), 511 IntegerType::get(vecType.getContext(), /*width=*/1), 512 vecType.getScalableDims()); 513 } 514 515 namespace { 516 // Only unit dimensions that are being reduced are folded. If the dimension is 517 // unit, but not reduced, it is not folded, thereby keeping the output type the 518 // same. If not all dimensions which are reduced are of unit dimension, this 519 // transformation does nothing. This is just a generalization of 520 // ElideSingleElementReduction for ReduceOp. 521 struct ElideUnitDimsInMultiDimReduction 522 : public OpRewritePattern<MultiDimReductionOp> { 523 using OpRewritePattern::OpRewritePattern; 524 525 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, 526 PatternRewriter &rewriter) const override { 527 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape(); 528 for (const auto &dim : enumerate(shape)) { 529 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1) 530 return failure(); 531 } 532 533 // Vector mask setup. 534 OpBuilder::InsertionGuard guard(rewriter); 535 Operation *rootOp; 536 Value mask; 537 if (reductionOp.isMasked()) { 538 rewriter.setInsertionPoint(reductionOp.getMaskingOp()); 539 rootOp = reductionOp.getMaskingOp(); 540 mask = reductionOp.getMaskingOp().getMask(); 541 } else { 542 rootOp = reductionOp; 543 } 544 545 Location loc = reductionOp.getLoc(); 546 Value acc = reductionOp.getAcc(); 547 Value cast; 548 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) { 549 if (mask) { 550 VectorType newMaskType = 551 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(), 552 dstVecType.getScalableDims()); 553 mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask); 554 } 555 cast = rewriter.create<vector::ShapeCastOp>( 556 loc, reductionOp.getDestType(), reductionOp.getSource()); 557 } else { 558 // This means we are reducing all the dimensions, and all reduction 559 // dimensions are of size 1. So a simple extraction would do. 560 SmallVector<int64_t> zeroIdx(shape.size(), 0); 561 if (mask) 562 mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx); 563 cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(), 564 zeroIdx); 565 } 566 567 Value result = 568 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc, 569 cast, /*fastmath=*/nullptr, mask); 570 rewriter.replaceOp(rootOp, result); 571 return success(); 572 } 573 }; 574 } // namespace 575 576 void MultiDimReductionOp::getCanonicalizationPatterns( 577 RewritePatternSet &results, MLIRContext *context) { 578 results.add<ElideUnitDimsInMultiDimReduction>(context); 579 } 580 581 //===----------------------------------------------------------------------===// 582 // ReductionOp 583 //===----------------------------------------------------------------------===// 584 585 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, 586 CombiningKind kind, Value vector, 587 arith::FastMathFlags fastMathFlags) { 588 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags); 589 } 590 591 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, 592 CombiningKind kind, Value vector, Value acc, 593 arith::FastMathFlags fastMathFlags) { 594 build(builder, result, 595 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector, 596 acc, fastMathFlags); 597 } 598 599 LogicalResult ReductionOp::verify() { 600 // Verify for 0-D and 1-D vector. 601 int64_t rank = getSourceVectorType().getRank(); 602 if (rank > 1) 603 return emitOpError("unsupported reduction rank: ") << rank; 604 605 // Verify supported reduction kind. 606 Type eltType = getDest().getType(); 607 if (!isSupportedCombiningKind(getKind(), eltType)) 608 return emitOpError("unsupported reduction type '") 609 << eltType << "' for kind '" << stringifyCombiningKind(getKind()) 610 << "'"; 611 612 return success(); 613 } 614 615 // MaskableOpInterface methods. 616 617 /// Returns the mask type expected by this operation. 618 Type ReductionOp::getExpectedMaskType() { 619 auto vecType = getSourceVectorType(); 620 return VectorType::get(vecType.getShape(), 621 IntegerType::get(vecType.getContext(), /*width=*/1), 622 vecType.getScalableDims()); 623 } 624 625 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, 626 OpBuilder &builder, Location loc, 627 Value vector) { 628 switch (op) { 629 case arith::AtomicRMWKind::addf: 630 case arith::AtomicRMWKind::addi: 631 return builder.create<vector::ReductionOp>(vector.getLoc(), 632 CombiningKind::ADD, vector); 633 case arith::AtomicRMWKind::mulf: 634 case arith::AtomicRMWKind::muli: 635 return builder.create<vector::ReductionOp>(vector.getLoc(), 636 CombiningKind::MUL, vector); 637 case arith::AtomicRMWKind::minimumf: 638 return builder.create<vector::ReductionOp>(vector.getLoc(), 639 CombiningKind::MINIMUMF, vector); 640 case arith::AtomicRMWKind::mins: 641 return builder.create<vector::ReductionOp>(vector.getLoc(), 642 CombiningKind::MINSI, vector); 643 case arith::AtomicRMWKind::minu: 644 return builder.create<vector::ReductionOp>(vector.getLoc(), 645 CombiningKind::MINUI, vector); 646 case arith::AtomicRMWKind::maximumf: 647 return builder.create<vector::ReductionOp>(vector.getLoc(), 648 CombiningKind::MAXIMUMF, vector); 649 case arith::AtomicRMWKind::maxs: 650 return builder.create<vector::ReductionOp>(vector.getLoc(), 651 CombiningKind::MAXSI, vector); 652 case arith::AtomicRMWKind::maxu: 653 return builder.create<vector::ReductionOp>(vector.getLoc(), 654 CombiningKind::MAXUI, vector); 655 case arith::AtomicRMWKind::andi: 656 return builder.create<vector::ReductionOp>(vector.getLoc(), 657 CombiningKind::AND, vector); 658 case arith::AtomicRMWKind::ori: 659 return builder.create<vector::ReductionOp>(vector.getLoc(), 660 CombiningKind::OR, vector); 661 // TODO: Add remaining reduction operations. 662 default: 663 (void)emitOptionalError(loc, "Reduction operation type not supported"); 664 break; 665 } 666 return nullptr; 667 } 668 669 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() { 670 return llvm::to_vector<4>(getSourceVectorType().getShape()); 671 } 672 673 namespace { 674 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> { 675 using OpRewritePattern::OpRewritePattern; 676 677 LogicalResult matchAndRewrite(ReductionOp reductionOp, 678 PatternRewriter &rewriter) const override { 679 // Vector mask setup. 680 OpBuilder::InsertionGuard guard(rewriter); 681 auto maskableOp = 682 cast<vector::MaskableOpInterface>(reductionOp.getOperation()); 683 Operation *rootOp; 684 Value mask; 685 if (maskableOp.isMasked()) { 686 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 687 rootOp = maskableOp.getMaskingOp(); 688 mask = maskableOp.getMaskingOp().getMask(); 689 } else { 690 rootOp = reductionOp; 691 } 692 693 auto vectorType = reductionOp.getSourceVectorType(); 694 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1) 695 return failure(); 696 697 Location loc = reductionOp.getLoc(); 698 Value result; 699 if (vectorType.getRank() == 0) { 700 if (mask) 701 mask = rewriter.create<ExtractElementOp>(loc, mask); 702 result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector()); 703 } else { 704 if (mask) 705 mask = rewriter.create<ExtractOp>(loc, mask, 0); 706 result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0); 707 } 708 709 if (Value acc = reductionOp.getAcc()) 710 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), 711 result, acc, 712 reductionOp.getFastmathAttr(), mask); 713 714 rewriter.replaceOp(rootOp, result); 715 return success(); 716 } 717 }; 718 } // namespace 719 720 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results, 721 MLIRContext *context) { 722 results.add<ElideSingleElementReduction>(context); 723 } 724 725 //===----------------------------------------------------------------------===// 726 // ContractionOp 727 //===----------------------------------------------------------------------===// 728 729 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 730 Value lhs, Value rhs, Value acc, 731 ArrayRef<ArrayRef<AffineExpr>> indexingExprs, 732 ArrayRef<IteratorType> iteratorTypes) { 733 result.addOperands({lhs, rhs, acc}); 734 result.addTypes(acc.getType()); 735 result.addAttribute( 736 getIndexingMapsAttrName(result.name), 737 builder.getAffineMapArrayAttr( 738 AffineMap::inferFromExprList(indexingExprs, builder.getContext()))); 739 result.addAttribute( 740 getIteratorTypesAttrName(result.name), 741 builder.getArrayAttr(llvm::to_vector(llvm::map_range( 742 iteratorTypes, [&](IteratorType t) -> mlir::Attribute { 743 return IteratorTypeAttr::get(builder.getContext(), t); 744 })))); 745 } 746 747 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 748 Value lhs, Value rhs, Value acc, 749 ArrayAttr indexingMaps, 750 ArrayAttr iteratorTypes) { 751 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes, 752 ContractionOp::getDefaultKind()); 753 } 754 755 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 756 Value lhs, Value rhs, Value acc, 757 ArrayAttr indexingMaps, 758 ArrayAttr iteratorTypes, CombiningKind kind) { 759 result.addOperands({lhs, rhs, acc}); 760 result.addTypes(acc.getType()); 761 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps); 762 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes); 763 result.addAttribute(getKindAttrName(result.name), 764 CombiningKindAttr::get(builder.getContext(), kind)); 765 } 766 767 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { 768 OpAsmParser::UnresolvedOperand lhsInfo; 769 OpAsmParser::UnresolvedOperand rhsInfo; 770 OpAsmParser::UnresolvedOperand accInfo; 771 SmallVector<OpAsmParser::UnresolvedOperand, 2> masksInfo; 772 SmallVector<Type, 2> types; 773 Type resultType; 774 auto loc = parser.getCurrentLocation(); 775 DictionaryAttr dictAttr; 776 // TODO: Unify linalg op attribute parsing. 777 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) || 778 parser.parseComma() || parser.parseOperand(rhsInfo) || 779 parser.parseComma() || parser.parseOperand(accInfo) || 780 parser.parseTrailingOperandList(masksInfo) || 781 parser.parseOptionalAttrDict(result.attributes) || 782 parser.parseColonTypeList(types) || 783 parser.parseKeywordType("into", resultType) || 784 parser.resolveOperand(lhsInfo, types[0], result.operands) || 785 parser.resolveOperand(rhsInfo, types[1], result.operands) || 786 parser.resolveOperand(accInfo, resultType, result.operands) || 787 parser.addTypeToList(resultType, result.types)) 788 return failure(); 789 result.attributes.append(dictAttr.getValue().begin(), 790 dictAttr.getValue().end()); 791 792 // Convert array of string into an array of IteratyType enums. This is needed, 793 // because tests still use the old format when 'iterator_types' attribute is 794 // represented as an array of strings. 795 // TODO: Remove this conversion once tests are fixed. 796 ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>( 797 result.attributes.get(getIteratorTypesAttrName(result.name))); 798 799 SmallVector<Attribute> iteratorTypeAttrs; 800 801 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) { 802 auto maybeIteratorType = symbolizeIteratorType(s); 803 if (!maybeIteratorType.has_value()) 804 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")"; 805 806 iteratorTypeAttrs.push_back( 807 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value())); 808 } 809 result.attributes.set(getIteratorTypesAttrName(result.name), 810 parser.getBuilder().getArrayAttr(iteratorTypeAttrs)); 811 812 if (!result.attributes.get(getKindAttrName(result.name))) { 813 result.addAttribute( 814 getKindAttrName(result.name), 815 CombiningKindAttr::get(result.getContext(), 816 ContractionOp::getDefaultKind())); 817 } 818 if (masksInfo.empty()) 819 return success(); 820 if (masksInfo.size() != 2) 821 return parser.emitError(parser.getNameLoc(), 822 "expected zero or exactly 2 vector mask operands"); 823 auto lhsType = llvm::cast<VectorType>(types[0]); 824 auto rhsType = llvm::cast<VectorType>(types[1]); 825 auto maskElementType = parser.getBuilder().getI1Type(); 826 std::array<VectorType, 2> maskTypes = { 827 VectorType::Builder(lhsType).setElementType(maskElementType), 828 VectorType::Builder(rhsType).setElementType(maskElementType)}; 829 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands)) 830 return failure(); 831 return success(); 832 } 833 834 void ContractionOp::print(OpAsmPrinter &p) { 835 // TODO: Unify printing code with linalg ops. 836 auto attrNames = getTraitAttrNames(); 837 llvm::StringSet<> traitAttrsSet; 838 traitAttrsSet.insert(attrNames.begin(), attrNames.end()); 839 SmallVector<NamedAttribute, 8> attrs; 840 for (auto attr : (*this)->getAttrs()) { 841 if (attr.getName() == getIteratorTypesAttrName()) { 842 auto iteratorTypes = 843 llvm::cast<ArrayAttr>(attr.getValue()) 844 .getAsValueRange<IteratorTypeAttr, IteratorType>(); 845 // Convert IteratorType enums into the string representation. This is 846 // needed, because tests still use the old format when 'iterator_types' 847 // attribute is represented as an array of strings. 848 // TODO: Remove this conversion once tests are fixed. 849 SmallVector<Attribute> iteratorTypeNames = llvm::to_vector( 850 llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute { 851 return StringAttr::get(getContext(), stringifyIteratorType(t)); 852 })); 853 854 attrs.emplace_back(getIteratorTypesAttrName(), 855 ArrayAttr::get(getContext(), iteratorTypeNames)); 856 } else if (traitAttrsSet.count(attr.getName().strref()) > 0) 857 attrs.push_back(attr); 858 } 859 860 auto dictAttr = DictionaryAttr::get(getContext(), attrs); 861 p << " " << dictAttr << " " << getLhs() << ", "; 862 p << getRhs() << ", " << getAcc(); 863 864 p.printOptionalAttrDict((*this)->getAttrs(), attrNames); 865 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into " 866 << getResultType(); 867 } 868 869 static bool verifyDimMap(VectorType lhsType, VectorType rhsType, 870 const std::vector<std::pair<int64_t, int64_t>> &map) { 871 for (auto &dimPair : map) { 872 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() || 873 dimPair.second < 0 || dimPair.second >= rhsType.getRank() || 874 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second)) 875 return false; 876 } 877 return true; 878 } 879 880 static LogicalResult verifyOutputShape( 881 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, 882 Type resType, 883 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap, 884 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) { 885 DenseSet<int64_t> lhsContractingDimSet; 886 DenseSet<int64_t> rhsContractingDimSet; 887 for (auto &dimPair : contractingDimMap) { 888 lhsContractingDimSet.insert(dimPair.first); 889 rhsContractingDimSet.insert(dimPair.second); 890 } 891 DenseSet<int64_t> rhsBatchDimSet; 892 for (auto &dimPair : batchDimMap) 893 rhsBatchDimSet.insert(dimPair.second); 894 895 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'. 896 SmallVector<int64_t, 4> expectedResultDims; 897 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) { 898 if (lhsContractingDimSet.count(i) > 0) 899 continue; 900 expectedResultDims.push_back(lhsType.getDimSize(i)); 901 } 902 903 // Add free dimensions from 'rhsType' to 'expectedResultDims'. 904 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) { 905 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0) 906 continue; 907 expectedResultDims.push_back(rhsType.getDimSize(i)); 908 } 909 910 // Verify 'expectedResultDims'. 911 if (expectedResultDims.empty()) { 912 // No batch or free dimension implies a scalar result. 913 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType)) 914 return op.emitOpError("invalid accumulator/result vector shape"); 915 } else { 916 // At least one batch or free dimension implies a vector result. 917 auto resVectorType = llvm::dyn_cast<VectorType>(resType); 918 auto accVectorType = llvm::dyn_cast<VectorType>(accType); 919 if (!resVectorType || !accVectorType) 920 return op.emitOpError("invalid accumulator/result vector shape"); 921 922 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector 923 // types fully define the result vector type. This assumes the affine maps 924 // are well-formed, which must have been verified already. 925 MLIRContext *ctx = op.getContext(); 926 AffineMap lhsMap = op.getIndexingMapsArray()[0]; 927 AffineMap rhsMap = op.getIndexingMapsArray()[1]; 928 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) 929 return op.emitOpError( 930 "expected all dimensions to be either a LHS or a RHS dimension"); 931 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs()); 932 for (auto pair : 933 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) { 934 VectorType v = pair.first; 935 auto map = pair.second; 936 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) { 937 unsigned pos = map.getDimPosition(idx); 938 if (!extents[pos]) 939 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx); 940 } 941 } 942 if (!llvm::all_of(extents, [](AffineExpr e) { return e; })) 943 return op.emitOpError("expected all dimensions to get an extent as " 944 "either a LHS or a RHS dimension"); 945 946 AffineMap resMap = op.getIndexingMapsArray()[2]; 947 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(), 948 /*symbolCount=*/0, extents, ctx); 949 // Compose the resMap with the extentsMap, which is a constant map. 950 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap)); 951 assert(llvm::all_of(expectedMap.getResults(), 952 llvm::IsaPred<AffineConstantExpr>) && 953 "expected constant extent along all dimensions."); 954 // Extract the expected shape and build the type. 955 auto expectedShape = llvm::to_vector<4>( 956 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) { 957 return cast<AffineConstantExpr>(e).getValue(); 958 })); 959 auto expected = 960 VectorType::get(expectedShape, resVectorType.getElementType(), 961 resVectorType.getScalableDims()); 962 if (resVectorType != expected || accVectorType != expected) 963 return op.emitOpError( 964 "invalid accumulator/result vector shape, expected: ") 965 << expected; 966 } 967 return success(); 968 } 969 970 LogicalResult ContractionOp::verify() { 971 VectorType lhsType = getLhsType(); 972 VectorType rhsType = getRhsType(); 973 Type accType = getAccType(); 974 Type resType = getResultType(); 975 976 if (llvm::isa<IntegerType>(lhsType.getElementType())) { 977 if (!lhsType.getElementType().isSignlessInteger()) 978 return emitOpError("only supports signless integer types"); 979 } 980 981 // Verify that an indexing map was specified for each vector operand. 982 if (getIndexingMapsArray().size() != 3) 983 return emitOpError("expected an indexing map for each vector operand"); 984 985 // Verify that each index map has 'numIterators' inputs, no symbols, and 986 // that the number of map outputs equals the rank of its associated 987 // vector operand. 988 unsigned numIterators = getIteratorTypes().getValue().size(); 989 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) { 990 auto index = it.index(); 991 auto map = it.value(); 992 if (map.getNumSymbols() != 0) 993 return emitOpError("expected indexing map ") 994 << index << " to have no symbols"; 995 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType()); 996 unsigned rank = vectorType ? vectorType.getShape().size() : 0; 997 // Verify that the map has the right number of inputs, outputs, and indices. 998 // This also correctly accounts for (..) -> () for rank-0 results. 999 if (map.getNumDims() != numIterators) 1000 return emitOpError("expected indexing map ") 1001 << index << " to have " << numIterators << " number of inputs"; 1002 if (map.getNumResults() != rank) 1003 return emitOpError("expected indexing map ") 1004 << index << " to have " << rank << " number of outputs"; 1005 if (!map.isProjectedPermutation()) 1006 return emitOpError("expected indexing map ") 1007 << index << " to be a projected permutation of its inputs"; 1008 } 1009 1010 auto contractingDimMap = getContractingDimMap(); 1011 auto batchDimMap = getBatchDimMap(); 1012 1013 // Verify at least one contracting dimension pair was specified. 1014 if (contractingDimMap.empty()) 1015 return emitOpError("expected at least one contracting dimension pair"); 1016 1017 // Verify contracting dimension map was properly constructed. 1018 if (!verifyDimMap(lhsType, rhsType, contractingDimMap)) 1019 return emitOpError("invalid contracting dimension map"); 1020 1021 // Verify batch dimension map was properly constructed. 1022 if (!verifyDimMap(lhsType, rhsType, batchDimMap)) 1023 return emitOpError("invalid batch dimension map"); 1024 1025 // Verify 'accType' and 'resType' shape. 1026 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType, 1027 contractingDimMap, batchDimMap))) 1028 return failure(); 1029 1030 // Verify supported combining kind. 1031 auto vectorType = llvm::dyn_cast<VectorType>(resType); 1032 auto elementType = vectorType ? vectorType.getElementType() : resType; 1033 if (!isSupportedCombiningKind(getKind(), elementType)) 1034 return emitOpError("unsupported contraction type"); 1035 1036 return success(); 1037 } 1038 1039 // MaskableOpInterface methods. 1040 1041 /// Returns the mask type expected by this operation. Mostly used for 1042 /// verification purposes. It requires the operation to be vectorized." 1043 Type ContractionOp::getExpectedMaskType() { 1044 auto indexingMaps = this->getIndexingMapsArray(); 1045 AffineMap lhsIdxMap = indexingMaps[0]; 1046 AffineMap rhsIdxMap = indexingMaps[1]; 1047 VectorType lhsType = this->getLhsType(); 1048 VectorType rhsType = this->getRhsType(); 1049 1050 unsigned numVecDims = lhsIdxMap.getNumDims(); 1051 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic); 1052 SmallVector<bool> maskShapeScalableDims(numVecDims, false); 1053 1054 // Using the information in the indexing maps, extract the size of each 1055 // dimension in the vector.contract operation from the two input operands. 1056 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) { 1057 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize; 1058 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] = 1059 lhsType.getScalableDims()[dimIdx]; 1060 } 1061 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) { 1062 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize; 1063 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] = 1064 rhsType.getScalableDims()[dimIdx]; 1065 } 1066 1067 assert(!ShapedType::isDynamicShape(maskShape) && 1068 "Mask shape couldn't be computed"); 1069 1070 return VectorType::get(maskShape, 1071 IntegerType::get(lhsType.getContext(), /*width=*/1), 1072 maskShapeScalableDims); 1073 } 1074 1075 SmallVector<StringRef> ContractionOp::getTraitAttrNames() { 1076 return SmallVector<StringRef>{getIndexingMapsAttrName(), 1077 getIteratorTypesAttrName(), getKindAttrName()}; 1078 } 1079 1080 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) { 1081 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) 1082 if (targetExpr == map.getResult(i)) 1083 return i; 1084 return -1; 1085 } 1086 1087 static std::vector<std::pair<int64_t, int64_t>> 1088 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes, 1089 IteratorType targetIteratorType, MLIRContext *context) { 1090 std::vector<std::pair<int64_t, int64_t>> dimMap; 1091 for (const auto &it : llvm::enumerate(iteratorTypes)) { 1092 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue(); 1093 if (iteratorType != targetIteratorType) 1094 continue; 1095 // Search lhs/rhs map results for 'targetExpr'. 1096 auto targetExpr = getAffineDimExpr(it.index(), context); 1097 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr); 1098 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr); 1099 if (lhsDim >= 0 && rhsDim >= 0) 1100 dimMap.emplace_back(lhsDim, rhsDim); 1101 } 1102 return dimMap; 1103 } 1104 1105 void ContractionOp::getIterationBounds( 1106 SmallVectorImpl<int64_t> &iterationBounds) { 1107 auto lhsShape = getLhsType().getShape(); 1108 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType()); 1109 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray()); 1110 SmallVector<int64_t, 2> iterationShape; 1111 for (const auto &it : llvm::enumerate(getIteratorTypes())) { 1112 // Search lhs/rhs map results for 'targetExpr'. 1113 auto targetExpr = getAffineDimExpr(it.index(), getContext()); 1114 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue(); 1115 if (iteratorType == IteratorType::reduction) { 1116 // Get reduction dim size from lhs shape (same size in rhsShape). 1117 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr); 1118 assert(lhsDimIndex >= 0); 1119 iterationBounds.push_back(lhsShape[lhsDimIndex]); 1120 continue; 1121 } 1122 // Get parallel dimension size from result shape. 1123 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr); 1124 assert(resDimIndex >= 0); 1125 assert(resVectorType != nullptr); 1126 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]); 1127 } 1128 } 1129 1130 void ContractionOp::getIterationIndexMap( 1131 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) { 1132 unsigned numMaps = getIndexingMapsArray().size(); 1133 iterationIndexMap.resize(numMaps); 1134 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) { 1135 auto index = it.index(); 1136 auto map = it.value(); 1137 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { 1138 auto dim = cast<AffineDimExpr>(map.getResult(i)); 1139 iterationIndexMap[index][dim.getPosition()] = i; 1140 } 1141 } 1142 } 1143 1144 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() { 1145 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray()); 1146 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction, 1147 getContext()); 1148 } 1149 1150 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() { 1151 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray()); 1152 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel, 1153 getContext()); 1154 } 1155 1156 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() { 1157 SmallVector<int64_t, 4> shape; 1158 getIterationBounds(shape); 1159 return shape; 1160 } 1161 1162 /// Return a fused vector::ContractionOp which represents a patterns such as: 1163 /// 1164 /// ```mlir 1165 /// %c0 = vector.constant 0: ... 1166 /// %c = vector.contract %a, %b, %c0: ... 1167 /// %e = add %c, %d: ... 1168 /// ``` 1169 /// 1170 /// by: 1171 /// 1172 /// ```mlir 1173 /// %e = vector.contract %a, %b, %d: ... 1174 /// ``` 1175 /// 1176 /// Return null if the canonicalization does not apply. 1177 // TODO: This should be a folding of Add into Contract in core but while they 1178 // live in different dialects, it is not possible without unnatural 1179 // dependencies. 1180 template <typename AddOpType> 1181 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> { 1182 using OpRewritePattern<AddOpType>::OpRewritePattern; 1183 1184 LogicalResult matchAndRewrite(AddOpType addOp, 1185 PatternRewriter &rewriter) const override { 1186 auto canonicalize = [&](Value maybeContraction, 1187 Value otherOperand) -> vector::ContractionOp { 1188 vector::ContractionOp contractionOp = 1189 dyn_cast_or_null<vector::ContractionOp>( 1190 maybeContraction.getDefiningOp()); 1191 if (!contractionOp) 1192 return vector::ContractionOp(); 1193 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>( 1194 contractionOp.getAcc().getDefiningOp())) { 1195 if (maybeZero.getValue() == 1196 rewriter.getZeroAttr(contractionOp.getAcc().getType())) { 1197 IRMapping bvm; 1198 bvm.map(contractionOp.getAcc(), otherOperand); 1199 auto newContraction = 1200 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm)); 1201 rewriter.replaceOp(addOp, newContraction.getResult()); 1202 return newContraction; 1203 } 1204 } 1205 return vector::ContractionOp(); 1206 }; 1207 1208 Value a = addOp->getOperand(0), b = addOp->getOperand(1); 1209 vector::ContractionOp contract = canonicalize(a, b); 1210 contract = contract ? contract : canonicalize(b, a); 1211 return contract ? success() : failure(); 1212 } 1213 }; 1214 1215 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, 1216 MLIRContext *context) { 1217 results.add<CanonicalizeContractAdd<arith::AddIOp>, 1218 CanonicalizeContractAdd<arith::AddFOp>>(context); 1219 } 1220 1221 //===----------------------------------------------------------------------===// 1222 // ExtractElementOp 1223 //===----------------------------------------------------------------------===// 1224 1225 void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 1226 SetIntRangeFn setResultRanges) { 1227 setResultRanges(getResult(), argRanges.front()); 1228 } 1229 1230 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, 1231 Value source) { 1232 result.addOperands({source}); 1233 result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType()); 1234 } 1235 1236 LogicalResult vector::ExtractElementOp::verify() { 1237 VectorType vectorType = getSourceVectorType(); 1238 if (vectorType.getRank() == 0) { 1239 if (getPosition()) 1240 return emitOpError("expected position to be empty with 0-D vector"); 1241 return success(); 1242 } 1243 if (vectorType.getRank() != 1) 1244 return emitOpError("unexpected >1 vector rank"); 1245 if (!getPosition()) 1246 return emitOpError("expected position for 1-D vector"); 1247 return success(); 1248 } 1249 1250 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { 1251 // Skip the 0-D vector here now. 1252 if (!adaptor.getPosition()) 1253 return {}; 1254 1255 // Fold extractelement (splat X) -> X. 1256 if (auto splat = getVector().getDefiningOp<vector::SplatOp>()) 1257 return splat.getInput(); 1258 1259 // Fold extractelement(broadcast(X)) -> X. 1260 if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>()) 1261 if (!llvm::isa<VectorType>(broadcast.getSource().getType())) 1262 return broadcast.getSource(); 1263 1264 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector()); 1265 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); 1266 if (!pos || !src) 1267 return {}; 1268 1269 auto srcElements = src.getValues<Attribute>(); 1270 1271 uint64_t posIdx = pos.getInt(); 1272 if (posIdx >= srcElements.size()) 1273 return {}; 1274 1275 return srcElements[posIdx]; 1276 } 1277 1278 // Returns `true` if `index` is either within [0, maxIndex) or equal to 1279 // `poisonValue`. 1280 static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, 1281 int64_t maxIndex) { 1282 return index == poisonValue || (index >= 0 && index < maxIndex); 1283 } 1284 1285 //===----------------------------------------------------------------------===// 1286 // ExtractOp 1287 //===----------------------------------------------------------------------===// 1288 1289 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 1290 SetIntRangeFn setResultRanges) { 1291 setResultRanges(getResult(), argRanges.front()); 1292 } 1293 1294 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 1295 Value source, int64_t position) { 1296 build(builder, result, source, ArrayRef<int64_t>{position}); 1297 } 1298 1299 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 1300 Value source, OpFoldResult position) { 1301 build(builder, result, source, ArrayRef<OpFoldResult>{position}); 1302 } 1303 1304 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 1305 Value source, ArrayRef<int64_t> position) { 1306 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(), 1307 builder.getDenseI64ArrayAttr(position)); 1308 } 1309 1310 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 1311 Value source, ArrayRef<OpFoldResult> position) { 1312 SmallVector<int64_t> staticPos; 1313 SmallVector<Value> dynamicPos; 1314 dispatchIndexOpFoldResults(position, dynamicPos, staticPos); 1315 build(builder, result, source, dynamicPos, 1316 builder.getDenseI64ArrayAttr(staticPos)); 1317 } 1318 1319 LogicalResult 1320 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>, 1321 ExtractOp::Adaptor adaptor, 1322 SmallVectorImpl<Type> &inferredReturnTypes) { 1323 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType()); 1324 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) == 1325 vectorType.getRank()) { 1326 inferredReturnTypes.push_back(vectorType.getElementType()); 1327 } else { 1328 auto n = std::min<size_t>(adaptor.getStaticPosition().size(), 1329 vectorType.getRank()); 1330 inferredReturnTypes.push_back(VectorType::get( 1331 vectorType.getShape().drop_front(n), vectorType.getElementType(), 1332 vectorType.getScalableDims().drop_front(n))); 1333 } 1334 return success(); 1335 } 1336 1337 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1338 // Allow extracting 1-element vectors instead of scalars. 1339 auto isCompatible = [](TypeRange l, TypeRange r) { 1340 auto vectorType = llvm::dyn_cast<VectorType>(l.front()); 1341 return vectorType && vectorType.getShape().equals({1}) && 1342 vectorType.getElementType() == r.front(); 1343 }; 1344 if (l.size() == 1 && r.size() == 1 && 1345 (isCompatible(l, r) || isCompatible(r, l))) 1346 return true; 1347 return l == r; 1348 } 1349 1350 LogicalResult vector::ExtractOp::verify() { 1351 // Note: This check must come before getMixedPosition() to prevent a crash. 1352 auto dynamicMarkersCount = 1353 llvm::count_if(getStaticPosition(), ShapedType::isDynamic); 1354 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size()) 1355 return emitOpError( 1356 "mismatch between dynamic and static positions (kDynamic marker but no " 1357 "corresponding dynamic position) -- this can only happen due to an " 1358 "incorrect fold/rewrite"); 1359 auto position = getMixedPosition(); 1360 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank())) 1361 return emitOpError( 1362 "expected position attribute of rank no greater than vector rank"); 1363 for (auto [idx, pos] : llvm::enumerate(position)) { 1364 if (auto attr = dyn_cast<Attribute>(pos)) { 1365 int64_t constIdx = cast<IntegerAttr>(attr).getInt(); 1366 if (!isValidPositiveIndexOrPoison( 1367 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) { 1368 return emitOpError("expected position attribute #") 1369 << (idx + 1) 1370 << " to be a non-negative integer smaller than the " 1371 "corresponding vector dimension or poison (-1)"; 1372 } 1373 } 1374 } 1375 return success(); 1376 } 1377 1378 template <typename IntType> 1379 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) { 1380 return llvm::to_vector<4>(llvm::map_range( 1381 arrayAttr.getAsRange<IntegerAttr>(), 1382 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 1383 } 1384 1385 /// Fold the result of chains of ExtractOp in place by simply concatenating the 1386 /// positions. 1387 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { 1388 if (!extractOp.getVector().getDefiningOp<ExtractOp>()) 1389 return failure(); 1390 1391 // TODO: Canonicalization for dynamic position not implemented yet. 1392 if (extractOp.hasDynamicPosition()) 1393 return failure(); 1394 1395 SmallVector<int64_t> globalPosition; 1396 ExtractOp currentOp = extractOp; 1397 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition(); 1398 globalPosition.append(extrPos.rbegin(), extrPos.rend()); 1399 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) { 1400 currentOp = nextOp; 1401 // TODO: Canonicalization for dynamic position not implemented yet. 1402 if (currentOp.hasDynamicPosition()) 1403 return failure(); 1404 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition(); 1405 globalPosition.append(extrPos.rbegin(), extrPos.rend()); 1406 } 1407 extractOp.setOperand(0, currentOp.getVector()); 1408 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1409 OpBuilder b(extractOp.getContext()); 1410 std::reverse(globalPosition.begin(), globalPosition.end()); 1411 extractOp.setStaticPosition(globalPosition); 1412 return success(); 1413 } 1414 1415 namespace { 1416 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. 1417 /// Walk back a chain of InsertOp/TransposeOp until we hit a match. 1418 /// Compose TransposeOp permutations as we walk back. 1419 /// This helper class keeps an updated extraction position `extractPosition` 1420 /// with extra trailing sentinels. 1421 /// The sentinels encode the internal transposition status of the result vector. 1422 /// As we iterate, extractPosition is permuted and updated. 1423 class ExtractFromInsertTransposeChainState { 1424 public: 1425 ExtractFromInsertTransposeChainState(ExtractOp e); 1426 1427 /// Iterate over producing insert and transpose ops until we find a fold. 1428 Value fold(); 1429 1430 private: 1431 /// Return true if the vector at position `a` is contained within the vector 1432 /// at position `b`. Under insert/extract semantics, this is the same as `a` 1433 /// is a prefix of `b`. 1434 template <typename ContainerA, typename ContainerB> 1435 bool isContainedWithin(const ContainerA &a, const ContainerB &b) { 1436 return a.size() <= b.size() && 1437 std::equal(a.begin(), a.begin() + a.size(), b.begin()); 1438 } 1439 1440 /// Return true if the vector at position `a` intersects the vector at 1441 /// position `b`. Under insert/extract semantics, this is the same as equality 1442 /// of all entries of `a` that are >=0 with the corresponding entries of b. 1443 /// Comparison is on the common prefix (i.e. zip). 1444 template <typename ContainerA, typename ContainerB> 1445 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) { 1446 for (auto [elemA, elemB] : llvm::zip(a, b)) { 1447 if (elemA < 0 || elemB < 0) 1448 continue; 1449 if (elemA != elemB) 1450 return false; 1451 } 1452 return true; 1453 } 1454 1455 /// Folding is only possible in the absence of an internal permutation in the 1456 /// result vector. 1457 bool canFold() { 1458 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank)); 1459 } 1460 1461 // Helper to get the next defining op of interest. 1462 void updateStateForNextIteration(Value v) { 1463 nextInsertOp = v.getDefiningOp<vector::InsertOp>(); 1464 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>(); 1465 }; 1466 1467 // Case 1. If we hit a transpose, just compose the map and iterate. 1468 // Invariant: insert + transpose do not change rank, we can always compose. 1469 LogicalResult handleTransposeOp(); 1470 1471 // Case 2: the insert position matches extractPosition exactly, early return. 1472 LogicalResult handleInsertOpWithMatchingPos(Value &res); 1473 1474 /// Case 3: if the insert position is a prefix of extractPosition, extract a 1475 /// portion of the source of the insert. 1476 /// Example: 1477 /// ``` 1478 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5> 1479 /// // extractPosition == [1, 2, 3] 1480 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5> 1481 /// // can fold to vector.extract %source[0, 3] 1482 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6> 1483 /// ``` 1484 /// To traverse through %source, we need to set the leading dims to 0 and 1485 /// drop the extra leading dims. 1486 /// This method updates the internal state. 1487 LogicalResult handleInsertOpWithPrefixPos(Value &res); 1488 1489 /// Try to fold in place to extract(source, extractPosition) and return the 1490 /// folded result. Return null if folding is not possible (e.g. due to an 1491 /// internal transposition in the result). 1492 Value tryToFoldExtractOpInPlace(Value source); 1493 1494 ExtractOp extractOp; 1495 int64_t vectorRank; 1496 int64_t extractedRank; 1497 1498 InsertOp nextInsertOp; 1499 TransposeOp nextTransposeOp; 1500 1501 /// Sentinel values that encode the internal permutation status of the result. 1502 /// They are set to (-1, ... , -k) at the beginning and appended to 1503 /// `extractPosition`. 1504 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to 1505 /// ensure that there is no internal transposition. 1506 /// Internal transposition cannot be accounted for with a folding pattern. 1507 // TODO: We could relax the internal transposition with an extra transposition 1508 // operation in a future canonicalizer. 1509 SmallVector<int64_t> sentinels; 1510 SmallVector<int64_t> extractPosition; 1511 }; 1512 } // namespace 1513 1514 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( 1515 ExtractOp e) 1516 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()), 1517 extractedRank(extractOp.getNumIndices()) { 1518 assert(vectorRank >= extractedRank && "Extracted position overflow"); 1519 sentinels.reserve(vectorRank - extractedRank); 1520 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) 1521 sentinels.push_back(-(i + 1)); 1522 extractPosition.assign(extractOp.getStaticPosition().begin(), 1523 extractOp.getStaticPosition().end()); 1524 llvm::append_range(extractPosition, sentinels); 1525 } 1526 1527 // Case 1. If we hit a transpose, just compose the map and iterate. 1528 // Invariant: insert + transpose do not change rank, we can always compose. 1529 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() { 1530 // TODO: Canonicalization for dynamic position not implemented yet. 1531 if (extractOp.hasDynamicPosition()) 1532 return failure(); 1533 1534 if (!nextTransposeOp) 1535 return failure(); 1536 AffineMap m = inversePermutation(AffineMap::getPermutationMap( 1537 nextTransposeOp.getPermutation(), extractOp.getContext())); 1538 extractPosition = applyPermutationMap(m, ArrayRef(extractPosition)); 1539 return success(); 1540 } 1541 1542 // Case 2: the insert position matches extractPosition exactly, early return. 1543 LogicalResult 1544 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( 1545 Value &res) { 1546 // TODO: Canonicalization for dynamic position not implemented yet. 1547 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition()) 1548 return failure(); 1549 1550 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition(); 1551 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank)) 1552 return failure(); 1553 // Case 2.a. early-exit fold. 1554 res = nextInsertOp.getSource(); 1555 // Case 2.b. if internal transposition is present, canFold will be false. 1556 return success(canFold()); 1557 } 1558 1559 /// Case 3: if inserted position is a prefix of extractPosition, 1560 /// extract a portion of the source of the insertion. 1561 /// This method updates the internal state. 1562 LogicalResult 1563 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { 1564 // TODO: Canonicalization for dynamic position not implemented yet. 1565 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition()) 1566 return failure(); 1567 1568 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition(); 1569 if (!isContainedWithin(insertedPos, extractPosition)) 1570 return failure(); 1571 // Set leading dims to zero. 1572 std::fill_n(extractPosition.begin(), insertedPos.size(), 0); 1573 // Drop extra leading dims. 1574 extractPosition.erase(extractPosition.begin(), 1575 extractPosition.begin() + insertedPos.size()); 1576 extractedRank = extractPosition.size() - sentinels.size(); 1577 // Case 3.a. early-exit fold (break and delegate to post-while path). 1578 res = nextInsertOp.getSource(); 1579 // Case 3.b. if internal transposition is present, canFold will be false. 1580 return success(); 1581 } 1582 1583 /// Try to fold in place to extract(source, extractPosition) and return the 1584 /// folded result. Return null if folding is not possible (e.g. due to an 1585 /// internal transposition in the result). 1586 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace( 1587 Value source) { 1588 // TODO: Canonicalization for dynamic position not implemented yet. 1589 if (extractOp.hasDynamicPosition()) 1590 return Value(); 1591 1592 // If we can't fold (either internal transposition, or nothing to fold), bail. 1593 bool nothingToFold = (source == extractOp.getVector()); 1594 if (nothingToFold || !canFold()) 1595 return Value(); 1596 1597 // Otherwise, fold by updating the op inplace and return its result. 1598 OpBuilder b(extractOp.getContext()); 1599 extractOp.setStaticPosition( 1600 ArrayRef(extractPosition).take_front(extractedRank)); 1601 extractOp.getVectorMutable().assign(source); 1602 return extractOp.getResult(); 1603 } 1604 1605 /// Iterate over producing insert and transpose ops until we find a fold. 1606 Value ExtractFromInsertTransposeChainState::fold() { 1607 // TODO: Canonicalization for dynamic position not implemented yet. 1608 if (extractOp.hasDynamicPosition()) 1609 return Value(); 1610 1611 Value valueToExtractFrom = extractOp.getVector(); 1612 updateStateForNextIteration(valueToExtractFrom); 1613 while (nextInsertOp || nextTransposeOp) { 1614 // Case 1. If we hit a transpose, just compose the map and iterate. 1615 // Invariant: insert + transpose do not change rank, we can always compose. 1616 if (succeeded(handleTransposeOp())) { 1617 valueToExtractFrom = nextTransposeOp.getVector(); 1618 updateStateForNextIteration(valueToExtractFrom); 1619 continue; 1620 } 1621 1622 Value result; 1623 // Case 2: the position match exactly. 1624 if (succeeded(handleInsertOpWithMatchingPos(result))) 1625 return result; 1626 1627 // Case 3: if the inserted position is a prefix of extractPosition, we can 1628 // just extract a portion of the source of the insert. 1629 if (succeeded(handleInsertOpWithPrefixPos(result))) 1630 return tryToFoldExtractOpInPlace(result); 1631 1632 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel 1633 // values. This is a more difficult case and we bail. 1634 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition(); 1635 if (isContainedWithin(extractPosition, insertedPos) || 1636 intersectsWhereNonNegative(extractPosition, insertedPos)) 1637 return Value(); 1638 1639 // Case 5: No intersection, we forward the extract to insertOp.dest(). 1640 valueToExtractFrom = nextInsertOp.getDest(); 1641 updateStateForNextIteration(valueToExtractFrom); 1642 } 1643 // If after all this we can fold, go for it. 1644 return tryToFoldExtractOpInPlace(valueToExtractFrom); 1645 } 1646 1647 /// Returns true if the operation has a 0-D vector type operand or result. 1648 static bool hasZeroDimVectors(Operation *op) { 1649 auto hasZeroDimVectorType = [](Type type) -> bool { 1650 auto vecType = dyn_cast<VectorType>(type); 1651 return vecType && vecType.getRank() == 0; 1652 }; 1653 1654 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) || 1655 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); 1656 } 1657 1658 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. 1659 static Value foldExtractFromBroadcast(ExtractOp extractOp) { 1660 // TODO: Canonicalization for dynamic position not implemented yet. 1661 if (extractOp.hasDynamicPosition()) 1662 return Value(); 1663 1664 Operation *defOp = extractOp.getVector().getDefiningOp(); 1665 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) 1666 return Value(); 1667 1668 Value source = defOp->getOperand(0); 1669 if (extractOp.getType() == source.getType()) 1670 return source; 1671 auto getRank = [](Type type) { 1672 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank() 1673 : 0; 1674 }; 1675 1676 // If splat or broadcast from a scalar, just return the source scalar. 1677 unsigned broadcastSrcRank = getRank(source.getType()); 1678 if (broadcastSrcRank == 0 && source.getType() == extractOp.getType()) 1679 return source; 1680 1681 unsigned extractResultRank = getRank(extractOp.getType()); 1682 if (extractResultRank >= broadcastSrcRank) 1683 return Value(); 1684 // Check that the dimension of the result haven't been broadcasted. 1685 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType()); 1686 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType()); 1687 if (extractVecType && broadcastVecType && 1688 extractVecType.getShape() != 1689 broadcastVecType.getShape().take_back(extractResultRank)) 1690 return Value(); 1691 1692 auto broadcastOp = cast<vector::BroadcastOp>(defOp); 1693 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank(); 1694 1695 // Detect all the positions that come from "dim-1" broadcasting. 1696 // These dimensions correspond to "dim-1" broadcasted dims; set the mathching 1697 // extract position to `0` when extracting from the source operand. 1698 llvm::SetVector<int64_t> broadcastedUnitDims = 1699 broadcastOp.computeBroadcastedUnitDims(); 1700 SmallVector<int64_t> extractPos(extractOp.getStaticPosition()); 1701 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank; 1702 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i) 1703 if (broadcastedUnitDims.contains(i)) 1704 extractPos[i] = 0; 1705 // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the 1706 // matching extract position when extracting from the source operand. 1707 int64_t rankDiff = broadcastSrcRank - extractResultRank; 1708 extractPos.erase(extractPos.begin(), 1709 std::next(extractPos.begin(), extractPos.size() - rankDiff)); 1710 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1711 OpBuilder b(extractOp.getContext()); 1712 extractOp.setOperand(0, source); 1713 extractOp.setStaticPosition(extractPos); 1714 return extractOp.getResult(); 1715 } 1716 1717 /// Fold extractOp coming from ShuffleOp. 1718 /// 1719 /// Example: 1720 /// 1721 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15] 1722 /// : vector<8xf32>, vector<8xf32> 1723 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32> 1724 /// -> 1725 /// %extract = vector.extract %b[7] : f32 from vector<8xf32> 1726 /// 1727 static Value foldExtractFromShuffle(ExtractOp extractOp) { 1728 // Dynamic positions are not folded as the resulting code would be more 1729 // complex than the input code. 1730 if (extractOp.hasDynamicPosition()) 1731 return Value(); 1732 1733 auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>(); 1734 if (!shuffleOp) 1735 return Value(); 1736 1737 // TODO: 0-D or multi-dimensional vectors not supported yet. 1738 if (shuffleOp.getResultVectorType().getRank() != 1) 1739 return Value(); 1740 1741 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0]; 1742 auto shuffleMask = shuffleOp.getMask(); 1743 int64_t extractIdx = extractOp.getStaticPosition()[0]; 1744 int64_t shuffleIdx = shuffleMask[extractIdx]; 1745 1746 // Find the shuffled vector to extract from based on the shuffle index. 1747 if (shuffleIdx < inputVecSize) { 1748 extractOp.setOperand(0, shuffleOp.getV1()); 1749 extractOp.setStaticPosition({shuffleIdx}); 1750 } else { 1751 extractOp.setOperand(0, shuffleOp.getV2()); 1752 extractOp.setStaticPosition({shuffleIdx - inputVecSize}); 1753 } 1754 1755 return extractOp.getResult(); 1756 } 1757 1758 // Fold extractOp with source coming from ShapeCast op. 1759 static Value foldExtractFromShapeCast(ExtractOp extractOp) { 1760 // TODO: Canonicalization for dynamic position not implemented yet. 1761 if (extractOp.hasDynamicPosition()) 1762 return Value(); 1763 1764 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>(); 1765 if (!shapeCastOp) 1766 return Value(); 1767 1768 // Get the nth dimension size starting from lowest dimension. 1769 auto getDimReverse = [](VectorType type, int64_t n) { 1770 return type.getShape().take_back(n + 1).front(); 1771 }; 1772 int64_t destinationRank = 1773 llvm::isa<VectorType>(extractOp.getType()) 1774 ? llvm::cast<VectorType>(extractOp.getType()).getRank() 1775 : 0; 1776 if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) 1777 return Value(); 1778 if (destinationRank > 0) { 1779 auto destinationType = 1780 llvm::cast<VectorType>(extractOp.getResult().getType()); 1781 for (int64_t i = 0; i < destinationRank; i++) { 1782 // The lowest dimension of the destination must match the lowest 1783 // dimension of the shapecast op source. 1784 // TODO: This case could be support in a canonicalization pattern. 1785 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) != 1786 getDimReverse(destinationType, i)) 1787 return Value(); 1788 } 1789 } 1790 // Extract the strides associated with the extract op vector source. Then use 1791 // this to calculate a linearized position for the extract. 1792 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition()); 1793 std::reverse(extractedPos.begin(), extractedPos.end()); 1794 SmallVector<int64_t, 4> strides; 1795 int64_t stride = 1; 1796 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) { 1797 strides.push_back(stride); 1798 stride *= 1799 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank); 1800 } 1801 1802 int64_t position = linearize(extractedPos, strides); 1803 // Then extract the strides associated to the shapeCast op vector source and 1804 // delinearize the position using those strides. 1805 SmallVector<int64_t, 4> newStrides; 1806 int64_t numDimension = 1807 shapeCastOp.getSourceVectorType().getRank() - destinationRank; 1808 stride = 1; 1809 for (int64_t i = 0; i < numDimension; i++) { 1810 newStrides.push_back(stride); 1811 stride *= 1812 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank); 1813 } 1814 std::reverse(newStrides.begin(), newStrides.end()); 1815 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides); 1816 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1817 OpBuilder b(extractOp.getContext()); 1818 extractOp.setStaticPosition(newPosition); 1819 extractOp.setOperand(0, shapeCastOp.getSource()); 1820 return extractOp.getResult(); 1821 } 1822 1823 /// Fold an ExtractOp from ExtractStridedSliceOp. 1824 static Value foldExtractFromExtractStrided(ExtractOp extractOp) { 1825 // TODO: Canonicalization for dynamic position not implemented yet. 1826 if (extractOp.hasDynamicPosition()) 1827 return Value(); 1828 1829 auto extractStridedSliceOp = 1830 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>(); 1831 if (!extractStridedSliceOp) 1832 return Value(); 1833 1834 // 0-D vectors not supported. 1835 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported"); 1836 if (hasZeroDimVectors(extractStridedSliceOp)) 1837 return Value(); 1838 1839 // Return if 'extractStridedSliceOp' has non-unit strides. 1840 if (extractStridedSliceOp.hasNonUnitStrides()) 1841 return Value(); 1842 1843 // Trim offsets for dimensions fully extracted. 1844 auto sliceOffsets = 1845 extractVector<int64_t>(extractStridedSliceOp.getOffsets()); 1846 while (!sliceOffsets.empty()) { 1847 size_t lastOffset = sliceOffsets.size() - 1; 1848 if (sliceOffsets.back() != 0 || 1849 extractStridedSliceOp.getType().getDimSize(lastOffset) != 1850 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset)) 1851 break; 1852 sliceOffsets.pop_back(); 1853 } 1854 unsigned destinationRank = 0; 1855 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType())) 1856 destinationRank = vecType.getRank(); 1857 // The dimensions of the result need to be untouched by the 1858 // extractStridedSlice op. 1859 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() - 1860 sliceOffsets.size()) 1861 return Value(); 1862 1863 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition()); 1864 assert(extractedPos.size() >= sliceOffsets.size()); 1865 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) 1866 extractedPos[i] = extractedPos[i] + sliceOffsets[i]; 1867 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector()); 1868 1869 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1870 OpBuilder b(extractOp.getContext()); 1871 extractOp.setStaticPosition(extractedPos); 1872 return extractOp.getResult(); 1873 } 1874 1875 /// Fold extract_op fed from a chain of insertStridedSlice ops. 1876 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) { 1877 // TODO: Canonicalization for dynamic position not implemented yet. 1878 if (extractOp.hasDynamicPosition()) 1879 return Value(); 1880 1881 int64_t destinationRank = 1882 llvm::isa<VectorType>(extractOp.getType()) 1883 ? llvm::cast<VectorType>(extractOp.getType()).getRank() 1884 : 0; 1885 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 1886 if (!insertOp) 1887 return Value(); 1888 1889 // 0-D vectors not supported. 1890 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported"); 1891 if (hasZeroDimVectors(insertOp)) 1892 return Value(); 1893 1894 while (insertOp) { 1895 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() - 1896 insertOp.getSourceVectorType().getRank(); 1897 if (destinationRank > insertOp.getSourceVectorType().getRank()) 1898 return Value(); 1899 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets()); 1900 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition(); 1901 1902 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { 1903 return llvm::cast<IntegerAttr>(attr).getInt() != 1; 1904 })) 1905 return Value(); 1906 bool disjoint = false; 1907 SmallVector<int64_t, 4> offsetDiffs; 1908 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { 1909 int64_t start = insertOffsets[dim]; 1910 int64_t size = 1911 (dim < insertRankDiff) 1912 ? 1 1913 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff); 1914 int64_t end = start + size; 1915 int64_t offset = extractOffsets[dim]; 1916 // Check if the start of the extract offset is in the interval inserted. 1917 if (start <= offset && offset < end) { 1918 if (dim >= insertRankDiff) 1919 offsetDiffs.push_back(offset - start); 1920 continue; 1921 } 1922 disjoint = true; 1923 break; 1924 } 1925 // The extract element chunk overlap with the vector inserted. 1926 if (!disjoint) { 1927 // If any of the inner dimensions are only partially inserted we have a 1928 // partial overlap. 1929 int64_t srcRankDiff = 1930 insertOp.getSourceVectorType().getRank() - destinationRank; 1931 for (int64_t i = 0; i < destinationRank; i++) { 1932 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) != 1933 insertOp.getDestVectorType().getDimSize(i + srcRankDiff + 1934 insertRankDiff)) 1935 return Value(); 1936 } 1937 extractOp.getVectorMutable().assign(insertOp.getSource()); 1938 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1939 OpBuilder b(extractOp.getContext()); 1940 extractOp.setStaticPosition(offsetDiffs); 1941 return extractOp.getResult(); 1942 } 1943 // If the chunk extracted is disjoint from the chunk inserted, keep 1944 // looking in the insert chain. 1945 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>(); 1946 } 1947 return Value(); 1948 } 1949 1950 /// Try to fold the extraction of a scalar from a vector defined by 1951 /// vector.from_elements. E.g.: 1952 /// 1953 /// %0 = vector.from_elements %a, %b : vector<2xf32> 1954 /// %1 = vector.extract %0[0] : f32 from vector<2xf32> 1955 /// ==> fold to %a 1956 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) { 1957 // Dynamic extractions cannot be folded. 1958 if (extractOp.hasDynamicPosition()) 1959 return {}; 1960 1961 // Look for extract(from_elements). 1962 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>(); 1963 if (!fromElementsOp) 1964 return {}; 1965 1966 // Scalable vectors are not supported. 1967 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType()); 1968 if (vecType.isScalable()) 1969 return {}; 1970 1971 // Only extractions of scalars are supported. 1972 int64_t rank = vecType.getRank(); 1973 ArrayRef<int64_t> indices = extractOp.getStaticPosition(); 1974 if (extractOp.getType() != vecType.getElementType()) 1975 return {}; 1976 assert(static_cast<int64_t>(indices.size()) == rank && 1977 "unexpected number of indices"); 1978 1979 // Compute flattened/linearized index and fold to operand. 1980 int flatIndex = 0; 1981 int stride = 1; 1982 for (int i = rank - 1; i >= 0; --i) { 1983 flatIndex += indices[i] * stride; 1984 stride *= vecType.getDimSize(i); 1985 } 1986 return fromElementsOp.getElements()[flatIndex]; 1987 } 1988 1989 /// Fold an insert or extract operation into an poison value when a poison index 1990 /// is found at any dimension of the static position. 1991 static ub::PoisonAttr 1992 foldPoisonIndexInsertExtractOp(MLIRContext *context, 1993 ArrayRef<int64_t> staticPos, int64_t poisonVal) { 1994 if (!llvm::is_contained(staticPos, poisonVal)) 1995 return ub::PoisonAttr(); 1996 1997 return ub::PoisonAttr::get(context); 1998 } 1999 2000 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { 2001 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v. 2002 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type 2003 // mismatch). 2004 if (getNumIndices() == 0 && getVector().getType() == getResult().getType()) 2005 return getVector(); 2006 if (auto res = foldPoisonIndexInsertExtractOp( 2007 getContext(), adaptor.getStaticPosition(), kPoisonIndex)) 2008 return res; 2009 if (succeeded(foldExtractOpFromExtractChain(*this))) 2010 return getResult(); 2011 if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) 2012 return res; 2013 if (auto res = foldExtractFromBroadcast(*this)) 2014 return res; 2015 if (auto res = foldExtractFromShuffle(*this)) 2016 return res; 2017 if (auto res = foldExtractFromShapeCast(*this)) 2018 return res; 2019 if (auto val = foldExtractFromExtractStrided(*this)) 2020 return val; 2021 if (auto val = foldExtractStridedOpFromInsertChain(*this)) 2022 return val; 2023 if (auto val = foldScalarExtractFromFromElements(*this)) 2024 return val; 2025 return OpFoldResult(); 2026 } 2027 2028 namespace { 2029 2030 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. 2031 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { 2032 public: 2033 using OpRewritePattern::OpRewritePattern; 2034 2035 LogicalResult matchAndRewrite(ExtractOp extractOp, 2036 PatternRewriter &rewriter) const override { 2037 Operation *defOp = extractOp.getVector().getDefiningOp(); 2038 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) 2039 return failure(); 2040 2041 Value source = defOp->getOperand(0); 2042 if (extractOp.getType() == source.getType()) 2043 return failure(); 2044 auto getRank = [](Type type) { 2045 return llvm::isa<VectorType>(type) 2046 ? llvm::cast<VectorType>(type).getRank() 2047 : 0; 2048 }; 2049 unsigned broadcastSrcRank = getRank(source.getType()); 2050 unsigned extractResultRank = getRank(extractOp.getType()); 2051 // We only consider the case where the rank of the source is less than or 2052 // equal to the rank of the extract dst. The other cases are handled in the 2053 // folding patterns. 2054 if (extractResultRank < broadcastSrcRank) 2055 return failure(); 2056 2057 // Special case if broadcast src is a 0D vector. 2058 if (extractResultRank == 0) { 2059 assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType())); 2060 rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source); 2061 return success(); 2062 } 2063 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 2064 extractOp, extractOp.getType(), source); 2065 return success(); 2066 } 2067 }; 2068 2069 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp. 2070 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> { 2071 public: 2072 using OpRewritePattern::OpRewritePattern; 2073 2074 LogicalResult matchAndRewrite(ExtractOp extractOp, 2075 PatternRewriter &rewriter) const override { 2076 // Return if 'ExtractOp' operand is not defined by a splat vector 2077 // ConstantOp. 2078 Value sourceVector = extractOp.getVector(); 2079 Attribute vectorCst; 2080 if (!matchPattern(sourceVector, m_Constant(&vectorCst))) 2081 return failure(); 2082 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst); 2083 if (!splat) 2084 return failure(); 2085 TypedAttr newAttr = splat.getSplatValue<TypedAttr>(); 2086 if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType())) 2087 newAttr = DenseElementsAttr::get(vecDstType, newAttr); 2088 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr); 2089 return success(); 2090 } 2091 }; 2092 2093 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp. 2094 class ExtractOpNonSplatConstantFolder final 2095 : public OpRewritePattern<ExtractOp> { 2096 public: 2097 using OpRewritePattern::OpRewritePattern; 2098 2099 LogicalResult matchAndRewrite(ExtractOp extractOp, 2100 PatternRewriter &rewriter) const override { 2101 // TODO: Canonicalization for dynamic position not implemented yet. 2102 if (extractOp.hasDynamicPosition()) 2103 return failure(); 2104 2105 // Return if 'ExtractOp' operand is not defined by a compatible vector 2106 // ConstantOp. 2107 Value sourceVector = extractOp.getVector(); 2108 Attribute vectorCst; 2109 if (!matchPattern(sourceVector, m_Constant(&vectorCst))) 2110 return failure(); 2111 2112 auto vecTy = llvm::cast<VectorType>(sourceVector.getType()); 2113 if (vecTy.isScalable()) 2114 return failure(); 2115 2116 // The splat case is handled by `ExtractOpSplatConstantFolder`. 2117 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst); 2118 if (!dense || dense.isSplat()) 2119 return failure(); 2120 2121 // Calculate the linearized position of the continuous chunk of elements to 2122 // extract. 2123 llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0); 2124 copy(extractOp.getStaticPosition(), completePositions.begin()); 2125 int64_t elemBeginPosition = 2126 linearize(completePositions, computeStrides(vecTy.getShape())); 2127 auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition; 2128 2129 TypedAttr newAttr; 2130 if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) { 2131 SmallVector<Attribute> elementValues( 2132 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements()); 2133 newAttr = DenseElementsAttr::get(resVecTy, elementValues); 2134 } else { 2135 newAttr = *denseValuesBegin; 2136 } 2137 2138 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr); 2139 return success(); 2140 } 2141 }; 2142 2143 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask. 2144 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> { 2145 public: 2146 using OpRewritePattern::OpRewritePattern; 2147 2148 LogicalResult matchAndRewrite(ExtractOp extractOp, 2149 PatternRewriter &rewriter) const override { 2150 auto createMaskOp = 2151 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); 2152 if (!createMaskOp) 2153 return failure(); 2154 2155 VectorType extractedMaskType = 2156 llvm::dyn_cast<VectorType>(extractOp.getResult().getType()); 2157 2158 if (!extractedMaskType) 2159 return failure(); 2160 2161 auto maskOperands = createMaskOp.getOperands(); 2162 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition(); 2163 VectorType maskType = createMaskOp.getVectorType(); 2164 2165 bool containsUnknownDims = false; 2166 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse; 2167 2168 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size(); 2169 dimIdx++) { 2170 int64_t pos = extractOpPos[dimIdx]; 2171 Value operand = maskOperands[dimIdx]; 2172 auto constantOp = operand.getDefiningOp<arith::ConstantOp>(); 2173 if (!constantOp) { 2174 // Bounds of this dim unknown. 2175 containsUnknownDims = true; 2176 continue; 2177 } 2178 2179 int64_t createMaskBound = 2180 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt(); 2181 2182 if (pos != ShapedType::kDynamic) { 2183 // If any position is outside the range from the `create_mask`, then the 2184 // extracted mask will be all-false. 2185 allFalse |= pos >= createMaskBound; 2186 } else if (createMaskBound < maskType.getDimSize(dimIdx)) { 2187 // This dim is not all-true and since this is a dynamic index we don't 2188 // know if the extraction is within the true or false region. 2189 // Note: Zero dims have already handled via getMaskFormat(). 2190 containsUnknownDims = true; 2191 } 2192 } 2193 2194 if (allFalse) { 2195 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 2196 extractOp, DenseElementsAttr::get(extractedMaskType, false)); 2197 } else if (!containsUnknownDims) { 2198 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>( 2199 extractOp, extractedMaskType, 2200 maskOperands.drop_front(extractOpPos.size())); 2201 } else { 2202 return failure(); 2203 } 2204 return success(); 2205 } 2206 }; 2207 2208 // Folds extract(shape_cast(..)) into shape_cast when the total element count 2209 // does not change. 2210 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp, 2211 PatternRewriter &rewriter) { 2212 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>(); 2213 if (!castOp) 2214 return failure(); 2215 2216 VectorType sourceType = castOp.getSourceVectorType(); 2217 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType()); 2218 if (!targetType) 2219 return failure(); 2220 2221 if (sourceType.getNumElements() != targetType.getNumElements()) 2222 return failure(); 2223 2224 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType, 2225 castOp.getSource()); 2226 return success(); 2227 } 2228 2229 /// Try to canonicalize the extraction of a subvector from a vector defined by 2230 /// vector.from_elements. E.g.: 2231 /// 2232 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32> 2233 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32> 2234 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32> 2235 LogicalResult foldExtractFromFromElements(ExtractOp extractOp, 2236 PatternRewriter &rewriter) { 2237 // Dynamic positions are not supported. 2238 if (extractOp.hasDynamicPosition()) 2239 return failure(); 2240 2241 // Scalar extracts are handled by the folder. 2242 auto resultType = dyn_cast<VectorType>(extractOp.getType()); 2243 if (!resultType) 2244 return failure(); 2245 2246 // Look for extracts from a from_elements op. 2247 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>(); 2248 if (!fromElementsOp) 2249 return failure(); 2250 VectorType inputType = fromElementsOp.getType(); 2251 2252 // Scalable vectors are not supported. 2253 if (resultType.isScalable() || inputType.isScalable()) 2254 return failure(); 2255 2256 // Compute the position of first extracted element and flatten/linearize the 2257 // position. 2258 SmallVector<int64_t> firstElementPos = 2259 llvm::to_vector(extractOp.getStaticPosition()); 2260 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0); 2261 int flatIndex = 0; 2262 int stride = 1; 2263 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) { 2264 flatIndex += firstElementPos[i] * stride; 2265 stride *= inputType.getDimSize(i); 2266 } 2267 2268 // Replace the op with a smaller from_elements op. 2269 rewriter.replaceOpWithNewOp<FromElementsOp>( 2270 extractOp, resultType, 2271 fromElementsOp.getElements().slice(flatIndex, 2272 resultType.getNumElements())); 2273 return success(); 2274 } 2275 2276 /// Fold an insert or extract operation into an poison value when a poison index 2277 /// is found at any dimension of the static position. 2278 template <typename OpTy> 2279 LogicalResult 2280 canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) { 2281 if (auto poisonAttr = foldPoisonIndexInsertExtractOp( 2282 op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) { 2283 rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr); 2284 return success(); 2285 } 2286 2287 return failure(); 2288 } 2289 2290 } // namespace 2291 2292 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, 2293 MLIRContext *context) { 2294 results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder, 2295 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context); 2296 results.add(foldExtractFromShapeCastToShapeCast); 2297 results.add(foldExtractFromFromElements); 2298 results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>); 2299 } 2300 2301 static void populateFromInt64AttrArray(ArrayAttr arrayAttr, 2302 SmallVectorImpl<int64_t> &results) { 2303 for (auto attr : arrayAttr) 2304 results.push_back(llvm::cast<IntegerAttr>(attr).getInt()); 2305 } 2306 2307 //===----------------------------------------------------------------------===// 2308 // FmaOp 2309 //===----------------------------------------------------------------------===// 2310 2311 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() { 2312 return llvm::to_vector<4>(getVectorType().getShape()); 2313 } 2314 2315 //===----------------------------------------------------------------------===// 2316 // FromElementsOp 2317 //===----------------------------------------------------------------------===// 2318 2319 /// Rewrite a vector.from_elements into a vector.splat if all elements are the 2320 /// same SSA value. E.g.: 2321 /// 2322 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32> 2323 /// ==> rewrite to vector.splat %a : vector<3xf32> 2324 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, 2325 PatternRewriter &rewriter) { 2326 if (!llvm::all_equal(fromElementsOp.getElements())) 2327 return failure(); 2328 rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(), 2329 fromElementsOp.getElements().front()); 2330 return success(); 2331 } 2332 2333 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, 2334 MLIRContext *context) { 2335 results.add(rewriteFromElementsAsSplat); 2336 } 2337 2338 //===----------------------------------------------------------------------===// 2339 // BroadcastOp 2340 //===----------------------------------------------------------------------===// 2341 2342 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 2343 SetIntRangeFn setResultRanges) { 2344 setResultRanges(getResult(), argRanges.front()); 2345 } 2346 2347 /// Return the dimensions of the result vector that were formerly ones in the 2348 /// source tensor and thus correspond to "dim-1" broadcasting. 2349 static llvm::SetVector<int64_t> 2350 computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape, 2351 ArrayRef<int64_t> dstShape) { 2352 int64_t rankDiff = dstShape.size() - srcShape.size(); 2353 int64_t dstDim = rankDiff; 2354 llvm::SetVector<int64_t> res; 2355 for (auto [s1, s2] : 2356 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) { 2357 if (s1 != s2) { 2358 assert(s1 == 1 && "expected dim-1 broadcasting"); 2359 res.insert(dstDim); 2360 } 2361 ++dstDim; 2362 } 2363 return res; 2364 } 2365 2366 llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() { 2367 // Scalar broadcast is without any unit dim broadcast. 2368 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType()); 2369 if (!srcVectorType) 2370 return {}; 2371 return ::computeBroadcastedUnitDims(srcVectorType.getShape(), 2372 getResultVectorType().getShape()); 2373 } 2374 2375 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the 2376 /// `broadcastedDims` dimensions in the dstShape are broadcasted. 2377 /// This requires (and asserts) that the broadcast is free of dim-1 2378 /// broadcasting. 2379 /// Since vector.broadcast only allows expanding leading dimensions, an extra 2380 /// vector.transpose may be inserted to make the broadcast possible. 2381 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or 2382 /// the helper will assert. This means: 2383 /// 1. `dstShape` must not be empty. 2384 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)] 2385 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims` 2386 // must match the `value` shape. 2387 Value BroadcastOp::createOrFoldBroadcastOp( 2388 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape, 2389 const llvm::SetVector<int64_t> &broadcastedDims) { 2390 assert(!dstShape.empty() && "unexpected empty dst shape"); 2391 2392 // Well-formedness check. 2393 SmallVector<int64_t> checkShape; 2394 for (int i = 0, e = dstShape.size(); i < e; ++i) { 2395 if (broadcastedDims.contains(i)) 2396 continue; 2397 checkShape.push_back(dstShape[i]); 2398 } 2399 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() && 2400 "ill-formed broadcastedDims contains values not confined to " 2401 "destVectorShape"); 2402 2403 Location loc = value.getLoc(); 2404 Type elementType = getElementTypeOrSelf(value.getType()); 2405 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType()); 2406 VectorType dstVectorType = VectorType::get(dstShape, elementType); 2407 2408 // Step 2. If scalar -> dstShape broadcast, just do it. 2409 if (!srcVectorType) { 2410 assert(checkShape.empty() && 2411 "ill-formed createOrFoldBroadcastOp arguments"); 2412 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value); 2413 } 2414 2415 assert(srcVectorType.getShape().equals(checkShape) && 2416 "ill-formed createOrFoldBroadcastOp arguments"); 2417 2418 // Step 3. Since vector.broadcast only allows creating leading dims, 2419 // vector -> dstShape broadcast may require a transpose. 2420 // Traverse the dims in order and construct: 2421 // 1. The leading entries of the broadcastShape that is guaranteed to be 2422 // achievable by a simple broadcast. 2423 // 2. The induced permutation for the subsequent vector.transpose that will 2424 // bring us from `broadcastShape` back to he desired `dstShape`. 2425 // If the induced permutation is not the identity, create a vector.transpose. 2426 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1); 2427 broadcastShape.reserve(dstShape.size()); 2428 // Consider the example: 2429 // srcShape = 2x4 2430 // dstShape = 1x2x3x4x5 2431 // broadcastedDims = [0, 2, 4] 2432 // 2433 // We want to build: 2434 // broadcastShape = 1x3x5x2x4 2435 // permutation = [0, 2, 4, 1, 3] 2436 // ---V--- -----V----- 2437 // leading broadcast part src shape part 2438 // 2439 // Note that the trailing dims of broadcastShape are exactly the srcShape 2440 // by construction. 2441 // nextSrcShapeDim is used to keep track of where in the permutation the 2442 // "src shape part" occurs. 2443 int64_t nextSrcShapeDim = broadcastedDims.size(); 2444 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) { 2445 if (broadcastedDims.contains(i)) { 2446 // 3.a. For each dim in the dst shape, if it is a broadcasted dim, 2447 // bring it to the head of the broadcastShape. 2448 // It will need to be permuted back from `broadcastShape.size() - 1` into 2449 // position `i`. 2450 broadcastShape.push_back(dstShape[i]); 2451 permutation[i] = broadcastShape.size() - 1; 2452 } else { 2453 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src 2454 // shape and needs to be permuted into position `i`. 2455 // Don't touch `broadcastShape` here, the whole srcShape will be 2456 // appended after. 2457 permutation[i] = nextSrcShapeDim++; 2458 } 2459 } 2460 // 3.c. Append the srcShape. 2461 llvm::append_range(broadcastShape, srcVectorType.getShape()); 2462 2463 // Ensure there are no dim-1 broadcasts. 2464 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape) 2465 .empty() && 2466 "unexpected dim-1 broadcast"); 2467 2468 VectorType broadcastType = VectorType::get(broadcastShape, elementType); 2469 assert(vector::isBroadcastableTo(value.getType(), broadcastType) == 2470 vector::BroadcastableToResult::Success && 2471 "must be broadcastable"); 2472 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value); 2473 // Step 4. If we find any dimension that indeed needs to be permuted, 2474 // immediately return a new vector.transpose. 2475 for (int64_t i = 0, e = permutation.size(); i < e; ++i) 2476 if (permutation[i] != i) 2477 return b.createOrFold<vector::TransposeOp>(loc, res, permutation); 2478 // Otherwise return res. 2479 return res; 2480 } 2481 2482 BroadcastableToResult mlir::vector::isBroadcastableTo( 2483 Type srcType, VectorType dstVectorType, 2484 std::pair<VectorDim, VectorDim> *mismatchingDims) { 2485 // Broadcast scalar to vector of the same element type. 2486 if (srcType.isIntOrIndexOrFloat() && dstVectorType && 2487 getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) 2488 return BroadcastableToResult::Success; 2489 // From now on, only vectors broadcast. 2490 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType); 2491 if (!srcVectorType) 2492 return BroadcastableToResult::SourceTypeNotAVector; 2493 2494 int64_t srcRank = srcVectorType.getRank(); 2495 int64_t dstRank = dstVectorType.getRank(); 2496 if (srcRank > dstRank) 2497 return BroadcastableToResult::SourceRankHigher; 2498 // Source has an exact match or singleton value for all trailing dimensions 2499 // (all leading dimensions are simply duplicated). 2500 int64_t lead = dstRank - srcRank; 2501 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) { 2502 // Have mismatching dims (in the sense of vector.broadcast semantics) been 2503 // encountered? 2504 bool foundMismatchingDims = false; 2505 2506 // Check fixed-width dims. 2507 int64_t srcDim = srcVectorType.getDimSize(dimIdx); 2508 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx); 2509 if (srcDim != 1 && srcDim != dstDim) 2510 foundMismatchingDims = true; 2511 2512 // Check scalable flags. 2513 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx]; 2514 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx]; 2515 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) || 2516 // 1 -> [N] is fine, everything else should be rejected when mixing 2517 // fixed-width and scalable dims 2518 (srcDimScalableFlag != dstDimScalableFlag && 2519 (srcDim != 1 || srcDimScalableFlag))) 2520 foundMismatchingDims = true; 2521 2522 if (foundMismatchingDims) { 2523 if (mismatchingDims != nullptr) { 2524 mismatchingDims->first.dim = srcDim; 2525 mismatchingDims->first.isScalable = srcDimScalableFlag; 2526 2527 mismatchingDims->second.dim = dstDim; 2528 mismatchingDims->second.isScalable = dstDimScalableFlag; 2529 } 2530 return BroadcastableToResult::DimensionMismatch; 2531 } 2532 } 2533 2534 return BroadcastableToResult::Success; 2535 } 2536 2537 LogicalResult BroadcastOp::verify() { 2538 std::pair<VectorDim, VectorDim> mismatchingDims; 2539 BroadcastableToResult res = isBroadcastableTo( 2540 getSourceType(), getResultVectorType(), &mismatchingDims); 2541 if (res == BroadcastableToResult::Success) 2542 return success(); 2543 if (res == BroadcastableToResult::SourceRankHigher) 2544 return emitOpError("source rank higher than destination rank"); 2545 if (res == BroadcastableToResult::DimensionMismatch) { 2546 return emitOpError("dimension mismatch (") 2547 << (mismatchingDims.first.isScalable ? "[" : "") 2548 << mismatchingDims.first.dim 2549 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. " 2550 << (mismatchingDims.second.isScalable ? "[" : "") 2551 << mismatchingDims.second.dim 2552 << (mismatchingDims.second.isScalable ? "]" : "") << ")"; 2553 } 2554 if (res == BroadcastableToResult::SourceTypeNotAVector) 2555 return emitOpError("source type is not a vector"); 2556 llvm_unreachable("unexpected vector.broadcast op error"); 2557 } 2558 2559 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { 2560 if (getSourceType() == getResultVectorType()) 2561 return getSource(); 2562 if (!adaptor.getSource()) 2563 return {}; 2564 auto vectorType = getResultVectorType(); 2565 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) { 2566 if (vectorType.getElementType() != attr.getType()) 2567 return {}; 2568 return DenseElementsAttr::get(vectorType, attr); 2569 } 2570 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) { 2571 if (vectorType.getElementType() != attr.getType()) 2572 return {}; 2573 return DenseElementsAttr::get(vectorType, attr); 2574 } 2575 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource())) 2576 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>()); 2577 return {}; 2578 } 2579 2580 namespace { 2581 2582 // Fold broadcast1(broadcast2(x)) into broadcast1(x). 2583 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { 2584 using OpRewritePattern::OpRewritePattern; 2585 2586 LogicalResult matchAndRewrite(BroadcastOp broadcastOp, 2587 PatternRewriter &rewriter) const override { 2588 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>(); 2589 if (!srcBroadcast) 2590 return failure(); 2591 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, 2592 broadcastOp.getResultVectorType(), 2593 srcBroadcast.getSource()); 2594 return success(); 2595 } 2596 }; 2597 } // namespace 2598 2599 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, 2600 MLIRContext *context) { 2601 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by 2602 // calling `populateCastAwayVectorLeadingOneDimPatterns` 2603 results.add<BroadcastFolder>(context); 2604 } 2605 2606 //===----------------------------------------------------------------------===// 2607 // ShuffleOp 2608 //===----------------------------------------------------------------------===// 2609 2610 LogicalResult ShuffleOp::verify() { 2611 VectorType resultType = getResultVectorType(); 2612 VectorType v1Type = getV1VectorType(); 2613 VectorType v2Type = getV2VectorType(); 2614 // Verify ranks. 2615 int64_t resRank = resultType.getRank(); 2616 int64_t v1Rank = v1Type.getRank(); 2617 int64_t v2Rank = v2Type.getRank(); 2618 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1; 2619 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank; 2620 if (!wellFormed0DCase && !wellFormedNDCase) 2621 return emitOpError("rank mismatch"); 2622 2623 // Verify all but leading dimension sizes. 2624 for (int64_t r = 1; r < v1Rank; ++r) { 2625 int64_t resDim = resultType.getDimSize(r); 2626 int64_t v1Dim = v1Type.getDimSize(r); 2627 int64_t v2Dim = v2Type.getDimSize(r); 2628 if (resDim != v1Dim || v1Dim != v2Dim) 2629 return emitOpError("dimension mismatch"); 2630 } 2631 // Verify mask length. 2632 ArrayRef<int64_t> mask = getMask(); 2633 int64_t maskLength = mask.size(); 2634 if (maskLength <= 0) 2635 return emitOpError("invalid mask length"); 2636 if (maskLength != resultType.getDimSize(0)) 2637 return emitOpError("mask length mismatch"); 2638 // Verify all indices. 2639 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + 2640 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); 2641 for (auto [idx, maskPos] : llvm::enumerate(mask)) { 2642 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize)) 2643 return emitOpError("mask index #") << (idx + 1) << " out of range"; 2644 } 2645 return success(); 2646 } 2647 2648 LogicalResult 2649 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>, 2650 ShuffleOp::Adaptor adaptor, 2651 SmallVectorImpl<Type> &inferredReturnTypes) { 2652 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType()); 2653 auto v1Rank = v1Type.getRank(); 2654 // Construct resulting type: leading dimension matches mask 2655 // length, all trailing dimensions match the operands. 2656 SmallVector<int64_t, 4> shape; 2657 shape.reserve(v1Rank); 2658 shape.push_back(std::max<size_t>(1, adaptor.getMask().size())); 2659 // In the 0-D case there is no trailing shape to append. 2660 if (v1Rank > 0) 2661 llvm::append_range(shape, v1Type.getShape().drop_front()); 2662 inferredReturnTypes.push_back( 2663 VectorType::get(shape, v1Type.getElementType())); 2664 return success(); 2665 } 2666 2667 template <typename T> 2668 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) { 2669 T expected = begin; 2670 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) { 2671 return value == expected++; 2672 }); 2673 } 2674 2675 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) { 2676 VectorType v1Type = getV1VectorType(); 2677 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding 2678 // but must be a canonicalization into a vector.broadcast. 2679 if (v1Type.getRank() == 0) 2680 return {}; 2681 2682 // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1 2683 if (!v1Type.isScalable() && 2684 isStepIndexArray(getMask(), 0, v1Type.getDimSize(0))) 2685 return getV1(); 2686 // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2 2687 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() && 2688 isStepIndexArray(getMask(), getV1VectorType().getDimSize(0), 2689 getV2VectorType().getDimSize(0))) 2690 return getV2(); 2691 2692 Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2(); 2693 if (!lhs || !rhs) 2694 return {}; 2695 2696 auto lhsType = 2697 llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType()); 2698 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr 2699 // manipulation. 2700 if (lhsType.getRank() != 1) 2701 return {}; 2702 int64_t lhsSize = lhsType.getDimSize(0); 2703 2704 SmallVector<Attribute> results; 2705 auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>(); 2706 auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>(); 2707 for (int64_t i : this->getMask()) { 2708 if (i >= lhsSize) { 2709 results.push_back(rhsElements[i - lhsSize]); 2710 } else { 2711 results.push_back(lhsElements[i]); 2712 } 2713 } 2714 2715 return DenseElementsAttr::get(getResultVectorType(), results); 2716 } 2717 2718 namespace { 2719 2720 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector 2721 // to a broadcast. 2722 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> { 2723 using OpRewritePattern::OpRewritePattern; 2724 2725 LogicalResult matchAndRewrite(ShuffleOp shuffleOp, 2726 PatternRewriter &rewriter) const override { 2727 VectorType v1VectorType = shuffleOp.getV1VectorType(); 2728 ArrayRef<int64_t> mask = shuffleOp.getMask(); 2729 if (v1VectorType.getRank() > 0) 2730 return failure(); 2731 if (mask.size() != 1) 2732 return failure(); 2733 VectorType resType = VectorType::Builder(v1VectorType).setShape({1}); 2734 if (mask[0] == 0) 2735 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType, 2736 shuffleOp.getV1()); 2737 else 2738 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType, 2739 shuffleOp.getV2()); 2740 return success(); 2741 } 2742 }; 2743 2744 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp. 2745 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> { 2746 public: 2747 using OpRewritePattern::OpRewritePattern; 2748 2749 LogicalResult matchAndRewrite(ShuffleOp op, 2750 PatternRewriter &rewriter) const override { 2751 auto v1Splat = op.getV1().getDefiningOp<SplatOp>(); 2752 auto v2Splat = op.getV2().getDefiningOp<SplatOp>(); 2753 2754 if (!v1Splat || !v2Splat) 2755 return failure(); 2756 2757 if (v1Splat.getInput() != v2Splat.getInput()) 2758 return failure(); 2759 2760 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput()); 2761 return success(); 2762 } 2763 }; 2764 2765 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to 2766 /// vector.interleave. 2767 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> { 2768 public: 2769 using OpRewritePattern::OpRewritePattern; 2770 2771 LogicalResult matchAndRewrite(ShuffleOp op, 2772 PatternRewriter &rewriter) const override { 2773 VectorType resultType = op.getResultVectorType(); 2774 if (resultType.isScalable()) 2775 return rewriter.notifyMatchFailure( 2776 op, "ShuffleOp can't represent a scalable interleave"); 2777 2778 if (resultType.getRank() != 1) 2779 return rewriter.notifyMatchFailure( 2780 op, "ShuffleOp can't represent an n-D interleave"); 2781 2782 VectorType sourceType = op.getV1VectorType(); 2783 if (sourceType != op.getV2VectorType() || 2784 sourceType.getNumElements() * 2 != resultType.getNumElements()) { 2785 return rewriter.notifyMatchFailure( 2786 op, "ShuffleOp types don't match an interleave"); 2787 } 2788 2789 ArrayRef<int64_t> shuffleMask = op.getMask(); 2790 int64_t resultVectorSize = resultType.getNumElements(); 2791 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) { 2792 int64_t maskValueA = shuffleMask[i * 2]; 2793 int64_t maskValueB = shuffleMask[(i * 2) + 1]; 2794 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i) 2795 return rewriter.notifyMatchFailure(op, 2796 "ShuffleOp mask not interleaving"); 2797 } 2798 2799 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2()); 2800 return success(); 2801 } 2802 }; 2803 2804 } // namespace 2805 2806 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, 2807 MLIRContext *context) { 2808 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>( 2809 context); 2810 } 2811 2812 //===----------------------------------------------------------------------===// 2813 // InsertElementOp 2814 //===----------------------------------------------------------------------===// 2815 2816 void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 2817 SetIntRangeFn setResultRanges) { 2818 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1])); 2819 } 2820 2821 void InsertElementOp::build(OpBuilder &builder, OperationState &result, 2822 Value source, Value dest) { 2823 build(builder, result, source, dest, {}); 2824 } 2825 2826 LogicalResult InsertElementOp::verify() { 2827 auto dstVectorType = getDestVectorType(); 2828 if (dstVectorType.getRank() == 0) { 2829 if (getPosition()) 2830 return emitOpError("expected position to be empty with 0-D vector"); 2831 return success(); 2832 } 2833 if (dstVectorType.getRank() != 1) 2834 return emitOpError("unexpected >1 vector rank"); 2835 if (!getPosition()) 2836 return emitOpError("expected position for 1-D vector"); 2837 return success(); 2838 } 2839 2840 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) { 2841 // Skip the 0-D vector here. 2842 if (!adaptor.getPosition()) 2843 return {}; 2844 2845 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource()); 2846 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest()); 2847 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); 2848 if (!src || !dst || !pos) 2849 return {}; 2850 2851 if (src.getType() != getDestVectorType().getElementType()) 2852 return {}; 2853 2854 auto dstElements = dst.getValues<Attribute>(); 2855 2856 SmallVector<Attribute> results(dstElements); 2857 2858 uint64_t posIdx = pos.getInt(); 2859 if (posIdx >= results.size()) 2860 return {}; 2861 results[posIdx] = src; 2862 2863 return DenseElementsAttr::get(getDestVectorType(), results); 2864 } 2865 2866 //===----------------------------------------------------------------------===// 2867 // InsertOp 2868 //===----------------------------------------------------------------------===// 2869 2870 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 2871 SetIntRangeFn setResultRanges) { 2872 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1])); 2873 } 2874 2875 void vector::InsertOp::build(OpBuilder &builder, OperationState &result, 2876 Value source, Value dest, int64_t position) { 2877 build(builder, result, source, dest, ArrayRef<int64_t>{position}); 2878 } 2879 2880 void vector::InsertOp::build(OpBuilder &builder, OperationState &result, 2881 Value source, Value dest, OpFoldResult position) { 2882 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position}); 2883 } 2884 2885 void vector::InsertOp::build(OpBuilder &builder, OperationState &result, 2886 Value source, Value dest, 2887 ArrayRef<int64_t> position) { 2888 SmallVector<OpFoldResult> posVals; 2889 posVals.reserve(position.size()); 2890 llvm::transform(position, std::back_inserter(posVals), 2891 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); }); 2892 build(builder, result, source, dest, posVals); 2893 } 2894 2895 void vector::InsertOp::build(OpBuilder &builder, OperationState &result, 2896 Value source, Value dest, 2897 ArrayRef<OpFoldResult> position) { 2898 SmallVector<int64_t> staticPos; 2899 SmallVector<Value> dynamicPos; 2900 dispatchIndexOpFoldResults(position, dynamicPos, staticPos); 2901 build(builder, result, source, dest, dynamicPos, 2902 builder.getDenseI64ArrayAttr(staticPos)); 2903 } 2904 2905 LogicalResult InsertOp::verify() { 2906 SmallVector<OpFoldResult> position = getMixedPosition(); 2907 auto destVectorType = getDestVectorType(); 2908 if (position.size() > static_cast<unsigned>(destVectorType.getRank())) 2909 return emitOpError( 2910 "expected position attribute of rank no greater than dest vector rank"); 2911 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType()); 2912 if (srcVectorType && 2913 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() != 2914 static_cast<unsigned>(destVectorType.getRank()))) 2915 return emitOpError("expected position attribute rank + source rank to " 2916 "match dest vector rank"); 2917 if (!srcVectorType && 2918 (position.size() != static_cast<unsigned>(destVectorType.getRank()))) 2919 return emitOpError( 2920 "expected position attribute rank to match the dest vector rank"); 2921 for (auto [idx, pos] : llvm::enumerate(position)) { 2922 if (auto attr = pos.dyn_cast<Attribute>()) { 2923 int64_t constIdx = cast<IntegerAttr>(attr).getInt(); 2924 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex, 2925 destVectorType.getDimSize(idx))) { 2926 return emitOpError("expected position attribute #") 2927 << (idx + 1) 2928 << " to be a non-negative integer smaller than the " 2929 "corresponding " 2930 "dest vector dimension"; 2931 } 2932 } 2933 } 2934 return success(); 2935 } 2936 2937 namespace { 2938 2939 // If insertOp is only inserting unit dimensions it can be transformed to a 2940 // broadcast. 2941 class InsertToBroadcast final : public OpRewritePattern<InsertOp> { 2942 public: 2943 using OpRewritePattern::OpRewritePattern; 2944 2945 LogicalResult matchAndRewrite(InsertOp insertOp, 2946 PatternRewriter &rewriter) const override { 2947 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType()); 2948 if (!srcVecType || insertOp.getDestVectorType().getNumElements() != 2949 srcVecType.getNumElements()) 2950 return failure(); 2951 rewriter.replaceOpWithNewOp<BroadcastOp>( 2952 insertOp, insertOp.getDestVectorType(), insertOp.getSource()); 2953 return success(); 2954 } 2955 }; 2956 2957 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp. 2958 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> { 2959 public: 2960 using OpRewritePattern::OpRewritePattern; 2961 2962 LogicalResult matchAndRewrite(InsertOp op, 2963 PatternRewriter &rewriter) const override { 2964 auto srcSplat = op.getSource().getDefiningOp<SplatOp>(); 2965 auto dstSplat = op.getDest().getDefiningOp<SplatOp>(); 2966 2967 if (!srcSplat || !dstSplat) 2968 return failure(); 2969 2970 if (srcSplat.getInput() != dstSplat.getInput()) 2971 return failure(); 2972 2973 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput()); 2974 return success(); 2975 } 2976 }; 2977 2978 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp. 2979 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> { 2980 public: 2981 using OpRewritePattern::OpRewritePattern; 2982 2983 // Do not create constants with more than `vectorSizeFoldThreashold` elements, 2984 // unless the source vector constant has a single use. 2985 static constexpr int64_t vectorSizeFoldThreshold = 256; 2986 2987 LogicalResult matchAndRewrite(InsertOp op, 2988 PatternRewriter &rewriter) const override { 2989 // TODO: Canonicalization for dynamic position not implemented yet. 2990 if (op.hasDynamicPosition()) 2991 return failure(); 2992 2993 // Return if 'InsertOp' operand is not defined by a compatible vector 2994 // ConstantOp. 2995 TypedValue<VectorType> destVector = op.getDest(); 2996 Attribute vectorDestCst; 2997 if (!matchPattern(destVector, m_Constant(&vectorDestCst))) 2998 return failure(); 2999 auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst); 3000 if (!denseDest) 3001 return failure(); 3002 3003 VectorType destTy = destVector.getType(); 3004 if (destTy.isScalable()) 3005 return failure(); 3006 3007 // Make sure we do not create too many large constants. 3008 if (destTy.getNumElements() > vectorSizeFoldThreshold && 3009 !destVector.hasOneUse()) 3010 return failure(); 3011 3012 Value sourceValue = op.getSource(); 3013 Attribute sourceCst; 3014 if (!matchPattern(sourceValue, m_Constant(&sourceCst))) 3015 return failure(); 3016 3017 // Calculate the linearized position of the continuous chunk of elements to 3018 // insert. 3019 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0); 3020 copy(op.getStaticPosition(), completePositions.begin()); 3021 int64_t insertBeginPosition = 3022 linearize(completePositions, computeStrides(destTy.getShape())); 3023 3024 SmallVector<Attribute> insertedValues; 3025 Type destEltType = destTy.getElementType(); 3026 3027 // The `convertIntegerAttr` method specifically handles the case 3028 // for `llvm.mlir.constant` which can hold an attribute with a 3029 // different type than the return type. 3030 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) { 3031 for (auto value : denseSource.getValues<Attribute>()) 3032 insertedValues.push_back(convertIntegerAttr(value, destEltType)); 3033 } else { 3034 insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType)); 3035 } 3036 3037 auto allValues = llvm::to_vector(denseDest.getValues<Attribute>()); 3038 copy(insertedValues, allValues.begin() + insertBeginPosition); 3039 auto newAttr = DenseElementsAttr::get(destTy, allValues); 3040 3041 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr); 3042 return success(); 3043 } 3044 3045 private: 3046 /// Converts the expected type to an IntegerAttr if there's 3047 /// a mismatch. 3048 Attribute convertIntegerAttr(Attribute attr, Type expectedType) const { 3049 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) { 3050 if (intAttr.getType() != expectedType) 3051 return IntegerAttr::get(expectedType, intAttr.getInt()); 3052 } 3053 return attr; 3054 } 3055 }; 3056 3057 } // namespace 3058 3059 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, 3060 MLIRContext *context) { 3061 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat, 3062 InsertOpConstantFolder>(context); 3063 results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>); 3064 } 3065 3066 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { 3067 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to 3068 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>" 3069 // (type mismatch). 3070 if (getNumIndices() == 0 && getSourceType() == getType()) 3071 return getSource(); 3072 if (auto res = foldPoisonIndexInsertExtractOp( 3073 getContext(), adaptor.getStaticPosition(), kPoisonIndex)) 3074 return res; 3075 3076 return {}; 3077 } 3078 3079 //===----------------------------------------------------------------------===// 3080 // InsertStridedSliceOp 3081 //===----------------------------------------------------------------------===// 3082 3083 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result, 3084 Value source, Value dest, 3085 ArrayRef<int64_t> offsets, 3086 ArrayRef<int64_t> strides) { 3087 result.addOperands({source, dest}); 3088 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); 3089 auto stridesAttr = getVectorSubscriptAttr(builder, strides); 3090 result.addTypes(dest.getType()); 3091 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name), 3092 offsetsAttr); 3093 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name), 3094 stridesAttr); 3095 } 3096 3097 // TODO: Should be moved to Tablegen ConfinedAttr attributes. 3098 template <typename OpType> 3099 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, 3100 ArrayAttr arrayAttr, 3101 ArrayRef<int64_t> shape, 3102 StringRef attrName) { 3103 if (arrayAttr.size() > shape.size()) 3104 return op.emitOpError("expected ") 3105 << attrName << " attribute of rank no greater than vector rank"; 3106 return success(); 3107 } 3108 3109 // Returns true if all integers in `arrayAttr` are in the half-open [min, max} 3110 // interval. If `halfOpen` is true then the admissible interval is [min, max). 3111 // Otherwise, the admissible interval is [min, max]. 3112 template <typename OpType> 3113 static LogicalResult 3114 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, 3115 int64_t max, StringRef attrName, 3116 bool halfOpen = true) { 3117 for (auto attr : arrayAttr) { 3118 auto val = llvm::cast<IntegerAttr>(attr).getInt(); 3119 auto upper = max; 3120 if (!halfOpen) 3121 upper += 1; 3122 if (val < min || val >= upper) 3123 return op.emitOpError("expected ") << attrName << " to be confined to [" 3124 << min << ", " << upper << ")"; 3125 } 3126 return success(); 3127 } 3128 3129 // Returns true if all integers in `arrayAttr` are in the half-open [min, max} 3130 // interval. If `halfOpen` is true then the admissible interval is [min, max). 3131 // Otherwise, the admissible interval is [min, max]. 3132 template <typename OpType> 3133 static LogicalResult 3134 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, 3135 ArrayRef<int64_t> shape, StringRef attrName, 3136 bool halfOpen = true, int64_t min = 0) { 3137 for (auto [index, attrDimPair] : 3138 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) { 3139 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt(); 3140 int64_t max = std::get<1>(attrDimPair); 3141 if (!halfOpen) 3142 max += 1; 3143 if (val < min || val >= max) 3144 return op.emitOpError("expected ") 3145 << attrName << " dimension " << index << " to be confined to [" 3146 << min << ", " << max << ")"; 3147 } 3148 return success(); 3149 } 3150 3151 // Returns true if, for all indices i = 0..shape.size()-1, val is in the 3152 // [min, max} interval: 3153 // val = `arrayAttr1[i]` + `arrayAttr2[i]`, 3154 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise, 3155 // the admissible interval is [min, max]. 3156 template <typename OpType> 3157 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape( 3158 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, 3159 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2, 3160 bool halfOpen = true, int64_t min = 1) { 3161 assert(arrayAttr1.size() <= shape.size()); 3162 assert(arrayAttr2.size() <= shape.size()); 3163 for (auto [index, it] : 3164 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) { 3165 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt(); 3166 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt(); 3167 int64_t max = std::get<2>(it); 3168 if (!halfOpen) 3169 max += 1; 3170 if (val1 + val2 < 0 || val1 + val2 >= max) 3171 return op.emitOpError("expected sum(") 3172 << attrName1 << ", " << attrName2 << ") dimension " << index 3173 << " to be confined to [" << min << ", " << max << ")"; 3174 } 3175 return success(); 3176 } 3177 3178 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values, 3179 MLIRContext *context) { 3180 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { 3181 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v)); 3182 }); 3183 return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); 3184 } 3185 3186 LogicalResult InsertStridedSliceOp::verify() { 3187 auto sourceVectorType = getSourceVectorType(); 3188 auto destVectorType = getDestVectorType(); 3189 auto offsets = getOffsetsAttr(); 3190 auto strides = getStridesAttr(); 3191 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank())) 3192 return emitOpError( 3193 "expected offsets of same size as destination vector rank"); 3194 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank())) 3195 return emitOpError("expected strides of same size as source vector rank"); 3196 if (sourceVectorType.getRank() > destVectorType.getRank()) 3197 return emitOpError( 3198 "expected source rank to be no greater than destination rank"); 3199 3200 auto sourceShape = sourceVectorType.getShape(); 3201 auto destShape = destVectorType.getShape(); 3202 SmallVector<int64_t, 4> sourceShapeAsDestShape( 3203 destShape.size() - sourceShape.size(), 0); 3204 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end()); 3205 auto offName = InsertStridedSliceOp::getOffsetsAttrName(); 3206 auto stridesName = InsertStridedSliceOp::getStridesAttrName(); 3207 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape, 3208 offName)) || 3209 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1, 3210 /*max=*/1, stridesName, 3211 /*halfOpen=*/false)) || 3212 failed(isSumOfIntegerArrayAttrConfinedToShape( 3213 *this, offsets, 3214 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape, 3215 offName, "source vector shape", 3216 /*halfOpen=*/false, /*min=*/1))) 3217 return failure(); 3218 3219 unsigned rankDiff = destShape.size() - sourceShape.size(); 3220 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) { 3221 if (sourceVectorType.getScalableDims()[idx] != 3222 destVectorType.getScalableDims()[idx + rankDiff]) { 3223 return emitOpError("mismatching scalable flags (at source vector idx=") 3224 << idx << ")"; 3225 } 3226 if (sourceVectorType.getScalableDims()[idx]) { 3227 auto sourceSize = sourceShape[idx]; 3228 auto destSize = destShape[idx + rankDiff]; 3229 if (sourceSize != destSize) { 3230 return emitOpError("expected size at idx=") 3231 << idx 3232 << (" to match the corresponding base size from the input " 3233 "vector (") 3234 << sourceSize << (" vs ") << destSize << (")"); 3235 } 3236 } 3237 } 3238 3239 return success(); 3240 } 3241 3242 namespace { 3243 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type, 3244 /// SplatOp(X):dst_type) to SplatOp(X):dst_type. 3245 class FoldInsertStridedSliceSplat final 3246 : public OpRewritePattern<InsertStridedSliceOp> { 3247 public: 3248 using OpRewritePattern::OpRewritePattern; 3249 3250 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, 3251 PatternRewriter &rewriter) const override { 3252 auto srcSplatOp = 3253 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>(); 3254 auto destSplatOp = 3255 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>(); 3256 3257 if (!srcSplatOp || !destSplatOp) 3258 return failure(); 3259 3260 if (srcSplatOp.getInput() != destSplatOp.getInput()) 3261 return failure(); 3262 3263 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest()); 3264 return success(); 3265 } 3266 }; 3267 3268 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst) 3269 /// to dst. 3270 class FoldInsertStridedSliceOfExtract final 3271 : public OpRewritePattern<InsertStridedSliceOp> { 3272 public: 3273 using OpRewritePattern::OpRewritePattern; 3274 3275 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, 3276 PatternRewriter &rewriter) const override { 3277 auto extractStridedSliceOp = 3278 insertStridedSliceOp.getSource() 3279 .getDefiningOp<vector::ExtractStridedSliceOp>(); 3280 3281 if (!extractStridedSliceOp) 3282 return failure(); 3283 3284 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest()) 3285 return failure(); 3286 3287 // Check if have the same strides and offsets. 3288 if (extractStridedSliceOp.getStrides() != 3289 insertStridedSliceOp.getStrides() || 3290 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets()) 3291 return failure(); 3292 3293 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest()); 3294 return success(); 3295 } 3296 }; 3297 3298 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) -> 3299 // ConstantOp. 3300 class InsertStridedSliceConstantFolder final 3301 : public OpRewritePattern<InsertStridedSliceOp> { 3302 public: 3303 using OpRewritePattern::OpRewritePattern; 3304 3305 // Do not create constants with more than `vectorSizeFoldThreashold` elements, 3306 // unless the source vector constant has a single use. 3307 static constexpr int64_t vectorSizeFoldThreshold = 256; 3308 3309 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 3310 PatternRewriter &rewriter) const override { 3311 // Return if 'InsertOp' operand is not defined by a compatible vector 3312 // ConstantOp. 3313 TypedValue<VectorType> destVector = op.getDest(); 3314 Attribute vectorDestCst; 3315 if (!matchPattern(destVector, m_Constant(&vectorDestCst))) 3316 return failure(); 3317 3318 VectorType destTy = destVector.getType(); 3319 if (destTy.isScalable()) 3320 return failure(); 3321 3322 // Make sure we do not create too many large constants. 3323 if (destTy.getNumElements() > vectorSizeFoldThreshold && 3324 !destVector.hasOneUse()) 3325 return failure(); 3326 3327 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst); 3328 3329 TypedValue<VectorType> sourceValue = op.getSource(); 3330 Attribute sourceCst; 3331 if (!matchPattern(sourceValue, m_Constant(&sourceCst))) 3332 return failure(); 3333 3334 // TODO: Handle non-unit strides when they become available. 3335 if (op.hasNonUnitStrides()) 3336 return failure(); 3337 3338 VectorType sliceVecTy = sourceValue.getType(); 3339 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape(); 3340 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank(); 3341 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets()); 3342 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape()); 3343 3344 // Calcualte the destination element indices by enumerating all slice 3345 // positions within the destination and linearizing them. The enumeration 3346 // order is lexicographic which yields a sequence of monotonically 3347 // increasing linearized position indices. 3348 // Because the destination may have higher dimensionality then the slice, 3349 // we keep track of two overlapping sets of positions and offsets. 3350 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst); 3351 auto sliceValuesIt = denseSlice.value_begin<Attribute>(); 3352 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>()); 3353 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end()); 3354 MutableArrayRef<int64_t> currSlicePosition( 3355 currDestPosition.begin() + rankDifference, currDestPosition.end()); 3356 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference, 3357 offsets.end()); 3358 do { 3359 int64_t linearizedPosition = linearize(currDestPosition, destStrides); 3360 assert(linearizedPosition < destTy.getNumElements() && "Invalid index"); 3361 assert(sliceValuesIt != denseSlice.value_end<Attribute>() && 3362 "Invalid slice element"); 3363 newValues[linearizedPosition] = *sliceValuesIt; 3364 ++sliceValuesIt; 3365 } while (succeeded( 3366 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets))); 3367 3368 auto newAttr = DenseElementsAttr::get(destTy, newValues); 3369 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr); 3370 return success(); 3371 } 3372 }; 3373 3374 } // namespace 3375 3376 void vector::InsertStridedSliceOp::getCanonicalizationPatterns( 3377 RewritePatternSet &results, MLIRContext *context) { 3378 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract, 3379 InsertStridedSliceConstantFolder>(context); 3380 } 3381 3382 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) { 3383 if (getSourceVectorType() == getDestVectorType()) 3384 return getSource(); 3385 return {}; 3386 } 3387 3388 //===----------------------------------------------------------------------===// 3389 // OuterProductOp 3390 //===----------------------------------------------------------------------===// 3391 3392 /// Build an op without mask, use the type of `acc` as the return type. 3393 void OuterProductOp::build(OpBuilder &builder, OperationState &result, 3394 Value lhs, Value rhs, Value acc) { 3395 result.addOperands({lhs, rhs, acc}); 3396 result.addTypes(acc.getType()); 3397 } 3398 3399 void OuterProductOp::print(OpAsmPrinter &p) { 3400 p << " " << getLhs() << ", " << getRhs(); 3401 if (getAcc()) { 3402 p << ", " << getAcc(); 3403 p.printOptionalAttrDict((*this)->getAttrs()); 3404 } 3405 p << " : " << getLhs().getType() << ", " << getRhs().getType(); 3406 } 3407 3408 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) { 3409 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo; 3410 Type tLHS, tRHS; 3411 if (parser.parseOperandList(operandsInfo) || 3412 parser.parseOptionalAttrDict(result.attributes) || 3413 parser.parseColonType(tLHS) || parser.parseComma() || 3414 parser.parseType(tRHS)) 3415 return failure(); 3416 if (operandsInfo.size() < 2) 3417 return parser.emitError(parser.getNameLoc(), 3418 "expected at least 2 operands"); 3419 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS); 3420 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS); 3421 if (!vLHS) 3422 return parser.emitError(parser.getNameLoc(), 3423 "expected vector type for operand #1"); 3424 3425 VectorType resType; 3426 if (vRHS) { 3427 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0], 3428 vRHS.getScalableDims()[0]}; 3429 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, 3430 vLHS.getElementType(), scalableDimsRes); 3431 } else { 3432 // Scalar RHS operand 3433 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]}; 3434 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(), 3435 scalableDimsRes); 3436 } 3437 3438 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) { 3439 result.attributes.append( 3440 OuterProductOp::getKindAttrName(result.name), 3441 CombiningKindAttr::get(result.getContext(), 3442 OuterProductOp::getDefaultKind())); 3443 } 3444 3445 return failure( 3446 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || 3447 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || 3448 (operandsInfo.size() > 2 && 3449 parser.resolveOperand(operandsInfo[2], resType, result.operands)) || 3450 parser.addTypeToList(resType, result.types)); 3451 } 3452 3453 LogicalResult OuterProductOp::verify() { 3454 Type tRHS = getOperandTypeRHS(); 3455 VectorType vLHS = getOperandVectorTypeLHS(), 3456 vRHS = llvm::dyn_cast<VectorType>(tRHS), 3457 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType(); 3458 3459 if (vLHS.getRank() != 1) 3460 return emitOpError("expected 1-d vector for operand #1"); 3461 3462 if (vRHS) { 3463 // Proper OUTER operation. 3464 if (vRHS.getRank() != 1) 3465 return emitOpError("expected 1-d vector for operand #2"); 3466 if (vRES.getRank() != 2) 3467 return emitOpError("expected 2-d vector result"); 3468 if (vLHS.getDimSize(0) != vRES.getDimSize(0)) 3469 return emitOpError("expected #1 operand dim to match result dim #1"); 3470 if (vRHS.getDimSize(0) != vRES.getDimSize(1)) 3471 return emitOpError("expected #2 operand dim to match result dim #2"); 3472 if (vLHS.isScalable() && !vRHS.isScalable()) { 3473 // This restriction reflects what's currently supported in terms of 3474 // scalable vectors. However, we could relax this if there's a use case. 3475 return emitOpError( 3476 "expected either both or only #2 operand dim to be scalable"); 3477 } 3478 } else { 3479 // An AXPY operation. 3480 if (vRES.getRank() != 1) 3481 return emitOpError("expected 1-d vector result"); 3482 if (vLHS.getDimSize(0) != vRES.getDimSize(0)) 3483 return emitOpError("expected #1 operand dim to match result dim #1"); 3484 } 3485 3486 if (vACC && vACC != vRES) 3487 return emitOpError("expected operand #3 of same type as result type"); 3488 3489 // Verify supported combining kind. 3490 if (!isSupportedCombiningKind(getKind(), vRES.getElementType())) 3491 return emitOpError("unsupported outerproduct type"); 3492 3493 return success(); 3494 } 3495 3496 // MaskableOpInterface methods. 3497 3498 /// Returns the mask type expected by this operation. Mostly used for 3499 /// verification purposes. It requires the operation to be vectorized." 3500 Type OuterProductOp::getExpectedMaskType() { 3501 auto vecType = this->getResultVectorType(); 3502 return VectorType::get(vecType.getShape(), 3503 IntegerType::get(vecType.getContext(), /*width=*/1), 3504 vecType.getScalableDims()); 3505 } 3506 3507 //===----------------------------------------------------------------------===// 3508 // ExtractStridedSliceOp 3509 //===----------------------------------------------------------------------===// 3510 3511 // Inference works as follows: 3512 // 1. Add 'sizes' from prefix of dims in 'offsets'. 3513 // 2. Add sizes from 'vectorType' for remaining dims. 3514 // Scalable flags are inherited from 'vectorType'. 3515 static Type inferStridedSliceOpResultType(VectorType vectorType, 3516 ArrayAttr offsets, ArrayAttr sizes, 3517 ArrayAttr strides) { 3518 assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); 3519 SmallVector<int64_t, 4> shape; 3520 shape.reserve(vectorType.getRank()); 3521 unsigned idx = 0; 3522 for (unsigned e = offsets.size(); idx < e; ++idx) 3523 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt()); 3524 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx) 3525 shape.push_back(vectorType.getShape()[idx]); 3526 3527 return VectorType::get(shape, vectorType.getElementType(), 3528 vectorType.getScalableDims()); 3529 } 3530 3531 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, 3532 Value source, ArrayRef<int64_t> offsets, 3533 ArrayRef<int64_t> sizes, 3534 ArrayRef<int64_t> strides) { 3535 result.addOperands(source); 3536 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); 3537 auto sizesAttr = getVectorSubscriptAttr(builder, sizes); 3538 auto stridesAttr = getVectorSubscriptAttr(builder, strides); 3539 result.addTypes( 3540 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()), 3541 offsetsAttr, sizesAttr, stridesAttr)); 3542 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name), 3543 offsetsAttr); 3544 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name), 3545 sizesAttr); 3546 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name), 3547 stridesAttr); 3548 } 3549 3550 LogicalResult ExtractStridedSliceOp::verify() { 3551 auto type = getSourceVectorType(); 3552 auto offsets = getOffsetsAttr(); 3553 auto sizes = getSizesAttr(); 3554 auto strides = getStridesAttr(); 3555 if (offsets.size() != sizes.size() || offsets.size() != strides.size()) 3556 return emitOpError( 3557 "expected offsets, sizes and strides attributes of same size"); 3558 3559 auto shape = type.getShape(); 3560 auto offName = getOffsetsAttrName(); 3561 auto sizesName = getSizesAttrName(); 3562 auto stridesName = getStridesAttrName(); 3563 if (failed( 3564 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || 3565 failed( 3566 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || 3567 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape, 3568 stridesName)) || 3569 failed( 3570 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || 3571 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, 3572 /*halfOpen=*/false, 3573 /*min=*/1)) || 3574 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1, 3575 /*max=*/1, stridesName, 3576 /*halfOpen=*/false)) || 3577 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, 3578 shape, offName, sizesName, 3579 /*halfOpen=*/false))) 3580 return failure(); 3581 3582 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(), 3583 offsets, sizes, strides); 3584 if (getResult().getType() != resultType) 3585 return emitOpError("expected result type to be ") << resultType; 3586 3587 for (unsigned idx = 0; idx < sizes.size(); ++idx) { 3588 if (type.getScalableDims()[idx]) { 3589 auto inputDim = type.getShape()[idx]; 3590 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt(); 3591 if (inputDim != inputSize) 3592 return emitOpError("expected size at idx=") 3593 << idx 3594 << (" to match the corresponding base size from the input " 3595 "vector (") 3596 << inputSize << (" vs ") << inputDim << (")"); 3597 } 3598 } 3599 3600 return success(); 3601 } 3602 3603 // When the source of ExtractStrided comes from a chain of InsertStrided ops try 3604 // to use the source of the InsertStrided ops if we can detect that the 3605 // extracted vector is a subset of one of the vector inserted. 3606 static LogicalResult 3607 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { 3608 // Helper to extract integer out of ArrayAttr. 3609 auto getElement = [](ArrayAttr array, int idx) { 3610 return llvm::cast<IntegerAttr>(array[idx]).getInt(); 3611 }; 3612 ArrayAttr extractOffsets = op.getOffsets(); 3613 ArrayAttr extractStrides = op.getStrides(); 3614 ArrayAttr extractSizes = op.getSizes(); 3615 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>(); 3616 while (insertOp) { 3617 if (op.getSourceVectorType().getRank() != 3618 insertOp.getSourceVectorType().getRank()) 3619 return failure(); 3620 ArrayAttr insertOffsets = insertOp.getOffsets(); 3621 ArrayAttr insertStrides = insertOp.getStrides(); 3622 // If the rank of extract is greater than the rank of insert, we are likely 3623 // extracting a partial chunk of the vector inserted. 3624 if (extractOffsets.size() > insertOffsets.size()) 3625 return failure(); 3626 bool patialoverlap = false; 3627 bool disjoint = false; 3628 SmallVector<int64_t, 4> offsetDiffs; 3629 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { 3630 if (getElement(extractStrides, dim) != getElement(insertStrides, dim)) 3631 return failure(); 3632 int64_t start = getElement(insertOffsets, dim); 3633 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim); 3634 int64_t offset = getElement(extractOffsets, dim); 3635 int64_t size = getElement(extractSizes, dim); 3636 // Check if the start of the extract offset is in the interval inserted. 3637 if (start <= offset && offset < end) { 3638 // If the extract interval overlaps but is not fully included we may 3639 // have a partial overlap that will prevent any folding. 3640 if (offset + size > end) 3641 patialoverlap = true; 3642 offsetDiffs.push_back(offset - start); 3643 continue; 3644 } 3645 disjoint = true; 3646 break; 3647 } 3648 // The extract element chunk is a subset of the insert element. 3649 if (!disjoint && !patialoverlap) { 3650 op.setOperand(insertOp.getSource()); 3651 // OpBuilder is only used as a helper to build an I64ArrayAttr. 3652 OpBuilder b(op.getContext()); 3653 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs)); 3654 return success(); 3655 } 3656 // If the chunk extracted is disjoint from the chunk inserted, keep looking 3657 // in the insert chain. 3658 if (disjoint) 3659 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>(); 3660 else { 3661 // The extracted vector partially overlap the inserted vector, we cannot 3662 // fold. 3663 return failure(); 3664 } 3665 } 3666 return failure(); 3667 } 3668 3669 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { 3670 if (getSourceVectorType() == getResult().getType()) 3671 return getVector(); 3672 if (succeeded(foldExtractStridedOpFromInsertChain(*this))) 3673 return getResult(); 3674 return {}; 3675 } 3676 3677 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) { 3678 populateFromInt64AttrArray(getOffsets(), results); 3679 } 3680 3681 namespace { 3682 3683 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to 3684 // ConstantMaskOp. 3685 class StridedSliceConstantMaskFolder final 3686 : public OpRewritePattern<ExtractStridedSliceOp> { 3687 public: 3688 using OpRewritePattern::OpRewritePattern; 3689 3690 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, 3691 PatternRewriter &rewriter) const override { 3692 // Return if 'extractStridedSliceOp' operand is not defined by a 3693 // ConstantMaskOp. 3694 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp(); 3695 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp); 3696 if (!constantMaskOp) 3697 return failure(); 3698 // Return if 'extractStridedSliceOp' has non-unit strides. 3699 if (extractStridedSliceOp.hasNonUnitStrides()) 3700 return failure(); 3701 // Gather constant mask dimension sizes. 3702 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes(); 3703 // Gather strided slice offsets and sizes. 3704 SmallVector<int64_t, 4> sliceOffsets; 3705 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(), 3706 sliceOffsets); 3707 SmallVector<int64_t, 4> sliceSizes; 3708 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes); 3709 3710 // Compute slice of vector mask region. 3711 SmallVector<int64_t, 4> sliceMaskDimSizes; 3712 sliceMaskDimSizes.reserve(maskDimSizes.size()); 3713 for (auto [maskDimSize, sliceOffset, sliceSize] : 3714 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { 3715 int64_t sliceMaskDimSize = std::max( 3716 static_cast<int64_t>(0), 3717 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset); 3718 sliceMaskDimSizes.push_back(sliceMaskDimSize); 3719 } 3720 // Add unchanged dimensions. 3721 if (sliceMaskDimSizes.size() < maskDimSizes.size()) 3722 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) 3723 sliceMaskDimSizes.push_back(maskDimSizes[i]); 3724 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked 3725 // region is a conjunction of mask dim intervals). 3726 if (llvm::is_contained(sliceMaskDimSizes, 0)) 3727 sliceMaskDimSizes.assign(maskDimSizes.size(), 0); 3728 3729 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask 3730 // region. 3731 rewriter.replaceOpWithNewOp<ConstantMaskOp>( 3732 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(), 3733 sliceMaskDimSizes); 3734 return success(); 3735 } 3736 }; 3737 3738 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. 3739 class StridedSliceSplatConstantFolder final 3740 : public OpRewritePattern<ExtractStridedSliceOp> { 3741 public: 3742 using OpRewritePattern::OpRewritePattern; 3743 3744 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, 3745 PatternRewriter &rewriter) const override { 3746 // Return if 'ExtractStridedSliceOp' operand is not defined by a splat 3747 // ConstantOp. 3748 Value sourceVector = extractStridedSliceOp.getVector(); 3749 Attribute vectorCst; 3750 if (!matchPattern(sourceVector, m_Constant(&vectorCst))) 3751 return failure(); 3752 3753 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst); 3754 if (!splat) 3755 return failure(); 3756 3757 auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(), 3758 splat.getSplatValue<Attribute>()); 3759 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp, 3760 newAttr); 3761 return success(); 3762 } 3763 }; 3764 3765 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) -> 3766 // ConstantOp. 3767 class StridedSliceNonSplatConstantFolder final 3768 : public OpRewritePattern<ExtractStridedSliceOp> { 3769 public: 3770 using OpRewritePattern::OpRewritePattern; 3771 3772 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, 3773 PatternRewriter &rewriter) const override { 3774 // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat 3775 // ConstantOp. 3776 Value sourceVector = extractStridedSliceOp.getVector(); 3777 Attribute vectorCst; 3778 if (!matchPattern(sourceVector, m_Constant(&vectorCst))) 3779 return failure(); 3780 3781 // The splat case is handled by `StridedSliceSplatConstantFolder`. 3782 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst); 3783 if (!dense || dense.isSplat()) 3784 return failure(); 3785 3786 // TODO: Handle non-unit strides when they become available. 3787 if (extractStridedSliceOp.hasNonUnitStrides()) 3788 return failure(); 3789 3790 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType()); 3791 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape(); 3792 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape); 3793 3794 VectorType sliceVecTy = extractStridedSliceOp.getType(); 3795 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape(); 3796 int64_t sliceRank = sliceVecTy.getRank(); 3797 3798 // Expand offsets and sizes to match the vector rank. 3799 SmallVector<int64_t, 4> offsets(sliceRank, 0); 3800 copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin()); 3801 3802 SmallVector<int64_t, 4> sizes(sourceShape); 3803 copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin()); 3804 3805 // Calculate the slice elements by enumerating all slice positions and 3806 // linearizing them. The enumeration order is lexicographic which yields a 3807 // sequence of monotonically increasing linearized position indices. 3808 auto denseValuesBegin = dense.value_begin<Attribute>(); 3809 SmallVector<Attribute> sliceValues; 3810 sliceValues.reserve(sliceVecTy.getNumElements()); 3811 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end()); 3812 do { 3813 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides); 3814 assert(linearizedPosition < sourceVecTy.getNumElements() && 3815 "Invalid index"); 3816 sliceValues.push_back(*(denseValuesBegin + linearizedPosition)); 3817 } while ( 3818 succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets))); 3819 3820 assert(static_cast<int64_t>(sliceValues.size()) == 3821 sliceVecTy.getNumElements() && 3822 "Invalid number of slice elements"); 3823 auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues); 3824 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp, 3825 newAttr); 3826 return success(); 3827 } 3828 }; 3829 3830 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to 3831 // BroadcastOp(ExtractStrideSliceOp). 3832 class StridedSliceBroadcast final 3833 : public OpRewritePattern<ExtractStridedSliceOp> { 3834 public: 3835 using OpRewritePattern::OpRewritePattern; 3836 3837 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 3838 PatternRewriter &rewriter) const override { 3839 auto broadcast = op.getVector().getDefiningOp<BroadcastOp>(); 3840 if (!broadcast) 3841 return failure(); 3842 auto srcVecType = 3843 llvm::dyn_cast<VectorType>(broadcast.getSource().getType()); 3844 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0; 3845 auto dstVecType = llvm::cast<VectorType>(op.getType()); 3846 unsigned dstRank = dstVecType.getRank(); 3847 unsigned rankDiff = dstRank - srcRank; 3848 // Check if the most inner dimensions of the source of the broadcast are the 3849 // same as the destination of the extract. If this is the case we can just 3850 // use a broadcast as the original dimensions are untouched. 3851 bool lowerDimMatch = true; 3852 for (unsigned i = 0; i < srcRank; i++) { 3853 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) { 3854 lowerDimMatch = false; 3855 break; 3856 } 3857 } 3858 Value source = broadcast.getSource(); 3859 // If the inner dimensions don't match, it means we need to extract from the 3860 // source of the orignal broadcast and then broadcast the extracted value. 3861 // We also need to handle degenerated cases where the source is effectively 3862 // just a single scalar. 3863 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1); 3864 if (!lowerDimMatch && !isScalarSrc) { 3865 source = rewriter.create<ExtractStridedSliceOp>( 3866 op->getLoc(), source, 3867 getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff), 3868 getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff), 3869 getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff)); 3870 } 3871 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source); 3872 return success(); 3873 } 3874 }; 3875 3876 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp. 3877 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { 3878 public: 3879 using OpRewritePattern::OpRewritePattern; 3880 3881 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 3882 PatternRewriter &rewriter) const override { 3883 auto splat = op.getVector().getDefiningOp<SplatOp>(); 3884 if (!splat) 3885 return failure(); 3886 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput()); 3887 return success(); 3888 } 3889 }; 3890 3891 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the 3892 /// slice is contiguous, into extract and shape_cast. 3893 /// 3894 /// Example: 3895 /// Before: 3896 /// %1 = vector.extract_strided_slice %arg0 { 3897 /// offsets = [0, 0, 0, 0, 0], 3898 /// sizes = [1, 1, 1, 1, 8], 3899 /// strides = [1, 1, 1, 1, 1] 3900 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8> 3901 /// After: 3902 /// %0 = vector.extract %arg0[0, 0, 0, 0] 3903 /// : vector<8xi8> from vector<8x1x1x2x8xi8> 3904 /// %1 = vector.shape_cast %0 3905 /// : vector<8xi8> to vector<1x1x1x1x8xi8> 3906 /// 3907 class ContiguousExtractStridedSliceToExtract final 3908 : public OpRewritePattern<ExtractStridedSliceOp> { 3909 public: 3910 using OpRewritePattern::OpRewritePattern; 3911 3912 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 3913 PatternRewriter &rewriter) const override { 3914 if (op.hasNonUnitStrides()) 3915 return failure(); 3916 Value source = op.getOperand(); 3917 auto sourceType = cast<VectorType>(source.getType()); 3918 if (sourceType.isScalable() || sourceType.getRank() == 0) 3919 return failure(); 3920 3921 // Compute the number of offsets to pass to ExtractOp::build. That is the 3922 // difference between the source rank and the desired slice rank. We walk 3923 // the dimensions from innermost out, and stop when the next slice dimension 3924 // is not full-size. 3925 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes()); 3926 int numOffsets; 3927 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) { 3928 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) 3929 break; 3930 } 3931 3932 // If the created extract op would have no offsets, then this whole 3933 // extract_strided_slice is the identity and should have been handled by 3934 // other canonicalizations. 3935 if (numOffsets == 0) 3936 return failure(); 3937 3938 // If not even the inner-most dimension is full-size, this op can't be 3939 // rewritten as an ExtractOp. 3940 if (numOffsets == sourceType.getRank() && 3941 static_cast<int>(sizes.size()) == sourceType.getRank()) 3942 return failure(); 3943 3944 // The outer dimensions must have unit size. 3945 for (int i = 0; i < numOffsets; ++i) { 3946 if (sizes[i] != 1) 3947 return failure(); 3948 } 3949 3950 // Avoid generating slices that have leading unit dimensions. The shape_cast 3951 // op that we create below would take bad generic fallback patterns 3952 // (ShapeCastOpRewritePattern). 3953 while (sizes[numOffsets] == 1 && 3954 numOffsets < static_cast<int>(sizes.size()) - 1) { 3955 ++numOffsets; 3956 } 3957 3958 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets()); 3959 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets); 3960 Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source, 3961 extractOffsets); 3962 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract); 3963 return success(); 3964 } 3965 }; 3966 3967 } // namespace 3968 3969 void ExtractStridedSliceOp::getCanonicalizationPatterns( 3970 RewritePatternSet &results, MLIRContext *context) { 3971 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> 3972 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. 3973 results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder, 3974 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast, 3975 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>( 3976 context); 3977 } 3978 3979 //===----------------------------------------------------------------------===// 3980 // TransferReadOp 3981 //===----------------------------------------------------------------------===// 3982 3983 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs). 3984 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 3985 VectorType vectorType, Value source, 3986 ValueRange indices, AffineMapAttr permutationMapAttr, 3987 /*optional*/ ArrayAttr inBoundsAttr) { 3988 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType(); 3989 Value padding = builder.create<arith::ConstantOp>( 3990 result.location, elemType, builder.getZeroAttr(elemType)); 3991 build(builder, result, vectorType, source, indices, permutationMapAttr, 3992 padding, /*mask=*/Value(), inBoundsAttr); 3993 } 3994 3995 /// 2. Builder that sets padding to zero an empty mask (variant without attrs). 3996 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 3997 VectorType vectorType, Value source, 3998 ValueRange indices, AffineMap permutationMap, 3999 std::optional<ArrayRef<bool>> inBounds) { 4000 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 4001 auto inBoundsAttr = (inBounds && !inBounds.value().empty()) 4002 ? builder.getBoolArrayAttr(inBounds.value()) 4003 : builder.getBoolArrayAttr( 4004 SmallVector<bool>(vectorType.getRank(), false)); 4005 build(builder, result, vectorType, source, indices, permutationMapAttr, 4006 inBoundsAttr); 4007 } 4008 4009 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'. 4010 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 4011 VectorType vectorType, Value source, 4012 ValueRange indices, Value padding, 4013 std::optional<ArrayRef<bool>> inBounds) { 4014 AffineMap permutationMap = getTransferMinorIdentityMap( 4015 llvm::cast<ShapedType>(source.getType()), vectorType); 4016 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 4017 auto inBoundsAttr = (inBounds && !inBounds.value().empty()) 4018 ? builder.getBoolArrayAttr(inBounds.value()) 4019 : builder.getBoolArrayAttr( 4020 SmallVector<bool>(vectorType.getRank(), false)); 4021 build(builder, result, vectorType, source, indices, permutationMapAttr, 4022 padding, 4023 /*mask=*/Value(), inBoundsAttr); 4024 } 4025 4026 /// 4. Builder that sets padding to zero and permutation map to 4027 /// 'getMinorIdentityMap'. 4028 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 4029 VectorType vectorType, Value source, 4030 ValueRange indices, 4031 std::optional<ArrayRef<bool>> inBounds) { 4032 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType(); 4033 Value padding = builder.create<arith::ConstantOp>( 4034 result.location, elemType, builder.getZeroAttr(elemType)); 4035 build(builder, result, vectorType, source, indices, padding, inBounds); 4036 } 4037 4038 template <typename EmitFun> 4039 static LogicalResult verifyPermutationMap(AffineMap permutationMap, 4040 EmitFun emitOpError) { 4041 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false); 4042 for (auto expr : permutationMap.getResults()) { 4043 auto dim = dyn_cast<AffineDimExpr>(expr); 4044 auto zero = dyn_cast<AffineConstantExpr>(expr); 4045 if (zero) { 4046 if (zero.getValue() != 0) { 4047 return emitOpError( 4048 "requires a projected permutation_map (at most one dim or the zero " 4049 "constant can appear in each result)"); 4050 } 4051 continue; 4052 } 4053 if (!dim) { 4054 return emitOpError("requires a projected permutation_map (at most one " 4055 "dim or the zero constant can appear in each result)"); 4056 } 4057 if (seen[dim.getPosition()]) { 4058 return emitOpError( 4059 "requires a permutation_map that is a permutation (found one dim " 4060 "used more than once)"); 4061 } 4062 seen[dim.getPosition()] = true; 4063 } 4064 return success(); 4065 } 4066 4067 static LogicalResult 4068 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, 4069 VectorType vectorType, VectorType maskType, 4070 VectorType inferredMaskType, AffineMap permutationMap, 4071 ArrayAttr inBounds) { 4072 if (op->hasAttr("masked")) { 4073 return op->emitOpError("masked attribute has been removed. " 4074 "Use in_bounds instead."); 4075 } 4076 4077 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType)) 4078 return op->emitOpError( 4079 "requires source to be a memref or ranked tensor type"); 4080 4081 auto elementType = shapedType.getElementType(); 4082 DataLayout dataLayout = DataLayout::closest(op); 4083 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) { 4084 // Memref or tensor has vector element type. 4085 unsigned sourceVecSize = 4086 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) * 4087 vectorElementType.getShape().back(); 4088 unsigned resultVecSize = 4089 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * 4090 vectorType.getShape().back(); 4091 if (resultVecSize % sourceVecSize != 0) 4092 return op->emitOpError( 4093 "requires the bitwidth of the minor 1-D vector to be an integral " 4094 "multiple of the bitwidth of the minor 1-D vector of the source"); 4095 4096 unsigned sourceVecEltRank = vectorElementType.getRank(); 4097 unsigned resultVecRank = vectorType.getRank(); 4098 if (sourceVecEltRank > resultVecRank) 4099 return op->emitOpError( 4100 "requires source vector element and vector result ranks to match."); 4101 unsigned rankOffset = resultVecRank - sourceVecEltRank; 4102 // Check that permutation map results match 'rankOffset' of vector type. 4103 if (permutationMap.getNumResults() != rankOffset) 4104 return op->emitOpError("requires a permutation_map with result dims of " 4105 "the same rank as the vector type"); 4106 4107 if (maskType) 4108 return op->emitOpError("does not support masks with vector element type"); 4109 } else { 4110 // Memref or tensor has scalar element type. 4111 unsigned minorSize = 4112 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back(); 4113 unsigned resultVecSize = 4114 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize; 4115 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0) 4116 return op->emitOpError( 4117 "requires the bitwidth of the minor 1-D vector to be an integral " 4118 "multiple of the bitwidth of the source element type"); 4119 4120 // Check that permutation map results match rank of vector type. 4121 if (permutationMap.getNumResults() != vectorType.getRank()) 4122 return op->emitOpError("requires a permutation_map with result dims of " 4123 "the same rank as the vector type"); 4124 } 4125 4126 if (permutationMap.getNumSymbols() != 0) 4127 return op->emitOpError("requires permutation_map without symbols"); 4128 4129 if (permutationMap.getNumInputs() != shapedType.getRank()) 4130 return op->emitOpError("requires a permutation_map with input dims of the " 4131 "same rank as the source type"); 4132 4133 if (maskType && maskType != inferredMaskType) 4134 return op->emitOpError("inferred mask type (") 4135 << inferredMaskType << ") and mask operand type (" << maskType 4136 << ") don't match"; 4137 4138 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size())) 4139 return op->emitOpError("expects the in_bounds attr of same rank " 4140 "as permutation_map results: ") 4141 << AffineMapAttr::get(permutationMap) 4142 << " vs inBounds of size: " << inBounds.size(); 4143 4144 return success(); 4145 } 4146 4147 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { 4148 SmallVector<StringRef, 3> elidedAttrs; 4149 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); 4150 if (op.getPermutationMap().isMinorIdentity()) 4151 elidedAttrs.push_back(op.getPermutationMapAttrName()); 4152 // Elide in_bounds attribute if all dims are out-of-bounds. 4153 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; })) 4154 elidedAttrs.push_back(op.getInBoundsAttrName()); 4155 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); 4156 } 4157 4158 void TransferReadOp::print(OpAsmPrinter &p) { 4159 p << " " << getSource() << "[" << getIndices() << "], " << getPadding(); 4160 if (getMask()) 4161 p << ", " << getMask(); 4162 printTransferAttrs(p, *this); 4163 p << " : " << getShapedType() << ", " << getVectorType(); 4164 } 4165 4166 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType, 4167 AffineMap permMap) { 4168 auto i1Type = IntegerType::get(permMap.getContext(), 1); 4169 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap)); 4170 assert(invPermMap && "Inversed permutation map couldn't be computed"); 4171 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape()); 4172 4173 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a 4174 // 0-D mask into a single-element 1-D mask. 4175 if (maskShape.empty()) 4176 maskShape.push_back(1); 4177 4178 SmallVector<bool> scalableDims = 4179 applyPermutationMap(invPermMap, vecType.getScalableDims()); 4180 4181 return VectorType::get(maskShape, i1Type, scalableDims); 4182 } 4183 4184 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { 4185 auto &builder = parser.getBuilder(); 4186 SMLoc typesLoc; 4187 OpAsmParser::UnresolvedOperand sourceInfo; 4188 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo; 4189 OpAsmParser::UnresolvedOperand paddingInfo; 4190 SmallVector<Type, 2> types; 4191 OpAsmParser::UnresolvedOperand maskInfo; 4192 // Parsing with support for paddingValue. 4193 if (parser.parseOperand(sourceInfo) || 4194 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 4195 parser.parseComma() || parser.parseOperand(paddingInfo)) 4196 return failure(); 4197 ParseResult hasMask = parser.parseOptionalComma(); 4198 if (hasMask.succeeded()) { 4199 if (parser.parseOperand(maskInfo)) 4200 return failure(); 4201 } 4202 if (parser.parseOptionalAttrDict(result.attributes) || 4203 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) 4204 return failure(); 4205 if (types.size() != 2) 4206 return parser.emitError(typesLoc, "requires two types"); 4207 auto indexType = builder.getIndexType(); 4208 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]); 4209 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType)) 4210 return parser.emitError(typesLoc, "requires memref or ranked tensor type"); 4211 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]); 4212 if (!vectorType) 4213 return parser.emitError(typesLoc, "requires vector type"); 4214 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name); 4215 Attribute permMapAttr = result.attributes.get(permMapAttrName); 4216 AffineMap permMap; 4217 if (!permMapAttr) { 4218 permMap = getTransferMinorIdentityMap(shapedType, vectorType); 4219 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); 4220 } else { 4221 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue(); 4222 } 4223 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name); 4224 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName); 4225 if (!inBoundsAttr) { 4226 result.addAttribute(inBoundsAttrName, 4227 builder.getBoolArrayAttr( 4228 SmallVector<bool>(permMap.getNumResults(), false))); 4229 } 4230 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || 4231 parser.resolveOperands(indexInfo, indexType, result.operands) || 4232 parser.resolveOperand(paddingInfo, shapedType.getElementType(), 4233 result.operands)) 4234 return failure(); 4235 if (hasMask.succeeded()) { 4236 if (llvm::dyn_cast<VectorType>(shapedType.getElementType())) 4237 return parser.emitError( 4238 maskInfo.location, "does not support masks with vector element type"); 4239 if (vectorType.getRank() != permMap.getNumResults()) { 4240 return parser.emitError(typesLoc, 4241 "expected the same rank for the vector and the " 4242 "results of the permutation map"); 4243 } 4244 // Instead of adding the mask type as an op type, compute it based on the 4245 // vector type and the permutation map (to keep the type signature small). 4246 auto maskType = inferTransferOpMaskType(vectorType, permMap); 4247 if (parser.resolveOperand(maskInfo, maskType, result.operands)) 4248 return failure(); 4249 } 4250 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(), 4251 builder.getDenseI32ArrayAttr( 4252 {1, static_cast<int32_t>(indexInfo.size()), 1, 4253 static_cast<int32_t>(hasMask.succeeded())})); 4254 return parser.addTypeToList(vectorType, result.types); 4255 } 4256 4257 LogicalResult TransferReadOp::verify() { 4258 // Consistency of elemental types in source and vector. 4259 ShapedType shapedType = getShapedType(); 4260 VectorType vectorType = getVectorType(); 4261 VectorType maskType = getMaskType(); 4262 auto paddingType = getPadding().getType(); 4263 auto permutationMap = getPermutationMap(); 4264 VectorType inferredMaskType = 4265 maskType ? inferTransferOpMaskType(vectorType, permutationMap) 4266 : VectorType(); 4267 auto sourceElementType = shapedType.getElementType(); 4268 4269 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank()) 4270 return emitOpError("requires ") << shapedType.getRank() << " indices"; 4271 4272 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()), 4273 shapedType, vectorType, maskType, 4274 inferredMaskType, permutationMap, getInBounds()))) 4275 return failure(); 4276 4277 if (auto sourceVectorElementType = 4278 llvm::dyn_cast<VectorType>(sourceElementType)) { 4279 // Source has vector element type. 4280 // Check that 'sourceVectorElementType' and 'paddingType' types match. 4281 if (sourceVectorElementType != paddingType) 4282 return emitOpError( 4283 "requires source element type and padding type to match."); 4284 4285 } else { 4286 // Check that 'paddingType' is valid to store in a vector type. 4287 if (!VectorType::isValidElementType(paddingType)) 4288 return emitOpError("requires valid padding vector elemental type"); 4289 4290 // Check that padding type and vector element types match. 4291 if (paddingType != sourceElementType) 4292 return emitOpError( 4293 "requires formal padding and source of the same elemental type"); 4294 } 4295 4296 return verifyPermutationMap(permutationMap, 4297 [&](Twine t) { return emitOpError(t); }); 4298 } 4299 4300 // MaskableOpInterface methods. 4301 4302 /// Returns the mask type expected by this operation. Mostly used for 4303 /// verification purposes. It requires the operation to be vectorized." 4304 Type TransferReadOp::getExpectedMaskType() { 4305 return inferTransferOpMaskType(getVectorType(), getPermutationMap()); 4306 } 4307 4308 template <typename TransferOp> 4309 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { 4310 // TODO: support more aggressive createOrFold on: 4311 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx) 4312 if (op.getShapedType().isDynamicDim(indicesIdx)) 4313 return false; 4314 Value index = op.getIndices()[indicesIdx]; 4315 std::optional<int64_t> cstOp = getConstantIntValue(index); 4316 if (!cstOp.has_value()) 4317 return false; 4318 4319 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx); 4320 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx); 4321 4322 return cstOp.value() + vectorSize <= sourceSize; 4323 } 4324 4325 template <typename TransferOp> 4326 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { 4327 // TODO: support 0-d corner case. 4328 // TODO: Be less conservative. 4329 if (op.getTransferRank() == 0) 4330 return failure(); 4331 AffineMap permutationMap = op.getPermutationMap(); 4332 bool changed = false; 4333 SmallVector<bool, 4> newInBounds; 4334 newInBounds.reserve(op.getTransferRank()); 4335 // Idxs of non-bcast dims - used when analysing bcast dims. 4336 SmallVector<unsigned> nonBcastDims; 4337 4338 // 1. Process non-broadcast dims 4339 for (unsigned i = 0; i < op.getTransferRank(); ++i) { 4340 // 1.1. Already marked as in-bounds, nothing to see here. 4341 if (op.isDimInBounds(i)) { 4342 newInBounds.push_back(true); 4343 continue; 4344 } 4345 // 1.2. Currently out-of-bounds, check whether we can statically determine 4346 // it is inBounds. 4347 bool inBounds = false; 4348 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i)); 4349 if (dimExpr) { 4350 inBounds = isInBounds(op, /*resultIdx=*/i, 4351 /*indicesIdx=*/dimExpr.getPosition()); 4352 nonBcastDims.push_back(i); 4353 } 4354 4355 newInBounds.push_back(inBounds); 4356 // We commit the pattern if it is "more inbounds". 4357 changed |= inBounds; 4358 } 4359 4360 // 2. Handle broadcast dims 4361 // If all non-broadcast dims are "in bounds", then all bcast dims should be 4362 // "in bounds" as well. 4363 bool allNonBcastDimsInBounds = llvm::all_of( 4364 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; }); 4365 if (allNonBcastDimsInBounds) { 4366 for (size_t idx : permutationMap.getBroadcastDims()) { 4367 changed |= !newInBounds[idx]; 4368 newInBounds[idx] = true; 4369 } 4370 } 4371 4372 if (!changed) 4373 return failure(); 4374 // OpBuilder is only used as a helper to build an I64ArrayAttr. 4375 OpBuilder b(op.getContext()); 4376 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds)); 4377 return success(); 4378 } 4379 4380 template <typename TransferOp> 4381 static LogicalResult foldTransferFullMask(TransferOp op) { 4382 auto mask = op.getMask(); 4383 if (!mask) 4384 return failure(); 4385 4386 if (getMaskFormat(mask) != MaskFormat::AllTrue) 4387 return failure(); 4388 4389 op.getMaskMutable().clear(); 4390 return success(); 4391 } 4392 4393 /// ``` 4394 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 4395 /// : vector<1x4xf32>, tensor<4x4xf32> 4396 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} 4397 /// : tensor<4x4xf32>, vector<1x4xf32> 4398 /// ``` 4399 /// -> Folds into 4400 /// ``` 4401 /// %v0 4402 /// ``` 4403 static Value foldRAW(TransferReadOp readOp) { 4404 if (!llvm::isa<RankedTensorType>(readOp.getShapedType())) 4405 return {}; 4406 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>(); 4407 while (defWrite) { 4408 if (checkSameValueRAW(defWrite, readOp)) 4409 return defWrite.getVector(); 4410 if (!isDisjointTransferIndices( 4411 cast<VectorTransferOpInterface>(defWrite.getOperation()), 4412 cast<VectorTransferOpInterface>(readOp.getOperation()))) 4413 break; 4414 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>(); 4415 } 4416 return {}; 4417 } 4418 4419 OpFoldResult TransferReadOp::fold(FoldAdaptor) { 4420 if (Value vec = foldRAW(*this)) 4421 return vec; 4422 /// transfer_read(memrefcast) -> transfer_read 4423 if (succeeded(foldTransferInBoundsAttribute(*this))) 4424 return getResult(); 4425 if (succeeded(foldTransferFullMask(*this))) 4426 return getResult(); 4427 if (succeeded(memref::foldMemRefCast(*this))) 4428 return getResult(); 4429 if (succeeded(tensor::foldTensorCast(*this))) 4430 return getResult(); 4431 return OpFoldResult(); 4432 } 4433 4434 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() { 4435 return llvm::to_vector<4>(getVectorType().getShape()); 4436 } 4437 4438 void TransferReadOp::getEffects( 4439 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 4440 &effects) { 4441 if (llvm::isa<MemRefType>(getShapedType())) 4442 effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(), 4443 SideEffects::DefaultResource::get()); 4444 } 4445 4446 Speculation::Speculatability TransferReadOp::getSpeculatability() { 4447 if (hasPureTensorSemantics()) 4448 return Speculation::Speculatable; 4449 return Speculation::NotSpeculatable; 4450 } 4451 4452 namespace { 4453 /// Store to load forwarding for transfer operations with permuation maps. 4454 /// Even if the permutation maps are different we can still propagate the store 4455 /// into the load if the size of the dimensions read and written match. Then we 4456 /// can replace the transfer_read + transfer_write by vector.broadcast and 4457 /// vector.transpose. 4458 /// Example: 4459 /// ``` 4460 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0] 4461 /// {in_bounds = [true, true], 4462 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} : 4463 /// vector<4x1xf32>, tensor<4x4x4xf32> 4464 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0 4465 /// {in_bounds = [true, true, true, true], 4466 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} : 4467 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32> 4468 /// ``` 4469 /// To: 4470 /// ``` 4471 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32> 4472 /// %r = vector.transpose %0, [3, 0, 2, 1] : 4473 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32> 4474 /// ``` 4475 struct TransferReadAfterWriteToBroadcast 4476 : public OpRewritePattern<TransferReadOp> { 4477 using OpRewritePattern::OpRewritePattern; 4478 4479 LogicalResult matchAndRewrite(TransferReadOp readOp, 4480 PatternRewriter &rewriter) const override { 4481 if (readOp.hasOutOfBoundsDim() || 4482 !llvm::isa<RankedTensorType>(readOp.getShapedType())) 4483 return failure(); 4484 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>(); 4485 if (!defWrite) 4486 return failure(); 4487 // TODO: If the written transfer chunk is a superset of the read transfer 4488 // chunk we could do an extract_strided_slice. 4489 if (readOp.getTransferChunkAccessed() != 4490 defWrite.getTransferChunkAccessed()) 4491 return failure(); 4492 // TODO: Support cases where a dim is explicitly written but implicitly 4493 // read (i.e., a unit dim that is rank reduced). 4494 if (getUnusedDimsBitVector({readOp.getPermutationMap()}) != 4495 getUnusedDimsBitVector({defWrite.getPermutationMap()})) 4496 return failure(); 4497 if (readOp.getIndices() != defWrite.getIndices() || 4498 readOp.getMask() != defWrite.getMask()) 4499 return failure(); 4500 Value vec = defWrite.getVector(); 4501 // TODO: loop through the chain of transfer_write if we can prove that they 4502 // don't overlap with the transfer_read. This requires improving 4503 // `isDisjointTransferIndices` helper. 4504 AffineMap readMap = compressUnusedDims(readOp.getPermutationMap()); 4505 AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap()); 4506 AffineMap map = readMap.compose(writeMap); 4507 if (map.getNumResults() == 0) 4508 return failure(); 4509 // Calculate the permutation to apply to go from the vector stored to the 4510 // vector read. 4511 SmallVector<unsigned> permutation; 4512 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) 4513 return failure(); 4514 4515 Location loc = readOp.getLoc(); 4516 // Calculate the broadcast shape by applying the reverse permutation to the 4517 // final shape we want. 4518 ArrayRef<int64_t> destShape = readOp.getVectorType().getShape(); 4519 SmallVector<int64_t> broadcastShape(destShape.size()); 4520 SmallVector<bool> broadcastScalableFlags(destShape.size()); 4521 for (const auto &pos : llvm::enumerate(permutation)) { 4522 broadcastShape[pos.value()] = destShape[pos.index()]; 4523 broadcastScalableFlags[pos.value()] = 4524 readOp.getVectorType().getScalableDims()[pos.index()]; 4525 } 4526 VectorType broadcastedType = VectorType::get( 4527 broadcastShape, defWrite.getVectorType().getElementType(), 4528 broadcastScalableFlags); 4529 vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec); 4530 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); 4531 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, 4532 transposePerm); 4533 return success(); 4534 } 4535 }; 4536 } // namespace 4537 4538 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, 4539 MLIRContext *context) { 4540 results.add<TransferReadAfterWriteToBroadcast>(context); 4541 } 4542 4543 //===----------------------------------------------------------------------===// 4544 // TransferWriteOp 4545 //===----------------------------------------------------------------------===// 4546 4547 /// 1. Builder with type inference. 4548 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 4549 Value vector, Value dest, ValueRange indices, 4550 AffineMapAttr permutationMapAttr, 4551 /*optional*/ Value mask, 4552 /*optional*/ ArrayAttr inBoundsAttr) { 4553 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType()); 4554 build(builder, result, resultType, vector, dest, indices, permutationMapAttr, 4555 mask, inBoundsAttr); 4556 } 4557 4558 /// 2. Builder with type inference that sets an empty mask (variant with attrs). 4559 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 4560 Value vector, Value dest, ValueRange indices, 4561 AffineMapAttr permutationMapAttr, 4562 /*optional*/ ArrayAttr inBoundsAttr) { 4563 build(builder, result, vector, dest, indices, permutationMapAttr, 4564 /*mask=*/Value(), inBoundsAttr); 4565 } 4566 4567 /// 3. Builder with type inference that sets an empty mask (variant without 4568 /// attrs) 4569 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 4570 Value vector, Value dest, ValueRange indices, 4571 AffineMap permutationMap, 4572 std::optional<ArrayRef<bool>> inBounds) { 4573 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 4574 auto inBoundsAttr = 4575 (inBounds && !inBounds.value().empty()) 4576 ? builder.getBoolArrayAttr(inBounds.value()) 4577 : builder.getBoolArrayAttr(SmallVector<bool>( 4578 llvm::cast<VectorType>(vector.getType()).getRank(), false)); 4579 build(builder, result, vector, dest, indices, permutationMapAttr, 4580 /*mask=*/Value(), inBoundsAttr); 4581 } 4582 4583 /// 4. Builder with type inference that sets an empty mask and sets permutation 4584 /// map to 'getMinorIdentityMap'. 4585 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 4586 Value vector, Value dest, ValueRange indices, 4587 std::optional<ArrayRef<bool>> inBounds) { 4588 auto vectorType = llvm::cast<VectorType>(vector.getType()); 4589 AffineMap permutationMap = getTransferMinorIdentityMap( 4590 llvm::cast<ShapedType>(dest.getType()), vectorType); 4591 build(builder, result, vector, dest, indices, permutationMap, inBounds); 4592 } 4593 4594 ParseResult TransferWriteOp::parse(OpAsmParser &parser, 4595 OperationState &result) { 4596 auto &builder = parser.getBuilder(); 4597 SMLoc typesLoc; 4598 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo; 4599 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo; 4600 SmallVector<Type, 2> types; 4601 OpAsmParser::UnresolvedOperand maskInfo; 4602 if (parser.parseOperand(vectorInfo) || parser.parseComma() || 4603 parser.parseOperand(sourceInfo) || 4604 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square)) 4605 return failure(); 4606 ParseResult hasMask = parser.parseOptionalComma(); 4607 if (hasMask.succeeded() && parser.parseOperand(maskInfo)) 4608 return failure(); 4609 if (parser.parseOptionalAttrDict(result.attributes) || 4610 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) 4611 return failure(); 4612 if (types.size() != 2) 4613 return parser.emitError(typesLoc, "requires two types"); 4614 auto indexType = builder.getIndexType(); 4615 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]); 4616 if (!vectorType) 4617 return parser.emitError(typesLoc, "requires vector type"); 4618 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]); 4619 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType)) 4620 return parser.emitError(typesLoc, "requires memref or ranked tensor type"); 4621 auto permMapAttrName = 4622 TransferWriteOp::getPermutationMapAttrName(result.name); 4623 auto permMapAttr = result.attributes.get(permMapAttrName); 4624 AffineMap permMap; 4625 if (!permMapAttr) { 4626 permMap = getTransferMinorIdentityMap(shapedType, vectorType); 4627 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); 4628 } else { 4629 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue(); 4630 } 4631 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name); 4632 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName); 4633 if (!inBoundsAttr) { 4634 result.addAttribute(inBoundsAttrName, 4635 builder.getBoolArrayAttr( 4636 SmallVector<bool>(permMap.getNumResults(), false))); 4637 } 4638 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) || 4639 parser.resolveOperand(sourceInfo, shapedType, result.operands) || 4640 parser.resolveOperands(indexInfo, indexType, result.operands)) 4641 return failure(); 4642 if (hasMask.succeeded()) { 4643 if (llvm::dyn_cast<VectorType>(shapedType.getElementType())) 4644 return parser.emitError( 4645 maskInfo.location, "does not support masks with vector element type"); 4646 if (vectorType.getRank() != permMap.getNumResults()) { 4647 return parser.emitError(typesLoc, 4648 "expected the same rank for the vector and the " 4649 "results of the permutation map"); 4650 } 4651 auto maskType = inferTransferOpMaskType(vectorType, permMap); 4652 if (parser.resolveOperand(maskInfo, maskType, result.operands)) 4653 return failure(); 4654 } 4655 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(), 4656 builder.getDenseI32ArrayAttr( 4657 {1, 1, static_cast<int32_t>(indexInfo.size()), 4658 static_cast<int32_t>(hasMask.succeeded())})); 4659 return failure(llvm::isa<RankedTensorType>(shapedType) && 4660 parser.addTypeToList(shapedType, result.types)); 4661 } 4662 4663 void TransferWriteOp::print(OpAsmPrinter &p) { 4664 p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]"; 4665 if (getMask()) 4666 p << ", " << getMask(); 4667 printTransferAttrs(p, *this); 4668 p << " : " << getVectorType() << ", " << getShapedType(); 4669 } 4670 4671 LogicalResult TransferWriteOp::verify() { 4672 // Consistency of elemental types in shape and vector. 4673 ShapedType shapedType = getShapedType(); 4674 VectorType vectorType = getVectorType(); 4675 VectorType maskType = getMaskType(); 4676 auto permutationMap = getPermutationMap(); 4677 VectorType inferredMaskType = 4678 maskType ? inferTransferOpMaskType(vectorType, permutationMap) 4679 : VectorType(); 4680 4681 if (llvm::size(getIndices()) != shapedType.getRank()) 4682 return emitOpError("requires ") << shapedType.getRank() << " indices"; 4683 4684 // We do not allow broadcast dimensions on TransferWriteOps for the moment, 4685 // as the semantics is unclear. This can be revisited later if necessary. 4686 if (hasBroadcastDim()) 4687 return emitOpError("should not have broadcast dimensions"); 4688 4689 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()), 4690 shapedType, vectorType, maskType, 4691 inferredMaskType, permutationMap, getInBounds()))) 4692 return failure(); 4693 4694 return verifyPermutationMap(permutationMap, 4695 [&](Twine t) { return emitOpError(t); }); 4696 } 4697 4698 // MaskableOpInterface methods. 4699 4700 /// Returns the mask type expected by this operation. Mostly used for 4701 /// verification purposes. 4702 Type TransferWriteOp::getExpectedMaskType() { 4703 return inferTransferOpMaskType(getVectorType(), getPermutationMap()); 4704 } 4705 4706 /// Fold: 4707 /// ``` 4708 /// %t1 = ... 4709 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} : 4710 /// tensor<static_sizesxf32>, vector<static_sizesxf32> 4711 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} : 4712 /// vector<static_sizesxf32>, tensor<static_sizesxf32> 4713 /// ``` 4714 /// 4715 /// into: 4716 /// 4717 /// ``` 4718 /// %t0 4719 /// ``` 4720 /// 4721 /// The producer of t1 may or may not be DCE'd depending on whether it is a 4722 /// block argument or has side effects. 4723 static LogicalResult foldReadInitWrite(TransferWriteOp write, 4724 ArrayRef<Attribute>, 4725 SmallVectorImpl<OpFoldResult> &results) { 4726 // TODO: support 0-d corner case. 4727 if (write.getTransferRank() == 0) 4728 return failure(); 4729 auto rankedTensorType = 4730 llvm::dyn_cast<RankedTensorType>(write.getSource().getType()); 4731 // If not operating on tensors, bail. 4732 if (!rankedTensorType) 4733 return failure(); 4734 // If no read, bail. 4735 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>(); 4736 if (!read) 4737 return failure(); 4738 // TODO: support 0-d corner case. 4739 if (read.getTransferRank() == 0) 4740 return failure(); 4741 // For now, only accept minor identity. Future: composition is minor identity. 4742 if (!read.getPermutationMap().isMinorIdentity() || 4743 !write.getPermutationMap().isMinorIdentity()) 4744 return failure(); 4745 // Bail on mismatching ranks. 4746 if (read.getTransferRank() != write.getTransferRank()) 4747 return failure(); 4748 // Bail on potential out-of-bounds accesses. 4749 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim()) 4750 return failure(); 4751 // Tensor types must be the same. 4752 if (read.getSource().getType() != rankedTensorType) 4753 return failure(); 4754 // Vector types must be the same. 4755 if (read.getVectorType() != write.getVectorType()) 4756 return failure(); 4757 // Vector and Tensor shapes must match. 4758 if (read.getVectorType().getShape() != rankedTensorType.getShape()) 4759 return failure(); 4760 // If any index is nonzero. 4761 auto isNotConstantZero = [](Value v) { 4762 auto cstOp = getConstantIntValue(v); 4763 return !cstOp.has_value() || cstOp.value() != 0; 4764 }; 4765 if (llvm::any_of(read.getIndices(), isNotConstantZero) || 4766 llvm::any_of(write.getIndices(), isNotConstantZero)) 4767 return failure(); 4768 // Success. 4769 results.push_back(read.getSource()); 4770 return success(); 4771 } 4772 4773 static bool checkSameValueWAR(vector::TransferReadOp read, 4774 vector::TransferWriteOp write) { 4775 return read.getSource() == write.getSource() && 4776 read.getIndices() == write.getIndices() && 4777 read.getPermutationMap() == write.getPermutationMap() && 4778 read.getVectorType() == write.getVectorType() && !read.getMask() && 4779 !write.getMask(); 4780 } 4781 /// Fold transfer_write write after read: 4782 /// ``` 4783 /// %t0 = ... 4784 /// %v = vector.transfer_read %t0[%c0...] : 4785 /// tensor<static_sizesxf32>, vector<static_sizesxf32> 4786 /// %t1 = vector.transfer_write %v, %t0[%c0...] : 4787 /// vector<static_sizesxf32>, tensor<static_sizesxf32> 4788 /// ``` 4789 /// 4790 /// into: 4791 /// 4792 /// ``` 4793 /// %t0 4794 /// ``` 4795 static LogicalResult foldWAR(TransferWriteOp write, 4796 SmallVectorImpl<OpFoldResult> &results) { 4797 if (!llvm::isa<RankedTensorType>(write.getSource().getType())) 4798 return failure(); 4799 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>(); 4800 if (!read) 4801 return failure(); 4802 4803 if (!checkSameValueWAR(read, write)) 4804 return failure(); 4805 results.push_back(read.getSource()); 4806 return success(); 4807 } 4808 4809 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor, 4810 SmallVectorImpl<OpFoldResult> &results) { 4811 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results))) 4812 return success(); 4813 if (succeeded(foldWAR(*this, results))) 4814 return success(); 4815 if (succeeded(foldTransferInBoundsAttribute(*this))) 4816 return success(); 4817 if (succeeded(foldTransferFullMask(*this))) 4818 return success(); 4819 return memref::foldMemRefCast(*this); 4820 } 4821 4822 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() { 4823 return llvm::to_vector<4>(getVectorType().getShape()); 4824 } 4825 4826 void TransferWriteOp::getEffects( 4827 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 4828 &effects) { 4829 if (llvm::isa<MemRefType>(getShapedType())) 4830 effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(), 4831 SideEffects::DefaultResource::get()); 4832 } 4833 4834 Speculation::Speculatability TransferWriteOp::getSpeculatability() { 4835 if (hasPureTensorSemantics()) 4836 return Speculation::Speculatable; 4837 return Speculation::NotSpeculatable; 4838 } 4839 4840 namespace { 4841 /// Remove dead transfer write from the SSA chain so that it an be eliminated by 4842 /// DCE 4843 /// ``` 4844 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 4845 /// : vector<1x4xf32>, tensor<4x4xf32> 4846 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} 4847 /// : vector<1x4xf32>, tensor<4x4xf32> 4848 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} 4849 /// : vector<1x4xf32>, tensor<4x4xf32> 4850 /// ``` 4851 /// 4852 /// into: 4853 /// 4854 /// ``` 4855 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 4856 /// : vector<1x4xf32>, tensor<4x4xf32> 4857 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]} 4858 /// : vector<1x4xf32>, tensor<4x4xf32> 4859 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} 4860 /// : vector<1x4xf32>, tensor<4x4xf32> 4861 /// ``` 4862 /// 4863 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have 4864 /// any other uses. 4865 class FoldWaw final : public OpRewritePattern<TransferWriteOp> { 4866 public: 4867 using OpRewritePattern::OpRewritePattern; 4868 LogicalResult matchAndRewrite(TransferWriteOp writeOp, 4869 PatternRewriter &rewriter) const override { 4870 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType())) 4871 return failure(); 4872 vector::TransferWriteOp writeToModify = writeOp; 4873 4874 auto defWrite = 4875 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>(); 4876 while (defWrite) { 4877 if (checkSameValueWAW(writeOp, defWrite)) { 4878 rewriter.modifyOpInPlace(writeToModify, [&]() { 4879 writeToModify.getSourceMutable().assign(defWrite.getSource()); 4880 }); 4881 return success(); 4882 } 4883 if (!isDisjointTransferIndices( 4884 cast<VectorTransferOpInterface>(defWrite.getOperation()), 4885 cast<VectorTransferOpInterface>(writeOp.getOperation()))) 4886 break; 4887 // If the previous write op doesn't have any other use we an safely look 4888 // at the previous store to see if it can be removed. 4889 if (!defWrite->hasOneUse()) 4890 break; 4891 writeToModify = defWrite; 4892 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>(); 4893 } 4894 return failure(); 4895 } 4896 }; 4897 4898 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to 4899 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is 4900 /// overwritten and inserted into another tensor. After this rewrite, the 4901 /// operations bufferize in-place since all of them work on the same slice. 4902 /// 4903 /// For example: 4904 /// ```mlir 4905 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0] 4906 /// : vector<8x16xf32>, tensor<8x16xf32> 4907 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1] 4908 /// : tensor<8x16xf32> to tensor<?x?xf32> 4909 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1] 4910 /// : tensor<?x?xf32> into tensor<27x37xf32> 4911 /// ``` 4912 /// folds to 4913 /// ```mlir 4914 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1] 4915 /// : tensor<27x37xf32> to tensor<?x?xf32> 4916 /// %1 = vector.transfer_write %vec, %0[%c0, %c0] 4917 /// : vector<8x16xf32>, tensor<?x?xf32> 4918 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1] 4919 /// : tensor<?x?xf32> into tensor<27x37xf32> 4920 /// ``` 4921 struct SwapExtractSliceOfTransferWrite 4922 : public OpRewritePattern<tensor::InsertSliceOp> { 4923 public: 4924 using OpRewritePattern::OpRewritePattern; 4925 4926 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, 4927 PatternRewriter &rewriter) const override { 4928 if (!insertOp.hasUnitStride()) 4929 return failure(); 4930 auto extractOp = 4931 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); 4932 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse()) 4933 return failure(); 4934 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>(); 4935 if (!transferOp || !transferOp->hasOneUse()) 4936 return failure(); 4937 4938 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is 4939 // rank-reducing. 4940 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) { 4941 return rewriter.notifyMatchFailure(insertOp, 4942 "use-def chain is rank-reducing"); 4943 } 4944 4945 // Fail if tensor::ExtractSliceOp has non-zero offset. 4946 if (!extractOp.hasZeroOffset()) { 4947 return rewriter.notifyMatchFailure(insertOp, 4948 "ExtractSliceOp has non-zero offset"); 4949 } 4950 4951 // Fail if tensor::TransferWriteOp has non-zero offset. 4952 if (!llvm::all_of(transferOp.getIndices(), [](Value value) { 4953 return getConstantIntValue(value) == static_cast<int64_t>(0); 4954 })) { 4955 return rewriter.notifyMatchFailure(insertOp, 4956 "TranferWriteOp has non-zero offset"); 4957 } 4958 4959 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ. 4960 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) { 4961 return rewriter.notifyMatchFailure( 4962 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ"); 4963 } 4964 4965 for (auto [insertSize, extractSize] : 4966 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) { 4967 if (!isEqualConstantIntOrValue(insertSize, extractSize)) { 4968 return rewriter.notifyMatchFailure( 4969 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ"); 4970 } 4971 } 4972 4973 // Fail if the vector::TransferWriteOp may not overwrite the full tensor. 4974 assert(transferOp.getVectorType().hasStaticShape() && 4975 "expected vector to have a static shape"); 4976 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape(); 4977 SmallVector<int64_t> resultShape = applyPermutationMap( 4978 transferOp.getPermutationMap(), transferOp.getShapedType().getShape()); 4979 if (transferOp.getMask() || !vectorShape.equals(resultShape)) { 4980 return rewriter.notifyMatchFailure( 4981 insertOp, "TransferWriteOp may not write the full tensor."); 4982 } 4983 4984 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp. 4985 // Set all in_bounds to false and let the folder infer them. 4986 SmallVector<bool> newInBounds(vectorShape.size(), false); 4987 auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>( 4988 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(), 4989 insertOp.getMixedOffsets(), insertOp.getMixedSizes(), 4990 insertOp.getMixedStrides()); 4991 auto newTransferWriteOp = rewriter.create<TransferWriteOp>( 4992 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(), 4993 transferOp.getIndices(), transferOp.getPermutationMapAttr(), 4994 rewriter.getBoolArrayAttr(newInBounds)); 4995 rewriter.modifyOpInPlace(insertOp, [&]() { 4996 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult()); 4997 }); 4998 return success(); 4999 } 5000 }; 5001 5002 } // namespace 5003 5004 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, 5005 MLIRContext *context) { 5006 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context); 5007 } 5008 5009 //===----------------------------------------------------------------------===// 5010 // LoadOp 5011 //===----------------------------------------------------------------------===// 5012 5013 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op, 5014 VectorType vecTy, 5015 MemRefType memRefTy) { 5016 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't 5017 // need any strides limitations. 5018 if (!vecTy.isScalable() && 5019 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1)) 5020 return success(); 5021 5022 if (!memRefTy.isLastDimUnitStride()) 5023 return op->emitOpError("most minor memref dim must have unit stride"); 5024 return success(); 5025 } 5026 5027 LogicalResult vector::LoadOp::verify() { 5028 VectorType resVecTy = getVectorType(); 5029 MemRefType memRefTy = getMemRefType(); 5030 5031 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy))) 5032 return failure(); 5033 5034 // Checks for vector memrefs. 5035 Type memElemTy = memRefTy.getElementType(); 5036 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) { 5037 if (memVecTy != resVecTy) 5038 return emitOpError("base memref and result vector types should match"); 5039 memElemTy = memVecTy.getElementType(); 5040 } 5041 5042 if (resVecTy.getElementType() != memElemTy) 5043 return emitOpError("base and result element types should match"); 5044 if (llvm::size(getIndices()) != memRefTy.getRank()) 5045 return emitOpError("requires ") << memRefTy.getRank() << " indices"; 5046 return success(); 5047 } 5048 5049 OpFoldResult LoadOp::fold(FoldAdaptor) { 5050 if (succeeded(memref::foldMemRefCast(*this))) 5051 return getResult(); 5052 return OpFoldResult(); 5053 } 5054 5055 //===----------------------------------------------------------------------===// 5056 // StoreOp 5057 //===----------------------------------------------------------------------===// 5058 5059 LogicalResult vector::StoreOp::verify() { 5060 VectorType valueVecTy = getVectorType(); 5061 MemRefType memRefTy = getMemRefType(); 5062 5063 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy))) 5064 return failure(); 5065 5066 // Checks for vector memrefs. 5067 Type memElemTy = memRefTy.getElementType(); 5068 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) { 5069 if (memVecTy != valueVecTy) 5070 return emitOpError( 5071 "base memref and valueToStore vector types should match"); 5072 memElemTy = memVecTy.getElementType(); 5073 } 5074 5075 if (valueVecTy.getElementType() != memElemTy) 5076 return emitOpError("base and valueToStore element type should match"); 5077 if (llvm::size(getIndices()) != memRefTy.getRank()) 5078 return emitOpError("requires ") << memRefTy.getRank() << " indices"; 5079 return success(); 5080 } 5081 5082 LogicalResult StoreOp::fold(FoldAdaptor adaptor, 5083 SmallVectorImpl<OpFoldResult> &results) { 5084 return memref::foldMemRefCast(*this); 5085 } 5086 5087 //===----------------------------------------------------------------------===// 5088 // MaskedLoadOp 5089 //===----------------------------------------------------------------------===// 5090 5091 LogicalResult MaskedLoadOp::verify() { 5092 VectorType maskVType = getMaskVectorType(); 5093 VectorType passVType = getPassThruVectorType(); 5094 VectorType resVType = getVectorType(); 5095 MemRefType memType = getMemRefType(); 5096 5097 if (resVType.getElementType() != memType.getElementType()) 5098 return emitOpError("base and result element type should match"); 5099 if (llvm::size(getIndices()) != memType.getRank()) 5100 return emitOpError("requires ") << memType.getRank() << " indices"; 5101 if (resVType.getShape() != maskVType.getShape()) 5102 return emitOpError("expected result shape to match mask shape"); 5103 if (resVType != passVType) 5104 return emitOpError("expected pass_thru of same type as result type"); 5105 return success(); 5106 } 5107 5108 namespace { 5109 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { 5110 public: 5111 using OpRewritePattern::OpRewritePattern; 5112 LogicalResult matchAndRewrite(MaskedLoadOp load, 5113 PatternRewriter &rewriter) const override { 5114 switch (getMaskFormat(load.getMask())) { 5115 case MaskFormat::AllTrue: 5116 rewriter.replaceOpWithNewOp<vector::LoadOp>( 5117 load, load.getType(), load.getBase(), load.getIndices()); 5118 return success(); 5119 case MaskFormat::AllFalse: 5120 rewriter.replaceOp(load, load.getPassThru()); 5121 return success(); 5122 case MaskFormat::Unknown: 5123 return failure(); 5124 } 5125 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad"); 5126 } 5127 }; 5128 } // namespace 5129 5130 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 5131 MLIRContext *context) { 5132 results.add<MaskedLoadFolder>(context); 5133 } 5134 5135 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) { 5136 if (succeeded(memref::foldMemRefCast(*this))) 5137 return getResult(); 5138 return OpFoldResult(); 5139 } 5140 5141 //===----------------------------------------------------------------------===// 5142 // MaskedStoreOp 5143 //===----------------------------------------------------------------------===// 5144 5145 LogicalResult MaskedStoreOp::verify() { 5146 VectorType maskVType = getMaskVectorType(); 5147 VectorType valueVType = getVectorType(); 5148 MemRefType memType = getMemRefType(); 5149 5150 if (valueVType.getElementType() != memType.getElementType()) 5151 return emitOpError("base and valueToStore element type should match"); 5152 if (llvm::size(getIndices()) != memType.getRank()) 5153 return emitOpError("requires ") << memType.getRank() << " indices"; 5154 if (valueVType.getShape() != maskVType.getShape()) 5155 return emitOpError("expected valueToStore shape to match mask shape"); 5156 return success(); 5157 } 5158 5159 namespace { 5160 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { 5161 public: 5162 using OpRewritePattern::OpRewritePattern; 5163 LogicalResult matchAndRewrite(MaskedStoreOp store, 5164 PatternRewriter &rewriter) const override { 5165 switch (getMaskFormat(store.getMask())) { 5166 case MaskFormat::AllTrue: 5167 rewriter.replaceOpWithNewOp<vector::StoreOp>( 5168 store, store.getValueToStore(), store.getBase(), store.getIndices()); 5169 return success(); 5170 case MaskFormat::AllFalse: 5171 rewriter.eraseOp(store); 5172 return success(); 5173 case MaskFormat::Unknown: 5174 return failure(); 5175 } 5176 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore"); 5177 } 5178 }; 5179 } // namespace 5180 5181 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 5182 MLIRContext *context) { 5183 results.add<MaskedStoreFolder>(context); 5184 } 5185 5186 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor, 5187 SmallVectorImpl<OpFoldResult> &results) { 5188 return memref::foldMemRefCast(*this); 5189 } 5190 5191 //===----------------------------------------------------------------------===// 5192 // GatherOp 5193 //===----------------------------------------------------------------------===// 5194 5195 LogicalResult GatherOp::verify() { 5196 VectorType indVType = getIndexVectorType(); 5197 VectorType maskVType = getMaskVectorType(); 5198 VectorType resVType = getVectorType(); 5199 ShapedType baseType = getBaseType(); 5200 5201 if (!llvm::isa<MemRefType, RankedTensorType>(baseType)) 5202 return emitOpError("requires base to be a memref or ranked tensor type"); 5203 5204 if (resVType.getElementType() != baseType.getElementType()) 5205 return emitOpError("base and result element type should match"); 5206 if (llvm::size(getIndices()) != baseType.getRank()) 5207 return emitOpError("requires ") << baseType.getRank() << " indices"; 5208 if (resVType.getShape() != indVType.getShape()) 5209 return emitOpError("expected result dim to match indices dim"); 5210 if (resVType.getShape() != maskVType.getShape()) 5211 return emitOpError("expected result dim to match mask dim"); 5212 if (resVType != getPassThruVectorType()) 5213 return emitOpError("expected pass_thru of same type as result type"); 5214 return success(); 5215 } 5216 5217 // MaskableOpInterface methods. 5218 5219 /// Returns the mask type expected by this operation. Mostly used for 5220 /// verification purposes. It requires the operation to be vectorized." 5221 Type GatherOp::getExpectedMaskType() { 5222 auto vecType = this->getIndexVectorType(); 5223 return VectorType::get(vecType.getShape(), 5224 IntegerType::get(vecType.getContext(), /*width=*/1), 5225 vecType.getScalableDims()); 5226 } 5227 5228 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() { 5229 return llvm::to_vector<4>(getVectorType().getShape()); 5230 } 5231 5232 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...] 5233 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) { 5234 auto vecType = dyn_cast<VectorType>(indexVec.getType()); 5235 if (!vecType || vecType.getRank() != 1 || vecType.isScalable()) 5236 return failure(); 5237 5238 if (indexVec.getDefiningOp<StepOp>()) 5239 return success(); 5240 5241 DenseIntElementsAttr elements; 5242 if (!matchPattern(indexVec, m_Constant(&elements))) 5243 return failure(); 5244 5245 return success( 5246 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements()))); 5247 } 5248 5249 namespace { 5250 class GatherFolder final : public OpRewritePattern<GatherOp> { 5251 public: 5252 using OpRewritePattern::OpRewritePattern; 5253 LogicalResult matchAndRewrite(GatherOp gather, 5254 PatternRewriter &rewriter) const override { 5255 switch (getMaskFormat(gather.getMask())) { 5256 case MaskFormat::AllTrue: 5257 return failure(); // no unmasked equivalent 5258 case MaskFormat::AllFalse: 5259 rewriter.replaceOp(gather, gather.getPassThru()); 5260 return success(); 5261 case MaskFormat::Unknown: 5262 return failure(); 5263 } 5264 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder"); 5265 } 5266 }; 5267 5268 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous 5269 /// maskedload. Only 1D fixed vectors are supported for now. 5270 class FoldContiguousGather final : public OpRewritePattern<GatherOp> { 5271 public: 5272 using OpRewritePattern::OpRewritePattern; 5273 LogicalResult matchAndRewrite(GatherOp op, 5274 PatternRewriter &rewriter) const override { 5275 if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) 5276 return failure(); 5277 5278 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(), 5279 op.getIndices(), op.getMask(), 5280 op.getPassThru()); 5281 return success(); 5282 } 5283 }; 5284 } // namespace 5285 5286 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, 5287 MLIRContext *context) { 5288 results.add<GatherFolder, FoldContiguousGather>(context); 5289 } 5290 5291 //===----------------------------------------------------------------------===// 5292 // ScatterOp 5293 //===----------------------------------------------------------------------===// 5294 5295 LogicalResult ScatterOp::verify() { 5296 VectorType indVType = getIndexVectorType(); 5297 VectorType maskVType = getMaskVectorType(); 5298 VectorType valueVType = getVectorType(); 5299 MemRefType memType = getMemRefType(); 5300 5301 if (valueVType.getElementType() != memType.getElementType()) 5302 return emitOpError("base and valueToStore element type should match"); 5303 if (llvm::size(getIndices()) != memType.getRank()) 5304 return emitOpError("requires ") << memType.getRank() << " indices"; 5305 if (valueVType.getDimSize(0) != indVType.getDimSize(0)) 5306 return emitOpError("expected valueToStore dim to match indices dim"); 5307 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 5308 return emitOpError("expected valueToStore dim to match mask dim"); 5309 return success(); 5310 } 5311 5312 namespace { 5313 class ScatterFolder final : public OpRewritePattern<ScatterOp> { 5314 public: 5315 using OpRewritePattern::OpRewritePattern; 5316 LogicalResult matchAndRewrite(ScatterOp scatter, 5317 PatternRewriter &rewriter) const override { 5318 switch (getMaskFormat(scatter.getMask())) { 5319 case MaskFormat::AllTrue: 5320 return failure(); // no unmasked equivalent 5321 case MaskFormat::AllFalse: 5322 rewriter.eraseOp(scatter); 5323 return success(); 5324 case MaskFormat::Unknown: 5325 return failure(); 5326 } 5327 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder"); 5328 } 5329 }; 5330 5331 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous 5332 /// maskedstore. Only 1D fixed vectors are supported for now. 5333 class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> { 5334 public: 5335 using OpRewritePattern::OpRewritePattern; 5336 LogicalResult matchAndRewrite(ScatterOp op, 5337 PatternRewriter &rewriter) const override { 5338 if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) 5339 return failure(); 5340 5341 rewriter.replaceOpWithNewOp<MaskedStoreOp>( 5342 op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore()); 5343 return success(); 5344 } 5345 }; 5346 } // namespace 5347 5348 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results, 5349 MLIRContext *context) { 5350 results.add<ScatterFolder, FoldContiguousScatter>(context); 5351 } 5352 5353 //===----------------------------------------------------------------------===// 5354 // ExpandLoadOp 5355 //===----------------------------------------------------------------------===// 5356 5357 LogicalResult ExpandLoadOp::verify() { 5358 VectorType maskVType = getMaskVectorType(); 5359 VectorType passVType = getPassThruVectorType(); 5360 VectorType resVType = getVectorType(); 5361 MemRefType memType = getMemRefType(); 5362 5363 if (resVType.getElementType() != memType.getElementType()) 5364 return emitOpError("base and result element type should match"); 5365 if (llvm::size(getIndices()) != memType.getRank()) 5366 return emitOpError("requires ") << memType.getRank() << " indices"; 5367 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 5368 return emitOpError("expected result dim to match mask dim"); 5369 if (resVType != passVType) 5370 return emitOpError("expected pass_thru of same type as result type"); 5371 return success(); 5372 } 5373 5374 namespace { 5375 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { 5376 public: 5377 using OpRewritePattern::OpRewritePattern; 5378 LogicalResult matchAndRewrite(ExpandLoadOp expand, 5379 PatternRewriter &rewriter) const override { 5380 switch (getMaskFormat(expand.getMask())) { 5381 case MaskFormat::AllTrue: 5382 rewriter.replaceOpWithNewOp<vector::LoadOp>( 5383 expand, expand.getType(), expand.getBase(), expand.getIndices()); 5384 return success(); 5385 case MaskFormat::AllFalse: 5386 rewriter.replaceOp(expand, expand.getPassThru()); 5387 return success(); 5388 case MaskFormat::Unknown: 5389 return failure(); 5390 } 5391 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder"); 5392 } 5393 }; 5394 } // namespace 5395 5396 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 5397 MLIRContext *context) { 5398 results.add<ExpandLoadFolder>(context); 5399 } 5400 5401 //===----------------------------------------------------------------------===// 5402 // CompressStoreOp 5403 //===----------------------------------------------------------------------===// 5404 5405 LogicalResult CompressStoreOp::verify() { 5406 VectorType maskVType = getMaskVectorType(); 5407 VectorType valueVType = getVectorType(); 5408 MemRefType memType = getMemRefType(); 5409 5410 if (valueVType.getElementType() != memType.getElementType()) 5411 return emitOpError("base and valueToStore element type should match"); 5412 if (llvm::size(getIndices()) != memType.getRank()) 5413 return emitOpError("requires ") << memType.getRank() << " indices"; 5414 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 5415 return emitOpError("expected valueToStore dim to match mask dim"); 5416 return success(); 5417 } 5418 5419 namespace { 5420 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { 5421 public: 5422 using OpRewritePattern::OpRewritePattern; 5423 LogicalResult matchAndRewrite(CompressStoreOp compress, 5424 PatternRewriter &rewriter) const override { 5425 switch (getMaskFormat(compress.getMask())) { 5426 case MaskFormat::AllTrue: 5427 rewriter.replaceOpWithNewOp<vector::StoreOp>( 5428 compress, compress.getValueToStore(), compress.getBase(), 5429 compress.getIndices()); 5430 return success(); 5431 case MaskFormat::AllFalse: 5432 rewriter.eraseOp(compress); 5433 return success(); 5434 case MaskFormat::Unknown: 5435 return failure(); 5436 } 5437 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder"); 5438 } 5439 }; 5440 } // namespace 5441 5442 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 5443 MLIRContext *context) { 5444 results.add<CompressStoreFolder>(context); 5445 } 5446 5447 //===----------------------------------------------------------------------===// 5448 // ShapeCastOp 5449 //===----------------------------------------------------------------------===// 5450 5451 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 5452 SetIntRangeFn setResultRanges) { 5453 setResultRanges(getResult(), argRanges.front()); 5454 } 5455 5456 /// Returns true if each element of 'a' is equal to the product of a contiguous 5457 /// sequence of the elements of 'b'. Returns false otherwise. 5458 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { 5459 unsigned rankA = a.size(); 5460 unsigned rankB = b.size(); 5461 assert(rankA < rankB); 5462 5463 auto isOne = [](int64_t v) { return v == 1; }; 5464 5465 // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape 5466 // casted to a 0-d vector. 5467 if (rankA == 0 && llvm::all_of(b, isOne)) 5468 return true; 5469 5470 unsigned i = 0; 5471 unsigned j = 0; 5472 while (i < rankA && j < rankB) { 5473 int64_t dimA = a[i]; 5474 int64_t dimB = 1; 5475 while (dimB < dimA && j < rankB) 5476 dimB *= b[j++]; 5477 if (dimA != dimB) 5478 break; 5479 ++i; 5480 5481 // Handle the case when trailing dimensions are of size 1. 5482 // Include them into the contiguous sequence. 5483 if (i < rankA && llvm::all_of(a.slice(i), isOne)) 5484 i = rankA; 5485 if (j < rankB && llvm::all_of(b.slice(j), isOne)) 5486 j = rankB; 5487 } 5488 5489 return i == rankA && j == rankB; 5490 } 5491 5492 static LogicalResult verifyVectorShapeCast(Operation *op, 5493 VectorType sourceVectorType, 5494 VectorType resultVectorType) { 5495 // Check that element type is the same. 5496 if (sourceVectorType.getElementType() != resultVectorType.getElementType()) 5497 return op->emitOpError("source/result vectors must have same element type"); 5498 auto sourceShape = sourceVectorType.getShape(); 5499 auto resultShape = resultVectorType.getShape(); 5500 5501 // Check that product of source dim sizes matches product of result dim sizes. 5502 int64_t sourceDimProduct = std::accumulate( 5503 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{}); 5504 int64_t resultDimProduct = std::accumulate( 5505 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{}); 5506 if (sourceDimProduct != resultDimProduct) 5507 return op->emitOpError("source/result number of elements must match"); 5508 5509 // Check that expanding/contracting rank cases. 5510 unsigned sourceRank = sourceVectorType.getRank(); 5511 unsigned resultRank = resultVectorType.getRank(); 5512 if (sourceRank < resultRank) { 5513 if (!isValidShapeCast(sourceShape, resultShape)) 5514 return op->emitOpError("invalid shape cast"); 5515 } else if (sourceRank > resultRank) { 5516 if (!isValidShapeCast(resultShape, sourceShape)) 5517 return op->emitOpError("invalid shape cast"); 5518 } 5519 5520 // Check that (non-)scalability is preserved 5521 int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims(); 5522 int64_t resultNScalableDims = resultVectorType.getNumScalableDims(); 5523 if (sourceNScalableDims != resultNScalableDims) 5524 return op->emitOpError("different number of scalable dims at source (") 5525 << sourceNScalableDims << ") and result (" << resultNScalableDims 5526 << ")"; 5527 sourceVectorType.getNumDynamicDims(); 5528 5529 return success(); 5530 } 5531 5532 LogicalResult ShapeCastOp::verify() { 5533 auto sourceVectorType = 5534 llvm::dyn_cast_or_null<VectorType>(getSource().getType()); 5535 auto resultVectorType = 5536 llvm::dyn_cast_or_null<VectorType>(getResult().getType()); 5537 5538 // Check if source/result are of vector type. 5539 if (sourceVectorType && resultVectorType) 5540 return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType); 5541 5542 return success(); 5543 } 5544 5545 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { 5546 // No-op shape cast. 5547 if (getSource().getType() == getResult().getType()) 5548 return getSource(); 5549 5550 // Canceling shape casts. 5551 if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) { 5552 if (getResult().getType() == otherOp.getSource().getType()) 5553 return otherOp.getSource(); 5554 5555 // Only allows valid transitive folding. 5556 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType()); 5557 VectorType resultType = llvm::cast<VectorType>(getResult().getType()); 5558 if (srcType.getRank() < resultType.getRank()) { 5559 if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) 5560 return {}; 5561 } else if (srcType.getRank() > resultType.getRank()) { 5562 if (!isValidShapeCast(resultType.getShape(), srcType.getShape())) 5563 return {}; 5564 } else { 5565 return {}; 5566 } 5567 5568 setOperand(otherOp.getSource()); 5569 return getResult(); 5570 } 5571 5572 // Cancelling broadcast and shape cast ops. 5573 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) { 5574 if (bcastOp.getSourceType() == getType()) 5575 return bcastOp.getSource(); 5576 } 5577 5578 return {}; 5579 } 5580 5581 namespace { 5582 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. 5583 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> { 5584 public: 5585 using OpRewritePattern::OpRewritePattern; 5586 5587 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, 5588 PatternRewriter &rewriter) const override { 5589 auto constantOp = 5590 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>(); 5591 if (!constantOp) 5592 return failure(); 5593 // Only handle splat for now. 5594 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue()); 5595 if (!dense) 5596 return failure(); 5597 auto newAttr = 5598 DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()), 5599 dense.getSplatValue<Attribute>()); 5600 rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr); 5601 return success(); 5602 } 5603 }; 5604 5605 /// Helper function that computes a new vector type based on the input vector 5606 /// type by removing the trailing one dims: 5607 /// 5608 /// vector<4x1x1xi1> --> vector<4x1> 5609 /// 5610 static VectorType trimTrailingOneDims(VectorType oldType) { 5611 ArrayRef<int64_t> oldShape = oldType.getShape(); 5612 ArrayRef<int64_t> newShape = oldShape; 5613 5614 ArrayRef<bool> oldScalableDims = oldType.getScalableDims(); 5615 ArrayRef<bool> newScalableDims = oldScalableDims; 5616 5617 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) { 5618 newShape = newShape.drop_back(1); 5619 newScalableDims = newScalableDims.drop_back(1); 5620 } 5621 5622 // Make sure we have at least 1 dimension. 5623 // TODO: Add support for 0-D vectors. 5624 if (newShape.empty()) { 5625 newShape = oldShape.take_back(); 5626 newScalableDims = oldScalableDims.take_back(); 5627 } 5628 5629 return VectorType::get(newShape, oldType.getElementType(), newScalableDims); 5630 } 5631 5632 /// Folds qualifying shape_cast(create_mask) into a new create_mask 5633 /// 5634 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit 5635 /// dimension. If the input vector comes from `vector.create_mask` for which 5636 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe 5637 /// to fold shape_cast into create_mask. 5638 /// 5639 /// BEFORE: 5640 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1> 5641 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1> 5642 /// AFTER: 5643 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1> 5644 class ShapeCastCreateMaskFolderTrailingOneDim final 5645 : public OpRewritePattern<ShapeCastOp> { 5646 public: 5647 using OpRewritePattern::OpRewritePattern; 5648 5649 LogicalResult matchAndRewrite(ShapeCastOp shapeOp, 5650 PatternRewriter &rewriter) const override { 5651 Value shapeOpSrc = shapeOp->getOperand(0); 5652 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>(); 5653 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>(); 5654 if (!createMaskOp && !constantMaskOp) 5655 return failure(); 5656 5657 VectorType shapeOpResTy = shapeOp.getResultVectorType(); 5658 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType(); 5659 5660 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy); 5661 if (newVecType != shapeOpResTy) 5662 return failure(); 5663 5664 auto numDimsToDrop = 5665 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size(); 5666 5667 // No unit dims to drop 5668 if (!numDimsToDrop) 5669 return failure(); 5670 5671 if (createMaskOp) { 5672 auto maskOperands = createMaskOp.getOperands(); 5673 auto numMaskOperands = maskOperands.size(); 5674 5675 // Check every mask dim size to see whether it can be dropped 5676 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop; 5677 --i) { 5678 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>(); 5679 if (!constant || (constant.value() != 1)) 5680 return failure(); 5681 } 5682 SmallVector<Value> newMaskOperands = 5683 maskOperands.drop_back(numDimsToDrop); 5684 5685 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy, 5686 newMaskOperands); 5687 return success(); 5688 } 5689 5690 if (constantMaskOp) { 5691 auto maskDimSizes = constantMaskOp.getMaskDimSizes(); 5692 auto numMaskOperands = maskDimSizes.size(); 5693 5694 // Check every mask dim size to see whether it can be dropped 5695 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop; 5696 --i) { 5697 if (maskDimSizes[i] != 1) 5698 return failure(); 5699 } 5700 5701 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop); 5702 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy, 5703 newMaskOperands); 5704 return success(); 5705 } 5706 5707 return failure(); 5708 } 5709 }; 5710 5711 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast. 5712 /// This only applies when the shape of the broadcast source 5713 /// 1. is a suffix of the shape of the result (i.e. when broadcast without 5714 /// reshape is expressive enough to capture the result in a single op), or 5715 /// 2. has the same element count as the shape cast result. 5716 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> { 5717 public: 5718 using OpRewritePattern::OpRewritePattern; 5719 5720 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, 5721 PatternRewriter &rewriter) const override { 5722 auto broadcastOp = 5723 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>(); 5724 if (!broadcastOp) 5725 return failure(); 5726 5727 ArrayRef<int64_t> broadcastSourceShape; 5728 if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) 5729 broadcastSourceShape = srcType.getShape(); 5730 ArrayRef<int64_t> shapeCastTargetShape = 5731 shapeCastOp.getResultVectorType().getShape(); 5732 5733 // If `broadcastSourceShape` is a suffix of the result, we can just replace 5734 // with a broadcast to the final shape. 5735 if (broadcastSourceShape == 5736 shapeCastTargetShape.take_back(broadcastSourceShape.size())) { 5737 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 5738 shapeCastOp, shapeCastOp.getResultVectorType(), 5739 broadcastOp.getSource()); 5740 return success(); 5741 } 5742 5743 // Otherwise, if the final result has the same element count, we can replace 5744 // with a shape cast. 5745 if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) { 5746 if (srcType.getNumElements() == 5747 shapeCastOp.getResultVectorType().getNumElements()) { 5748 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 5749 shapeCastOp, shapeCastOp.getResultVectorType(), 5750 broadcastOp.getSource()); 5751 return success(); 5752 } 5753 } 5754 5755 return failure(); 5756 } 5757 }; 5758 5759 } // namespace 5760 5761 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 5762 MLIRContext *context) { 5763 results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim, 5764 ShapeCastBroadcastFolder>(context); 5765 } 5766 5767 //===----------------------------------------------------------------------===// 5768 // VectorBitCastOp 5769 //===----------------------------------------------------------------------===// 5770 5771 LogicalResult BitCastOp::verify() { 5772 auto sourceVectorType = getSourceVectorType(); 5773 auto resultVectorType = getResultVectorType(); 5774 5775 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) { 5776 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i)) 5777 return emitOpError("dimension size mismatch at: ") << i; 5778 } 5779 5780 DataLayout dataLayout = DataLayout::closest(*this); 5781 auto sourceElementBits = 5782 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()); 5783 auto resultElementBits = 5784 dataLayout.getTypeSizeInBits(resultVectorType.getElementType()); 5785 5786 if (sourceVectorType.getRank() == 0) { 5787 if (sourceElementBits != resultElementBits) 5788 return emitOpError("source/result bitwidth of the 0-D vector element " 5789 "types must be equal"); 5790 } else if (sourceElementBits * sourceVectorType.getShape().back() != 5791 resultElementBits * resultVectorType.getShape().back()) { 5792 return emitOpError( 5793 "source/result bitwidth of the minor 1-D vectors must be equal"); 5794 } 5795 5796 return success(); 5797 } 5798 5799 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) { 5800 // Nop cast. 5801 if (getSource().getType() == getResult().getType()) 5802 return getSource(); 5803 5804 // Canceling bitcasts. 5805 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) { 5806 if (getResult().getType() == otherOp.getSource().getType()) 5807 return otherOp.getSource(); 5808 5809 setOperand(otherOp.getSource()); 5810 return getResult(); 5811 } 5812 5813 Attribute sourceConstant = adaptor.getSource(); 5814 if (!sourceConstant) 5815 return {}; 5816 5817 Type srcElemType = getSourceVectorType().getElementType(); 5818 Type dstElemType = getResultVectorType().getElementType(); 5819 5820 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) { 5821 if (floatPack.isSplat()) { 5822 auto splat = floatPack.getSplatValue<FloatAttr>(); 5823 5824 // Casting fp16 into fp32. 5825 if (srcElemType.isF16() && dstElemType.isF32()) { 5826 uint32_t bits = static_cast<uint32_t>( 5827 splat.getValue().bitcastToAPInt().getZExtValue()); 5828 // Duplicate the 16-bit pattern. 5829 bits = (bits << 16) | (bits & 0xffff); 5830 APInt intBits(32, bits); 5831 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits); 5832 return DenseElementsAttr::get(getResultVectorType(), floatBits); 5833 } 5834 } 5835 } 5836 5837 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) { 5838 if (intPack.isSplat()) { 5839 auto splat = intPack.getSplatValue<IntegerAttr>(); 5840 5841 if (llvm::isa<IntegerType>(dstElemType)) { 5842 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth(); 5843 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth(); 5844 5845 // Casting to a larger integer bit width. 5846 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) { 5847 APInt intBits = splat.getValue().zext(dstBitWidth); 5848 5849 // Duplicate the lower width element. 5850 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++) 5851 intBits = (intBits << srcBitWidth) | intBits; 5852 return DenseElementsAttr::get(getResultVectorType(), intBits); 5853 } 5854 } 5855 } 5856 } 5857 5858 return {}; 5859 } 5860 5861 //===----------------------------------------------------------------------===// 5862 // TypeCastOp 5863 //===----------------------------------------------------------------------===// 5864 5865 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) { 5866 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType()); 5867 SmallVector<int64_t, 8> res(memRefType.getShape()); 5868 if (vectorType) 5869 res.append(vectorType.getShape().begin(), vectorType.getShape().end()); 5870 return res; 5871 } 5872 5873 /// Build the canonical memRefType with a single vector. 5874 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>. 5875 void TypeCastOp::build(OpBuilder &builder, OperationState &result, 5876 Value source) { 5877 result.addOperands(source); 5878 MemRefType memRefType = llvm::cast<MemRefType>(source.getType()); 5879 VectorType vectorType = 5880 VectorType::get(extractShape(memRefType), 5881 getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); 5882 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(), 5883 memRefType.getMemorySpace())); 5884 } 5885 5886 LogicalResult TypeCastOp::verify() { 5887 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout(); 5888 if (!canonicalType.getLayout().isIdentity()) 5889 return emitOpError("expects operand to be a memref with identity layout"); 5890 if (!getResultMemRefType().getLayout().isIdentity()) 5891 return emitOpError("expects result to be a memref with identity layout"); 5892 if (getResultMemRefType().getMemorySpace() != 5893 getMemRefType().getMemorySpace()) 5894 return emitOpError("expects result in same memory space"); 5895 5896 auto sourceType = getMemRefType(); 5897 auto resultType = getResultMemRefType(); 5898 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) != 5899 getElementTypeOrSelf(getElementTypeOrSelf(resultType))) 5900 return emitOpError( 5901 "expects result and operand with same underlying scalar type: ") 5902 << resultType; 5903 if (extractShape(sourceType) != extractShape(resultType)) 5904 return emitOpError( 5905 "expects concatenated result and operand shapes to be equal: ") 5906 << resultType; 5907 return success(); 5908 } 5909 5910 //===----------------------------------------------------------------------===// 5911 // TransposeOp 5912 //===----------------------------------------------------------------------===// 5913 5914 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, 5915 Value vector, ArrayRef<int64_t> permutation) { 5916 VectorType vt = llvm::cast<VectorType>(vector.getType()); 5917 SmallVector<int64_t, 4> transposedShape(vt.getRank()); 5918 SmallVector<bool, 4> transposedScalableDims(vt.getRank()); 5919 for (unsigned i = 0; i < permutation.size(); ++i) { 5920 transposedShape[i] = vt.getShape()[permutation[i]]; 5921 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]]; 5922 } 5923 5924 result.addOperands(vector); 5925 result.addTypes(VectorType::get(transposedShape, vt.getElementType(), 5926 transposedScalableDims)); 5927 result.addAttribute(TransposeOp::getPermutationAttrName(result.name), 5928 builder.getDenseI64ArrayAttr(permutation)); 5929 } 5930 5931 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { 5932 // Eliminate splat constant transpose ops. 5933 if (auto attr = 5934 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector())) 5935 if (attr.isSplat()) 5936 return attr.reshape(getResultVectorType()); 5937 5938 // Eliminate identity transpose ops. This happens when the dimensions of the 5939 // input vector remain in their original order after the transpose operation. 5940 ArrayRef<int64_t> perm = getPermutation(); 5941 5942 // Check if the permutation of the dimensions contains sequential values: 5943 // {0, 1, 2, ...}. 5944 for (int64_t i = 0, e = perm.size(); i < e; i++) { 5945 if (perm[i] != i) 5946 return {}; 5947 } 5948 5949 return getVector(); 5950 } 5951 5952 LogicalResult vector::TransposeOp::verify() { 5953 VectorType vectorType = getSourceVectorType(); 5954 VectorType resultType = getResultVectorType(); 5955 int64_t rank = resultType.getRank(); 5956 if (vectorType.getRank() != rank) 5957 return emitOpError("vector result rank mismatch: ") << rank; 5958 // Verify transposition array. 5959 ArrayRef<int64_t> perm = getPermutation(); 5960 int64_t size = perm.size(); 5961 if (rank != size) 5962 return emitOpError("transposition length mismatch: ") << size; 5963 SmallVector<bool, 8> seen(rank, false); 5964 for (const auto &ta : llvm::enumerate(perm)) { 5965 if (ta.value() < 0 || ta.value() >= rank) 5966 return emitOpError("transposition index out of range: ") << ta.value(); 5967 if (seen[ta.value()]) 5968 return emitOpError("duplicate position index: ") << ta.value(); 5969 seen[ta.value()] = true; 5970 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value())) 5971 return emitOpError("dimension size mismatch at: ") << ta.value(); 5972 } 5973 return success(); 5974 } 5975 5976 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() { 5977 return llvm::to_vector<4>(getResultVectorType().getShape()); 5978 } 5979 5980 namespace { 5981 5982 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. 5983 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { 5984 public: 5985 using OpRewritePattern::OpRewritePattern; 5986 5987 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 5988 PatternRewriter &rewriter) const override { 5989 // Composes two permutations: result[i] = permutation1[permutation2[i]]. 5990 auto composePermutations = [](ArrayRef<int64_t> permutation1, 5991 ArrayRef<int64_t> permutation2) { 5992 SmallVector<int64_t, 4> result; 5993 for (auto index : permutation2) 5994 result.push_back(permutation1[index]); 5995 return result; 5996 }; 5997 5998 // Return if the input of 'transposeOp' is not defined by another transpose. 5999 vector::TransposeOp parentTransposeOp = 6000 transposeOp.getVector().getDefiningOp<vector::TransposeOp>(); 6001 if (!parentTransposeOp) 6002 return failure(); 6003 6004 SmallVector<int64_t, 4> permutation = composePermutations( 6005 parentTransposeOp.getPermutation(), transposeOp.getPermutation()); 6006 // Replace 'transposeOp' with a new transpose operation. 6007 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 6008 transposeOp, transposeOp.getResult().getType(), 6009 parentTransposeOp.getVector(), permutation); 6010 return success(); 6011 } 6012 }; 6013 6014 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>). 6015 struct FoldTransposedScalarBroadcast final 6016 : public OpRewritePattern<vector::TransposeOp> { 6017 using OpRewritePattern::OpRewritePattern; 6018 6019 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 6020 PatternRewriter &rewriter) const override { 6021 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>(); 6022 if (!bcastOp) 6023 return failure(); 6024 6025 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType()); 6026 if (!srcVectorType || srcVectorType.getNumElements() == 1) { 6027 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 6028 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource()); 6029 return success(); 6030 } 6031 6032 return failure(); 6033 } 6034 }; 6035 6036 // Folds transpose(splat x : src_type) : res_type into splat x : res_type. 6037 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> { 6038 public: 6039 using OpRewritePattern::OpRewritePattern; 6040 6041 LogicalResult matchAndRewrite(TransposeOp transposeOp, 6042 PatternRewriter &rewriter) const override { 6043 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>(); 6044 if (!splatOp) 6045 return failure(); 6046 6047 rewriter.replaceOpWithNewOp<vector::SplatOp>( 6048 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput()); 6049 return success(); 6050 } 6051 }; 6052 6053 /// Folds transpose(create_mask) into a new transposed create_mask. 6054 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> { 6055 public: 6056 using OpRewritePattern::OpRewritePattern; 6057 6058 LogicalResult matchAndRewrite(TransposeOp transpOp, 6059 PatternRewriter &rewriter) const override { 6060 Value transposeSrc = transpOp.getVector(); 6061 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>(); 6062 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>(); 6063 if (!createMaskOp && !constantMaskOp) 6064 return failure(); 6065 6066 // Get the transpose permutation and apply it to the vector.create_mask or 6067 // vector.constant_mask operands. 6068 ArrayRef<int64_t> permutation = transpOp.getPermutation(); 6069 6070 if (createMaskOp) { 6071 auto maskOperands = createMaskOp.getOperands(); 6072 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end()); 6073 applyPermutationToVector(newOperands, permutation); 6074 6075 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>( 6076 transpOp, transpOp.getResultVectorType(), newOperands); 6077 return success(); 6078 } 6079 6080 // ConstantMaskOp case. 6081 auto maskDimSizes = constantMaskOp.getMaskDimSizes(); 6082 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation); 6083 6084 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>( 6085 transpOp, transpOp.getResultVectorType(), newMaskDimSizes); 6086 return success(); 6087 } 6088 }; 6089 6090 } // namespace 6091 6092 void vector::TransposeOp::getCanonicalizationPatterns( 6093 RewritePatternSet &results, MLIRContext *context) { 6094 results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast, 6095 TransposeFolder, FoldTransposeSplat>(context); 6096 } 6097 6098 //===----------------------------------------------------------------------===// 6099 // ConstantMaskOp 6100 //===----------------------------------------------------------------------===// 6101 6102 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result, 6103 VectorType type, ConstantMaskKind kind) { 6104 assert(kind == ConstantMaskKind::AllTrue || 6105 kind == ConstantMaskKind::AllFalse); 6106 build(builder, result, type, 6107 kind == ConstantMaskKind::AllTrue 6108 ? type.getShape() 6109 : SmallVector<int64_t>(type.getRank(), 0)); 6110 } 6111 6112 LogicalResult ConstantMaskOp::verify() { 6113 auto resultType = llvm::cast<VectorType>(getResult().getType()); 6114 // Check the corner case of 0-D vectors first. 6115 if (resultType.getRank() == 0) { 6116 if (getMaskDimSizes().size() != 1) 6117 return emitError("array attr must have length 1 for 0-D vectors"); 6118 auto dim = getMaskDimSizes()[0]; 6119 if (dim != 0 && dim != 1) 6120 return emitError("mask dim size must be either 0 or 1 for 0-D vectors"); 6121 return success(); 6122 } 6123 6124 // Verify that array attr size matches the rank of the vector result. 6125 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank()) 6126 return emitOpError( 6127 "must specify array attr of size equal vector result rank"); 6128 // Verify that each array attr element is in bounds of corresponding vector 6129 // result dimension size. 6130 auto resultShape = resultType.getShape(); 6131 auto resultScalableDims = resultType.getScalableDims(); 6132 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes(); 6133 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) { 6134 if (maskDimSize < 0 || maskDimSize > resultShape[index]) 6135 return emitOpError( 6136 "array attr of size out of bounds of vector result dimension size"); 6137 if (resultScalableDims[index] && maskDimSize != 0 && 6138 maskDimSize != resultShape[index]) 6139 return emitOpError( 6140 "only supports 'none set' or 'all set' scalable dimensions"); 6141 } 6142 // Verify that if one mask dim size is zero, they all should be zero (because 6143 // the mask region is a conjunction of each mask dimension interval). 6144 bool anyZeros = llvm::is_contained(maskDimSizes, 0); 6145 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); 6146 if (anyZeros && !allZeros) 6147 return emitOpError("expected all mask dim sizes to be zeros, " 6148 "as a result of conjunction with zero mask dim"); 6149 return success(); 6150 } 6151 6152 bool ConstantMaskOp::isAllOnesMask() { 6153 auto resultType = getVectorType(); 6154 // Check the corner case of 0-D vectors first. 6155 if (resultType.getRank() == 0) { 6156 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask"); 6157 return getMaskDimSizes()[0] == 1; 6158 } 6159 for (const auto [resultSize, maskDimSize] : 6160 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) { 6161 if (maskDimSize < resultSize) 6162 return false; 6163 } 6164 return true; 6165 } 6166 6167 //===----------------------------------------------------------------------===// 6168 // CreateMaskOp 6169 //===----------------------------------------------------------------------===// 6170 6171 void CreateMaskOp::build(OpBuilder &builder, OperationState &result, 6172 VectorType type, 6173 ArrayRef<OpFoldResult> mixedOperands) { 6174 SmallVector<Value> operands = 6175 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands); 6176 build(builder, result, type, operands); 6177 } 6178 6179 LogicalResult CreateMaskOp::verify() { 6180 auto vectorType = llvm::cast<VectorType>(getResult().getType()); 6181 // Verify that an operand was specified for each result vector each dimension. 6182 if (vectorType.getRank() == 0) { 6183 if (getNumOperands() != 1) 6184 return emitOpError( 6185 "must specify exactly one operand for 0-D create_mask"); 6186 } else if (getNumOperands() != 6187 llvm::cast<VectorType>(getResult().getType()).getRank()) { 6188 return emitOpError( 6189 "must specify an operand for each result vector dimension"); 6190 } 6191 return success(); 6192 } 6193 6194 namespace { 6195 6196 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. 6197 /// 6198 /// Ex 1: 6199 /// %c2 = arith.constant 2 : index 6200 /// %c3 = arith.constant 3 : index 6201 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1> 6202 /// Becomes: 6203 /// vector.constant_mask [3, 2] : vector<4x3xi1> 6204 /// 6205 /// Ex 2: 6206 /// %c_neg_1 = arith.constant -1 : index 6207 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1> 6208 /// becomes: 6209 /// vector.constant_mask [0] : vector<[8]xi1> 6210 /// 6211 /// Ex 3: 6212 /// %c8 = arith.constant 8 : index 6213 /// %c16 = arith.constant 16 : index 6214 /// %0 = vector.vscale 6215 /// %1 = arith.muli %0, %c16 : index 6216 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1> 6217 /// becomes: 6218 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> 6219 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { 6220 public: 6221 using OpRewritePattern::OpRewritePattern; 6222 6223 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, 6224 PatternRewriter &rewriter) const override { 6225 VectorType maskType = createMaskOp.getVectorType(); 6226 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape(); 6227 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims(); 6228 6229 // Special case: Rank zero shape. 6230 constexpr std::array<int64_t, 1> rankZeroShape{1}; 6231 constexpr std::array<bool, 1> rankZeroScalableDims{false}; 6232 if (maskType.getRank() == 0) { 6233 maskTypeDimSizes = rankZeroShape; 6234 maskTypeDimScalableFlags = rankZeroScalableDims; 6235 } 6236 6237 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and 6238 // collect the `constantDims` (for the ConstantMaskOp). 6239 SmallVector<int64_t, 4> constantDims; 6240 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) { 6241 if (auto intSize = getConstantIntValue(dimSize)) { 6242 // Constant value. 6243 // If the mask dim is non-scalable this can be any value. 6244 // If the mask dim is scalable only zero (all-false) is supported. 6245 if (maskTypeDimScalableFlags[i] && intSize >= 0) 6246 return failure(); 6247 constantDims.push_back(*intSize); 6248 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) { 6249 // Constant vscale multiple (e.g. 4 x vscale). 6250 // Must be all-true to fold to a ConstantMask. 6251 if (vscaleMultiplier < maskTypeDimSizes[i]) 6252 return failure(); 6253 constantDims.push_back(*vscaleMultiplier); 6254 } else { 6255 return failure(); 6256 } 6257 } 6258 6259 // Clamp values to constant_mask bounds. 6260 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes)) 6261 value = std::clamp<int64_t>(value, 0, maskDimSize); 6262 6263 // If one of dim sizes is zero, set all dims to zero. 6264 if (llvm::is_contained(constantDims, 0)) 6265 constantDims.assign(constantDims.size(), 0); 6266 6267 // Replace 'createMaskOp' with ConstantMaskOp. 6268 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType, 6269 constantDims); 6270 return success(); 6271 } 6272 }; 6273 6274 } // namespace 6275 6276 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results, 6277 MLIRContext *context) { 6278 results.add<CreateMaskFolder>(context); 6279 } 6280 6281 //===----------------------------------------------------------------------===// 6282 // MaskOp 6283 //===----------------------------------------------------------------------===// 6284 6285 void MaskOp::build( 6286 OpBuilder &builder, OperationState &result, Value mask, 6287 Operation *maskableOp, 6288 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) { 6289 assert(maskRegionBuilder && 6290 "builder callback for 'maskRegion' must be present"); 6291 6292 result.addOperands(mask); 6293 OpBuilder::InsertionGuard guard(builder); 6294 Region *maskRegion = result.addRegion(); 6295 builder.createBlock(maskRegion); 6296 maskRegionBuilder(builder, maskableOp); 6297 } 6298 6299 void MaskOp::build( 6300 OpBuilder &builder, OperationState &result, TypeRange resultTypes, 6301 Value mask, Operation *maskableOp, 6302 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) { 6303 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp, 6304 maskRegionBuilder); 6305 } 6306 6307 void MaskOp::build( 6308 OpBuilder &builder, OperationState &result, TypeRange resultTypes, 6309 Value mask, Value passthru, Operation *maskableOp, 6310 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) { 6311 build(builder, result, mask, maskableOp, maskRegionBuilder); 6312 if (passthru) 6313 result.addOperands(passthru); 6314 result.addTypes(resultTypes); 6315 } 6316 6317 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) { 6318 // Create the op region. 6319 result.regions.reserve(1); 6320 Region &maskRegion = *result.addRegion(); 6321 6322 auto &builder = parser.getBuilder(); 6323 6324 // Parse all the operands. 6325 OpAsmParser::UnresolvedOperand mask; 6326 if (parser.parseOperand(mask)) 6327 return failure(); 6328 6329 // Optional passthru operand. 6330 OpAsmParser::UnresolvedOperand passthru; 6331 ParseResult parsePassthru = parser.parseOptionalComma(); 6332 if (parsePassthru.succeeded() && parser.parseOperand(passthru)) 6333 return failure(); 6334 6335 // Parse op region. 6336 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{})) 6337 return failure(); 6338 6339 MaskOp::ensureTerminator(maskRegion, builder, result.location); 6340 6341 // Parse the optional attribute list. 6342 if (parser.parseOptionalAttrDict(result.attributes)) 6343 return failure(); 6344 6345 // Parse all the types. 6346 Type maskType; 6347 if (parser.parseColonType(maskType)) 6348 return failure(); 6349 6350 SmallVector<Type> resultTypes; 6351 if (parser.parseOptionalArrowTypeList(resultTypes)) 6352 return failure(); 6353 result.types.append(resultTypes); 6354 6355 // Resolve operands. 6356 if (parser.resolveOperand(mask, maskType, result.operands)) 6357 return failure(); 6358 6359 if (parsePassthru.succeeded()) 6360 if (parser.resolveOperand(passthru, resultTypes[0], result.operands)) 6361 return failure(); 6362 6363 return success(); 6364 } 6365 6366 void mlir::vector::MaskOp::print(OpAsmPrinter &p) { 6367 p << " " << getMask(); 6368 if (getPassthru()) 6369 p << ", " << getPassthru(); 6370 6371 // Print single masked operation and skip terminator. 6372 p << " { "; 6373 Block *singleBlock = &getMaskRegion().getBlocks().front(); 6374 if (singleBlock && !singleBlock->getOperations().empty()) 6375 p.printCustomOrGenericOp(&singleBlock->front()); 6376 p << " }"; 6377 6378 p.printOptionalAttrDict(getOperation()->getAttrs()); 6379 6380 p << " : " << getMask().getType(); 6381 if (getNumResults() > 0) 6382 p << " -> " << getResultTypes(); 6383 } 6384 6385 void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { 6386 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl< 6387 MaskOp>::ensureTerminator(region, builder, loc); 6388 // Keep the default yield terminator if the number of masked operations is not 6389 // the expected. This case will trigger a verification failure. 6390 Block &block = region.front(); 6391 if (block.getOperations().size() != 2) 6392 return; 6393 6394 // Replace default yield terminator with a new one that returns the results 6395 // from the masked operation. 6396 OpBuilder opBuilder(builder.getContext()); 6397 Operation *maskedOp = &block.front(); 6398 Operation *oldYieldOp = &block.back(); 6399 assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp"); 6400 6401 // Empty vector.mask op. 6402 if (maskedOp == oldYieldOp) 6403 return; 6404 6405 opBuilder.setInsertionPoint(oldYieldOp); 6406 opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults()); 6407 oldYieldOp->dropAllReferences(); 6408 oldYieldOp->erase(); 6409 } 6410 6411 LogicalResult MaskOp::verify() { 6412 // Structural checks. 6413 Block &block = getMaskRegion().getBlocks().front(); 6414 if (block.getOperations().empty()) 6415 return emitOpError("expects a terminator within the mask region"); 6416 6417 unsigned numMaskRegionOps = block.getOperations().size(); 6418 if (numMaskRegionOps > 2) 6419 return emitOpError("expects only one operation to mask"); 6420 6421 // Terminator checks. 6422 auto terminator = dyn_cast<vector::YieldOp>(block.back()); 6423 if (!terminator) 6424 return emitOpError("expects a terminator within the mask region"); 6425 6426 if (terminator->getNumOperands() != getNumResults()) 6427 return emitOpError( 6428 "expects number of results to match mask region yielded values"); 6429 6430 // Empty vector.mask. Nothing else to check. 6431 if (numMaskRegionOps == 1) 6432 return success(); 6433 6434 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front()); 6435 if (!maskableOp) 6436 return emitOpError("expects a MaskableOpInterface within the mask region"); 6437 6438 // Result checks. 6439 if (maskableOp->getNumResults() != getNumResults()) 6440 return emitOpError("expects number of results to match maskable operation " 6441 "number of results"); 6442 6443 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes())) 6444 return emitOpError( 6445 "expects result type to match maskable operation result type"); 6446 6447 if (llvm::count_if(maskableOp->getResultTypes(), 6448 [](Type t) { return llvm::isa<VectorType>(t); }) > 1) 6449 return emitOpError("multiple vector results not supported"); 6450 6451 // Mask checks. 6452 Type expectedMaskType = maskableOp.getExpectedMaskType(); 6453 if (getMask().getType() != expectedMaskType) 6454 return emitOpError("expects a ") 6455 << expectedMaskType << " mask for the maskable operation"; 6456 6457 // Passthru checks. 6458 Value passthru = getPassthru(); 6459 if (passthru) { 6460 if (!maskableOp.supportsPassthru()) 6461 return emitOpError( 6462 "doesn't expect a passthru argument for this maskable operation"); 6463 6464 if (maskableOp->getNumResults() != 1) 6465 return emitOpError("expects result when passthru argument is provided"); 6466 6467 if (passthru.getType() != maskableOp->getResultTypes()[0]) 6468 return emitOpError("expects passthru type to match result type"); 6469 } 6470 6471 return success(); 6472 } 6473 6474 /// Folds vector.mask ops with an all-true mask. 6475 LogicalResult MaskOp::fold(FoldAdaptor adaptor, 6476 SmallVectorImpl<OpFoldResult> &results) { 6477 MaskFormat maskFormat = getMaskFormat(getMask()); 6478 if (isEmpty()) 6479 return failure(); 6480 6481 if (maskFormat != MaskFormat::AllTrue) 6482 return failure(); 6483 6484 // Move maskable operation outside of the `vector.mask` region. 6485 Operation *maskableOp = getMaskableOp(); 6486 maskableOp->dropAllUses(); 6487 maskableOp->moveBefore(getOperation()); 6488 6489 llvm::append_range(results, maskableOp->getResults()); 6490 return success(); 6491 } 6492 6493 // Elides empty vector.mask operations with or without return values. Propagates 6494 // the yielded values by the vector.yield terminator, if any, or erases the op, 6495 // otherwise. 6496 class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> { 6497 using OpRewritePattern::OpRewritePattern; 6498 6499 LogicalResult matchAndRewrite(MaskOp maskOp, 6500 PatternRewriter &rewriter) const override { 6501 auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation()); 6502 if (maskingOp.getMaskableOp()) 6503 return failure(); 6504 6505 if (!maskOp.isEmpty()) 6506 return failure(); 6507 6508 Block *block = maskOp.getMaskBlock(); 6509 auto terminator = cast<vector::YieldOp>(block->front()); 6510 if (terminator.getNumOperands() == 0) 6511 rewriter.eraseOp(maskOp); 6512 else 6513 rewriter.replaceOp(maskOp, terminator.getOperands()); 6514 6515 return success(); 6516 } 6517 }; 6518 6519 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results, 6520 MLIRContext *context) { 6521 results.add<ElideEmptyMaskOp>(context); 6522 } 6523 6524 // MaskingOpInterface definitions. 6525 6526 /// Returns the operation masked by this 'vector.mask'. 6527 Operation *MaskOp::getMaskableOp() { 6528 Block *block = getMaskBlock(); 6529 if (block->getOperations().size() < 2) 6530 return nullptr; 6531 6532 return &block->front(); 6533 } 6534 6535 /// Returns true if 'vector.mask' has a passthru value. 6536 bool MaskOp::hasPassthru() { return getPassthru() != Value(); } 6537 6538 //===----------------------------------------------------------------------===// 6539 // ScanOp 6540 //===----------------------------------------------------------------------===// 6541 6542 LogicalResult ScanOp::verify() { 6543 VectorType srcType = getSourceType(); 6544 VectorType initialType = getInitialValueType(); 6545 // Check reduction dimension < rank. 6546 int64_t srcRank = srcType.getRank(); 6547 int64_t reductionDim = getReductionDim(); 6548 if (reductionDim >= srcRank) 6549 return emitOpError("reduction dimension ") 6550 << reductionDim << " has to be less than " << srcRank; 6551 6552 // Check that rank(initial_value) = rank(src) - 1. 6553 int64_t initialValueRank = initialType.getRank(); 6554 if (initialValueRank != srcRank - 1) 6555 return emitOpError("initial value rank ") 6556 << initialValueRank << " has to be equal to " << srcRank - 1; 6557 6558 // Check shapes of initial value and src. 6559 ArrayRef<int64_t> srcShape = srcType.getShape(); 6560 ArrayRef<int64_t> initialValueShapes = initialType.getShape(); 6561 SmallVector<int64_t> expectedShape; 6562 for (int i = 0; i < srcRank; i++) { 6563 if (i != reductionDim) 6564 expectedShape.push_back(srcShape[i]); 6565 } 6566 if (!llvm::equal(initialValueShapes, expectedShape)) { 6567 return emitOpError("incompatible input/initial value shapes"); 6568 } 6569 6570 // Verify supported reduction kind. 6571 Type eltType = getDestType().getElementType(); 6572 if (!isSupportedCombiningKind(getKind(), eltType)) 6573 return emitOpError("unsupported reduction type ") 6574 << eltType << " for kind '" << stringifyCombiningKind(getKind()) 6575 << "'"; 6576 6577 return success(); 6578 } 6579 6580 void mlir::vector::populateVectorToVectorCanonicalizationPatterns( 6581 RewritePatternSet &patterns, PatternBenefit benefit) { 6582 patterns 6583 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder, 6584 ScatterFolder, ExpandLoadFolder, CompressStoreFolder, 6585 StridedSliceConstantMaskFolder, TransposeFolder>( 6586 patterns.getContext(), benefit); 6587 } 6588 6589 //===----------------------------------------------------------------------===// 6590 // SplatOp 6591 //===----------------------------------------------------------------------===// 6592 6593 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { 6594 auto constOperand = adaptor.getInput(); 6595 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand)) 6596 return {}; 6597 6598 // SplatElementsAttr::get treats single value for second arg as being a splat. 6599 return SplatElementsAttr::get(getType(), {constOperand}); 6600 } 6601 6602 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 6603 SetIntRangeFn setResultRanges) { 6604 setResultRanges(getResult(), argRanges.front()); 6605 } 6606 6607 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, 6608 CombiningKind kind, Value v1, Value acc, 6609 arith::FastMathFlagsAttr fastmath, 6610 Value mask) { 6611 Type t1 = getElementTypeOrSelf(v1.getType()); 6612 Type tAcc = getElementTypeOrSelf(acc.getType()); 6613 Value result; 6614 6615 switch (kind) { 6616 case CombiningKind::ADD: 6617 if (t1.isIntOrIndex() && tAcc.isIntOrIndex()) 6618 result = b.createOrFold<arith::AddIOp>(loc, v1, acc); 6619 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc)) 6620 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath); 6621 else 6622 llvm_unreachable("invalid value types for ADD reduction"); 6623 break; 6624 case CombiningKind::AND: 6625 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); 6626 result = b.createOrFold<arith::AndIOp>(loc, v1, acc); 6627 break; 6628 case CombiningKind::MAXNUMF: 6629 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) && 6630 "expected float values"); 6631 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath); 6632 break; 6633 case CombiningKind::MAXIMUMF: 6634 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) && 6635 "expected float values"); 6636 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath); 6637 break; 6638 case CombiningKind::MINNUMF: 6639 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) && 6640 "expected float values"); 6641 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath); 6642 break; 6643 case CombiningKind::MINIMUMF: 6644 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) && 6645 "expected float values"); 6646 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath); 6647 break; 6648 case CombiningKind::MAXSI: 6649 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); 6650 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc); 6651 break; 6652 case CombiningKind::MINSI: 6653 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); 6654 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc); 6655 break; 6656 case CombiningKind::MAXUI: 6657 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); 6658 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc); 6659 break; 6660 case CombiningKind::MINUI: 6661 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); 6662 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc); 6663 break; 6664 case CombiningKind::MUL: 6665 if (t1.isIntOrIndex() && tAcc.isIntOrIndex()) 6666 result = b.createOrFold<arith::MulIOp>(loc, v1, acc); 6667 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc)) 6668 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath); 6669 else 6670 llvm_unreachable("invalid value types for MUL reduction"); 6671 break; 6672 case CombiningKind::OR: 6673 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); 6674 result = b.createOrFold<arith::OrIOp>(loc, v1, acc); 6675 break; 6676 case CombiningKind::XOR: 6677 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); 6678 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc); 6679 break; 6680 }; 6681 6682 assert(result && "unknown CombiningKind"); 6683 return selectPassthru(b, mask, result, acc); 6684 } 6685 6686 //===----------------------------------------------------------------------===// 6687 // Vector Masking Utilities 6688 //===----------------------------------------------------------------------===// 6689 6690 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp` 6691 /// as masked operation. 6692 void mlir::vector::createMaskOpRegion(OpBuilder &builder, 6693 Operation *maskableOp) { 6694 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block"); 6695 Block *insBlock = builder.getInsertionBlock(); 6696 // Create a block and move the op to that block. 6697 insBlock->getOperations().splice( 6698 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp); 6699 builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults()); 6700 } 6701 6702 /// Creates a vector.mask operation around a maskable operation. Returns the 6703 /// vector.mask operation if the mask provided is valid. Otherwise, returns 6704 /// the maskable operation itself. 6705 Operation *mlir::vector::maskOperation(OpBuilder &builder, 6706 Operation *maskableOp, Value mask, 6707 Value passthru) { 6708 if (!mask) 6709 return maskableOp; 6710 if (passthru) 6711 return builder.create<MaskOp>(maskableOp->getLoc(), 6712 maskableOp->getResultTypes(), mask, passthru, 6713 maskableOp, createMaskOpRegion); 6714 return builder.create<MaskOp>(maskableOp->getLoc(), 6715 maskableOp->getResultTypes(), mask, maskableOp, 6716 createMaskOpRegion); 6717 } 6718 6719 /// Creates a vector select operation that picks values from `newValue` or 6720 /// `passthru` for each result vector lane based on `mask`. This utility is used 6721 /// to propagate the pass-thru value of vector.mask or for cases where only the 6722 /// pass-thru value propagation is needed. VP intrinsics do not support 6723 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are 6724 /// usually able to match op + select patterns and fold them into a native 6725 /// target instructions. 6726 Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask, 6727 Value newValue, Value passthru) { 6728 if (!mask) 6729 return newValue; 6730 6731 return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(), 6732 mask, newValue, passthru); 6733 } 6734 6735 //===----------------------------------------------------------------------===// 6736 // TableGen'd op method definitions 6737 //===----------------------------------------------------------------------===// 6738 6739 #define GET_ATTRDEF_CLASSES 6740 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc" 6741 6742 #define GET_OP_CLASSES 6743 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" 6744