1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <cassert> 16 #include <cstdint> 17 #include <functional> 18 #include <optional> 19 #include <type_traits> 20 21 #include "mlir/Dialect/Affine/IR/AffineOps.h" 22 #include "mlir/Dialect/Arith/IR/Arith.h" 23 #include "mlir/Dialect/Arith/Utils/Utils.h" 24 #include "mlir/Dialect/Linalg/IR/Linalg.h" 25 #include "mlir/Dialect/MemRef/IR/MemRef.h" 26 #include "mlir/Dialect/SCF/IR/SCF.h" 27 #include "mlir/Dialect/Tensor/IR/Tensor.h" 28 #include "mlir/Dialect/Utils/IndexingUtils.h" 29 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 30 #include "mlir/Dialect/Vector/IR/VectorOps.h" 31 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 32 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 33 #include "mlir/IR/BuiltinAttributeInterfaces.h" 34 #include "mlir/IR/BuiltinTypes.h" 35 #include "mlir/IR/ImplicitLocOpBuilder.h" 36 #include "mlir/IR/Location.h" 37 #include "mlir/IR/Matchers.h" 38 #include "mlir/IR/PatternMatch.h" 39 #include "mlir/IR/TypeUtilities.h" 40 #include "mlir/Interfaces/VectorInterfaces.h" 41 #include "mlir/Support/LogicalResult.h" 42 43 #include "llvm/ADT/DenseSet.h" 44 #include "llvm/ADT/MapVector.h" 45 #include "llvm/ADT/STLExtras.h" 46 #include "llvm/Support/CommandLine.h" 47 #include "llvm/Support/Debug.h" 48 #include "llvm/Support/FormatVariadic.h" 49 #include "llvm/Support/raw_ostream.h" 50 51 #define DEBUG_TYPE "vector-to-vector" 52 53 using namespace mlir; 54 using namespace mlir::vector; 55 56 template <typename IntType> 57 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) { 58 return llvm::to_vector<4>(llvm::map_range( 59 arrayAttr.getAsRange<IntegerAttr>(), 60 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 61 } 62 63 // Helper to find an index in an affine map. 64 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 65 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 66 int64_t idx = map.getDimPosition(i); 67 if (idx == index) 68 return i; 69 } 70 return std::nullopt; 71 } 72 73 namespace { 74 75 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 76 // 77 // Example: 78 // 79 // The following MLIR with cancelling ShapeCastOps: 80 // 81 // %0 = source : vector<5x4x2xf32> 82 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 83 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 84 // %3 = user %2 : vector<5x4x2xf32> 85 // 86 // Should canonicalize to the following: 87 // 88 // %0 = source : vector<5x4x2xf32> 89 // %1 = user %0 : vector<5x4x2xf32> 90 // 91 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 92 using OpRewritePattern::OpRewritePattern; 93 94 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 95 PatternRewriter &rewriter) const override { 96 // Check if 'shapeCastOp' has vector source/result type. 97 auto sourceVectorType = 98 dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType()); 99 auto resultVectorType = 100 dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType()); 101 if (!sourceVectorType || !resultVectorType) 102 return failure(); 103 104 // Check if shape cast op source operand is also a shape cast op. 105 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 106 shapeCastOp.getSource().getDefiningOp()); 107 if (!sourceShapeCastOp) 108 return failure(); 109 auto operandSourceVectorType = 110 cast<VectorType>(sourceShapeCastOp.getSource().getType()); 111 auto operandResultVectorType = sourceShapeCastOp.getType(); 112 113 // Check if shape cast operations invert each other. 114 if (operandSourceVectorType != resultVectorType || 115 operandResultVectorType != sourceVectorType) 116 return failure(); 117 118 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource()); 119 return success(); 120 } 121 }; 122 123 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 124 /// Ex: 125 /// ``` 126 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 127 /// %1 = vector.multi_reduction add, %0 [1] 128 /// : vector<8x32x16xf32> to vector<8x16xf32> 129 /// ``` 130 /// Gets converted to: 131 /// ``` 132 /// %1 = vector.contract {indexing_maps = [ 133 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 134 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 135 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 136 /// iterator_types = ["parallel", "parallel", "reduction"], 137 /// kind = add} %0, %arg1, %cst_f0 138 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 139 /// ``` 140 struct MultiReduceToContract 141 : public OpRewritePattern<vector::MultiDimReductionOp> { 142 using OpRewritePattern::OpRewritePattern; 143 144 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 145 PatternRewriter &rewriter) const override { 146 if (reduceOp.getKind() != vector::CombiningKind::ADD) 147 return failure(); 148 Operation *mulOp = reduceOp.getSource().getDefiningOp(); 149 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 150 return failure(); 151 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 152 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 153 SmallVector<AffineExpr> exprs; 154 SmallVector<vector::IteratorType> iteratorTypes; 155 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 156 if (!isReduceDim.value()) { 157 iteratorTypes.push_back(vector::IteratorType::parallel); 158 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 159 } else { 160 iteratorTypes.push_back(vector::IteratorType::reduction); 161 } 162 } 163 auto dstMap = 164 AffineMap::get(/*dimCount=*/reductionMask.size(), 165 /*symbolCount=*/0, exprs, reduceOp.getContext()); 166 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 167 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(), 168 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 169 rewriter.getArrayAttr(llvm::to_vector(llvm::map_range( 170 iteratorTypes, [&](IteratorType t) -> mlir::Attribute { 171 return IteratorTypeAttr::get(rewriter.getContext(), t); 172 })))); 173 return success(); 174 } 175 }; 176 177 /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user. 178 /// Ex: 179 /// ``` 180 /// %0 = vector.transpose %arg0, [2, 0, 1] 181 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 182 /// %1 = vector.contract {indexing_maps = [ 183 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 184 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 185 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 186 /// iterator_types = ["parallel", "parallel", "reduction"], 187 /// kind = add} %0, %arg1, %cst_f0 188 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 189 /// ``` 190 /// Gets converted to: 191 /// ``` 192 /// %1 = vector.contract {indexing_maps = [ 193 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 194 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 195 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 196 /// iterator_types = ["parallel", "parallel", "reduction"], 197 /// kind = add} %arg0, %arg1, %cst_f0 198 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 199 /// ``` 200 struct CombineContractABTranspose final 201 : public OpRewritePattern<vector::ContractionOp> { 202 using OpRewritePattern::OpRewritePattern; 203 204 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 205 PatternRewriter &rewriter) const override { 206 SmallVector<AffineMap> maps = 207 llvm::to_vector<4>(contractOp.getIndexingMapsArray()); 208 Value lhs = contractOp.getLhs(); 209 Value rhs = contractOp.getRhs(); 210 size_t index = 0; 211 bool changed = false; 212 for (Value *operand : {&lhs, &rhs}) { 213 AffineMap &map = maps[index++]; 214 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 215 if (!transposeOp) 216 continue; 217 AffineMap permutationMap = AffineMap::getPermutationMap( 218 transposeOp.getPermutation(), contractOp.getContext()); 219 map = inversePermutation(permutationMap).compose(map); 220 *operand = transposeOp.getVector(); 221 changed = true; 222 } 223 if (!changed) 224 return failure(); 225 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 226 contractOp, lhs, rhs, contractOp.getAcc(), 227 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 228 return success(); 229 } 230 }; 231 232 /// Merges accumulator and result transposes into contract. 233 /// 234 /// For example: 235 /// ```mlir 236 /// %accT = vector.transpose %acc, [0, 2, 1] 237 /// : vector<2x8x4xf32> to vector<2x4x8xf32> 238 /// %contract = vector.contract { 239 /// indexing_maps = [ 240 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, 241 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, 242 /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 243 /// ], 244 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"], 245 /// kind = #vector.kind<add> 246 /// } %lhs, %rhs, %accT 247 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32> 248 /// %0 = vector.transpose %contract, [0, 2, 1] 249 /// : vector<2x4x8xf32> to vector<2x8x4> 250 /// ``` 251 /// Becomes: 252 /// ```mlir 253 /// %0 = vector.contract { 254 /// indexing_maps = [ 255 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, 256 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, 257 /// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> 258 /// ], 259 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"], 260 /// kind = #vector.kind<add> 261 /// } %lhs, %rhs, %acc 262 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32> 263 /// ``` 264 struct CombineContractResultTranspose final 265 : public OpRewritePattern<vector::TransposeOp> { 266 using OpRewritePattern::OpRewritePattern; 267 268 LogicalResult matchAndRewrite(vector::TransposeOp resTOp, 269 PatternRewriter &rewriter) const override { 270 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>(); 271 if (!contractOp || !contractOp->hasOneUse()) 272 return failure(); 273 274 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>(); 275 if (!accTOp) 276 return failure(); 277 278 MLIRContext *context = contractOp.getContext(); 279 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray()); 280 AffineMap contractMap = maps.back(); 281 282 // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B. 283 // To index into A in contract, we need revert(f)(g(C)) -> A. 284 auto accTMap = 285 AffineMap::getPermutationMap(accTOp.getPermutation(), context); 286 287 // Contract performs g(C) -> D. Result transpose performs h(D) -> E. 288 // To index into E in contract, we need h(g(C)) -> E. 289 auto resTMap = 290 AffineMap::getPermutationMap(resTOp.getPermutation(), context); 291 auto combinedResMap = resTMap.compose(contractMap); 292 293 // The accumulator and result share the same indexing map. So they should be 294 // the same to be able to merge. This means combinedResMap is the same as 295 // inversePermutation(accTMap).compose(contractMap), which means 296 if (inversePermutation(accTMap) != resTMap) 297 return failure(); 298 maps.back() = combinedResMap; 299 300 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 301 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(), 302 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 303 return success(); 304 } 305 }; 306 307 /// Merge BroadcastOp into ContractionOp user. 308 /// Ex: 309 /// ``` 310 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 311 /// %1 = vector.contract {indexing_maps = [ 312 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 313 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 314 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 315 /// iterator_types = ["parallel", "parallel", "reduction"], 316 /// kind = add} %0, %arg1, %cst_f0 317 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 318 /// ``` 319 /// Gets converted to: 320 /// ``` 321 /// %1 = vector.contract {indexing_maps = [ 322 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 323 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 324 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 325 /// iterator_types = ["parallel", "parallel", "reduction"], 326 /// kind = add} %arg0, %arg1, %cst_f0 327 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 328 /// ``` 329 struct CombineContractBroadcast 330 : public OpRewritePattern<vector::ContractionOp> { 331 using OpRewritePattern::OpRewritePattern; 332 333 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 334 PatternRewriter &rewriter) const override { 335 SmallVector<AffineMap> maps = 336 llvm::to_vector<4>(contractOp.getIndexingMapsArray()); 337 Value lhs = contractOp.getLhs(); 338 Value rhs = contractOp.getRhs(); 339 size_t index = 0; 340 bool changed = false; 341 for (Value *operand : {&lhs, &rhs}) { 342 AffineMap &map = maps[index++]; 343 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 344 if (!broadcast) 345 continue; 346 // contractionOp can only take vector as operands. 347 auto srcType = dyn_cast<VectorType>(broadcast.getSourceType()); 348 if (!srcType || 349 srcType.getRank() == broadcast.getResultVectorType().getRank()) 350 continue; 351 int64_t rankDiff = 352 broadcast.getResultVectorType().getRank() - srcType.getRank(); 353 bool innerDimBroadcast = false; 354 SmallVector<AffineExpr> originalDims; 355 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 356 if (dim.value() != broadcast.getResultVectorType().getDimSize( 357 rankDiff + dim.index())) { 358 innerDimBroadcast = true; 359 break; 360 } 361 originalDims.push_back( 362 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 363 } 364 // Contract doesn't support inner dimension broadcast. Once this is 365 // relaxed we can remove this case. 366 if (innerDimBroadcast) 367 continue; 368 369 // It would be incorrect to fold a broadcast onto a reduction dimension 370 // of non-unit size. 371 bool nonUnitDimReductionBroadcast = false; 372 for (int64_t i = 0; i < rankDiff; ++i) { 373 if (broadcast.getResultVectorType().getDimSize(i) != 1 && 374 isReductionIterator(contractOp.getIteratorTypes() 375 .getValue()[map.getDimPosition(i)])) { 376 nonUnitDimReductionBroadcast = true; 377 break; 378 } 379 } 380 if (nonUnitDimReductionBroadcast) 381 continue; 382 383 AffineMap broadcastMap = 384 AffineMap::get(broadcast.getResultVectorType().getRank(), 0, 385 originalDims, contractOp.getContext()); 386 map = broadcastMap.compose(map); 387 *operand = broadcast.getSource(); 388 changed = true; 389 } 390 391 if (!changed) 392 return failure(); 393 394 // Determine which dims are usused, now that the maps have been composed 395 // with the broadcast maps. 396 llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps); 397 // Compress unused dims. 398 for (auto &m : maps) 399 m = compressDims(m, unusedDimsBitVector); 400 // Compute the combined iterators. 401 SmallVector<Attribute> iterators; 402 for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) { 403 if (!unusedDimsBitVector.test(i)) 404 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); 405 } 406 // Check that compressing unused dims isn't removing all reduction dimension 407 // pairs. For example, if the vector.contract had only one reduction 408 // iterator and that was a unit-dimension created by a broadcast, 409 // then we should bail here, otherwise we would create a contract without 410 // a reduction dimension pair. 411 bool hasReductionIteratorApplyingOnBothSides = false; 412 for (unsigned i = 0; i < iterators.size(); ++i) { 413 if (!isReductionIterator(iterators[i])) 414 continue; 415 if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) { 416 hasReductionIteratorApplyingOnBothSides = true; 417 break; 418 } 419 } 420 if (!hasReductionIteratorApplyingOnBothSides) 421 return failure(); 422 423 // If the compressed maps have a dimension that is not used by either LHS or 424 // RHS then the ContractionOp verifier would fail. 425 if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) 426 return failure(); 427 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 428 contractOp, lhs, rhs, contractOp.getAcc(), 429 rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); 430 return success(); 431 } 432 }; 433 434 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and 435 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when 436 /// casting ops are around these operations. 437 /// Ex: 438 /// ``` 439 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> 440 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 441 /// ``` 442 /// Gets converted to: 443 /// ``` 444 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> 445 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> 446 /// ``` 447 struct ReorderCastOpsOnBroadcast 448 : public OpInterfaceRewritePattern<CastOpInterface> { 449 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 450 451 LogicalResult matchAndRewrite(CastOpInterface op, 452 PatternRewriter &rewriter) const override { 453 if (op->getNumOperands() != 1) 454 return failure(); 455 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); 456 if (!bcastOp) 457 return failure(); 458 459 Type castResTy = getElementTypeOrSelf(op->getResult(0)); 460 if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType())) 461 castResTy = vecTy.clone(castResTy); 462 auto *castOp = 463 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 464 bcastOp.getSource(), castResTy, op->getAttrs()); 465 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 466 op, op->getResult(0).getType(), castOp->getResult(0)); 467 return success(); 468 } 469 }; 470 471 /// Reorders elementwise(transpose) to transpose(elementwise). This makes 472 /// transpose ops and contraction ops closer, which kicks in 473 /// CombineContractABTranspose pattern when elementwise ops are between these 474 /// operations. Ex: 475 /// ``` 476 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 477 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 478 /// %r = arith.addf %at, %bt : vector<2x4xf32> 479 /// ``` 480 /// Gets converted to: 481 /// ``` 482 /// %0 = arith.addf %a, %b : vector<4x2xf32> 483 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> 484 /// ``` 485 struct ReorderElementwiseOpsOnTranspose final 486 : public OpTraitRewritePattern<OpTrait::Elementwise> { 487 using OpTraitRewritePattern::OpTraitRewritePattern; 488 LogicalResult matchAndRewrite(Operation *op, 489 PatternRewriter &rewriter) const override { 490 if (op->getNumResults() != 1 || op->getNumRegions() != 0) 491 return failure(); 492 493 // Make sure all operands are transpose/constant ops and collect their 494 // transposition maps. 495 SmallVector<ArrayRef<int64_t>> transposeMaps; 496 transposeMaps.reserve(op->getNumOperands()); 497 // Record the initial type before transposition. We'll use its shape later. 498 // Any type will do here as we will check all transpose maps are the same. 499 VectorType srcType; 500 for (Value operand : op->getOperands()) { 501 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 502 if (transposeOp) { 503 transposeMaps.push_back(transposeOp.getPermutation()); 504 srcType = transposeOp.getSourceVectorType(); 505 } else if (!matchPattern(operand, m_Constant())) { 506 return failure(); 507 } 508 } 509 if (transposeMaps.empty()) 510 return failure(); 511 // This is an elementwise op, so all transposed operands should have the 512 // same type. We need to additionally check that all transposes uses the 513 // same map. 514 if (!llvm::all_equal(transposeMaps)) 515 return rewriter.notifyMatchFailure(op, "different transpose map"); 516 517 SmallVector<Value> srcValues; 518 srcValues.reserve(op->getNumOperands()); 519 520 // If there are constant operands, we need to insert inverse transposes for 521 // them. Calculate the inverse order first. 522 auto order = transposeMaps.front(); 523 SmallVector<int64_t> invOrder(order.size()); 524 for (int i = 0, e = order.size(); i < e; ++i) 525 invOrder[order[i]] = i; 526 527 for (Value operand : op->getOperands()) { 528 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 529 if (transposeOp) { 530 srcValues.push_back(transposeOp.getVector()); 531 } else { 532 // This is a constant. Create a reverse transpose op for it. 533 auto vectorType = 534 srcType.clone(cast<VectorType>(operand.getType()).getElementType()); 535 srcValues.push_back(rewriter.create<vector::TransposeOp>( 536 operand.getLoc(), vectorType, operand, invOrder)); 537 } 538 } 539 540 auto vectorType = srcType.clone( 541 cast<VectorType>(op->getResultTypes()[0]).getElementType()); 542 Operation *elementwiseOp = 543 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, 544 vectorType, op->getAttrs()); 545 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 546 op, op->getResultTypes()[0], elementwiseOp->getResult(0), 547 transposeMaps.front()); 548 return success(); 549 } 550 }; 551 552 // Returns the values in `arrayAttr` as an integer vector. 553 static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) { 554 return llvm::to_vector<4>( 555 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 556 [](IntegerAttr attr) { return attr.getInt(); })); 557 } 558 559 // Shuffles vector.bitcast op after vector.extract op. 560 // 561 // This transforms IR like: 562 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 563 // %1 = vector.extract %0[3] : f16 from vector<8xf16> 564 // Into: 565 // %0 = vector.extract %src[1] : f32 from vector<4xf32> 566 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 567 // %2 = vector.extract %1[1] : f16 from vector<2xf16> 568 struct BubbleDownVectorBitCastForExtract 569 : public OpRewritePattern<vector::ExtractOp> { 570 using OpRewritePattern::OpRewritePattern; 571 572 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 573 PatternRewriter &rewriter) const override { 574 // Only support extracting scalars for now. 575 if (extractOp.getSourceVectorType().getRank() != 1) 576 return failure(); 577 578 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 579 if (!castOp) 580 return failure(); 581 582 VectorType castSrcType = castOp.getSourceVectorType(); 583 VectorType castDstType = castOp.getResultVectorType(); 584 assert(castSrcType.getRank() == castDstType.getRank()); 585 586 // Fail to match if we only have one element in the cast op source. 587 // This is to avoid infinite loop given that this pattern can generate 588 // such cases. 589 if (castSrcType.getNumElements() == 1) 590 return failure(); 591 592 // Only support casting to a larger number of elements or now. 593 // E.g., vector<4xf32> -> vector<8xf16>. 594 if (castSrcType.getNumElements() > castDstType.getNumElements()) 595 return failure(); 596 597 unsigned expandRatio = 598 castDstType.getNumElements() / castSrcType.getNumElements(); 599 600 auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t { 601 assert(values[0].is<Attribute>() && "Unexpected non-constant index"); 602 return cast<IntegerAttr>(values[0].get<Attribute>()).getInt(); 603 }; 604 605 uint64_t index = getFirstIntValue(extractOp.getMixedPosition()); 606 607 // Get the single scalar (as a vector) in the source value that packs the 608 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 609 Location loc = extractOp.getLoc(); 610 Value packedValue = rewriter.create<vector::ExtractOp>( 611 loc, castOp.getSource(), index / expandRatio); 612 Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType()); 613 Value zero = rewriter.create<arith::ConstantOp>( 614 loc, packedVecType, rewriter.getZeroAttr(packedVecType)); 615 packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero, 616 /*position=*/0); 617 618 // Cast it to a vector with the desired scalar's type. 619 // E.g. f32 -> vector<2xf16> 620 VectorType packedType = 621 VectorType::get({expandRatio}, castDstType.getElementType()); 622 Value castedValue = 623 rewriter.create<vector::BitCastOp>(loc, packedType, packedValue); 624 625 // Finally extract the desired scalar. 626 rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue, 627 index % expandRatio); 628 return success(); 629 } 630 }; 631 632 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 633 // 634 // This transforms IR like: 635 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 636 // %0 = vector.extract_strided_slice %cast { 637 // offsets = [4], sizes = [4], strides = [1] 638 // } : vector<8xf16> to vector<4xf16> 639 // Into: 640 // %0 = vector.extract_strided_slice %src { 641 // offsets = [2], sizes = [2], strides = [1] 642 // } : vector<4xf32> to vector<2xf32> 643 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 644 struct BubbleDownBitCastForStridedSliceExtract 645 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 646 using OpRewritePattern::OpRewritePattern; 647 648 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 649 PatternRewriter &rewriter) const override { 650 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 651 if (!castOp) 652 return failure(); 653 654 VectorType castSrcType = castOp.getSourceVectorType(); 655 VectorType castDstType = castOp.getResultVectorType(); 656 assert(castSrcType.getRank() == castDstType.getRank()); 657 658 int64_t castSrcLastDim = castSrcType.getShape().back(); 659 int64_t castDstLastDim = castDstType.getShape().back(); 660 // Require casting to more elements for now; other cases to be implemented. 661 if (castSrcLastDim > castDstLastDim) 662 return failure(); 663 664 // Only accept all one strides for now. 665 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(), 666 [](const APInt &val) { return !val.isOne(); })) 667 return failure(); 668 669 unsigned rank = extractOp.getSourceVectorType().getRank(); 670 assert(castDstLastDim % castSrcLastDim == 0); 671 int64_t expandRatio = castDstLastDim / castSrcLastDim; 672 673 // If we have a less number of offsets than the rank, then implicitly we 674 // are selecting the full range for the last bitcasted dimension; other 675 // dimensions aren't affected. Otherwise, we need to scale down the last 676 // dimension's offset given we are extracting from less elements now. 677 ArrayAttr newOffsets = extractOp.getOffsets(); 678 if (newOffsets.size() == rank) { 679 SmallVector<int64_t> offsets = getIntValueVector(newOffsets); 680 if (offsets.back() % expandRatio != 0) 681 return failure(); 682 offsets.back() = offsets.back() / expandRatio; 683 newOffsets = rewriter.getI64ArrayAttr(offsets); 684 } 685 686 // Similarly for sizes. 687 ArrayAttr newSizes = extractOp.getSizes(); 688 if (newSizes.size() == rank) { 689 SmallVector<int64_t> sizes = getIntValueVector(newSizes); 690 if (sizes.back() % expandRatio != 0) 691 return failure(); 692 sizes.back() = sizes.back() / expandRatio; 693 newSizes = rewriter.getI64ArrayAttr(sizes); 694 } 695 696 SmallVector<int64_t> dims = 697 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape()); 698 dims.back() = dims.back() / expandRatio; 699 VectorType newExtractType = 700 VectorType::get(dims, castSrcType.getElementType()); 701 702 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 703 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, 704 newSizes, extractOp.getStrides()); 705 706 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 707 extractOp, extractOp.getType(), newExtractOp); 708 709 return success(); 710 } 711 }; 712 713 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 714 // 715 // This transforms IR like: 716 // %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4> 717 // %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8> 718 // Into: 719 // %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8> 720 // %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8> 721 // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8> 722 // 723 struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { 724 using OpRewritePattern::OpRewritePattern; 725 726 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 727 PatternRewriter &rewriter) const override { 728 VectorType castSrcType = bitcastOp.getSourceVectorType(); 729 VectorType castDstType = bitcastOp.getResultVectorType(); 730 731 // 0-D and scalable vectors are not supported yet. 732 if (castSrcType.getRank() == 0 || castSrcType.isScalable() || 733 castDstType.isScalable()) 734 return failure(); 735 736 int64_t castSrcLastDim = castSrcType.getShape().back(); 737 int64_t castDstLastDim = castDstType.getShape().back(); 738 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim; 739 int64_t ratio; 740 if (isNumElemsShrink) { 741 assert(castSrcLastDim % castDstLastDim == 0); 742 ratio = castSrcLastDim / castDstLastDim; 743 } else { 744 assert(castDstLastDim % castSrcLastDim == 0); 745 ratio = castDstLastDim / castSrcLastDim; 746 } 747 748 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>(); 749 if (!insertOp) 750 return failure(); 751 752 // Only vector sources are supported for now. 753 auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType()); 754 if (!insertSrcType) 755 return failure(); 756 757 // Bitcast the source. 758 SmallVector<int64_t> srcDims(insertSrcType.getShape()); 759 srcDims.back() = 760 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio; 761 VectorType newCastSrcType = 762 VectorType::get(srcDims, castDstType.getElementType()); 763 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 764 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); 765 766 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape()); 767 dstDims.back() = 768 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio; 769 VectorType newCastDstType = 770 VectorType::get(dstDims, castDstType.getElementType()); 771 772 // Bitcast the destination. 773 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 774 bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); 775 776 // Generate new insert. 777 rewriter.replaceOpWithNewOp<vector::InsertOp>( 778 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition()); 779 return success(); 780 } 781 }; 782 783 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 784 // 785 // This transforms IR like: 786 // %0 = vector.insert_strided_slice %src, %dst { 787 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 788 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 789 // Into: 790 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 791 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 792 // %2 = vector.insert_strided_slice %src, %dst { 793 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 794 struct BubbleUpBitCastForStridedSliceInsert 795 : public OpRewritePattern<vector::BitCastOp> { 796 using OpRewritePattern::OpRewritePattern; 797 798 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 799 PatternRewriter &rewriter) const override { 800 VectorType castSrcType = bitcastOp.getSourceVectorType(); 801 VectorType castDstType = bitcastOp.getResultVectorType(); 802 assert(castSrcType.getRank() == castDstType.getRank()); 803 // Skip 0-D vector which will not from InsertStridedSliceOp. 804 if (castSrcType.getRank() == 0) 805 return failure(); 806 807 int64_t castSrcLastDim = castSrcType.getShape().back(); 808 int64_t castDstLastDim = castDstType.getShape().back(); 809 // Require casting to less elements for now; other cases to be implemented. 810 if (castSrcLastDim < castDstLastDim) 811 return failure(); 812 813 assert(castSrcLastDim % castDstLastDim == 0); 814 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 815 816 auto insertOp = 817 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>(); 818 if (!insertOp) 819 return failure(); 820 821 // Only accept all one strides for now. 822 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(), 823 [](const APInt &val) { return !val.isOne(); })) 824 return failure(); 825 826 unsigned rank = insertOp.getSourceVectorType().getRank(); 827 // Require insert op to have the same rank for the source and destination 828 // vector; other cases to be implemented. 829 if (rank != insertOp.getDestVectorType().getRank()) 830 return failure(); 831 832 // Requires that shape of insert op src is castable to dstType. 833 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth(); 834 unsigned destinationWidth = 835 castDstType.getElementType().getIntOrFloatBitWidth(); 836 unsigned numElements = destinationWidth / sourceWidth; 837 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0) 838 return failure(); 839 840 ArrayAttr newOffsets = insertOp.getOffsets(); 841 assert(newOffsets.size() == rank); 842 SmallVector<int64_t> offsets = getIntValueVector(newOffsets); 843 if (offsets.back() % shrinkRatio != 0) 844 return failure(); 845 offsets.back() = offsets.back() / shrinkRatio; 846 newOffsets = rewriter.getI64ArrayAttr(offsets); 847 848 SmallVector<int64_t> srcDims = 849 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 850 srcDims.back() = srcDims.back() / shrinkRatio; 851 VectorType newCastSrcType = 852 VectorType::get(srcDims, castDstType.getElementType()); 853 854 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 855 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); 856 857 SmallVector<int64_t> dstDims = 858 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 859 dstDims.back() = dstDims.back() / shrinkRatio; 860 VectorType newCastDstType = 861 VectorType::get(dstDims, castDstType.getElementType()); 862 863 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 864 bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); 865 866 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 867 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 868 insertOp.getStrides()); 869 870 return success(); 871 } 872 }; 873 874 // Breaks down vector.bitcast op 875 // 876 // This transforms IR like: 877 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 878 // Into: 879 // %cst = vector.splat %c0_f32 : vector<4xf32> 880 // %1 = vector.extract_strided_slice %0 { 881 // offsets = [0], sizes = [4], strides = [1] 882 // } : vector<8xf16> to vector<4xf16> 883 // %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32> 884 // %4 = vector.insert_strided_slice %2, %cst { 885 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 886 // %5 = vector.extract_strided_slice %0 { 887 // offsets = [4], sizes = [4], strides = [1] 888 // } : vector<8xf16> to vector<4xf16> 889 // %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32> 890 // %7 = vector.insert_strided_slice %6, %cst { 891 // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> 892 struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> { 893 using OpRewritePattern::OpRewritePattern; 894 895 public: 896 BreakDownVectorBitCast(MLIRContext *context, 897 std::function<bool(vector::BitCastOp)> controlFn, 898 PatternBenefit benefit) 899 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} 900 901 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 902 PatternRewriter &rewriter) const override { 903 904 if (controlFn && !controlFn(bitcastOp)) 905 return failure(); 906 907 VectorType castSrcType = bitcastOp.getSourceVectorType(); 908 VectorType castDstType = bitcastOp.getResultVectorType(); 909 assert(castSrcType.getRank() == castDstType.getRank()); 910 911 // Only support rank 1 case for now. 912 if (castSrcType.getRank() != 1) 913 return failure(); 914 915 int64_t castSrcLastDim = castSrcType.getShape().back(); 916 int64_t castDstLastDim = castDstType.getShape().back(); 917 // Require casting to less elements for now; other cases to be implemented. 918 if (castSrcLastDim < castDstLastDim) 919 return failure(); 920 921 assert(castSrcLastDim % castDstLastDim == 0); 922 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 923 // Nothing to do if it is already bitcasting to a single element. 924 if (castSrcLastDim == shrinkRatio) 925 return failure(); 926 927 Location loc = bitcastOp.getLoc(); 928 Type elemType = castDstType.getElementType(); 929 assert(elemType.isSignlessIntOrIndexOrFloat()); 930 931 Value zero = rewriter.create<arith::ConstantOp>( 932 loc, elemType, rewriter.getZeroAttr(elemType)); 933 Value res = rewriter.create<SplatOp>(loc, castDstType, zero); 934 935 SmallVector<int64_t> sliceShape{castDstLastDim}; 936 SmallVector<int64_t> strides{1}; 937 VectorType newCastDstType = 938 VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio}, 939 castDstType.getElementType()); 940 941 for (int i = 0, e = shrinkRatio; i < e; ++i) { 942 Value extracted = rewriter.create<ExtractStridedSliceOp>( 943 loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim}, 944 sliceShape, strides); 945 Value bitcast = 946 rewriter.create<BitCastOp>(loc, newCastDstType, extracted); 947 res = rewriter.create<InsertStridedSliceOp>( 948 loc, bitcast, res, 949 ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides); 950 } 951 rewriter.replaceOp(bitcastOp, res); 952 return success(); 953 } 954 955 private: 956 std::function<bool(BitCastOp)> controlFn; 957 }; 958 959 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: 960 /// ``` 961 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex> 962 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex> 963 /// %r = arith.addi %a, %b : vector<1x4xindex> 964 /// ``` 965 /// Gets converted to: 966 /// ``` 967 /// %r = arith.addi %arg0, %arg1 : index 968 /// %b = vector.broadcast %r : index to vector<1x4xindex> 969 /// ``` 970 /// 971 /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting 972 /// ops. 973 struct ReorderElementwiseOpsOnBroadcast final 974 : public OpTraitRewritePattern<OpTrait::Elementwise> { 975 using OpTraitRewritePattern::OpTraitRewritePattern; 976 LogicalResult matchAndRewrite(Operation *op, 977 PatternRewriter &rewriter) const override { 978 if (op->getNumResults() != 1) 979 return failure(); 980 if (!llvm::isa<ShapedType>(op->getResults()[0].getType())) 981 return failure(); 982 if (!OpTrait::hasElementwiseMappableTraits(op)) 983 return failure(); 984 if (op->getNumOperands() == 0 || 985 op->getResults()[0].getType() != op->getOperand(0).getType()) { 986 return failure(); 987 } 988 // Avoid operations that only accept vector types, since broadcast 989 // source might be scalar types. 990 if (isa<vector::FMAOp>(op)) { 991 return failure(); 992 } 993 994 // Get the type of the lhs operand 995 auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp(); 996 if (!lhsBcastOrSplat || 997 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat)) 998 return failure(); 999 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType(); 1000 1001 // Make sure that all operands are broadcast from identical types: 1002 // * scalar (`vector.broadcast` + `vector.splat`), or 1003 // * vector (`vector.broadcast`). 1004 // Otherwise the re-ordering wouldn't be safe. 1005 if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) { 1006 auto bcast = val.getDefiningOp<vector::BroadcastOp>(); 1007 if (bcast) 1008 return (bcast.getOperand().getType() == lhsBcastOrSplatType); 1009 auto splat = val.getDefiningOp<vector::SplatOp>(); 1010 if (splat) 1011 return (splat.getOperand().getType() == lhsBcastOrSplatType); 1012 return false; 1013 })) { 1014 return failure(); 1015 } 1016 1017 // Collect the source values before broadcasting 1018 SmallVector<Value> srcValues; 1019 srcValues.reserve(op->getNumOperands()); 1020 for (Value operand : op->getOperands()) { 1021 srcValues.push_back(operand.getDefiningOp()->getOperand(0)); 1022 } 1023 1024 // Create the "elementwise" Op 1025 Operation *elementwiseOp = 1026 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, 1027 lhsBcastOrSplatType, op->getAttrs()); 1028 1029 // Replace the original Op with the elementwise Op 1030 auto vectorType = op->getResultTypes()[0]; 1031 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1032 op, vectorType, elementwiseOp->getResults()); 1033 1034 return success(); 1035 } 1036 }; 1037 1038 // Helper that returns a vector comparison that constructs a mask: 1039 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 1040 // 1041 // If `dim == 0` then the result will be a 0-D vector. 1042 // 1043 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 1044 // much more compact, IR for this operation, but LLVM eventually 1045 // generates more elaborate instructions for this intrinsic since it 1046 // is very conservative on the boundary conditions. 1047 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 1048 bool force32BitVectorIndices, int64_t dim, 1049 Value b, Value *off = nullptr) { 1050 auto loc = op->getLoc(); 1051 // If we can assume all indices fit in 32-bit, we perform the vector 1052 // comparison in 32-bit to get a higher degree of SIMD parallelism. 1053 // Otherwise we perform the vector comparison using 64-bit indices. 1054 Type idxType = 1055 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1056 DenseIntElementsAttr indicesAttr; 1057 if (dim == 0 && force32BitVectorIndices) { 1058 indicesAttr = DenseIntElementsAttr::get( 1059 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 1060 } else if (dim == 0) { 1061 indicesAttr = DenseIntElementsAttr::get( 1062 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 1063 } else if (force32BitVectorIndices) { 1064 indicesAttr = rewriter.getI32VectorAttr( 1065 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 1066 } else { 1067 indicesAttr = rewriter.getI64VectorAttr( 1068 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 1069 } 1070 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 1071 // Add in an offset if requested. 1072 if (off) { 1073 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); 1074 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 1075 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 1076 } 1077 // Construct the vector comparison. 1078 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); 1079 Value bounds = 1080 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 1081 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 1082 bounds); 1083 } 1084 1085 template <typename ConcreteOp> 1086 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 1087 public: 1088 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt, 1089 PatternBenefit benefit = 1) 1090 : mlir::OpRewritePattern<ConcreteOp>(context, benefit), 1091 force32BitVectorIndices(enableIndexOpt) {} 1092 1093 LogicalResult matchAndRewrite(ConcreteOp xferOp, 1094 PatternRewriter &rewriter) const override { 1095 if (!xferOp.hasOutOfBoundsDim()) 1096 return failure(); 1097 1098 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty()) 1099 return failure(); 1100 1101 Location loc = xferOp->getLoc(); 1102 VectorType vtp = xferOp.getVectorType(); 1103 1104 // Create the in-bounds mask with all elements between [0 .. dim - offset) 1105 // set and [dim - offset .. vector_length) unset. 1106 // 1107 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1108 // dimensions here. 1109 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1; 1110 Value off = xferOp.getIndices()[lastIndex]; 1111 Value dim = 1112 vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex); 1113 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); 1114 Value mask = rewriter.create<vector::CreateMaskOp>( 1115 loc, 1116 VectorType::get(vtp.getShape(), rewriter.getI1Type(), 1117 vtp.getScalableDims()), 1118 b); 1119 if (xferOp.getMask()) { 1120 // Intersect the in-bounds with the mask specified as an op parameter. 1121 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask()); 1122 } 1123 1124 rewriter.modifyOpInPlace(xferOp, [&]() { 1125 xferOp.getMaskMutable().assign(mask); 1126 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 1127 }); 1128 1129 return success(); 1130 } 1131 1132 private: 1133 const bool force32BitVectorIndices; 1134 }; 1135 1136 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 1137 class VectorCreateMaskOpConversion 1138 : public OpRewritePattern<vector::CreateMaskOp> { 1139 public: 1140 explicit VectorCreateMaskOpConversion(MLIRContext *context, 1141 bool enableIndexOpt, 1142 PatternBenefit benefit = 1) 1143 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit), 1144 force32BitVectorIndices(enableIndexOpt) {} 1145 1146 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1147 PatternRewriter &rewriter) const override { 1148 auto dstType = op.getType(); 1149 if (cast<VectorType>(dstType).isScalable()) 1150 return failure(); 1151 int64_t rank = dstType.getRank(); 1152 if (rank > 1) 1153 return failure(); 1154 rewriter.replaceOp( 1155 op, buildVectorComparison(rewriter, op, force32BitVectorIndices, 1156 rank == 0 ? 0 : dstType.getDimSize(0), 1157 op.getOperand(0))); 1158 return success(); 1159 } 1160 1161 private: 1162 const bool force32BitVectorIndices; 1163 }; 1164 1165 /// Returns true if all the `i1` elements of `constantOp` are set to `value`. 1166 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) { 1167 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue()); 1168 // TODO: Support non-dense constant. 1169 if (!denseAttr) 1170 return false; 1171 1172 assert(denseAttr.getElementType().isInteger(1) && "Unexpected type"); 1173 return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value; 1174 } 1175 1176 /// Folds a select operation between an all-true and all-false vector. For now, 1177 /// only single element vectors (i.e., vector<1xi1>) are supported. That is: 1178 /// 1179 /// %true = arith.constant dense<true> : vector<1xi1> 1180 /// %false = arith.constant dense<false> : vector<1xi1> 1181 /// %result = arith.select %cond, %true, %false : i1, vector<1xi1> 1182 /// => 1183 /// %result = vector.broadcast %cond : i1 to vector<1xi1> 1184 /// 1185 /// InstCombine seems to handle vectors with multiple elements but not the 1186 /// single element ones. 1187 struct FoldI1Select : public OpRewritePattern<arith::SelectOp> { 1188 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1189 1190 LogicalResult matchAndRewrite(arith::SelectOp selectOp, 1191 PatternRewriter &rewriter) const override { 1192 auto vecType = dyn_cast<VectorType>(selectOp.getType()); 1193 if (!vecType || !vecType.getElementType().isInteger(1)) 1194 return failure(); 1195 1196 // Only scalar conditions can be folded. 1197 Value cond = selectOp.getCondition(); 1198 if (isa<VectorType>(cond.getType())) 1199 return failure(); 1200 1201 // TODO: Support n-D and scalable vectors. 1202 if (vecType.getRank() != 1 || vecType.isScalable()) 1203 return failure(); 1204 1205 // TODO: Support vectors with multiple elements. 1206 if (vecType.getShape()[0] != 1) 1207 return failure(); 1208 1209 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>(); 1210 if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true)) 1211 return failure(); 1212 1213 auto falseConst = 1214 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>(); 1215 if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false)) 1216 return failure(); 1217 1218 // Replace select with its condition broadcasted to single element vector. 1219 auto elemType = rewriter.getIntegerType(vecType.getNumElements()); 1220 auto bcastType = VectorType::get(/*shape=*/{1}, elemType); 1221 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond); 1222 return success(); 1223 } 1224 }; 1225 1226 /// Returns the number of dims can be folded away from transfer ops. It returns 1227 /// a failure if it can not determine the number of dims to be folded. 1228 /// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and 1229 /// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims 1230 /// can be dropped by memref.subview ops. 1231 /// Example 2: it returns "1" if `srcType` is the same memref type with 1232 /// [8192, 16, 8, 1] strides. 1233 static FailureOr<size_t> 1234 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { 1235 SmallVector<int64_t> srcStrides; 1236 int64_t srcOffset; 1237 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 1238 return failure(); 1239 1240 auto isUnitDim = [](VectorType type, int dim) { 1241 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim]; 1242 }; 1243 1244 // According to vector.transfer_read/write semantics, the vector can be a 1245 // slice. Thus, we have to offset the check index with `rankDiff` in 1246 // `srcStrides` and source dim sizes. 1247 size_t result = 0; 1248 int rankDiff = srcType.getRank() - vectorType.getRank(); 1249 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) { 1250 // Check that the inner dim size is 1 for both memref type and vector slice. 1251 // It can be folded only if they are 1 and the stride is 1. 1252 int dim = vectorType.getRank() - i - 1; 1253 if (srcStrides[dim + rankDiff] != 1 || 1254 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim)) 1255 break; 1256 result++; 1257 } 1258 return result; 1259 } 1260 1261 /// Drop inner most contiguous unit dimensions from transfer_read operand. 1262 class DropInnerMostUnitDimsTransferRead 1263 : public OpRewritePattern<vector::TransferReadOp> { 1264 using OpRewritePattern::OpRewritePattern; 1265 1266 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 1267 PatternRewriter &rewriter) const override { 1268 // TODO: support 0-d corner case. 1269 if (readOp.getTransferRank() == 0) 1270 return failure(); 1271 1272 // TODO: support mask. 1273 if (readOp.getMask()) 1274 return failure(); 1275 1276 auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType()); 1277 if (!srcType) 1278 return failure(); 1279 1280 if (!readOp.getPermutationMap().isMinorIdentity()) 1281 return failure(); 1282 1283 auto targetType = readOp.getVectorType(); 1284 if (targetType.getRank() <= 1) 1285 return failure(); 1286 1287 FailureOr<size_t> maybeDimsToDrop = 1288 getTransferFoldableInnerUnitDims(srcType, targetType); 1289 if (failed(maybeDimsToDrop)) 1290 return failure(); 1291 1292 size_t dimsToDrop = maybeDimsToDrop.value(); 1293 if (dimsToDrop == 0) 1294 return failure(); 1295 1296 // Make sure that the indices to be dropped are equal 0. 1297 // TODO: Deal with cases when the indices are not 0. 1298 if (!llvm::all_of(readOp.getIndices().take_back(dimsToDrop), isZeroIndex)) 1299 return failure(); 1300 1301 auto resultTargetVecType = 1302 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 1303 targetType.getElementType(), 1304 targetType.getScalableDims().drop_back(dimsToDrop)); 1305 1306 auto loc = readOp.getLoc(); 1307 SmallVector<OpFoldResult> sizes = 1308 memref::getMixedSizes(rewriter, loc, readOp.getSource()); 1309 SmallVector<OpFoldResult> offsets(srcType.getRank(), 1310 rewriter.getIndexAttr(0)); 1311 SmallVector<OpFoldResult> strides(srcType.getRank(), 1312 rewriter.getIndexAttr(1)); 1313 auto resultMemrefType = 1314 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( 1315 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, 1316 strides)); 1317 ArrayAttr inBoundsAttr = 1318 readOp.getInBounds() 1319 ? rewriter.getArrayAttr( 1320 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) 1321 : ArrayAttr(); 1322 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 1323 loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides); 1324 auto permMap = getTransferMinorIdentityMap( 1325 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType); 1326 Value result = rewriter.create<vector::TransferReadOp>( 1327 loc, resultTargetVecType, rankedReducedView, 1328 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 1329 readOp.getPadding(), 1330 // TODO: support mask. 1331 /*mask=*/Value(), inBoundsAttr); 1332 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 1333 result); 1334 return success(); 1335 } 1336 }; 1337 1338 /// Drop inner most contiguous unit dimensions from transfer_write operand. 1339 /// E.g., 1340 /// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] 1341 /// {in_bounds = [true, true, true, true, true]} 1342 /// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32> 1343 /// 1344 /// will be replaced with 1345 /// 1346 /// %subview = memref.subview %arg0 1347 /// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1] 1348 /// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32> 1349 /// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32> 1350 /// to vector<1x16x16xf32> 1351 /// vector.transfer_write %0, %subview[%c0, %arg2, %c0] 1352 /// {in_bounds = [true, true, true]} 1353 /// : vector<1x16x16xf32>, memref<1x512x16xf32> 1354 class DropInnerMostUnitDimsTransferWrite 1355 : public OpRewritePattern<vector::TransferWriteOp> { 1356 using OpRewritePattern::OpRewritePattern; 1357 1358 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 1359 PatternRewriter &rewriter) const override { 1360 // TODO: support 0-d corner case. 1361 if (writeOp.getTransferRank() == 0) 1362 return failure(); 1363 1364 // TODO: support mask. 1365 if (writeOp.getMask()) 1366 return failure(); 1367 1368 auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType()); 1369 if (!srcType) 1370 return failure(); 1371 1372 if (!writeOp.getPermutationMap().isMinorIdentity()) 1373 return failure(); 1374 1375 auto targetType = writeOp.getVectorType(); 1376 if (targetType.getRank() <= 1) 1377 return failure(); 1378 1379 FailureOr<size_t> maybeDimsToDrop = 1380 getTransferFoldableInnerUnitDims(srcType, targetType); 1381 if (failed(maybeDimsToDrop)) 1382 return failure(); 1383 1384 size_t dimsToDrop = maybeDimsToDrop.value(); 1385 if (dimsToDrop == 0) 1386 return failure(); 1387 1388 auto resultTargetVecType = 1389 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 1390 targetType.getElementType(), 1391 targetType.getScalableDims().drop_back(dimsToDrop)); 1392 1393 Location loc = writeOp.getLoc(); 1394 SmallVector<OpFoldResult> sizes = 1395 memref::getMixedSizes(rewriter, loc, writeOp.getSource()); 1396 SmallVector<OpFoldResult> offsets(srcType.getRank(), 1397 rewriter.getIndexAttr(0)); 1398 SmallVector<OpFoldResult> strides(srcType.getRank(), 1399 rewriter.getIndexAttr(1)); 1400 auto resultMemrefType = 1401 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( 1402 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, 1403 strides)); 1404 ArrayAttr inBoundsAttr = 1405 writeOp.getInBounds() 1406 ? rewriter.getArrayAttr( 1407 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) 1408 : ArrayAttr(); 1409 1410 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 1411 loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides); 1412 auto permMap = getTransferMinorIdentityMap( 1413 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType); 1414 1415 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 1416 loc, resultTargetVecType, writeOp.getVector()); 1417 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 1418 writeOp, shapeCast, rankedReducedView, 1419 writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 1420 // TODO: support mask. 1421 /*mask=*/Value(), inBoundsAttr); 1422 return success(); 1423 } 1424 }; 1425 1426 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul 1427 /// semantics to a contraction suitable for MMT (matrix matrix multiplication 1428 /// with the RHS transposed) lowering. 1429 struct CanonicalizeContractMatmulToMMT final 1430 : OpRewritePattern<vector::ContractionOp> { 1431 using OpRewritePattern::OpRewritePattern; 1432 1433 using FilterConstraintType = 1434 std::function<LogicalResult(vector::ContractionOp op)>; 1435 1436 CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit, 1437 FilterConstraintType constraint) 1438 : OpRewritePattern<vector::ContractionOp>(context, benefit), 1439 filter(std::move(constraint)) {} 1440 1441 LogicalResult matchAndRewrite(vector::ContractionOp op, 1442 PatternRewriter &rewriter) const override { 1443 if (failed(filter(op))) 1444 return failure(); 1445 1446 Location loc = op.getLoc(); 1447 Value lhs = op.getLhs(); 1448 Value rhs = op.getRhs(); 1449 Value res = op.getAcc(); 1450 1451 // Set up the parallel/reduction structure in right form. 1452 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1453 auto infer = [&](MapList m) { 1454 return AffineMap::inferFromExprList(m, op.getContext()); 1455 }; 1456 AffineExpr m; 1457 AffineExpr n; 1458 AffineExpr k; 1459 bindDims(rewriter.getContext(), m, n, k); 1460 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1461 auto iteratorTypes = op.getIteratorTypes().getValue(); 1462 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 1463 if (iteratorTypes.size() != 3 || 1464 !vector::isParallelIterator(iteratorTypes[0]) || 1465 !vector::isParallelIterator(iteratorTypes[1]) || 1466 !vector::isReductionIterator(iteratorTypes[2])) 1467 return rewriter.notifyMatchFailure(op, "contraction is not a gemm"); 1468 1469 // The canonical form is "TNT" = A row-major, B col-major, C row-major. 1470 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); 1471 if (maps == canonicalForm) 1472 return rewriter.notifyMatchFailure(op, "already in the canonical form"); 1473 1474 // Create a vector transpose making sure to emit zero/sign-extend at the 1475 // end. 1476 auto createTranspose = [&rewriter, loc](Value mat) -> Value { 1477 if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) { 1478 Value trans = 1479 rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm); 1480 VectorType newType = 1481 cast<VectorType>(trans.getType()) 1482 .clone(cast<VectorType>(mat.getType()).getElementType()); 1483 return rewriter.create<arith::ExtSIOp>(loc, newType, trans); 1484 } 1485 if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) { 1486 Value trans = 1487 rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm); 1488 VectorType newType = 1489 VectorType::get(cast<VectorType>(trans.getType()).getShape(), 1490 cast<VectorType>(mat.getType()).getElementType()); 1491 return rewriter.create<arith::ExtUIOp>(loc, newType, trans); 1492 } 1493 return rewriter.create<vector::TransposeOp>(loc, mat, perm); 1494 }; 1495 1496 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1497 rhs = createTranspose(rhs); 1498 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1499 lhs = createTranspose(lhs); 1500 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1501 rhs = createTranspose(rhs); 1502 lhs = createTranspose(lhs); 1503 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1504 std::swap(rhs, lhs); 1505 rhs = createTranspose(rhs); 1506 lhs = createTranspose(lhs); 1507 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1508 std::swap(rhs, lhs); 1509 rhs = createTranspose(rhs); 1510 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1511 std::swap(lhs, rhs); 1512 lhs = createTranspose(lhs); 1513 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1514 std::swap(lhs, rhs); 1515 } else { 1516 return rewriter.notifyMatchFailure(op, "unhandled contraction form"); 1517 } 1518 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1519 op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm), 1520 op.getIteratorTypes()); 1521 return success(); 1522 }; 1523 1524 private: 1525 FilterConstraintType filter; 1526 }; 1527 1528 /// Pattern to fold arithmetic extensions on floating point data types into 1529 /// vector contraction operations. linalg.matmul introduces arithmetic 1530 /// extensions on its operands. Please mlir snippets below for more details. 1531 /// ```mlir 1532 /// "linalg.matmul"(%lhs, %rhs, %acc) ({ 1533 /// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32): 1534 /// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32 1535 /// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32 1536 /// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32 1537 /// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32 1538 /// "linalg.yield"(%acc) : (f32) -> () 1539 /// }) 1540 /// ``` 1541 /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor 1542 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`. 1543 /// This pattern folds the arithmetic extensions into the vector contraction and 1544 /// enables the usage of native mixed precision Tensor Core instructions. 1545 struct FoldArithExtIntoContractionOp 1546 : public OpRewritePattern<vector::ContractionOp> { 1547 using OpRewritePattern::OpRewritePattern; 1548 1549 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 1550 PatternRewriter &rewriter) const override { 1551 1552 auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>(); 1553 auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>(); 1554 1555 if (!lhsDefOp || !rhsDefOp) { 1556 return rewriter.notifyMatchFailure(contractOp, 1557 "no defining op on contract operands"); 1558 } 1559 1560 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1561 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0), 1562 contractOp.getAcc(), contractOp.getIndexingMapsAttr(), 1563 contractOp.getIteratorTypesAttr()); 1564 1565 return success(); 1566 } 1567 }; 1568 1569 /// Pattern to fold chained reduction to a series of vector additions and a 1570 /// final reduction. This form should require fewer subgroup operations. 1571 /// 1572 /// ```mlir 1573 /// %a = vector.reduction <add> %x, %acc 1574 /// %b = vector.reduction <add> %y, %a 1575 /// ==> 1576 /// %a = arith.addf %x, %y 1577 /// %b = vector.reduction <add> %a, %acc 1578 /// ``` 1579 struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> { 1580 using OpRewritePattern::OpRewritePattern; 1581 1582 LogicalResult matchAndRewrite(vector::ReductionOp op, 1583 PatternRewriter &rewriter) const override { 1584 // TODO: Handle other combining kinds. 1585 if (op.getKind() != vector::CombiningKind::ADD) 1586 return failure(); 1587 1588 // Accumulator is optional. 1589 Value acc = op.getAcc(); 1590 if (!acc) 1591 return failure(); 1592 1593 if (!acc.getType().isIntOrFloat()) 1594 return failure(); 1595 1596 auto parentReduction = acc.getDefiningOp<vector::ReductionOp>(); 1597 if (!parentReduction) 1598 return failure(); 1599 1600 Location loc = op.getLoc(); 1601 Value vAdd; 1602 if (isa<IntegerType>(acc.getType())) { 1603 vAdd = rewriter.createOrFold<arith::AddIOp>( 1604 loc, parentReduction.getVector(), op.getVector()); 1605 } else { 1606 vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(), 1607 op.getVector()); 1608 } 1609 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd, 1610 parentReduction.getAcc()); 1611 return success(); 1612 } 1613 }; 1614 1615 // Scalable unit dimensions are not supported. Folding such dimensions would 1616 // require "shifting" the scalable flag onto some other fixed-width dim (e.g. 1617 // vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the 1618 // future. 1619 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) { 1620 auto inVecShape = inVecTy.getShape(); 1621 SmallVector<int64_t> newShape; 1622 SmallVector<bool> newScalableDims; 1623 for (auto [dim, isScalable] : 1624 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) { 1625 if (dim == 1 && !isScalable) 1626 continue; 1627 1628 newShape.push_back(dim); 1629 newScalableDims.push_back(isScalable); 1630 } 1631 1632 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims); 1633 } 1634 1635 /// For vectors with at least an unit dim, replaces: 1636 /// elementwise(a, b) 1637 /// with: 1638 /// sc_a = shape_cast(a) 1639 /// sc_b = shape_cast(b) 1640 /// res = elementwise(sc_a, sc_b) 1641 /// return shape_cast(res) 1642 /// The newly inserted shape_cast Ops fold (before elementwise Op) and then 1643 /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are 1644 /// required to be rank > 1. 1645 /// 1646 /// Ex: 1647 /// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32> 1648 /// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32> 1649 /// 1650 /// gets converted to: 1651 /// 1652 /// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32> 1653 /// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32> 1654 /// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32> 1655 /// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32> 1656 /// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32> 1657 /// 1658 /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and 1659 /// `%cast`. 1660 struct DropUnitDimFromElementwiseOps final 1661 : public OpTraitRewritePattern<OpTrait::Elementwise> { 1662 using OpTraitRewritePattern::OpTraitRewritePattern; 1663 LogicalResult matchAndRewrite(Operation *op, 1664 PatternRewriter &rewriter) const override { 1665 if (op->getNumResults() != 1 || op->getNumRegions() != 0) 1666 return failure(); 1667 1668 auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType()); 1669 if (!resultVectorType) 1670 return failure(); 1671 1672 // Check the operand pre-conditions. For `Elementwise` ops all operands are 1673 // guaranteed to have identical shapes (with some exceptions such as 1674 // `arith.select`) and it suffices to only check one of them. 1675 auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType()); 1676 if (!sourceVectorType || sourceVectorType.getRank() < 2) 1677 return failure(); 1678 1679 SmallVector<Value> newOperands; 1680 auto loc = op->getLoc(); 1681 for (auto operand : op->getOperands()) { 1682 auto opVectorType = cast<VectorType>(operand.getType()); 1683 auto newVType = dropNonScalableUnitDimFromType(opVectorType); 1684 if (newVType == opVectorType) 1685 return rewriter.notifyMatchFailure(op, "No unit dimension to remove."); 1686 1687 auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand); 1688 newOperands.push_back(opSC); 1689 } 1690 1691 VectorType newResultVectorType = 1692 dropNonScalableUnitDimFromType(resultVectorType); 1693 // Create an updated elementwise Op without unit dim. 1694 Operation *elementwiseOp = 1695 rewriter.create(loc, op->getName().getIdentifier(), newOperands, 1696 newResultVectorType, op->getAttrs()); 1697 1698 // Restore the unit dim by applying vector.shape_cast to the result. 1699 rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType, 1700 elementwiseOp->getResult(0)); 1701 1702 return success(); 1703 } 1704 }; 1705 1706 /// Pattern to eliminate redundant zero-constants added to reduction operands. 1707 /// It's enough for there to be one initial zero value, so we can eliminate the 1708 /// extra ones that feed into `vector.reduction <add>`. These get created by the 1709 /// `ChainedReduction` pattern. 1710 /// 1711 /// ```mlir 1712 /// %a = arith.addf %x, %zero 1713 /// %b = arith.addf %a, %y 1714 /// %c = vector.reduction <add> %b, %acc 1715 /// ==> 1716 /// %b = arith.addf %a, %y 1717 /// %c = vector.reduction <add> %b, %acc 1718 /// ``` 1719 struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> { 1720 using OpRewritePattern::OpRewritePattern; 1721 1722 LogicalResult matchAndRewrite(vector::ReductionOp op, 1723 PatternRewriter &rewriter) const override { 1724 // TODO: Handle other reduction kinds and their identity values. 1725 if (op.getKind() != vector::CombiningKind::ADD) 1726 return failure(); 1727 1728 Type elemType = op.getSourceVectorType().getElementType(); 1729 // The integer case should be handled by `arith.addi` folders, only check 1730 // for floats here. 1731 if (!isa<FloatType>(elemType)) 1732 return failure(); 1733 1734 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>(); 1735 if (!vAdd) 1736 return failure(); 1737 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>(); 1738 if (!addLhs) 1739 return failure(); 1740 1741 if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat())) 1742 return failure(); 1743 1744 auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(), 1745 vAdd.getRhs()); 1746 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd, 1747 op.getAcc()); 1748 return success(); 1749 } 1750 }; 1751 1752 /// Example: 1753 /// ``` 1754 /// %a = vector.reduction <add> %x : vector<2xf32> into f32 1755 /// ``` 1756 /// is transformed into: 1757 /// ``` 1758 /// %y = vector.extract %x[0] : f32 from vector<2xf32> 1759 /// %z = vector.extract %x[1] : f32 from vector<2xf32> 1760 /// %a = arith.addf %y, %z : f32 1761 /// ``` 1762 struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> { 1763 BreakDownVectorReduction(MLIRContext *context, 1764 unsigned maxNumElementsToExtract, 1765 PatternBenefit benefit) 1766 : OpRewritePattern(context, benefit), 1767 maxNumElementsToExtract(maxNumElementsToExtract) {} 1768 1769 LogicalResult matchAndRewrite(vector::ReductionOp op, 1770 PatternRewriter &rewriter) const override { 1771 VectorType type = op.getSourceVectorType(); 1772 if (type.isScalable() || op.isMasked()) 1773 return failure(); 1774 assert(type.getRank() == 1 && "Expected a 1-d vector"); 1775 1776 int64_t numElems = type.getNumElements(); 1777 if (numElems > maxNumElementsToExtract) { 1778 return rewriter.notifyMatchFailure( 1779 op, llvm::formatv("has too many vector elements ({0}) to break down " 1780 "(max allowed: {1})", 1781 numElems, maxNumElementsToExtract)); 1782 } 1783 1784 Location loc = op.getLoc(); 1785 SmallVector<Value> extracted(numElems, nullptr); 1786 for (auto [idx, extractedElem] : llvm::enumerate(extracted)) 1787 extractedElem = rewriter.create<vector::ExtractOp>( 1788 loc, op.getVector(), static_cast<int64_t>(idx)); 1789 1790 Value res = extracted.front(); 1791 for (auto extractedElem : llvm::drop_begin(extracted)) 1792 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, 1793 extractedElem, op.getFastmathAttr()); 1794 if (Value acc = op.getAcc()) 1795 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc, 1796 op.getFastmathAttr()); 1797 1798 rewriter.replaceOp(op, res); 1799 return success(); 1800 } 1801 1802 private: 1803 unsigned maxNumElementsToExtract = 0; 1804 }; 1805 1806 } // namespace 1807 1808 void mlir::vector::populateFoldArithExtensionPatterns( 1809 RewritePatternSet &patterns) { 1810 patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext()); 1811 } 1812 1813 void mlir::vector::populateVectorMaskMaterializationPatterns( 1814 RewritePatternSet &patterns, bool force32BitVectorIndices, 1815 PatternBenefit benefit) { 1816 patterns.add<VectorCreateMaskOpConversion, 1817 MaterializeTransferMask<vector::TransferReadOp>, 1818 MaterializeTransferMask<vector::TransferWriteOp>>( 1819 patterns.getContext(), force32BitVectorIndices, benefit); 1820 patterns.add<FoldI1Select>(patterns.getContext(), benefit); 1821 } 1822 1823 void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns, 1824 PatternBenefit benefit) { 1825 patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit); 1826 } 1827 1828 void mlir::vector::populateDropUnitDimWithShapeCastPatterns( 1829 RewritePatternSet &patterns, PatternBenefit benefit) { 1830 patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>( 1831 patterns.getContext(), benefit); 1832 } 1833 1834 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 1835 RewritePatternSet &patterns, PatternBenefit benefit) { 1836 patterns.add<BubbleDownVectorBitCastForExtract, 1837 BubbleDownBitCastForStridedSliceExtract, 1838 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>( 1839 patterns.getContext(), benefit); 1840 } 1841 1842 void mlir::vector::populateBreakDownVectorBitCastOpPatterns( 1843 RewritePatternSet &patterns, 1844 std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) { 1845 patterns.add<BreakDownVectorBitCast>(patterns.getContext(), 1846 std::move(controlFn), benefit); 1847 } 1848 1849 void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT( 1850 RewritePatternSet &patterns, 1851 std::function<LogicalResult(vector::ContractionOp)> constraint, 1852 PatternBenefit benefit) { 1853 patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit, 1854 std::move(constraint)); 1855 } 1856 1857 void mlir::vector::populateVectorReductionToContractPatterns( 1858 RewritePatternSet &patterns, PatternBenefit benefit) { 1859 patterns.add<MultiReduceToContract, CombineContractBroadcast, 1860 CombineContractABTranspose, CombineContractResultTranspose, 1861 ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>( 1862 patterns.getContext(), benefit); 1863 } 1864 1865 void mlir::vector:: 1866 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 1867 RewritePatternSet &patterns, PatternBenefit benefit) { 1868 patterns.add<DropInnerMostUnitDimsTransferRead, 1869 DropInnerMostUnitDimsTransferWrite>(patterns.getContext(), 1870 benefit); 1871 } 1872 1873 void mlir::vector::populateSinkVectorBroadcastPatterns( 1874 RewritePatternSet &patterns, PatternBenefit benefit) { 1875 patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>( 1876 patterns.getContext(), benefit); 1877 } 1878 1879 void mlir::vector::populateChainedVectorReductionFoldingPatterns( 1880 RewritePatternSet &patterns, PatternBenefit benefit) { 1881 patterns.add<ChainedReduction>(patterns.getContext(), benefit); 1882 patterns.add<ReduceRedundantZero>(patterns.getContext(), 1883 PatternBenefit(benefit.getBenefit() + 1)); 1884 } 1885 1886 void mlir::vector::populateBreakDownVectorReductionPatterns( 1887 RewritePatternSet &patterns, unsigned maxNumElementsToExtract, 1888 PatternBenefit benefit) { 1889 patterns.add<BreakDownVectorReduction>(patterns.getContext(), 1890 maxNumElementsToExtract, benefit); 1891 } 1892 1893 //===----------------------------------------------------------------------===// 1894 // TableGen'd enum attribute definitions 1895 //===----------------------------------------------------------------------===// 1896 1897 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc" 1898