1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 10 #include "mlir/Dialect/Vector/IR/VectorOps.h" 11 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 12 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 13 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 #define DEBUG_TYPE "vector-drop-unit-dim" 18 19 using namespace mlir; 20 using namespace mlir::vector; 21 22 // Trims leading one dimensions from `oldType` and returns the result type. 23 // Returns `vector<1xT>` if `oldType` only has one element. 24 static VectorType trimLeadingOneDims(VectorType oldType) { 25 ArrayRef<int64_t> oldShape = oldType.getShape(); 26 ArrayRef<int64_t> newShape = 27 oldShape.drop_while([](int64_t dim) { return dim == 1; }); 28 // Make sure we have at least 1 dimension per vector type requirements. 29 if (newShape.empty()) 30 newShape = oldShape.take_back(); 31 return VectorType::get(newShape, oldType.getElementType()); 32 } 33 34 /// Return a smallVector of size `rank` containing all zeros. 35 static SmallVector<int64_t> splatZero(int64_t rank) { 36 return SmallVector<int64_t>(rank, 0); 37 } 38 namespace { 39 40 // Casts away leading one dimensions in vector.extract_strided_slice's vector 41 // input by inserting vector.broadcast. 42 struct CastAwayExtractStridedSliceLeadingOneDim 43 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 44 using OpRewritePattern::OpRewritePattern; 45 46 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 47 PatternRewriter &rewriter) const override { 48 // vector.extract_strided_slice requires the input and output vector to have 49 // the same rank. Here we drop leading one dimensions from the input vector 50 // type to make sure we don't cause mismatch. 51 VectorType oldSrcType = extractOp.getSourceVectorType(); 52 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 53 54 if (newSrcType.getRank() == oldSrcType.getRank()) 55 return failure(); 56 57 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); 58 59 VectorType oldDstType = extractOp.getType(); 60 VectorType newDstType = 61 VectorType::get(oldDstType.getShape().drop_front(dropCount), 62 oldDstType.getElementType()); 63 64 Location loc = extractOp.getLoc(); 65 66 Value newSrcVector = rewriter.create<vector::ExtractOp>( 67 loc, extractOp.getVector(), splatZero(dropCount)); 68 69 // The offsets/sizes/strides attribute can have a less number of elements 70 // than the input vector's rank: it is meant for the leading dimensions. 71 auto newOffsets = rewriter.getArrayAttr( 72 extractOp.getOffsets().getValue().drop_front(dropCount)); 73 auto newSizes = rewriter.getArrayAttr( 74 extractOp.getSizes().getValue().drop_front(dropCount)); 75 auto newStrides = rewriter.getArrayAttr( 76 extractOp.getStrides().getValue().drop_front(dropCount)); 77 78 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 79 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); 80 81 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType, 82 newExtractOp); 83 84 return success(); 85 } 86 }; 87 88 // Casts away leading one dimensions in vector.insert_strided_slice's vector 89 // inputs by inserting vector.broadcast. 90 struct CastAwayInsertStridedSliceLeadingOneDim 91 : public OpRewritePattern<vector::InsertStridedSliceOp> { 92 using OpRewritePattern::OpRewritePattern; 93 94 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, 95 PatternRewriter &rewriter) const override { 96 VectorType oldSrcType = insertOp.getSourceVectorType(); 97 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 98 VectorType oldDstType = insertOp.getDestVectorType(); 99 VectorType newDstType = trimLeadingOneDims(oldDstType); 100 101 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); 102 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 103 if (srcDropCount == 0 && dstDropCount == 0) 104 return failure(); 105 106 // Trim leading one dimensions from both operands. 107 Location loc = insertOp.getLoc(); 108 109 Value newSrcVector = rewriter.create<vector::ExtractOp>( 110 loc, insertOp.getSource(), splatZero(srcDropCount)); 111 Value newDstVector = rewriter.create<vector::ExtractOp>( 112 loc, insertOp.getDest(), splatZero(dstDropCount)); 113 114 auto newOffsets = rewriter.getArrayAttr( 115 insertOp.getOffsets().getValue().take_back(newDstType.getRank())); 116 auto newStrides = rewriter.getArrayAttr( 117 insertOp.getStrides().getValue().take_back(newSrcType.getRank())); 118 119 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>( 120 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); 121 122 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 123 newInsertOp); 124 125 return success(); 126 } 127 }; 128 129 // Casts away leading one dimensions in vector.insert's vector inputs by 130 // inserting vector.broadcast. 131 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { 132 using OpRewritePattern::OpRewritePattern; 133 134 LogicalResult matchAndRewrite(vector::InsertOp insertOp, 135 PatternRewriter &rewriter) const override { 136 Type oldSrcType = insertOp.getSourceType(); 137 Type newSrcType = oldSrcType; 138 int64_t oldSrcRank = 0, newSrcRank = 0; 139 if (auto type = dyn_cast<VectorType>(oldSrcType)) { 140 newSrcType = trimLeadingOneDims(type); 141 oldSrcRank = type.getRank(); 142 newSrcRank = cast<VectorType>(newSrcType).getRank(); 143 } 144 145 VectorType oldDstType = insertOp.getDestVectorType(); 146 VectorType newDstType = trimLeadingOneDims(oldDstType); 147 148 int64_t srcDropCount = oldSrcRank - newSrcRank; 149 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 150 if (srcDropCount == 0 && dstDropCount == 0) 151 return failure(); 152 153 // Trim leading one dimensions from both operands. 154 Location loc = insertOp.getLoc(); 155 156 Value newSrcVector = insertOp.getSource(); 157 if (oldSrcRank != 0) { 158 newSrcVector = rewriter.create<vector::ExtractOp>( 159 loc, insertOp.getSource(), splatZero(srcDropCount)); 160 } 161 Value newDstVector = rewriter.create<vector::ExtractOp>( 162 loc, insertOp.getDest(), splatZero(dstDropCount)); 163 164 // New position rank needs to be computed in two steps: (1) if destination 165 // type has leading unit dims, we also trim the position array accordingly, 166 // then (2) if source type also has leading unit dims, we need to append 167 // zeroes to the position array accordingly. 168 unsigned oldPosRank = insertOp.getPosition().size(); 169 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount); 170 SmallVector<int64_t> newPositions = 171 llvm::to_vector(insertOp.getPosition().take_back(newPosRank)); 172 newPositions.resize(newDstType.getRank() - newSrcRank, 0); 173 174 auto newInsertOp = rewriter.create<vector::InsertOp>( 175 loc, newDstType, newSrcVector, newDstVector, newPositions); 176 177 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 178 newInsertOp); 179 180 return success(); 181 } 182 }; 183 184 // Turns vector.transfer_read on vector with leading 1 dimensions into 185 // vector.shape_cast followed by vector.transfer_read on vector without leading 186 // 1 dimensions. 187 struct CastAwayTransferReadLeadingOneDim 188 : public OpRewritePattern<vector::TransferReadOp> { 189 using OpRewritePattern::OpRewritePattern; 190 191 LogicalResult matchAndRewrite(vector::TransferReadOp read, 192 PatternRewriter &rewriter) const override { 193 // TODO: support 0-d corner case. 194 if (read.getTransferRank() == 0) 195 return failure(); 196 197 if (read.getMask()) 198 return failure(); 199 200 auto shapedType = cast<ShapedType>(read.getSource().getType()); 201 if (shapedType.getElementType() != read.getVectorType().getElementType()) 202 return failure(); 203 204 VectorType oldType = read.getVectorType(); 205 VectorType newType = trimLeadingOneDims(oldType); 206 207 if (newType == oldType) 208 return failure(); 209 210 AffineMap oldMap = read.getPermutationMap(); 211 ArrayRef<AffineExpr> newResults = 212 oldMap.getResults().take_back(newType.getRank()); 213 AffineMap newMap = 214 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 215 rewriter.getContext()); 216 217 ArrayAttr inBoundsAttr; 218 if (read.getInBounds()) 219 inBoundsAttr = rewriter.getArrayAttr( 220 read.getInBoundsAttr().getValue().take_back(newType.getRank())); 221 222 auto newRead = rewriter.create<vector::TransferReadOp>( 223 read.getLoc(), newType, read.getSource(), read.getIndices(), 224 AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(), 225 inBoundsAttr); 226 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead); 227 228 return success(); 229 } 230 }; 231 232 // Turns vector.transfer_write on vector with leading 1 dimensions into 233 // vector.shape_cast followed by vector.transfer_write on vector without leading 234 // 1 dimensions. 235 struct CastAwayTransferWriteLeadingOneDim 236 : public OpRewritePattern<vector::TransferWriteOp> { 237 using OpRewritePattern::OpRewritePattern; 238 239 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 240 PatternRewriter &rewriter) const override { 241 // TODO: support 0-d corner case. 242 if (write.getTransferRank() == 0) 243 return failure(); 244 245 if (write.getMask()) 246 return failure(); 247 248 auto shapedType = dyn_cast<ShapedType>(write.getSource().getType()); 249 if (shapedType.getElementType() != write.getVectorType().getElementType()) 250 return failure(); 251 252 VectorType oldType = write.getVectorType(); 253 VectorType newType = trimLeadingOneDims(oldType); 254 if (newType == oldType) 255 return failure(); 256 int64_t dropDim = oldType.getRank() - newType.getRank(); 257 258 AffineMap oldMap = write.getPermutationMap(); 259 ArrayRef<AffineExpr> newResults = 260 oldMap.getResults().take_back(newType.getRank()); 261 AffineMap newMap = 262 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 263 rewriter.getContext()); 264 265 ArrayAttr inBoundsAttr; 266 if (write.getInBounds()) 267 inBoundsAttr = rewriter.getArrayAttr( 268 write.getInBoundsAttr().getValue().take_back(newType.getRank())); 269 270 auto newVector = rewriter.create<vector::ExtractOp>( 271 write.getLoc(), write.getVector(), splatZero(dropDim)); 272 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 273 write, newVector, write.getSource(), write.getIndices(), 274 AffineMapAttr::get(newMap), inBoundsAttr); 275 276 return success(); 277 } 278 }; 279 280 } // namespace 281 282 LogicalResult 283 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, 284 RewriterBase &rewriter) { 285 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType()); 286 if (oldAccType == nullptr) 287 return failure(); 288 if (oldAccType.getRank() < 2) 289 return failure(); 290 if (oldAccType.getShape()[0] != 1) 291 return failure(); 292 // currently we support only dropping one dim but the pattern can be applied 293 // greedily to drop more. 294 int64_t dropDim = 1; 295 296 auto oldIndexingMaps = contractOp.getIndexingMapsArray(); 297 SmallVector<AffineMap> newIndexingMaps; 298 299 auto oldIteratorTypes = contractOp.getIteratorTypes(); 300 SmallVector<Attribute> newIteratorTypes; 301 302 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); 303 304 if (!isParallelIterator(oldIteratorTypes[dimToDrop])) 305 // only parallel type iterators can be dropped. 306 return failure(); 307 308 for (const auto &it : llvm::enumerate(oldIteratorTypes)) { 309 int64_t currDim = it.index(); 310 if (currDim == dimToDrop) 311 continue; 312 newIteratorTypes.push_back(it.value()); 313 } 314 315 SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(), 316 contractOp.getAcc()}; 317 SmallVector<Value> newOperands; 318 319 for (const auto &it : llvm::enumerate(oldIndexingMaps)) { 320 // Check if the dim to be dropped exists as a leading dim in the operand 321 // if it does then we use vector.extract to drop it. 322 bool validExtract = false; 323 SmallVector<AffineExpr> results; 324 auto map = it.value(); 325 int64_t orginalZeroDim = it.value().getDimPosition(0); 326 if (orginalZeroDim != dimToDrop) { 327 // There are two reasons to be in this path, 1. We need to 328 // tranpose the operand to make the dim to be dropped 329 // leading. 2. The dim to be dropped does not exist and in 330 // that case we dont want to add a unit tranpose but we must 331 // check all the indices to make sure this is the case. 332 bool tranposeNeeded = false; 333 SmallVector<int64_t> perm; 334 SmallVector<AffineExpr> transposeResults; 335 336 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 337 int64_t currDim = map.getDimPosition(i); 338 if (currDim == dimToDrop) { 339 tranposeNeeded = true; 340 perm.insert(perm.begin(), i); 341 auto targetExpr = rewriter.getAffineDimExpr(currDim); 342 transposeResults.insert(transposeResults.begin(), targetExpr); 343 } else { 344 perm.push_back(i); 345 auto targetExpr = rewriter.getAffineDimExpr(currDim); 346 transposeResults.push_back(targetExpr); 347 } 348 } 349 // Do the tranpose now if needed so that we can drop the 350 // correct dim using extract later. 351 if (tranposeNeeded) { 352 map = AffineMap::get(map.getNumDims(), 0, transposeResults, 353 contractOp.getContext()); 354 operands[it.index()] = rewriter.create<vector::TransposeOp>( 355 contractOp.getLoc(), operands[it.index()], perm); 356 } 357 } 358 // We have taken care to have the dim to be dropped be 359 // the leading dim. If its still not leading that means it 360 // does not exist in this operand and hence we do not need 361 // an extract. 362 if (map.getDimPosition(0) == dimToDrop) 363 validExtract = true; 364 365 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 366 int64_t currDim = map.getDimPosition(i); 367 if (currDim == dimToDrop) 368 // This is the dim we are dropping. 369 continue; 370 auto targetExpr = rewriter.getAffineDimExpr( 371 currDim < dimToDrop ? currDim : currDim - 1); 372 results.push_back(targetExpr); 373 } 374 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, 375 contractOp.getContext())); 376 // Extract if its a valid extraction, otherwise use the operand 377 // without extraction. 378 newOperands.push_back( 379 validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(), 380 operands[it.index()], 381 splatZero(dropDim)) 382 : operands[it.index()]); 383 } 384 auto newContractOp = rewriter.create<vector::ContractionOp>( 385 contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], 386 rewriter.getAffineMapArrayAttr(newIndexingMaps), 387 rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); 388 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 389 contractOp, contractOp->getResultTypes()[0], newContractOp); 390 return success(); 391 } 392 393 namespace { 394 395 /// Turns vector.contract on vector with leading 1 dimensions into 396 /// vector.extract followed by vector.contract on vector without leading 397 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required 398 /// prior to extract. 399 struct CastAwayContractionLeadingOneDim 400 : public OpRewritePattern<vector::ContractionOp> { 401 using OpRewritePattern::OpRewritePattern; 402 403 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 404 PatternRewriter &rewriter) const override { 405 return castAwayContractionLeadingOneDim(contractOp, rewriter); 406 } 407 }; 408 409 class CastAwayElementwiseLeadingOneDim : public RewritePattern { 410 public: 411 CastAwayElementwiseLeadingOneDim(MLIRContext *context, 412 PatternBenefit benefit = 1) 413 : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} 414 415 LogicalResult matchAndRewrite(Operation *op, 416 PatternRewriter &rewriter) const override { 417 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 418 return failure(); 419 auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]); 420 if (!vecType) 421 return failure(); 422 VectorType newVecType = trimLeadingOneDims(vecType); 423 if (newVecType == vecType) 424 return failure(); 425 int64_t dropDim = vecType.getRank() - newVecType.getRank(); 426 SmallVector<Value, 4> newOperands; 427 for (Value operand : op->getOperands()) { 428 if (auto opVecType = dyn_cast<VectorType>(operand.getType())) { 429 newOperands.push_back(rewriter.create<vector::ExtractOp>( 430 op->getLoc(), operand, splatZero(dropDim))); 431 } else { 432 newOperands.push_back(operand); 433 } 434 } 435 Operation *newOp = 436 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 437 newOperands, newVecType, op->getAttrs()); 438 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType, 439 newOp->getResult(0)); 440 return success(); 441 } 442 }; 443 444 } // namespace 445 446 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( 447 RewritePatternSet &patterns, PatternBenefit benefit) { 448 patterns 449 .add<CastAwayExtractStridedSliceLeadingOneDim, 450 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim, 451 CastAwayTransferReadLeadingOneDim, 452 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, 453 CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit); 454 populateShapeCastFoldingPatterns(patterns, benefit); 455 } 456