1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 10 #include "mlir/Dialect/Vector/IR/VectorOps.h" 11 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 12 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/ImplicitLocOpBuilder.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 #define DEBUG_TYPE "vector-drop-unit-dim" 18 19 using namespace mlir; 20 using namespace mlir::vector; 21 22 // Trims leading one dimensions from `oldType` and returns the result type. 23 // Returns `vector<1xT>` if `oldType` only has one element. 24 static VectorType trimLeadingOneDims(VectorType oldType) { 25 ArrayRef<int64_t> oldShape = oldType.getShape(); 26 ArrayRef<int64_t> newShape = 27 oldShape.drop_while([](int64_t dim) { return dim == 1; }); 28 // Make sure we have at least 1 dimension per vector type requirements. 29 if (newShape.empty()) 30 newShape = oldShape.take_back(); 31 return VectorType::get(newShape, oldType.getElementType()); 32 } 33 34 /// Return a smallVector of size `rank` containing all zeros. 35 static SmallVector<int64_t> splatZero(int64_t rank) { 36 return SmallVector<int64_t>(rank, 0); 37 } 38 namespace { 39 40 // Casts away leading one dimensions in vector.extract_strided_slice's vector 41 // input by inserting vector.broadcast. 42 struct CastAwayExtractStridedSliceLeadingOneDim 43 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 44 using OpRewritePattern::OpRewritePattern; 45 46 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 47 PatternRewriter &rewriter) const override { 48 // vector.extract_strided_slice requires the input and output vector to have 49 // the same rank. Here we drop leading one dimensions from the input vector 50 // type to make sure we don't cause mismatch. 51 VectorType oldSrcType = extractOp.getSourceVectorType(); 52 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 53 54 if (newSrcType.getRank() == oldSrcType.getRank()) 55 return failure(); 56 57 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); 58 59 VectorType oldDstType = extractOp.getType(); 60 VectorType newDstType = 61 VectorType::get(oldDstType.getShape().drop_front(dropCount), 62 oldDstType.getElementType()); 63 64 Location loc = extractOp.getLoc(); 65 66 Value newSrcVector = rewriter.create<vector::ExtractOp>( 67 loc, extractOp.getVector(), splatZero(dropCount)); 68 69 // The offsets/sizes/strides attribute can have a less number of elements 70 // than the input vector's rank: it is meant for the leading dimensions. 71 auto newOffsets = rewriter.getArrayAttr( 72 extractOp.getOffsets().getValue().drop_front(dropCount)); 73 auto newSizes = rewriter.getArrayAttr( 74 extractOp.getSizes().getValue().drop_front(dropCount)); 75 auto newStrides = rewriter.getArrayAttr( 76 extractOp.getStrides().getValue().drop_front(dropCount)); 77 78 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 79 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); 80 81 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType, 82 newExtractOp); 83 84 return success(); 85 } 86 }; 87 88 // Casts away leading one dimensions in vector.insert_strided_slice's vector 89 // inputs by inserting vector.broadcast. 90 struct CastAwayInsertStridedSliceLeadingOneDim 91 : public OpRewritePattern<vector::InsertStridedSliceOp> { 92 using OpRewritePattern::OpRewritePattern; 93 94 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, 95 PatternRewriter &rewriter) const override { 96 VectorType oldSrcType = insertOp.getSourceVectorType(); 97 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 98 VectorType oldDstType = insertOp.getDestVectorType(); 99 VectorType newDstType = trimLeadingOneDims(oldDstType); 100 101 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); 102 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 103 if (srcDropCount == 0 && dstDropCount == 0) 104 return failure(); 105 106 // Trim leading one dimensions from both operands. 107 Location loc = insertOp.getLoc(); 108 109 Value newSrcVector = rewriter.create<vector::ExtractOp>( 110 loc, insertOp.getSource(), splatZero(srcDropCount)); 111 Value newDstVector = rewriter.create<vector::ExtractOp>( 112 loc, insertOp.getDest(), splatZero(dstDropCount)); 113 114 auto newOffsets = rewriter.getArrayAttr( 115 insertOp.getOffsets().getValue().take_back(newDstType.getRank())); 116 auto newStrides = rewriter.getArrayAttr( 117 insertOp.getStrides().getValue().take_back(newSrcType.getRank())); 118 119 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>( 120 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); 121 122 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 123 newInsertOp); 124 125 return success(); 126 } 127 }; 128 129 // Casts away leading one dimensions in vector.insert's vector inputs by 130 // inserting vector.broadcast. 131 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { 132 using OpRewritePattern::OpRewritePattern; 133 134 LogicalResult matchAndRewrite(vector::InsertOp insertOp, 135 PatternRewriter &rewriter) const override { 136 Type oldSrcType = insertOp.getSourceType(); 137 Type newSrcType = oldSrcType; 138 int64_t oldSrcRank = 0, newSrcRank = 0; 139 if (auto type = oldSrcType.dyn_cast<VectorType>()) { 140 newSrcType = trimLeadingOneDims(type); 141 oldSrcRank = type.getRank(); 142 newSrcRank = newSrcType.cast<VectorType>().getRank(); 143 } 144 145 VectorType oldDstType = insertOp.getDestVectorType(); 146 VectorType newDstType = trimLeadingOneDims(oldDstType); 147 148 int64_t srcDropCount = oldSrcRank - newSrcRank; 149 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 150 if (srcDropCount == 0 && dstDropCount == 0) 151 return failure(); 152 153 // Trim leading one dimensions from both operands. 154 Location loc = insertOp.getLoc(); 155 156 Value newSrcVector = insertOp.getSource(); 157 if (oldSrcRank != 0) { 158 newSrcVector = rewriter.create<vector::ExtractOp>( 159 loc, insertOp.getSource(), splatZero(srcDropCount)); 160 } 161 Value newDstVector = rewriter.create<vector::ExtractOp>( 162 loc, insertOp.getDest(), splatZero(dstDropCount)); 163 164 // New position rank needs to be computed in two steps: (1) if destination 165 // type has leading unit dims, we also trim the position array accordingly, 166 // then (2) if source type also has leading unit dims, we need to append 167 // zeroes to the position array accordingly. 168 unsigned oldPosRank = insertOp.getPosition().getValue().size(); 169 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount); 170 SmallVector<Attribute> newPositions = llvm::to_vector( 171 insertOp.getPosition().getValue().take_back(newPosRank)); 172 if (srcDropCount >= dstDropCount) { 173 auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type()); 174 newPositions.resize(newPosRank + srcDropCount, zeroAttr); 175 } 176 177 auto newInsertOp = rewriter.create<vector::InsertOp>( 178 loc, newDstType, newSrcVector, newDstVector, 179 rewriter.getArrayAttr(newPositions)); 180 181 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 182 newInsertOp); 183 184 return success(); 185 } 186 }; 187 188 // Turns vector.transfer_read on vector with leading 1 dimensions into 189 // vector.shape_cast followed by vector.transfer_read on vector without leading 190 // 1 dimensions. 191 struct CastAwayTransferReadLeadingOneDim 192 : public OpRewritePattern<vector::TransferReadOp> { 193 using OpRewritePattern::OpRewritePattern; 194 195 LogicalResult matchAndRewrite(vector::TransferReadOp read, 196 PatternRewriter &rewriter) const override { 197 // TODO: support 0-d corner case. 198 if (read.getTransferRank() == 0) 199 return failure(); 200 201 if (read.getMask()) 202 return failure(); 203 204 auto shapedType = read.getSource().getType().cast<ShapedType>(); 205 if (shapedType.getElementType() != read.getVectorType().getElementType()) 206 return failure(); 207 208 VectorType oldType = read.getVectorType(); 209 VectorType newType = trimLeadingOneDims(oldType); 210 211 if (newType == oldType) 212 return failure(); 213 214 AffineMap oldMap = read.getPermutationMap(); 215 ArrayRef<AffineExpr> newResults = 216 oldMap.getResults().take_back(newType.getRank()); 217 AffineMap newMap = 218 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 219 rewriter.getContext()); 220 221 ArrayAttr inBoundsAttr; 222 if (read.getInBounds()) 223 inBoundsAttr = rewriter.getArrayAttr( 224 read.getInBoundsAttr().getValue().take_back(newType.getRank())); 225 226 auto newRead = rewriter.create<vector::TransferReadOp>( 227 read.getLoc(), newType, read.getSource(), read.getIndices(), 228 AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(), 229 inBoundsAttr); 230 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead); 231 232 return success(); 233 } 234 }; 235 236 // Turns vector.transfer_write on vector with leading 1 dimensions into 237 // vector.shape_cast followed by vector.transfer_write on vector without leading 238 // 1 dimensions. 239 struct CastAwayTransferWriteLeadingOneDim 240 : public OpRewritePattern<vector::TransferWriteOp> { 241 using OpRewritePattern::OpRewritePattern; 242 243 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 244 PatternRewriter &rewriter) const override { 245 // TODO: support 0-d corner case. 246 if (write.getTransferRank() == 0) 247 return failure(); 248 249 if (write.getMask()) 250 return failure(); 251 252 auto shapedType = write.getSource().getType().dyn_cast<ShapedType>(); 253 if (shapedType.getElementType() != write.getVectorType().getElementType()) 254 return failure(); 255 256 VectorType oldType = write.getVectorType(); 257 VectorType newType = trimLeadingOneDims(oldType); 258 if (newType == oldType) 259 return failure(); 260 int64_t dropDim = oldType.getRank() - newType.getRank(); 261 262 AffineMap oldMap = write.getPermutationMap(); 263 ArrayRef<AffineExpr> newResults = 264 oldMap.getResults().take_back(newType.getRank()); 265 AffineMap newMap = 266 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 267 rewriter.getContext()); 268 269 ArrayAttr inBoundsAttr; 270 if (write.getInBounds()) 271 inBoundsAttr = rewriter.getArrayAttr( 272 write.getInBoundsAttr().getValue().take_back(newType.getRank())); 273 274 auto newVector = rewriter.create<vector::ExtractOp>( 275 write.getLoc(), write.getVector(), splatZero(dropDim)); 276 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 277 write, newVector, write.getSource(), write.getIndices(), 278 AffineMapAttr::get(newMap), inBoundsAttr); 279 280 return success(); 281 } 282 }; 283 284 /// Turns vector.contract on vector with leading 1 dimensions into 285 /// vector.extract followed by vector.contract on vector without leading 286 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required 287 /// prior to extract. 288 struct CastAwayContractionLeadingOneDim 289 : public OpRewritePattern<vector::ContractionOp> { 290 using OpRewritePattern::OpRewritePattern; 291 292 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 293 PatternRewriter &rewriter) const override { 294 VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>(); 295 if (oldAccType == nullptr) 296 return failure(); 297 if (oldAccType.getRank() < 2) 298 return failure(); 299 if (oldAccType.getShape()[0] != 1) 300 return failure(); 301 // currently we support only dropping one dim but the pattern can be applied 302 // greedily to drop more. 303 int64_t dropDim = 1; 304 305 auto oldIndexingMaps = contractOp.getIndexingMapsArray(); 306 SmallVector<AffineMap> newIndexingMaps; 307 308 auto oldIteratorTypes = contractOp.getIteratorTypes(); 309 SmallVector<Attribute> newIteratorTypes; 310 311 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); 312 313 if (!isParallelIterator(oldIteratorTypes[dimToDrop])) 314 // only parallel type iterators can be dropped. 315 return failure(); 316 317 for (const auto &it : llvm::enumerate(oldIteratorTypes)) { 318 int64_t currDim = it.index(); 319 if (currDim == dimToDrop) 320 continue; 321 newIteratorTypes.push_back(it.value()); 322 } 323 324 SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(), 325 contractOp.getAcc()}; 326 SmallVector<Value> newOperands; 327 328 for (const auto &it : llvm::enumerate(oldIndexingMaps)) { 329 // Check if the dim to be dropped exists as a leading dim in the operand 330 // if it does then we use vector.extract to drop it. 331 bool validExtract = false; 332 SmallVector<AffineExpr> results; 333 auto map = it.value(); 334 int64_t orginalZeroDim = it.value().getDimPosition(0); 335 if (orginalZeroDim != dimToDrop) { 336 // There are two reasons to be in this path, 1. We need to 337 // tranpose the operand to make the dim to be dropped 338 // leading. 2. The dim to be dropped does not exist and in 339 // that case we dont want to add a unit tranpose but we must 340 // check all the indices to make sure this is the case. 341 bool tranposeNeeded = false; 342 SmallVector<int64_t> perm; 343 SmallVector<AffineExpr> transposeResults; 344 345 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 346 int64_t currDim = map.getDimPosition(i); 347 if (currDim == dimToDrop) { 348 tranposeNeeded = true; 349 perm.insert(perm.begin(), i); 350 auto targetExpr = rewriter.getAffineDimExpr(currDim); 351 transposeResults.insert(transposeResults.begin(), targetExpr); 352 } else { 353 perm.push_back(i); 354 auto targetExpr = rewriter.getAffineDimExpr(currDim); 355 transposeResults.push_back(targetExpr); 356 } 357 } 358 // Do the tranpose now if needed so that we can drop the 359 // correct dim using extract later. 360 if (tranposeNeeded) { 361 map = AffineMap::get(map.getNumDims(), 0, transposeResults, 362 contractOp.getContext()); 363 operands[it.index()] = rewriter.create<vector::TransposeOp>( 364 contractOp.getLoc(), operands[it.index()], perm); 365 } 366 } 367 // We have taken care to have the dim to be dropped be 368 // the leading dim. If its still not leading that means it 369 // does not exist in this operand and hence we do not need 370 // an extract. 371 if (map.getDimPosition(0) == dimToDrop) 372 validExtract = true; 373 374 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 375 int64_t currDim = map.getDimPosition(i); 376 if (currDim == dimToDrop) 377 // This is the dim we are dropping. 378 continue; 379 auto targetExpr = rewriter.getAffineDimExpr( 380 currDim < dimToDrop ? currDim : currDim - 1); 381 results.push_back(targetExpr); 382 } 383 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, 384 contractOp.getContext())); 385 // Extract if its a valid extraction, otherwise use the operand 386 // without extraction. 387 newOperands.push_back(validExtract 388 ? rewriter.create<vector::ExtractOp>( 389 contractOp.getLoc(), operands[it.index()], 390 splatZero(dropDim)) 391 : operands[it.index()]); 392 } 393 auto newContractOp = rewriter.create<vector::ContractionOp>( 394 contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], 395 rewriter.getAffineMapArrayAttr(newIndexingMaps), 396 rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); 397 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 398 contractOp, contractOp->getResultTypes()[0], newContractOp); 399 return success(); 400 } 401 }; 402 403 class CastAwayElementwiseLeadingOneDim : public RewritePattern { 404 public: 405 CastAwayElementwiseLeadingOneDim(MLIRContext *context, 406 PatternBenefit benefit = 1) 407 : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} 408 409 LogicalResult matchAndRewrite(Operation *op, 410 PatternRewriter &rewriter) const override { 411 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 412 return failure(); 413 auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>(); 414 if (!vecType) 415 return failure(); 416 VectorType newVecType = trimLeadingOneDims(vecType); 417 if (newVecType == vecType) 418 return failure(); 419 int64_t dropDim = vecType.getRank() - newVecType.getRank(); 420 SmallVector<Value, 4> newOperands; 421 for (Value operand : op->getOperands()) { 422 if (auto opVecType = operand.getType().dyn_cast<VectorType>()) { 423 newOperands.push_back(rewriter.create<vector::ExtractOp>( 424 op->getLoc(), operand, splatZero(dropDim))); 425 } else { 426 newOperands.push_back(operand); 427 } 428 } 429 Operation *newOp = 430 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 431 newOperands, newVecType, op->getAttrs()); 432 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType, 433 newOp->getResult(0)); 434 return success(); 435 } 436 }; 437 438 } // namespace 439 440 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( 441 RewritePatternSet &patterns, PatternBenefit benefit) { 442 patterns 443 .add<CastAwayExtractStridedSliceLeadingOneDim, 444 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim, 445 CastAwayTransferReadLeadingOneDim, 446 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, 447 CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit); 448 populateShapeCastFoldingPatterns(patterns, benefit); 449 } 450