1 //===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "mlir/Dialect/Linalg/IR/Linalg.h" 12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 13 #include "mlir/Dialect/Linalg/Utils/Utils.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/ADT/TypeSwitch.h" 18 19 #include <optional> 20 21 namespace mlir { 22 #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL 23 #include "mlir/Dialect/Linalg/Passes.h.inc" 24 } // namespace mlir 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 /// Return constant range span or nullopt, otherwise. 30 static std::optional<int64_t> getConstantRange(const Range &range) { 31 std::optional<int64_t> stride = getConstantIntValue(range.stride); 32 if (!stride || *stride != 1) 33 return std::nullopt; 34 std::optional<int64_t> offset = getConstantIntValue(range.offset); 35 if (!offset) 36 return std::nullopt; 37 std::optional<int64_t> size = getConstantIntValue(range.size); 38 if (!size) 39 return std::nullopt; 40 return (*size - *offset); 41 } 42 43 /// Return true if all dimensions are fully divisible by the respective tiles. 44 static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, 45 ArrayRef<OpFoldResult> tiles, 46 ArrayRef<int64_t> dims) { 47 if (dims.size() != tiles.size() || tiles.empty()) 48 return false; 49 50 FailureOr<ContractionDimensions> contractDims = 51 inferContractionDims(linalgOp); 52 if (failed(contractDims)) 53 return false; 54 unsigned batchDimsOffset = contractDims->batch.size(); 55 56 // Skip the batch dimension if present. 57 // Offset all dimensions accordingly. 58 SmallVector<int64_t, 3> offsetDims(dims); 59 for (size_t i = 0; i < offsetDims.size(); i++) 60 offsetDims[i] += batchDimsOffset; 61 62 auto tileOp = cast<TilingInterface>(linalgOp.getOperation()); 63 OpBuilder builder(tileOp); 64 OpBuilder::InsertionGuard guard(builder); 65 SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder); 66 67 for (auto dim : llvm::enumerate(offsetDims)) { 68 if (dim.value() >= static_cast<int64_t>(iterationDomain.size())) 69 return false; 70 71 std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]); 72 std::optional<int64_t> rangeOnDim = 73 getConstantRange(iterationDomain[dim.value()]); 74 75 // If the tile factor or the range are non-constant, the tile size is 76 // considered to be invalid. 77 if (!tileSize || !rangeOnDim) 78 return false; 79 80 // The dimension must be fully divisible by the tile. 81 if (*rangeOnDim % *tileSize != 0) 82 return false; 83 } 84 85 return true; 86 } 87 88 /// Return failure or packed matmul with one of its operands transposed. 89 static FailureOr<PackTransposeResult> 90 transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, 91 tensor::PackOp packOp, AffineMap operandMap, 92 ArrayRef<unsigned> blocksStartDimPos, 93 bool transposeOuterBlocks, bool transposeInnerBlocks) { 94 assert(operandMap.getNumDims() >= 4 && 95 "expected at least 4D prepacked matmul"); 96 assert(blocksStartDimPos.size() >= 2 && 97 "expected starting outer and inner block positions"); 98 99 // Bias toward innermost dimensions. 100 unsigned outerBlockPos = operandMap.getNumResults() - 4; 101 unsigned innerBlockPos = operandMap.getNumResults() - 2; 102 103 // Transpose control options define the desired block and element layout. 104 // Block transposition (outer dimensions) or element transposition (inner 105 // dimensions) may not be necessary depending on the original matmul data 106 // layout. 107 bool isOuterTransposed = 108 operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2]; 109 bool isInnerTransposed = 110 operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back(); 111 112 // Transpose only the dimensions that need that to conform to the provided 113 // transpotion settings. 114 SmallVector<int64_t> innerPerm = {0, 1}; 115 if (isInnerTransposed != transposeInnerBlocks) 116 innerPerm = {1, 0}; 117 SmallVector<int64_t> outerPerm = {0, 1}; 118 if (isOuterTransposed != transposeOuterBlocks) 119 outerPerm = {1, 0}; 120 121 // Leave the outer dimensions, like batch, unchanged by offsetting all 122 // outer dimensions permutations. 123 SmallVector<int64_t> offsetPerms; 124 for (auto i : llvm::seq(0u, outerBlockPos)) 125 offsetPerms.push_back(i); 126 for (auto perm : outerPerm) 127 offsetPerms.push_back(perm + outerBlockPos); 128 outerPerm = offsetPerms; 129 130 FailureOr<PackTransposeResult> packTransposedMatmul = 131 packTranspose(rewriter, packOp, linalgOp, 132 /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm); 133 134 return packTransposedMatmul; 135 } 136 137 /// Pack a matmul operation into blocked 4D layout. 138 FailureOr<PackResult> 139 linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, 140 const ControlBlockPackMatmulFn &controlPackMatmul) { 141 if (linalgOp.hasPureBufferSemantics()) 142 return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics"); 143 144 std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp); 145 if (!options) 146 return rewriter.notifyMatchFailure(linalgOp, "invalid packing options"); 147 148 if (options->blockFactors.size() != 3) 149 return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors"); 150 151 SmallVector<OpFoldResult> mnkTiles = 152 getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors)); 153 154 // If padding is disabled, make sure that dimensions can be packed cleanly. 155 if (!options->allowPadding && 156 !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) { 157 return rewriter.notifyMatchFailure(linalgOp, 158 "expect packing full tiles only"); 159 } 160 161 OpBuilder::InsertionGuard guard(rewriter); 162 // The op is replaced, we need to set the insertion point after it. 163 rewriter.setInsertionPointAfter(linalgOp); 164 165 // Pack the matmul operation into blocked layout with two levels of 166 // subdivision: 167 // - major 2D blocks - outer dimensions, consist of minor blocks 168 // - minor 2D blocks - inner dimensions, consist of scalar elements 169 FailureOr<PackResult> packedMatmul = packMatmulGreedily( 170 rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf, 171 options->mnkOrder); 172 if (failed(packedMatmul)) 173 return failure(); 174 175 assert(packedMatmul->packOps.size() == 3 && 176 "invalid number of pack ops after matmul packing"); 177 assert(packedMatmul->unPackOps.size() == 1 && 178 "invalid number of unpack ops after matmul packing"); 179 180 FailureOr<ContractionDimensions> contractDims = 181 inferContractionDims(packedMatmul->packedLinalgOp); 182 if (failed(contractDims)) 183 return failure(); 184 185 auto genericOp = 186 dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation()); 187 SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray(); 188 189 // Transpose LHS matrix according to the options. 190 FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul( 191 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0], 192 contractDims->m, options->lhsTransposeOuterBlocks, 193 options->lhsTransposeInnerBlocks); 194 if (failed(packedLhs)) 195 return failure(); 196 197 // Update results. 198 packedMatmul->packOps[0] = packedLhs->transposedPackOp; 199 packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp; 200 201 // Transpose RHS matrix according to the options. 202 FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul( 203 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1], 204 contractDims->k, options->rhsTransposeOuterBlocks, 205 options->rhsTransposeInnerBlocks); 206 if (failed(packedRhs)) 207 return failure(); 208 209 // Update results. 210 packedMatmul->packOps[1] = packedRhs->transposedPackOp; 211 packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp; 212 213 return packedMatmul; 214 } 215 216 namespace { 217 template <typename OpTy> 218 struct BlockPackMatmul : public OpRewritePattern<OpTy> { 219 BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun, 220 PatternBenefit benefit = 1) 221 : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {} 222 223 LogicalResult matchAndRewrite(OpTy linalgOp, 224 PatternRewriter &rewriter) const override { 225 FailureOr<PackResult> packedMatmul = 226 blockPackMatmul(rewriter, linalgOp, controlFn); 227 if (failed(packedMatmul)) 228 return failure(); 229 return success(); 230 } 231 232 private: 233 ControlBlockPackMatmulFn controlFn; 234 }; 235 236 template <> 237 struct BlockPackMatmul<linalg::GenericOp> 238 : public OpRewritePattern<linalg::GenericOp> { 239 BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun, 240 PatternBenefit benefit = 1) 241 : OpRewritePattern<linalg::GenericOp>(context, benefit), 242 controlFn(std::move(fun)) {} 243 244 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, 245 PatternRewriter &rewriter) const override { 246 // Match suitable generics. 247 if (!linalg::isaContractionOpInterface(linalgOp)) { 248 return rewriter.notifyMatchFailure(linalgOp, "not a contraction"); 249 } 250 251 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 252 auto infer = [&](MapList m) { 253 return AffineMap::inferFromExprList(m, linalgOp.getContext()); 254 }; 255 256 AffineExpr i, j, k; 257 bindDims(linalgOp->getContext(), i, j, k); 258 SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray(); 259 260 // For now, only match simple matmuls. 261 if (!(maps == infer({{i, k}, {k, j}, {i, j}}) || 262 maps == infer({{k, i}, {k, j}, {i, j}}) || 263 maps == infer({{i, k}, {j, k}, {i, j}}))) { 264 return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul"); 265 } 266 267 FailureOr<PackResult> packedMatmul = 268 blockPackMatmul(rewriter, linalgOp, controlFn); 269 if (failed(packedMatmul)) 270 return failure(); 271 return success(); 272 } 273 274 private: 275 ControlBlockPackMatmulFn controlFn; 276 }; 277 278 /// Convert linalg matmul ops to block layout and back. 279 struct LinalgBlockPackMatmul 280 : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> { 281 using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase; 282 283 void runOnOperation() override { 284 Operation *op = getOperation(); 285 RewritePatternSet patterns(&getContext()); 286 287 ControlBlockPackMatmulFn controlFn = 288 [&](linalg::LinalgOp op) -> BlockPackMatmulOptions { 289 BlockPackMatmulOptions options; 290 options.blockFactors = SmallVector<int64_t>{*blockFactors}; 291 options.allowPadding = allowPadding; 292 options.mnkPaddedSizesNextMultipleOf = 293 SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf}; 294 if (!mnkOrder.empty()) 295 options.mnkOrder = SmallVector<int64_t>{*mnkOrder}; 296 options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks; 297 options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks; 298 options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks; 299 options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks; 300 return options; 301 }; 302 303 linalg::populateBlockPackMatmulPatterns(patterns, controlFn); 304 if (failed(applyPatternsGreedily(op, std::move(patterns)))) 305 return signalPassFailure(); 306 } 307 }; 308 } // namespace 309 310 void linalg::populateBlockPackMatmulPatterns( 311 RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) { 312 patterns.add<BlockPackMatmul<linalg::GenericOp>, 313 BlockPackMatmul<linalg::MatmulOp>, 314 BlockPackMatmul<linalg::BatchMatmulOp>, 315 BlockPackMatmul<linalg::MatmulTransposeAOp>, 316 BlockPackMatmul<linalg::BatchMatmulTransposeAOp>, 317 BlockPackMatmul<linalg::MatmulTransposeBOp>, 318 BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>( 319 patterns.getContext(), controlFn); 320 } 321