1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <cassert> 16 #include <cstdint> 17 #include <functional> 18 #include <optional> 19 #include <type_traits> 20 21 #include "mlir/Dialect/Affine/IR/AffineOps.h" 22 #include "mlir/Dialect/Arith/IR/Arith.h" 23 #include "mlir/Dialect/Arith/Utils/Utils.h" 24 #include "mlir/Dialect/Linalg/IR/Linalg.h" 25 #include "mlir/Dialect/MemRef/IR/MemRef.h" 26 #include "mlir/Dialect/SCF/IR/SCF.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/Dialect/Utils/IndexingUtils.h" 29 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 30 #include "mlir/Dialect/Vector/IR/VectorOps.h" 31 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 32 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 33 #include "mlir/IR/BuiltinAttributeInterfaces.h" 34 #include "mlir/IR/BuiltinTypes.h" 35 #include "mlir/IR/ImplicitLocOpBuilder.h" 36 #include "mlir/IR/Location.h" 37 #include "mlir/IR/Matchers.h" 38 #include "mlir/IR/PatternMatch.h" 39 #include "mlir/IR/TypeUtilities.h" 40 #include "mlir/Interfaces/VectorInterfaces.h" 41 42 #include "llvm/ADT/DenseSet.h" 43 #include "llvm/ADT/MapVector.h" 44 #include "llvm/ADT/STLExtras.h" 45 #include "llvm/Support/CommandLine.h" 46 #include "llvm/Support/Debug.h" 47 #include "llvm/Support/FormatVariadic.h" 48 #include "llvm/Support/raw_ostream.h" 49 50 #define DEBUG_TYPE "vector-to-vector" 51 52 using namespace mlir; 53 using namespace mlir::vector; 54 55 template <typename IntType> 56 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) { 57 return llvm::to_vector<4>(llvm::map_range( 58 arrayAttr.getAsRange<IntegerAttr>(), 59 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 60 } 61 62 // Helper to find an index in an affine map. 63 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 64 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 65 int64_t idx = map.getDimPosition(i); 66 if (idx == index) 67 return i; 68 } 69 return std::nullopt; 70 } 71 72 namespace { 73 74 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 75 // 76 // Example: 77 // 78 // The following MLIR with cancelling ShapeCastOps: 79 // 80 // %0 = source : vector<5x4x2xf32> 81 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 82 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 83 // %3 = user %2 : vector<5x4x2xf32> 84 // 85 // Should canonicalize to the following: 86 // 87 // %0 = source : vector<5x4x2xf32> 88 // %1 = user %0 : vector<5x4x2xf32> 89 // 90 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 91 using OpRewritePattern::OpRewritePattern; 92 93 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 94 PatternRewriter &rewriter) const override { 95 // Check if 'shapeCastOp' has vector source/result type. 96 auto sourceVectorType = 97 dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType()); 98 auto resultVectorType = 99 dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType()); 100 if (!sourceVectorType || !resultVectorType) 101 return failure(); 102 103 // Check if shape cast op source operand is also a shape cast op. 104 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 105 shapeCastOp.getSource().getDefiningOp()); 106 if (!sourceShapeCastOp) 107 return failure(); 108 auto operandSourceVectorType = 109 cast<VectorType>(sourceShapeCastOp.getSource().getType()); 110 auto operandResultVectorType = sourceShapeCastOp.getType(); 111 112 // Check if shape cast operations invert each other. 113 if (operandSourceVectorType != resultVectorType || 114 operandResultVectorType != sourceVectorType) 115 return failure(); 116 117 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource()); 118 return success(); 119 } 120 }; 121 122 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 123 /// Ex: 124 /// ``` 125 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 126 /// %1 = vector.multi_reduction add, %0 [1] 127 /// : vector<8x32x16xf32> to vector<8x16xf32> 128 /// ``` 129 /// Gets converted to: 130 /// ``` 131 /// %1 = vector.contract {indexing_maps = [ 132 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 133 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 134 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 135 /// iterator_types = ["parallel", "parallel", "reduction"], 136 /// kind = add} %0, %arg1, %cst_f0 137 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 138 /// ``` 139 struct MultiReduceToContract 140 : public OpRewritePattern<vector::MultiDimReductionOp> { 141 using OpRewritePattern::OpRewritePattern; 142 143 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 144 PatternRewriter &rewriter) const override { 145 if (reduceOp.getKind() != vector::CombiningKind::ADD) 146 return failure(); 147 Operation *mulOp = reduceOp.getSource().getDefiningOp(); 148 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 149 return failure(); 150 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 151 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 152 SmallVector<AffineExpr> exprs; 153 SmallVector<vector::IteratorType> iteratorTypes; 154 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 155 if (!isReduceDim.value()) { 156 iteratorTypes.push_back(vector::IteratorType::parallel); 157 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 158 } else { 159 iteratorTypes.push_back(vector::IteratorType::reduction); 160 } 161 } 162 auto dstMap = 163 AffineMap::get(/*dimCount=*/reductionMask.size(), 164 /*symbolCount=*/0, exprs, reduceOp.getContext()); 165 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 166 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(), 167 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 168 rewriter.getArrayAttr(llvm::to_vector(llvm::map_range( 169 iteratorTypes, [&](IteratorType t) -> mlir::Attribute { 170 return IteratorTypeAttr::get(rewriter.getContext(), t); 171 })))); 172 return success(); 173 } 174 }; 175 176 /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user. 177 /// Ex: 178 /// ``` 179 /// %0 = vector.transpose %arg0, [2, 0, 1] 180 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 181 /// %1 = vector.contract {indexing_maps = [ 182 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 183 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 184 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 185 /// iterator_types = ["parallel", "parallel", "reduction"], 186 /// kind = add} %0, %arg1, %cst_f0 187 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 188 /// ``` 189 /// Gets converted to: 190 /// ``` 191 /// %1 = vector.contract {indexing_maps = [ 192 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 193 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 194 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 195 /// iterator_types = ["parallel", "parallel", "reduction"], 196 /// kind = add} %arg0, %arg1, %cst_f0 197 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 198 /// ``` 199 struct CombineContractABTranspose final 200 : public OpRewritePattern<vector::ContractionOp> { 201 using OpRewritePattern::OpRewritePattern; 202 203 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 204 PatternRewriter &rewriter) const override { 205 SmallVector<AffineMap> maps = 206 llvm::to_vector<4>(contractOp.getIndexingMapsArray()); 207 Value lhs = contractOp.getLhs(); 208 Value rhs = contractOp.getRhs(); 209 size_t index = 0; 210 bool changed = false; 211 for (Value *operand : {&lhs, &rhs}) { 212 AffineMap &map = maps[index++]; 213 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 214 if (!transposeOp) 215 continue; 216 AffineMap permutationMap = AffineMap::getPermutationMap( 217 transposeOp.getPermutation(), contractOp.getContext()); 218 map = inversePermutation(permutationMap).compose(map); 219 *operand = transposeOp.getVector(); 220 changed = true; 221 } 222 if (!changed) 223 return failure(); 224 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 225 contractOp, lhs, rhs, contractOp.getAcc(), 226 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 227 return success(); 228 } 229 }; 230 231 /// Merges accumulator and result transposes into contract. 232 /// 233 /// For example: 234 /// ```mlir 235 /// %accT = vector.transpose %acc, [0, 2, 1] 236 /// : vector<2x8x4xf32> to vector<2x4x8xf32> 237 /// %contract = vector.contract { 238 /// indexing_maps = [ 239 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, 240 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, 241 /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 242 /// ], 243 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"], 244 /// kind = #vector.kind<add> 245 /// } %lhs, %rhs, %accT 246 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32> 247 /// %0 = vector.transpose %contract, [0, 2, 1] 248 /// : vector<2x4x8xf32> to vector<2x8x4> 249 /// ``` 250 /// Becomes: 251 /// ```mlir 252 /// %0 = vector.contract { 253 /// indexing_maps = [ 254 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, 255 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, 256 /// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> 257 /// ], 258 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"], 259 /// kind = #vector.kind<add> 260 /// } %lhs, %rhs, %acc 261 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32> 262 /// ``` 263 struct CombineContractResultTranspose final 264 : public OpRewritePattern<vector::TransposeOp> { 265 using OpRewritePattern::OpRewritePattern; 266 267 LogicalResult matchAndRewrite(vector::TransposeOp resTOp, 268 PatternRewriter &rewriter) const override { 269 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>(); 270 if (!contractOp || !contractOp->hasOneUse()) 271 return failure(); 272 273 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>(); 274 if (!accTOp) 275 return failure(); 276 277 MLIRContext *context = contractOp.getContext(); 278 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray()); 279 AffineMap contractMap = maps.back(); 280 281 // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B. 282 // To index into A in contract, we need revert(f)(g(C)) -> A. 283 auto accTMap = 284 AffineMap::getPermutationMap(accTOp.getPermutation(), context); 285 286 // Contract performs g(C) -> D. Result transpose performs h(D) -> E. 287 // To index into E in contract, we need h(g(C)) -> E. 288 auto resTMap = 289 AffineMap::getPermutationMap(resTOp.getPermutation(), context); 290 auto combinedResMap = resTMap.compose(contractMap); 291 292 // The accumulator and result share the same indexing map. So they should be 293 // the same to be able to merge. This means combinedResMap is the same as 294 // inversePermutation(accTMap).compose(contractMap), which means 295 if (inversePermutation(accTMap) != resTMap) 296 return failure(); 297 maps.back() = combinedResMap; 298 299 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 300 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(), 301 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 302 return success(); 303 } 304 }; 305 306 /// Merge BroadcastOp into ContractionOp user. 307 /// Ex: 308 /// ``` 309 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 310 /// %1 = vector.contract {indexing_maps = [ 311 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 312 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 313 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 314 /// iterator_types = ["parallel", "parallel", "reduction"], 315 /// kind = add} %0, %arg1, %cst_f0 316 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 317 /// ``` 318 /// Gets converted to: 319 /// ``` 320 /// %1 = vector.contract {indexing_maps = [ 321 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 322 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 323 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 324 /// iterator_types = ["parallel", "parallel", "reduction"], 325 /// kind = add} %arg0, %arg1, %cst_f0 326 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 327 /// ``` 328 struct CombineContractBroadcast 329 : public OpRewritePattern<vector::ContractionOp> { 330 using OpRewritePattern::OpRewritePattern; 331 332 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 333 PatternRewriter &rewriter) const override { 334 SmallVector<AffineMap> maps = 335 llvm::to_vector<4>(contractOp.getIndexingMapsArray()); 336 Value lhs = contractOp.getLhs(); 337 Value rhs = contractOp.getRhs(); 338 size_t index = 0; 339 bool changed = false; 340 for (Value *operand : {&lhs, &rhs}) { 341 AffineMap &map = maps[index++]; 342 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 343 if (!broadcast) 344 continue; 345 // contractionOp can only take vector as operands. 346 auto srcType = dyn_cast<VectorType>(broadcast.getSourceType()); 347 if (!srcType || 348 srcType.getRank() == broadcast.getResultVectorType().getRank()) 349 continue; 350 int64_t rankDiff = 351 broadcast.getResultVectorType().getRank() - srcType.getRank(); 352 bool innerDimBroadcast = false; 353 SmallVector<AffineExpr> originalDims; 354 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 355 if (dim.value() != broadcast.getResultVectorType().getDimSize( 356 rankDiff + dim.index())) { 357 innerDimBroadcast = true; 358 break; 359 } 360 originalDims.push_back( 361 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 362 } 363 // Contract doesn't support inner dimension broadcast. Once this is 364 // relaxed we can remove this case. 365 if (innerDimBroadcast) 366 continue; 367 368 // It would be incorrect to fold a broadcast onto a reduction dimension 369 // of non-unit size. 370 bool nonUnitDimReductionBroadcast = false; 371 for (int64_t i = 0; i < rankDiff; ++i) { 372 if (broadcast.getResultVectorType().getDimSize(i) != 1 && 373 isReductionIterator(contractOp.getIteratorTypes() 374 .getValue()[map.getDimPosition(i)])) { 375 nonUnitDimReductionBroadcast = true; 376 break; 377 } 378 } 379 if (nonUnitDimReductionBroadcast) 380 continue; 381 382 AffineMap broadcastMap = 383 AffineMap::get(broadcast.getResultVectorType().getRank(), 0, 384 originalDims, contractOp.getContext()); 385 map = broadcastMap.compose(map); 386 *operand = broadcast.getSource(); 387 changed = true; 388 } 389 390 if (!changed) 391 return failure(); 392 393 // Determine which dims are usused, now that the maps have been composed 394 // with the broadcast maps. 395 llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps); 396 // Compress unused dims. 397 for (auto &m : maps) 398 m = compressDims(m, unusedDimsBitVector); 399 // Compute the combined iterators. 400 SmallVector<Attribute> iterators; 401 for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) { 402 if (!unusedDimsBitVector.test(i)) 403 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); 404 } 405 // Check that compressing unused dims isn't removing all reduction dimension 406 // pairs. For example, if the vector.contract had only one reduction 407 // iterator and that was a unit-dimension created by a broadcast, 408 // then we should bail here, otherwise we would create a contract without 409 // a reduction dimension pair. 410 bool hasReductionIteratorApplyingOnBothSides = false; 411 for (unsigned i = 0; i < iterators.size(); ++i) { 412 if (!isReductionIterator(iterators[i])) 413 continue; 414 if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) { 415 hasReductionIteratorApplyingOnBothSides = true; 416 break; 417 } 418 } 419 if (!hasReductionIteratorApplyingOnBothSides) 420 return failure(); 421 422 // If the compressed maps have a dimension that is not used by either LHS or 423 // RHS then the ContractionOp verifier would fail. 424 if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) 425 return failure(); 426 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 427 contractOp, lhs, rhs, contractOp.getAcc(), 428 rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); 429 return success(); 430 } 431 }; 432 433 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and 434 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when 435 /// casting ops are around these operations. 436 /// Ex: 437 /// ``` 438 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> 439 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 440 /// ``` 441 /// Gets converted to: 442 /// ``` 443 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> 444 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> 445 /// ``` 446 struct ReorderCastOpsOnBroadcast 447 : public OpInterfaceRewritePattern<CastOpInterface> { 448 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 449 450 LogicalResult matchAndRewrite(CastOpInterface op, 451 PatternRewriter &rewriter) const override { 452 if (op->getNumOperands() != 1) 453 return failure(); 454 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); 455 if (!bcastOp) 456 return failure(); 457 458 Type castResTy = getElementTypeOrSelf(op->getResult(0)); 459 if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType())) 460 castResTy = vecTy.clone(castResTy); 461 auto *castOp = 462 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 463 bcastOp.getSource(), castResTy, op->getAttrs()); 464 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 465 op, op->getResult(0).getType(), castOp->getResult(0)); 466 return success(); 467 } 468 }; 469 470 /// Reorders elementwise(transpose) to transpose(elementwise). This makes 471 /// transpose ops and contraction ops closer, which kicks in 472 /// CombineContractABTranspose pattern when elementwise ops are between these 473 /// operations. Ex: 474 /// ``` 475 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 476 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 477 /// %r = arith.addf %at, %bt : vector<2x4xf32> 478 /// ``` 479 /// Gets converted to: 480 /// ``` 481 /// %0 = arith.addf %a, %b : vector<4x2xf32> 482 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> 483 /// ``` 484 struct ReorderElementwiseOpsOnTranspose final 485 : public OpTraitRewritePattern<OpTrait::Elementwise> { 486 using OpTraitRewritePattern::OpTraitRewritePattern; 487 LogicalResult matchAndRewrite(Operation *op, 488 PatternRewriter &rewriter) const override { 489 if (op->getNumResults() != 1 || op->getNumRegions() != 0) 490 return failure(); 491 492 // Make sure all operands are transpose/constant ops and collect their 493 // transposition maps. 494 SmallVector<ArrayRef<int64_t>> transposeMaps; 495 transposeMaps.reserve(op->getNumOperands()); 496 // Record the initial type before transposition. We'll use its shape later. 497 // Any type will do here as we will check all transpose maps are the same. 498 VectorType srcType; 499 for (Value operand : op->getOperands()) { 500 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 501 if (transposeOp) { 502 transposeMaps.push_back(transposeOp.getPermutation()); 503 srcType = transposeOp.getSourceVectorType(); 504 } else if (!matchPattern(operand, m_Constant())) { 505 return failure(); 506 } 507 } 508 if (transposeMaps.empty()) 509 return failure(); 510 // This is an elementwise op, so all transposed operands should have the 511 // same type. We need to additionally check that all transposes uses the 512 // same map. 513 if (!llvm::all_equal(transposeMaps)) 514 return rewriter.notifyMatchFailure(op, "different transpose map"); 515 516 SmallVector<Value> srcValues; 517 srcValues.reserve(op->getNumOperands()); 518 519 // If there are constant operands, we need to insert inverse transposes for 520 // them. Calculate the inverse order first. 521 auto order = transposeMaps.front(); 522 SmallVector<int64_t> invOrder(order.size()); 523 for (int i = 0, e = order.size(); i < e; ++i) 524 invOrder[order[i]] = i; 525 526 for (Value operand : op->getOperands()) { 527 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 528 if (transposeOp) { 529 srcValues.push_back(transposeOp.getVector()); 530 } else { 531 // This is a constant. Create a reverse transpose op for it. 532 auto vectorType = 533 srcType.clone(cast<VectorType>(operand.getType()).getElementType()); 534 srcValues.push_back(rewriter.create<vector::TransposeOp>( 535 operand.getLoc(), vectorType, operand, invOrder)); 536 } 537 } 538 539 auto vectorType = srcType.clone( 540 cast<VectorType>(op->getResultTypes()[0]).getElementType()); 541 Operation *elementwiseOp = 542 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, 543 vectorType, op->getAttrs()); 544 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 545 op, op->getResultTypes()[0], elementwiseOp->getResult(0), 546 transposeMaps.front()); 547 return success(); 548 } 549 }; 550 551 // Returns the values in `arrayAttr` as an integer vector. 552 static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) { 553 return llvm::to_vector<4>( 554 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 555 [](IntegerAttr attr) { return attr.getInt(); })); 556 } 557 558 // Shuffles vector.bitcast op after vector.extract op. 559 // 560 // This transforms IR like: 561 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 562 // %1 = vector.extract %0[3] : f16 from vector<8xf16> 563 // Into: 564 // %0 = vector.extract %src[1] : f32 from vector<4xf32> 565 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 566 // %2 = vector.extract %1[1] : f16 from vector<2xf16> 567 struct BubbleDownVectorBitCastForExtract 568 : public OpRewritePattern<vector::ExtractOp> { 569 using OpRewritePattern::OpRewritePattern; 570 571 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 572 PatternRewriter &rewriter) const override { 573 // Only support extracting scalars for now. 574 if (extractOp.getSourceVectorType().getRank() != 1) 575 return failure(); 576 577 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 578 if (!castOp) 579 return failure(); 580 581 VectorType castSrcType = castOp.getSourceVectorType(); 582 VectorType castDstType = castOp.getResultVectorType(); 583 assert(castSrcType.getRank() == castDstType.getRank()); 584 585 // Fail to match if we only have one element in the cast op source. 586 // This is to avoid infinite loop given that this pattern can generate 587 // such cases. 588 if (castSrcType.getNumElements() == 1) 589 return failure(); 590 591 // Only support casting to a larger number of elements or now. 592 // E.g., vector<4xf32> -> vector<8xf16>. 593 if (castSrcType.getNumElements() > castDstType.getNumElements()) 594 return failure(); 595 596 unsigned expandRatio = 597 castDstType.getNumElements() / castSrcType.getNumElements(); 598 599 // Get the first element of the mixed position as integer. 600 auto mixedPos = extractOp.getMixedPosition(); 601 if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0])) 602 return failure(); 603 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt(); 604 605 // Get the single scalar (as a vector) in the source value that packs the 606 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 607 Location loc = extractOp.getLoc(); 608 Value packedValue = rewriter.create<vector::ExtractOp>( 609 loc, castOp.getSource(), index / expandRatio); 610 Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType()); 611 Value zero = rewriter.create<arith::ConstantOp>( 612 loc, packedVecType, rewriter.getZeroAttr(packedVecType)); 613 packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero, 614 /*position=*/0); 615 616 // Cast it to a vector with the desired scalar's type. 617 // E.g. f32 -> vector<2xf16> 618 VectorType packedType = 619 VectorType::get({expandRatio}, castDstType.getElementType()); 620 Value castedValue = 621 rewriter.create<vector::BitCastOp>(loc, packedType, packedValue); 622 623 // Finally extract the desired scalar. 624 rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue, 625 index % expandRatio); 626 return success(); 627 } 628 }; 629 630 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 631 // 632 // This transforms IR like: 633 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 634 // %0 = vector.extract_strided_slice %cast { 635 // offsets = [4], sizes = [4], strides = [1] 636 // } : vector<8xf16> to vector<4xf16> 637 // Into: 638 // %0 = vector.extract_strided_slice %src { 639 // offsets = [2], sizes = [2], strides = [1] 640 // } : vector<4xf32> to vector<2xf32> 641 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 642 struct BubbleDownBitCastForStridedSliceExtract 643 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 644 using OpRewritePattern::OpRewritePattern; 645 646 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 647 PatternRewriter &rewriter) const override { 648 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 649 if (!castOp) 650 return failure(); 651 652 VectorType castSrcType = castOp.getSourceVectorType(); 653 VectorType castDstType = castOp.getResultVectorType(); 654 assert(castSrcType.getRank() == castDstType.getRank()); 655 656 int64_t castSrcLastDim = castSrcType.getShape().back(); 657 int64_t castDstLastDim = castDstType.getShape().back(); 658 // Require casting to more elements for now; other cases to be implemented. 659 if (castSrcLastDim > castDstLastDim) 660 return failure(); 661 662 // Only accept all one strides for now. 663 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(), 664 [](const APInt &val) { return !val.isOne(); })) 665 return failure(); 666 667 unsigned rank = extractOp.getSourceVectorType().getRank(); 668 assert(castDstLastDim % castSrcLastDim == 0); 669 int64_t expandRatio = castDstLastDim / castSrcLastDim; 670 671 // If we have a less number of offsets than the rank, then implicitly we 672 // are selecting the full range for the last bitcasted dimension; other 673 // dimensions aren't affected. Otherwise, we need to scale down the last 674 // dimension's offset given we are extracting from less elements now. 675 ArrayAttr newOffsets = extractOp.getOffsets(); 676 if (newOffsets.size() == rank) { 677 SmallVector<int64_t> offsets = getIntValueVector(newOffsets); 678 if (offsets.back() % expandRatio != 0) 679 return failure(); 680 offsets.back() = offsets.back() / expandRatio; 681 newOffsets = rewriter.getI64ArrayAttr(offsets); 682 } 683 684 // Similarly for sizes. 685 ArrayAttr newSizes = extractOp.getSizes(); 686 if (newSizes.size() == rank) { 687 SmallVector<int64_t> sizes = getIntValueVector(newSizes); 688 if (sizes.back() % expandRatio != 0) 689 return failure(); 690 sizes.back() = sizes.back() / expandRatio; 691 newSizes = rewriter.getI64ArrayAttr(sizes); 692 } 693 694 SmallVector<int64_t> dims = 695 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape()); 696 dims.back() = dims.back() / expandRatio; 697 VectorType newExtractType = 698 VectorType::get(dims, castSrcType.getElementType()); 699 700 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 701 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, 702 newSizes, extractOp.getStrides()); 703 704 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 705 extractOp, extractOp.getType(), newExtractOp); 706 707 return success(); 708 } 709 }; 710 711 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 712 // 713 // This transforms IR like: 714 // %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4> 715 // %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8> 716 // Into: 717 // %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8> 718 // %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8> 719 // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8> 720 // 721 struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { 722 using OpRewritePattern::OpRewritePattern; 723 724 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 725 PatternRewriter &rewriter) const override { 726 VectorType castSrcType = bitcastOp.getSourceVectorType(); 727 VectorType castDstType = bitcastOp.getResultVectorType(); 728 729 // 0-D and scalable vectors are not supported yet. 730 if (castSrcType.getRank() == 0 || castSrcType.isScalable() || 731 castDstType.isScalable()) 732 return failure(); 733 734 int64_t castSrcLastDim = castSrcType.getShape().back(); 735 int64_t castDstLastDim = castDstType.getShape().back(); 736 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim; 737 int64_t ratio; 738 if (isNumElemsShrink) { 739 assert(castSrcLastDim % castDstLastDim == 0); 740 ratio = castSrcLastDim / castDstLastDim; 741 } else { 742 assert(castDstLastDim % castSrcLastDim == 0); 743 ratio = castDstLastDim / castSrcLastDim; 744 } 745 746 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>(); 747 if (!insertOp) 748 return failure(); 749 750 // Only vector sources are supported for now. 751 auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType()); 752 if (!insertSrcType) 753 return failure(); 754 755 // Bitcast the source. 756 SmallVector<int64_t> srcDims(insertSrcType.getShape()); 757 srcDims.back() = 758 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio; 759 VectorType newCastSrcType = 760 VectorType::get(srcDims, castDstType.getElementType()); 761 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 762 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); 763 764 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape()); 765 dstDims.back() = 766 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio; 767 VectorType newCastDstType = 768 VectorType::get(dstDims, castDstType.getElementType()); 769 770 // Bitcast the destination. 771 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 772 bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); 773 774 // Generate new insert. 775 rewriter.replaceOpWithNewOp<vector::InsertOp>( 776 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition()); 777 return success(); 778 } 779 }; 780 781 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 782 // 783 // This transforms IR like: 784 // %0 = vector.insert_strided_slice %src, %dst { 785 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 786 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 787 // Into: 788 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 789 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 790 // %2 = vector.insert_strided_slice %src, %dst { 791 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 792 struct BubbleUpBitCastForStridedSliceInsert 793 : public OpRewritePattern<vector::BitCastOp> { 794 using OpRewritePattern::OpRewritePattern; 795 796 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 797 PatternRewriter &rewriter) const override { 798 VectorType castSrcType = bitcastOp.getSourceVectorType(); 799 VectorType castDstType = bitcastOp.getResultVectorType(); 800 assert(castSrcType.getRank() == castDstType.getRank()); 801 // Skip 0-D vector which will not from InsertStridedSliceOp. 802 if (castSrcType.getRank() == 0) 803 return failure(); 804 805 int64_t castSrcLastDim = castSrcType.getShape().back(); 806 int64_t castDstLastDim = castDstType.getShape().back(); 807 // Require casting to less elements for now; other cases to be implemented. 808 if (castSrcLastDim < castDstLastDim) 809 return failure(); 810 811 assert(castSrcLastDim % castDstLastDim == 0); 812 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 813 814 auto insertOp = 815 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>(); 816 if (!insertOp) 817 return failure(); 818 819 // Only accept all one strides for now. 820 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(), 821 [](const APInt &val) { return !val.isOne(); })) 822 return failure(); 823 824 unsigned rank = insertOp.getSourceVectorType().getRank(); 825 // Require insert op to have the same rank for the source and destination 826 // vector; other cases to be implemented. 827 if (rank != insertOp.getDestVectorType().getRank()) 828 return failure(); 829 830 // Requires that shape of insert op src is castable to dstType. 831 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth(); 832 unsigned destinationWidth = 833 castDstType.getElementType().getIntOrFloatBitWidth(); 834 unsigned numElements = destinationWidth / sourceWidth; 835 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0) 836 return failure(); 837 838 ArrayAttr newOffsets = insertOp.getOffsets(); 839 assert(newOffsets.size() == rank); 840 SmallVector<int64_t> offsets = getIntValueVector(newOffsets); 841 if (offsets.back() % shrinkRatio != 0) 842 return failure(); 843 offsets.back() = offsets.back() / shrinkRatio; 844 newOffsets = rewriter.getI64ArrayAttr(offsets); 845 846 SmallVector<int64_t> srcDims = 847 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 848 srcDims.back() = srcDims.back() / shrinkRatio; 849 VectorType newCastSrcType = 850 VectorType::get(srcDims, castDstType.getElementType()); 851 852 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 853 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); 854 855 SmallVector<int64_t> dstDims = 856 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 857 dstDims.back() = dstDims.back() / shrinkRatio; 858 VectorType newCastDstType = 859 VectorType::get(dstDims, castDstType.getElementType()); 860 861 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 862 bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); 863 864 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 865 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 866 insertOp.getStrides()); 867 868 return success(); 869 } 870 }; 871 872 // Breaks down vector.bitcast op 873 // 874 // This transforms IR like: 875 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 876 // Into: 877 // %cst = vector.splat %c0_f32 : vector<4xf32> 878 // %1 = vector.extract_strided_slice %0 { 879 // offsets = [0], sizes = [4], strides = [1] 880 // } : vector<8xf16> to vector<4xf16> 881 // %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32> 882 // %4 = vector.insert_strided_slice %2, %cst { 883 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 884 // %5 = vector.extract_strided_slice %0 { 885 // offsets = [4], sizes = [4], strides = [1] 886 // } : vector<8xf16> to vector<4xf16> 887 // %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32> 888 // %7 = vector.insert_strided_slice %6, %cst { 889 // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> 890 struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> { 891 using OpRewritePattern::OpRewritePattern; 892 893 public: 894 BreakDownVectorBitCast(MLIRContext *context, 895 std::function<bool(vector::BitCastOp)> controlFn, 896 PatternBenefit benefit) 897 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} 898 899 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 900 PatternRewriter &rewriter) const override { 901 902 if (controlFn && !controlFn(bitcastOp)) 903 return failure(); 904 905 VectorType castSrcType = bitcastOp.getSourceVectorType(); 906 VectorType castDstType = bitcastOp.getResultVectorType(); 907 assert(castSrcType.getRank() == castDstType.getRank()); 908 909 // This transformation builds on top of 910 // vector.{extract|insert}_strided_slice, which do not support 911 // extracting/inserting "scallable sub-vectors". Bail out. 912 if (castSrcType.isScalable()) 913 return rewriter.notifyMatchFailure(bitcastOp, 914 "Scalable vectors are not supported"); 915 916 // Only support rank 1 case for now. 917 if (castSrcType.getRank() != 1) 918 return failure(); 919 920 int64_t castSrcLastDim = castSrcType.getShape().back(); 921 int64_t castDstLastDim = castDstType.getShape().back(); 922 // Require casting to less elements for now; other cases to be implemented. 923 if (castSrcLastDim < castDstLastDim) 924 return failure(); 925 926 assert(castSrcLastDim % castDstLastDim == 0); 927 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 928 // Nothing to do if it is already bitcasting to a single element. 929 if (castSrcLastDim == shrinkRatio) 930 return failure(); 931 932 Location loc = bitcastOp.getLoc(); 933 Type elemType = castDstType.getElementType(); 934 assert(elemType.isSignlessIntOrIndexOrFloat()); 935 936 Value zero = rewriter.create<arith::ConstantOp>( 937 loc, elemType, rewriter.getZeroAttr(elemType)); 938 Value res = rewriter.create<SplatOp>(loc, castDstType, zero); 939 940 SmallVector<int64_t> sliceShape = {castDstLastDim}; 941 SmallVector<int64_t> strides = {1}; 942 VectorType newCastDstType = 943 VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio}, 944 castDstType.getElementType()); 945 946 for (int i = 0, e = shrinkRatio; i < e; ++i) { 947 Value extracted = rewriter.create<ExtractStridedSliceOp>( 948 loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim}, 949 sliceShape, strides); 950 Value bitcast = 951 rewriter.create<BitCastOp>(loc, newCastDstType, extracted); 952 res = rewriter.create<InsertStridedSliceOp>( 953 loc, bitcast, res, 954 ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides); 955 } 956 rewriter.replaceOp(bitcastOp, res); 957 return success(); 958 } 959 960 private: 961 std::function<bool(BitCastOp)> controlFn; 962 }; 963 964 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: 965 /// ``` 966 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex> 967 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex> 968 /// %r = arith.addi %a, %b : vector<1x4xindex> 969 /// ``` 970 /// Gets converted to: 971 /// ``` 972 /// %r = arith.addi %arg0, %arg1 : index 973 /// %b = vector.broadcast %r : index to vector<1x4xindex> 974 /// ``` 975 /// 976 /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting 977 /// ops. 978 struct ReorderElementwiseOpsOnBroadcast final 979 : public OpTraitRewritePattern<OpTrait::Elementwise> { 980 using OpTraitRewritePattern::OpTraitRewritePattern; 981 LogicalResult matchAndRewrite(Operation *op, 982 PatternRewriter &rewriter) const override { 983 if (op->getNumResults() != 1) 984 return failure(); 985 if (!llvm::isa<ShapedType>(op->getResults()[0].getType())) 986 return failure(); 987 if (!OpTrait::hasElementwiseMappableTraits(op)) 988 return rewriter.notifyMatchFailure( 989 op, "Op doesn't have ElementwiseMappableTraits"); 990 if (op->getNumOperands() == 0) 991 return failure(); 992 if (op->getResults()[0].getType() != op->getOperand(0).getType()) 993 return rewriter.notifyMatchFailure(op, 994 "result and operand type mismatch"); 995 if (isa<vector::FMAOp>(op)) { 996 return rewriter.notifyMatchFailure( 997 op, 998 "Op only accepts vector types - not supported as broadcast source " 999 "might be a scalar"); 1000 } 1001 1002 // Get the type of the lhs operand 1003 auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp(); 1004 if (!lhsBcastOrSplat || 1005 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat)) 1006 return failure(); 1007 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType(); 1008 1009 // Make sure that all operands are broadcast from identical types: 1010 // * scalar (`vector.broadcast` + `vector.splat`), or 1011 // * vector (`vector.broadcast`). 1012 // Otherwise the re-ordering wouldn't be safe. 1013 if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) { 1014 auto bcast = val.getDefiningOp<vector::BroadcastOp>(); 1015 if (bcast) 1016 return (bcast.getOperand().getType() == lhsBcastOrSplatType); 1017 auto splat = val.getDefiningOp<vector::SplatOp>(); 1018 if (splat) 1019 return (splat.getOperand().getType() == lhsBcastOrSplatType); 1020 return false; 1021 })) { 1022 return failure(); 1023 } 1024 1025 // Collect the source values before broadcasting 1026 SmallVector<Value> srcValues; 1027 srcValues.reserve(op->getNumOperands()); 1028 for (Value operand : op->getOperands()) { 1029 srcValues.push_back(operand.getDefiningOp()->getOperand(0)); 1030 } 1031 1032 // Create the "elementwise" Op 1033 Operation *elementwiseOp = 1034 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, 1035 lhsBcastOrSplatType, op->getAttrs()); 1036 1037 // Replace the original Op with the elementwise Op 1038 auto vectorType = op->getResultTypes()[0]; 1039 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1040 op, vectorType, elementwiseOp->getResults()); 1041 1042 return success(); 1043 } 1044 }; 1045 1046 // Helper that returns a vector comparison that constructs a mask: 1047 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 1048 // 1049 // If `dim == 0` then the result will be a 0-D vector. 1050 // 1051 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 1052 // much more compact, IR for this operation, but LLVM eventually 1053 // generates more elaborate instructions for this intrinsic since it 1054 // is very conservative on the boundary conditions. 1055 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 1056 bool force32BitVectorIndices, int64_t dim, 1057 Value b, Value *off = nullptr) { 1058 auto loc = op->getLoc(); 1059 // If we can assume all indices fit in 32-bit, we perform the vector 1060 // comparison in 32-bit to get a higher degree of SIMD parallelism. 1061 // Otherwise we perform the vector comparison using 64-bit indices. 1062 Type idxType = 1063 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1064 DenseIntElementsAttr indicesAttr; 1065 if (dim == 0 && force32BitVectorIndices) { 1066 indicesAttr = DenseIntElementsAttr::get( 1067 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 1068 } else if (dim == 0) { 1069 indicesAttr = DenseIntElementsAttr::get( 1070 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 1071 } else if (force32BitVectorIndices) { 1072 indicesAttr = rewriter.getI32VectorAttr( 1073 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 1074 } else { 1075 indicesAttr = rewriter.getI64VectorAttr( 1076 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 1077 } 1078 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 1079 // Add in an offset if requested. 1080 if (off) { 1081 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); 1082 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 1083 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 1084 } 1085 // Construct the vector comparison. 1086 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); 1087 Value bounds = 1088 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 1089 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 1090 bounds); 1091 } 1092 1093 template <typename ConcreteOp> 1094 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 1095 public: 1096 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt, 1097 PatternBenefit benefit = 1) 1098 : mlir::OpRewritePattern<ConcreteOp>(context, benefit), 1099 force32BitVectorIndices(enableIndexOpt) {} 1100 1101 LogicalResult matchAndRewrite(ConcreteOp xferOp, 1102 PatternRewriter &rewriter) const override { 1103 if (!xferOp.hasOutOfBoundsDim()) 1104 return failure(); 1105 1106 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty()) 1107 return failure(); 1108 1109 Location loc = xferOp->getLoc(); 1110 VectorType vtp = xferOp.getVectorType(); 1111 1112 // Create the in-bounds mask with all elements between [0 .. dim - offset) 1113 // set and [dim - offset .. vector_length) unset. 1114 // 1115 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1116 // dimensions here. 1117 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1; 1118 Value off = xferOp.getIndices()[lastIndex]; 1119 Value dim = 1120 vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex); 1121 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); 1122 Value mask = rewriter.create<vector::CreateMaskOp>( 1123 loc, 1124 VectorType::get(vtp.getShape(), rewriter.getI1Type(), 1125 vtp.getScalableDims()), 1126 b); 1127 if (xferOp.getMask()) { 1128 // Intersect the in-bounds with the mask specified as an op parameter. 1129 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask()); 1130 } 1131 1132 rewriter.modifyOpInPlace(xferOp, [&]() { 1133 xferOp.getMaskMutable().assign(mask); 1134 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 1135 }); 1136 1137 return success(); 1138 } 1139 1140 private: 1141 const bool force32BitVectorIndices; 1142 }; 1143 1144 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 1145 class VectorCreateMaskOpConversion 1146 : public OpRewritePattern<vector::CreateMaskOp> { 1147 public: 1148 explicit VectorCreateMaskOpConversion(MLIRContext *context, 1149 bool enableIndexOpt, 1150 PatternBenefit benefit = 1) 1151 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit), 1152 force32BitVectorIndices(enableIndexOpt) {} 1153 1154 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1155 PatternRewriter &rewriter) const override { 1156 auto dstType = op.getType(); 1157 if (cast<VectorType>(dstType).isScalable()) 1158 return failure(); 1159 int64_t rank = dstType.getRank(); 1160 if (rank > 1) 1161 return failure(); 1162 rewriter.replaceOp( 1163 op, buildVectorComparison(rewriter, op, force32BitVectorIndices, 1164 rank == 0 ? 0 : dstType.getDimSize(0), 1165 op.getOperand(0))); 1166 return success(); 1167 } 1168 1169 private: 1170 const bool force32BitVectorIndices; 1171 }; 1172 1173 /// Returns true if all the `i1` elements of `constantOp` are set to `value`. 1174 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) { 1175 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue()); 1176 // TODO: Support non-dense constant. 1177 if (!denseAttr) 1178 return false; 1179 1180 assert(denseAttr.getElementType().isInteger(1) && "Unexpected type"); 1181 return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value; 1182 } 1183 1184 /// Folds a select operation between an all-true and all-false vector. For now, 1185 /// only single element vectors (i.e., vector<1xi1>) are supported. That is: 1186 /// 1187 /// %true = arith.constant dense<true> : vector<1xi1> 1188 /// %false = arith.constant dense<false> : vector<1xi1> 1189 /// %result = arith.select %cond, %true, %false : i1, vector<1xi1> 1190 /// => 1191 /// %result = vector.broadcast %cond : i1 to vector<1xi1> 1192 /// 1193 /// InstCombine seems to handle vectors with multiple elements but not the 1194 /// single element ones. 1195 struct FoldI1Select : public OpRewritePattern<arith::SelectOp> { 1196 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1197 1198 LogicalResult matchAndRewrite(arith::SelectOp selectOp, 1199 PatternRewriter &rewriter) const override { 1200 auto vecType = dyn_cast<VectorType>(selectOp.getType()); 1201 if (!vecType || !vecType.getElementType().isInteger(1)) 1202 return failure(); 1203 1204 // Only scalar conditions can be folded. 1205 Value cond = selectOp.getCondition(); 1206 if (isa<VectorType>(cond.getType())) 1207 return failure(); 1208 1209 // TODO: Support n-D and scalable vectors. 1210 if (vecType.getRank() != 1 || vecType.isScalable()) 1211 return failure(); 1212 1213 // TODO: Support vectors with multiple elements. 1214 if (vecType.getShape()[0] != 1) 1215 return failure(); 1216 1217 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>(); 1218 if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true)) 1219 return failure(); 1220 1221 auto falseConst = 1222 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>(); 1223 if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false)) 1224 return failure(); 1225 1226 // Replace select with its condition broadcasted to single element vector. 1227 auto elemType = rewriter.getIntegerType(vecType.getNumElements()); 1228 auto bcastType = VectorType::get(/*shape=*/{1}, elemType); 1229 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond); 1230 return success(); 1231 } 1232 }; 1233 1234 /// Returns the number of dims can be folded away from transfer ops. It returns 1235 /// a failure if it can not determine the number of dims to be folded. 1236 /// 1237 /// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and 1238 /// `vectorType` is vector<16x16x1x1xf32> 1239 /// (there two inner most dims can be dropped by memref.subview ops) 1240 /// 1241 /// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with 1242 /// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32> 1243 /// (only the inner most unit dim of `srcType` can be dropped) 1244 /// 1245 /// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and 1246 /// `vectorType` is vector<16x16x1x[1]xf32> 1247 /// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable 1248 /// unit") 1249 static FailureOr<size_t> 1250 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { 1251 SmallVector<int64_t> srcStrides; 1252 int64_t srcOffset; 1253 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset))) 1254 return failure(); 1255 1256 auto isUnitDim = [](VectorType type, int dim) { 1257 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim]; 1258 }; 1259 1260 // According to vector.transfer_read/write semantics, the vector can be a 1261 // slice. Thus, we have to offset the check index with `rankDiff` in 1262 // `srcStrides` and source dim sizes. 1263 size_t result = 0; 1264 int rankDiff = srcType.getRank() - vectorType.getRank(); 1265 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) { 1266 // Check that the inner dim size is 1 for both memref type and vector slice. 1267 // It can be folded only if they are 1 and the stride is 1. 1268 int dim = vectorType.getRank() - i - 1; 1269 if (srcStrides[dim + rankDiff] != 1 || 1270 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim)) 1271 break; 1272 result++; 1273 } 1274 return result; 1275 } 1276 1277 /// Drop inner most contiguous unit dimensions from transfer_read operand. 1278 class DropInnerMostUnitDimsTransferRead 1279 : public OpRewritePattern<vector::TransferReadOp> { 1280 using OpRewritePattern::OpRewritePattern; 1281 1282 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 1283 PatternRewriter &rewriter) const override { 1284 // TODO: support 0-d corner case. 1285 if (readOp.getTransferRank() == 0) 1286 return failure(); 1287 1288 // TODO: support mask. 1289 if (readOp.getMask()) 1290 return failure(); 1291 1292 auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType()); 1293 if (!srcType) 1294 return failure(); 1295 1296 if (!readOp.getPermutationMap().isMinorIdentity()) 1297 return failure(); 1298 1299 auto targetType = readOp.getVectorType(); 1300 if (targetType.getRank() <= 1) 1301 return failure(); 1302 1303 FailureOr<size_t> maybeDimsToDrop = 1304 getTransferFoldableInnerUnitDims(srcType, targetType); 1305 if (failed(maybeDimsToDrop)) 1306 return failure(); 1307 1308 size_t dimsToDrop = maybeDimsToDrop.value(); 1309 if (dimsToDrop == 0) 1310 return failure(); 1311 1312 auto inBounds = readOp.getInBoundsValues(); 1313 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop); 1314 if (llvm::is_contained(droppedInBounds, false)) 1315 return failure(); 1316 1317 auto resultTargetVecType = 1318 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 1319 targetType.getElementType(), 1320 targetType.getScalableDims().drop_back(dimsToDrop)); 1321 1322 auto loc = readOp.getLoc(); 1323 SmallVector<OpFoldResult> sizes = 1324 memref::getMixedSizes(rewriter, loc, readOp.getSource()); 1325 SmallVector<OpFoldResult> offsets(srcType.getRank(), 1326 rewriter.getIndexAttr(0)); 1327 SmallVector<OpFoldResult> strides(srcType.getRank(), 1328 rewriter.getIndexAttr(1)); 1329 auto resultMemrefType = 1330 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( 1331 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, 1332 strides)); 1333 ArrayAttr inBoundsAttr = rewriter.getArrayAttr( 1334 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)); 1335 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 1336 loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides); 1337 auto permMap = getTransferMinorIdentityMap( 1338 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType); 1339 Value result = rewriter.create<vector::TransferReadOp>( 1340 loc, resultTargetVecType, rankedReducedView, 1341 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 1342 readOp.getPadding(), 1343 // TODO: support mask. 1344 /*mask=*/Value(), inBoundsAttr); 1345 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 1346 result); 1347 return success(); 1348 } 1349 }; 1350 1351 /// Drop inner most contiguous unit dimensions from transfer_write operand. 1352 /// E.g., 1353 /// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] 1354 /// {in_bounds = [true, true, true, true, true]} 1355 /// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32> 1356 /// 1357 /// will be replaced with 1358 /// 1359 /// %subview = memref.subview %arg0 1360 /// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1] 1361 /// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32> 1362 /// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32> 1363 /// to vector<1x16x16xf32> 1364 /// vector.transfer_write %0, %subview[%c0, %arg2, %c0] 1365 /// {in_bounds = [true, true, true]} 1366 /// : vector<1x16x16xf32>, memref<1x512x16xf32> 1367 /// 1368 /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`). 1369 class DropInnerMostUnitDimsTransferWrite 1370 : public OpRewritePattern<vector::TransferWriteOp> { 1371 using OpRewritePattern::OpRewritePattern; 1372 1373 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 1374 PatternRewriter &rewriter) const override { 1375 // TODO: support 0-d corner case. 1376 if (writeOp.getTransferRank() == 0) 1377 return failure(); 1378 1379 // TODO: support mask. 1380 if (writeOp.getMask()) 1381 return failure(); 1382 1383 auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType()); 1384 if (!srcType) 1385 return failure(); 1386 1387 if (!writeOp.getPermutationMap().isMinorIdentity()) 1388 return failure(); 1389 1390 auto targetType = writeOp.getVectorType(); 1391 if (targetType.getRank() <= 1) 1392 return failure(); 1393 1394 FailureOr<size_t> maybeDimsToDrop = 1395 getTransferFoldableInnerUnitDims(srcType, targetType); 1396 if (failed(maybeDimsToDrop)) 1397 return failure(); 1398 1399 size_t dimsToDrop = maybeDimsToDrop.value(); 1400 if (dimsToDrop == 0) 1401 return failure(); 1402 1403 auto inBounds = writeOp.getInBoundsValues(); 1404 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop); 1405 if (llvm::is_contained(droppedInBounds, false)) 1406 return failure(); 1407 1408 auto resultTargetVecType = 1409 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 1410 targetType.getElementType(), 1411 targetType.getScalableDims().drop_back(dimsToDrop)); 1412 1413 Location loc = writeOp.getLoc(); 1414 SmallVector<OpFoldResult> sizes = 1415 memref::getMixedSizes(rewriter, loc, writeOp.getSource()); 1416 SmallVector<OpFoldResult> offsets(srcType.getRank(), 1417 rewriter.getIndexAttr(0)); 1418 SmallVector<OpFoldResult> strides(srcType.getRank(), 1419 rewriter.getIndexAttr(1)); 1420 auto resultMemrefType = 1421 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( 1422 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, 1423 strides)); 1424 ArrayAttr inBoundsAttr = rewriter.getArrayAttr( 1425 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)); 1426 1427 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 1428 loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides); 1429 auto permMap = getTransferMinorIdentityMap( 1430 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType); 1431 1432 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 1433 loc, resultTargetVecType, writeOp.getVector()); 1434 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 1435 writeOp, shapeCast, rankedReducedView, 1436 writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 1437 // TODO: support mask. 1438 /*mask=*/Value(), inBoundsAttr); 1439 return success(); 1440 } 1441 }; 1442 1443 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul 1444 /// semantics to a contraction suitable for MMT (matrix matrix multiplication 1445 /// with the RHS transposed) lowering. 1446 struct CanonicalizeContractMatmulToMMT final 1447 : OpRewritePattern<vector::ContractionOp> { 1448 using OpRewritePattern::OpRewritePattern; 1449 1450 using FilterConstraintType = 1451 std::function<LogicalResult(vector::ContractionOp op)>; 1452 1453 CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit, 1454 FilterConstraintType constraint) 1455 : OpRewritePattern<vector::ContractionOp>(context, benefit), 1456 filter(std::move(constraint)) {} 1457 1458 LogicalResult matchAndRewrite(vector::ContractionOp op, 1459 PatternRewriter &rewriter) const override { 1460 if (failed(filter(op))) 1461 return failure(); 1462 1463 Location loc = op.getLoc(); 1464 Value lhs = op.getLhs(); 1465 Value rhs = op.getRhs(); 1466 Value res = op.getAcc(); 1467 1468 // Set up the parallel/reduction structure in right form. 1469 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1470 auto infer = [&](MapList m) { 1471 return AffineMap::inferFromExprList(m, op.getContext()); 1472 }; 1473 AffineExpr m; 1474 AffineExpr n; 1475 AffineExpr k; 1476 bindDims(rewriter.getContext(), m, n, k); 1477 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1478 auto iteratorTypes = op.getIteratorTypes().getValue(); 1479 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 1480 if (iteratorTypes.size() != 3 || 1481 !vector::isParallelIterator(iteratorTypes[0]) || 1482 !vector::isParallelIterator(iteratorTypes[1]) || 1483 !vector::isReductionIterator(iteratorTypes[2])) 1484 return rewriter.notifyMatchFailure(op, "contraction is not a gemm"); 1485 1486 // The canonical form is "TNT" = A row-major, B col-major, C row-major. 1487 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); 1488 if (maps == canonicalForm) 1489 return rewriter.notifyMatchFailure(op, "already in the canonical form"); 1490 1491 // Create a vector transpose making sure to emit zero/sign-extend at the 1492 // end. 1493 auto createTranspose = [&rewriter, loc](Value mat) -> Value { 1494 if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) { 1495 Value trans = 1496 rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm); 1497 VectorType newType = 1498 cast<VectorType>(trans.getType()) 1499 .clone(cast<VectorType>(mat.getType()).getElementType()); 1500 return rewriter.create<arith::ExtSIOp>(loc, newType, trans); 1501 } 1502 if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) { 1503 Value trans = 1504 rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm); 1505 VectorType newType = 1506 VectorType::get(cast<VectorType>(trans.getType()).getShape(), 1507 cast<VectorType>(mat.getType()).getElementType()); 1508 return rewriter.create<arith::ExtUIOp>(loc, newType, trans); 1509 } 1510 return rewriter.create<vector::TransposeOp>(loc, mat, perm); 1511 }; 1512 1513 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1514 rhs = createTranspose(rhs); 1515 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1516 lhs = createTranspose(lhs); 1517 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1518 rhs = createTranspose(rhs); 1519 lhs = createTranspose(lhs); 1520 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1521 std::swap(rhs, lhs); 1522 rhs = createTranspose(rhs); 1523 lhs = createTranspose(lhs); 1524 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1525 std::swap(rhs, lhs); 1526 rhs = createTranspose(rhs); 1527 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1528 std::swap(lhs, rhs); 1529 lhs = createTranspose(lhs); 1530 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1531 std::swap(lhs, rhs); 1532 } else { 1533 return rewriter.notifyMatchFailure(op, "unhandled contraction form"); 1534 } 1535 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1536 op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm), 1537 op.getIteratorTypes()); 1538 return success(); 1539 }; 1540 1541 private: 1542 FilterConstraintType filter; 1543 }; 1544 1545 /// Pattern to fold arithmetic extensions on floating point data types into 1546 /// vector contraction operations. linalg.matmul introduces arithmetic 1547 /// extensions on its operands. Please mlir snippets below for more details. 1548 /// ```mlir 1549 /// "linalg.matmul"(%lhs, %rhs, %acc) ({ 1550 /// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32): 1551 /// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32 1552 /// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32 1553 /// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32 1554 /// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32 1555 /// "linalg.yield"(%acc) : (f32) -> () 1556 /// }) 1557 /// ``` 1558 /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor 1559 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`. 1560 /// This pattern folds the arithmetic extensions into the vector contraction and 1561 /// enables the usage of native mixed precision Tensor Core instructions. 1562 template <typename ExtOp> 1563 struct FoldArithExtIntoContractionOp 1564 : public OpRewritePattern<vector::ContractionOp> { 1565 using OpRewritePattern::OpRewritePattern; 1566 1567 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 1568 PatternRewriter &rewriter) const override { 1569 1570 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>(); 1571 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>(); 1572 1573 if (!lhsDefOp || !rhsDefOp) { 1574 return rewriter.notifyMatchFailure(contractOp, 1575 "no defining op on contract operands"); 1576 } 1577 1578 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1579 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0), 1580 contractOp.getAcc(), contractOp.getIndexingMapsAttr(), 1581 contractOp.getIteratorTypesAttr()); 1582 1583 return success(); 1584 } 1585 }; 1586 1587 /// Pattern to fold chained reduction to a series of vector additions and a 1588 /// final reduction. This form should require fewer subgroup operations. 1589 /// 1590 /// ```mlir 1591 /// %a = vector.reduction <add> %x, %acc 1592 /// %b = vector.reduction <add> %y, %a 1593 /// ==> 1594 /// %a = arith.addf %x, %y 1595 /// %b = vector.reduction <add> %a, %acc 1596 /// ``` 1597 struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> { 1598 using OpRewritePattern::OpRewritePattern; 1599 1600 LogicalResult matchAndRewrite(vector::ReductionOp op, 1601 PatternRewriter &rewriter) const override { 1602 // TODO: Handle other combining kinds. 1603 if (op.getKind() != vector::CombiningKind::ADD) 1604 return failure(); 1605 1606 // Accumulator is optional. 1607 Value acc = op.getAcc(); 1608 if (!acc) 1609 return failure(); 1610 1611 if (!acc.getType().isIntOrFloat()) 1612 return failure(); 1613 1614 auto parentReduction = acc.getDefiningOp<vector::ReductionOp>(); 1615 if (!parentReduction) 1616 return failure(); 1617 1618 Location loc = op.getLoc(); 1619 Value vAdd; 1620 if (isa<IntegerType>(acc.getType())) { 1621 vAdd = rewriter.createOrFold<arith::AddIOp>( 1622 loc, parentReduction.getVector(), op.getVector()); 1623 } else { 1624 vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(), 1625 op.getVector()); 1626 } 1627 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd, 1628 parentReduction.getAcc()); 1629 return success(); 1630 } 1631 }; 1632 1633 // Helper function dropping unit non-scalable dimension from a VectorType 1634 // keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit 1635 // dimensions are not dropped. Folding such dimensions would require "shifting" 1636 // the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> -> 1637 // vector<[4]xf32>). This could be implemented in the future. 1638 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) { 1639 auto inVecShape = inVecTy.getShape(); 1640 SmallVector<int64_t> newShape; 1641 SmallVector<bool> newScalableDims; 1642 for (auto [dim, isScalable] : 1643 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) { 1644 if (dim == 1 && !isScalable) 1645 continue; 1646 1647 newShape.push_back(dim); 1648 newScalableDims.push_back(isScalable); 1649 } 1650 // All dims have been dropped, return vector<1xeType>. 1651 if (newShape.empty()) { 1652 newShape.push_back(1); 1653 newScalableDims.push_back(false); 1654 } 1655 1656 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims); 1657 } 1658 1659 /// For vectors with at least one unit dim, replaces: 1660 /// elementwise(a, b) 1661 /// with: 1662 /// sc_a = shape_cast(a) 1663 /// sc_b = shape_cast(b) 1664 /// res = elementwise(sc_a, sc_b) 1665 /// return shape_cast(res) 1666 /// The newly inserted shape_cast Ops fold (before elementwise Op) and then 1667 /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are 1668 /// required to be rank > 1. 1669 /// 1670 /// Ex: 1671 /// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32> 1672 /// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32> 1673 /// 1674 /// gets converted to: 1675 /// 1676 /// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32> 1677 /// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32> 1678 /// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32> 1679 /// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32> 1680 /// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32> 1681 /// 1682 /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and 1683 /// `%cast`. 1684 struct DropUnitDimFromElementwiseOps final 1685 : public OpTraitRewritePattern<OpTrait::Elementwise> { 1686 using OpTraitRewritePattern::OpTraitRewritePattern; 1687 LogicalResult matchAndRewrite(Operation *op, 1688 PatternRewriter &rewriter) const override { 1689 if (op->getNumResults() != 1 || op->getNumRegions() != 0) 1690 return failure(); 1691 1692 auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType()); 1693 if (!resultVectorType) 1694 return failure(); 1695 1696 // Check the operand pre-conditions. For `Elementwise` ops all operands are 1697 // guaranteed to have identical shapes (with some exceptions such as 1698 // `arith.select`) and it suffices to only check one of them. 1699 auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType()); 1700 if (!sourceVectorType) 1701 return failure(); 1702 if (sourceVectorType.getRank() < 2) 1703 return failure(); 1704 1705 SmallVector<Value> newOperands; 1706 auto loc = op->getLoc(); 1707 for (auto operand : op->getOperands()) { 1708 auto opVectorType = cast<VectorType>(operand.getType()); 1709 auto newVType = dropNonScalableUnitDimFromType(opVectorType); 1710 if (newVType == opVectorType) 1711 return rewriter.notifyMatchFailure(op, "No unit dimension to remove."); 1712 1713 auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand); 1714 newOperands.push_back(opSC); 1715 } 1716 1717 VectorType newResultVectorType = 1718 dropNonScalableUnitDimFromType(resultVectorType); 1719 // Create an updated elementwise Op without unit dim. 1720 Operation *elementwiseOp = 1721 rewriter.create(loc, op->getName().getIdentifier(), newOperands, 1722 newResultVectorType, op->getAttrs()); 1723 1724 // Restore the unit dim by applying vector.shape_cast to the result. 1725 rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType, 1726 elementwiseOp->getResult(0)); 1727 1728 return success(); 1729 } 1730 }; 1731 1732 /// A pattern to drop unit dims from vector.transpose. 1733 /// 1734 /// Example: 1735 /// 1736 /// BEFORE: 1737 /// ```mlir 1738 /// %transpose = vector.transpose %vector, [3, 0, 1, 2] 1739 /// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> 1740 /// ``` 1741 /// 1742 /// AFTER: 1743 /// ```mlir 1744 /// %dropDims = vector.shape_cast %vector 1745 /// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> 1746 /// %transpose = vector.transpose %0, [1, 0] 1747 /// : vector<4x[4]xf32> to vector<[4]x4xf32> 1748 /// %restoreDims = vector.shape_cast %transpose 1749 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> 1750 /// ``` 1751 struct DropUnitDimsFromTransposeOp final 1752 : OpRewritePattern<vector::TransposeOp> { 1753 using OpRewritePattern::OpRewritePattern; 1754 1755 LogicalResult matchAndRewrite(vector::TransposeOp op, 1756 PatternRewriter &rewriter) const override { 1757 VectorType sourceType = op.getSourceVectorType(); 1758 VectorType sourceTypeWithoutUnitDims = 1759 dropNonScalableUnitDimFromType(sourceType); 1760 1761 if (sourceType == sourceTypeWithoutUnitDims) 1762 return failure(); 1763 1764 // Construct a map from dimIdx -> number of dims dropped before dimIdx. 1765 auto sourceDims = llvm::to_vector(vector::getDims(sourceType)); 1766 SmallVector<int64_t> droppedDimsBefore(sourceType.getRank()); 1767 int64_t droppedDims = 0; 1768 for (auto [i, dim] : llvm::enumerate(sourceDims)) { 1769 droppedDimsBefore[i] = droppedDims; 1770 if (dim == std::make_tuple(1, false)) 1771 ++droppedDims; 1772 } 1773 1774 // Drop unit dims from transpose permutation. 1775 ArrayRef<int64_t> perm = op.getPermutation(); 1776 SmallVector<int64_t> newPerm; 1777 for (int64_t idx : perm) { 1778 if (sourceDims[idx] == std::make_tuple(1, false)) 1779 continue; 1780 newPerm.push_back(idx - droppedDimsBefore[idx]); 1781 } 1782 1783 // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT> 1784 // type when the dimensions are unit dimensions. In this case, the newPerm 1785 // should be [0]. 1786 if (newPerm.empty()) { 1787 newPerm.push_back(0); 1788 } 1789 1790 Location loc = op.getLoc(); 1791 // Drop the unit dims via shape_cast. 1792 auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>( 1793 loc, sourceTypeWithoutUnitDims, op.getVector()); 1794 // Create the new transpose. 1795 auto transposeWithoutUnitDims = 1796 rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm); 1797 // Restore the unit dims via shape cast. 1798 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 1799 op, op.getResultVectorType(), transposeWithoutUnitDims); 1800 1801 return success(); 1802 } 1803 }; 1804 1805 /// A pattern to drop unit dims from the iter_args of an scf.for. 1806 /// 1807 /// Example: 1808 /// 1809 /// BEFORE: 1810 /// ```mlir 1811 /// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> { 1812 /// ... 1813 /// scf.yield % 1814 /// } 1815 /// ``` 1816 /// 1817 /// AFTER: 1818 /// ```mlir 1819 /// %drop = vector.shape_cast %init 1820 /// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32> 1821 /// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> { 1822 /// %new_iter = vector.shape_cast %iter 1823 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> 1824 /// ... 1825 /// } 1826 /// %res = vector.shape_cast %new_loop 1827 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> 1828 /// ``` 1829 struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> { 1830 using OpRewritePattern::OpRewritePattern; 1831 1832 LogicalResult matchAndRewrite(scf::ForOp forOp, 1833 PatternRewriter &rewriter) const override { 1834 /// Find the first iter_arg with droppable unit dims. Further applications 1835 /// of this pattern will apply to later arguments. 1836 for (OpOperand &operand : forOp.getInitArgsMutable()) { 1837 auto vectorType = dyn_cast<VectorType>(operand.get().getType()); 1838 if (!vectorType) 1839 continue; 1840 1841 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType); 1842 if (vectorType == newVectorType) 1843 continue; 1844 1845 // Create a new ForOp with that iter operand replaced. 1846 auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) { 1847 return b.create<vector::ShapeCastOp>(loc, type, source); 1848 }; 1849 1850 Value replacement = 1851 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get()); 1852 rewriter.replaceOp(forOp, 1853 replaceAndCastForOpIterArg(rewriter, forOp, operand, 1854 replacement, castFn)); 1855 return success(); 1856 } 1857 return failure(); 1858 } 1859 }; 1860 1861 /// Pattern to eliminate redundant zero-constants added to reduction operands. 1862 /// It's enough for there to be one initial zero value, so we can eliminate the 1863 /// extra ones that feed into `vector.reduction <add>`. These get created by the 1864 /// `ChainedReduction` pattern. 1865 /// 1866 /// ```mlir 1867 /// %a = arith.addf %x, %zero 1868 /// %b = arith.addf %a, %y 1869 /// %c = vector.reduction <add> %b, %acc 1870 /// ==> 1871 /// %b = arith.addf %a, %y 1872 /// %c = vector.reduction <add> %b, %acc 1873 /// ``` 1874 struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> { 1875 using OpRewritePattern::OpRewritePattern; 1876 1877 LogicalResult matchAndRewrite(vector::ReductionOp op, 1878 PatternRewriter &rewriter) const override { 1879 // TODO: Handle other reduction kinds and their identity values. 1880 if (op.getKind() != vector::CombiningKind::ADD) 1881 return failure(); 1882 1883 Type elemType = op.getSourceVectorType().getElementType(); 1884 // The integer case should be handled by `arith.addi` folders, only check 1885 // for floats here. 1886 if (!isa<FloatType>(elemType)) 1887 return failure(); 1888 1889 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>(); 1890 if (!vAdd) 1891 return failure(); 1892 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>(); 1893 if (!addLhs) 1894 return failure(); 1895 1896 if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat())) 1897 return failure(); 1898 1899 auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(), 1900 vAdd.getRhs()); 1901 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd, 1902 op.getAcc()); 1903 return success(); 1904 } 1905 }; 1906 1907 /// Example: 1908 /// ``` 1909 /// %a = vector.reduction <add> %x : vector<2xf32> into f32 1910 /// ``` 1911 /// is transformed into: 1912 /// ``` 1913 /// %y = vector.extract %x[0] : f32 from vector<2xf32> 1914 /// %z = vector.extract %x[1] : f32 from vector<2xf32> 1915 /// %a = arith.addf %y, %z : f32 1916 /// ``` 1917 struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> { 1918 BreakDownVectorReduction(MLIRContext *context, 1919 unsigned maxNumElementsToExtract, 1920 PatternBenefit benefit) 1921 : OpRewritePattern(context, benefit), 1922 maxNumElementsToExtract(maxNumElementsToExtract) {} 1923 1924 LogicalResult matchAndRewrite(vector::ReductionOp op, 1925 PatternRewriter &rewriter) const override { 1926 VectorType type = op.getSourceVectorType(); 1927 if (type.isScalable() || op.isMasked()) 1928 return failure(); 1929 assert(type.getRank() == 1 && "Expected a 1-d vector"); 1930 1931 int64_t numElems = type.getNumElements(); 1932 if (numElems > maxNumElementsToExtract) { 1933 return rewriter.notifyMatchFailure( 1934 op, llvm::formatv("has too many vector elements ({0}) to break down " 1935 "(max allowed: {1})", 1936 numElems, maxNumElementsToExtract)); 1937 } 1938 1939 Location loc = op.getLoc(); 1940 SmallVector<Value> extracted(numElems, nullptr); 1941 for (auto [idx, extractedElem] : llvm::enumerate(extracted)) 1942 extractedElem = rewriter.create<vector::ExtractOp>( 1943 loc, op.getVector(), static_cast<int64_t>(idx)); 1944 1945 Value res = extracted.front(); 1946 for (auto extractedElem : llvm::drop_begin(extracted)) 1947 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, 1948 extractedElem, op.getFastmathAttr()); 1949 if (Value acc = op.getAcc()) 1950 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc, 1951 op.getFastmathAttr()); 1952 1953 rewriter.replaceOp(op, res); 1954 return success(); 1955 } 1956 1957 private: 1958 unsigned maxNumElementsToExtract = 0; 1959 }; 1960 1961 /// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A, 1962 /// B)`. 1963 /// Example: 1964 /// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32> 1965 /// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to 1966 /// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to 1967 /// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32> 1968 /// 1969 /// Becomes : 1970 /// 1971 /// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32> 1972 /// 1973 /// Supports only 1D-to-2D broadcasts. The following cases are not supported. 1974 /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32> 1975 /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32> 1976 /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32> 1977 template <typename MulOpType> 1978 struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> { 1979 using OpRewritePattern<MulOpType>::OpRewritePattern; 1980 // Returns whether a vector.broadcast matches requirements for an outerproduct 1981 // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension. 1982 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const { 1983 // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating 1984 // shape_casts/broadcasts which does not belong in this pattern. 1985 if (!broadcastOp.computeBroadcastedUnitDims().empty()) 1986 return false; 1987 // Avoid broadcast like f32 or vector<f32> -> ResType 1988 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()); 1989 return srcType && srcType.getRank() != 2; 1990 } 1991 1992 LogicalResult matchAndRewrite(MulOpType mulOp, 1993 PatternRewriter &rewriter) const override { 1994 auto resType = llvm::cast<VectorType>(mulOp.getResult().getType()); 1995 if (!resType) 1996 return failure(); 1997 if (resType.getRank() != 2) 1998 return failure(); 1999 /// If operandA can be written as tr(broadcast(A)) and operandB as 2000 /// broadcast(B) where broadcasts are 1D-to-2D, create and return 2001 /// vector.outerproduct(A, B). Returns failure() otherwise. 2002 auto matchOuterProduct = 2003 [&](Value operandA, 2004 Value operandB) -> FailureOr<vector::OuterProductOp> { 2005 auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>(); 2006 if (!transposedLhs) 2007 return failure(); 2008 // Fail unless this is a true 2-D matrix transpose. 2009 ArrayRef<int64_t> permutation = transposedLhs.getPermutation(); 2010 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0) 2011 return failure(); 2012 2013 auto broadcastedLhs = 2014 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>(); 2015 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs)) 2016 return failure(); 2017 2018 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>(); 2019 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs)) 2020 return failure(); 2021 2022 return rewriter.create<vector::OuterProductOp>( 2023 mulOp->getLoc(), resType, broadcastedLhs.getSource(), 2024 broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD); 2025 }; 2026 2027 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1); 2028 auto maybeOuterP = matchOuterProduct(lhs, rhs); 2029 // Handle commutativity, the transposed op is the outerproduct LHS. 2030 if (failed(maybeOuterP)) 2031 maybeOuterP = matchOuterProduct(rhs, lhs); 2032 if (failed(maybeOuterP)) 2033 return failure(); 2034 rewriter.replaceOp(mulOp, maybeOuterP->getResult()); 2035 return success(); 2036 } 2037 }; 2038 2039 } // namespace 2040 2041 void mlir::vector::populateFoldArithExtensionPatterns( 2042 RewritePatternSet &patterns) { 2043 patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>, 2044 FoldArithExtIntoContractionOp<arith::ExtSIOp>>( 2045 patterns.getContext()); 2046 } 2047 2048 void mlir::vector::populateVectorMaskMaterializationPatterns( 2049 RewritePatternSet &patterns, bool force32BitVectorIndices, 2050 PatternBenefit benefit) { 2051 patterns.add<VectorCreateMaskOpConversion, 2052 MaterializeTransferMask<vector::TransferReadOp>, 2053 MaterializeTransferMask<vector::TransferWriteOp>>( 2054 patterns.getContext(), force32BitVectorIndices, benefit); 2055 patterns.add<FoldI1Select>(patterns.getContext(), benefit); 2056 } 2057 2058 void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns, 2059 PatternBenefit benefit) { 2060 patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit); 2061 } 2062 2063 void mlir::vector::populateDropUnitDimWithShapeCastPatterns( 2064 RewritePatternSet &patterns, PatternBenefit benefit) { 2065 // TODO: Consider either: 2066 // * including DropInnerMostUnitDimsTransferRead and 2067 // DropInnerMostUnitDimsTransferWrite, or 2068 // * better naming to distinguish this and 2069 // populateVectorTransferCollapseInnerMostContiguousDimsPatterns. 2070 patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp, 2071 DropUnitDimsFromTransposeOp, ShapeCastOpFolder>( 2072 patterns.getContext(), benefit); 2073 } 2074 2075 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2076 RewritePatternSet &patterns, PatternBenefit benefit) { 2077 patterns.add<BubbleDownVectorBitCastForExtract, 2078 BubbleDownBitCastForStridedSliceExtract, 2079 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>( 2080 patterns.getContext(), benefit); 2081 } 2082 2083 void mlir::vector::populateBreakDownVectorBitCastOpPatterns( 2084 RewritePatternSet &patterns, 2085 std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) { 2086 patterns.add<BreakDownVectorBitCast>(patterns.getContext(), 2087 std::move(controlFn), benefit); 2088 } 2089 2090 void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT( 2091 RewritePatternSet &patterns, 2092 std::function<LogicalResult(vector::ContractionOp)> constraint, 2093 PatternBenefit benefit) { 2094 patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit, 2095 std::move(constraint)); 2096 } 2097 2098 void mlir::vector::populateVectorReductionToContractPatterns( 2099 RewritePatternSet &patterns, PatternBenefit benefit) { 2100 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2101 CombineContractABTranspose, CombineContractResultTranspose>( 2102 patterns.getContext(), benefit); 2103 } 2104 2105 void mlir::vector:: 2106 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2107 RewritePatternSet &patterns, PatternBenefit benefit) { 2108 patterns.add<DropInnerMostUnitDimsTransferRead, 2109 DropInnerMostUnitDimsTransferWrite>(patterns.getContext(), 2110 benefit); 2111 } 2112 2113 void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, 2114 PatternBenefit benefit) { 2115 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast, 2116 ReorderElementwiseOpsOnBroadcast>(patterns.getContext(), 2117 benefit); 2118 } 2119 2120 void mlir::vector::populateChainedVectorReductionFoldingPatterns( 2121 RewritePatternSet &patterns, PatternBenefit benefit) { 2122 patterns.add<ChainedReduction>(patterns.getContext(), benefit); 2123 patterns.add<ReduceRedundantZero>(patterns.getContext(), 2124 PatternBenefit(benefit.getBenefit() + 1)); 2125 } 2126 2127 void mlir::vector::populateBreakDownVectorReductionPatterns( 2128 RewritePatternSet &patterns, unsigned maxNumElementsToExtract, 2129 PatternBenefit benefit) { 2130 patterns.add<BreakDownVectorReduction>(patterns.getContext(), 2131 maxNumElementsToExtract, benefit); 2132 } 2133 2134 void mlir::vector::populateElementwiseToVectorOpsPatterns( 2135 RewritePatternSet &patterns) { 2136 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>, 2137 FoldArithToVectorOuterProduct<arith::MulIOp>>( 2138 patterns.getContext()); 2139 } 2140 2141 //===----------------------------------------------------------------------===// 2142 // TableGen'd enum attribute definitions 2143 //===----------------------------------------------------------------------===// 2144 2145 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc" 2146