1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.contract' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Utils/Utils.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 26 #include "mlir/IR/BuiltinAttributeInterfaces.h" 27 #include "mlir/IR/BuiltinTypes.h" 28 #include "mlir/IR/ImplicitLocOpBuilder.h" 29 #include "mlir/IR/Location.h" 30 #include "mlir/IR/Matchers.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/IR/TypeUtilities.h" 33 #include "mlir/Interfaces/VectorInterfaces.h" 34 #include "mlir/Support/LogicalResult.h" 35 36 #define DEBUG_TYPE "vector-contract-lowering" 37 38 using namespace mlir; 39 using namespace mlir::vector; 40 41 //===----------------------------------------------------------------------===// 42 // Helper functions 43 //===----------------------------------------------------------------------===// 44 45 // Helper to find an index in an affine map. 46 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 47 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 48 int64_t idx = map.getDimPosition(i); 49 if (idx == index) 50 return i; 51 } 52 return std::nullopt; 53 } 54 55 // Helper to construct iterator types with one index removed. 56 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes, 57 int64_t index) { 58 SmallVector<Attribute> results; 59 for (const auto &it : llvm::enumerate(iteratorTypes)) { 60 int64_t idx = it.index(); 61 if (idx == index) 62 continue; 63 results.push_back(it.value()); 64 } 65 return results; 66 } 67 68 // Helper to construct an affine map with one index removed. 69 static AffineMap adjustMap(AffineMap map, int64_t index, 70 PatternRewriter &rewriter) { 71 auto *ctx = rewriter.getContext(); 72 SmallVector<AffineExpr> results; 73 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 74 int64_t idx = map.getDimPosition(i); 75 if (idx == index) 76 continue; 77 // Re-insert remaining indices, but renamed when occurring 78 // after the removed index. 79 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 80 results.push_back(targetExpr); 81 } 82 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 83 } 84 85 // Helper method to possibly drop a dimension in a load. 86 // TODO 87 static Value reshapeLoad(Location loc, Value val, VectorType type, 88 int64_t index, int64_t pos, 89 PatternRewriter &rewriter) { 90 if (index == -1) 91 return val; 92 93 // At extraction dimension? 94 if (index == 0) 95 return rewriter.create<vector::ExtractOp>(loc, val, pos); 96 97 // Unroll leading dimensions. 98 VectorType vType = VectorType::Builder(type).dropDim(0); 99 VectorType resType = VectorType::Builder(type).dropDim(index); 100 Value result = rewriter.create<arith::ConstantOp>( 101 loc, resType, rewriter.getZeroAttr(resType)); 102 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { 103 Value ext = rewriter.create<vector::ExtractOp>(loc, val, d); 104 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 105 result = rewriter.create<vector::InsertOp>(loc, load, result, d); 106 } 107 return result; 108 } 109 110 // Helper method to possibly drop a dimension in a store. 111 // TODO 112 static Value reshapeStore(Location loc, Value val, Value result, 113 VectorType type, int64_t index, int64_t pos, 114 PatternRewriter &rewriter) { 115 // Unmodified? 116 if (index == -1) 117 return val; 118 // At insertion dimension? 119 if (index == 0) 120 return rewriter.create<vector::InsertOp>(loc, val, result, pos); 121 122 // Unroll leading dimensions. 123 VectorType vType = VectorType::Builder(type).dropDim(0); 124 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 125 Value ext = rewriter.create<vector::ExtractOp>(loc, result, d); 126 Value ins = rewriter.create<vector::ExtractOp>(loc, val, d); 127 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 128 result = rewriter.create<vector::InsertOp>(loc, sto, result, d); 129 } 130 return result; 131 } 132 133 /// Helper to create arithmetic operation associated with a kind of contraction. 134 static std::optional<Value> 135 createContractArithOp(Location loc, Value x, Value y, Value acc, 136 vector::CombiningKind kind, PatternRewriter &rewriter, 137 bool isInt, Value mask = Value()) { 138 using vector::CombiningKind; 139 Value mul; 140 141 if (isInt) { 142 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF || 143 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF) 144 // Only valid for floating point types. 145 return std::nullopt; 146 mul = rewriter.create<arith::MulIOp>(loc, x, y); 147 } else { 148 // Float case. 149 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || 150 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI || 151 kind == CombiningKind::MAXSI || kind == CombiningKind::OR || 152 kind == CombiningKind::XOR) 153 // Only valid for integer types. 154 return std::nullopt; 155 // Special case for fused multiply-add. 156 if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) { 157 Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc); 158 if (mask) 159 // The fma op doesn't need explicit masking. However, fma ops used in 160 // reductions must preserve previous 'acc' values for masked-out lanes. 161 fma = selectPassthru(rewriter, mask, fma, acc); 162 return fma; 163 } 164 mul = rewriter.create<arith::MulFOp>(loc, x, y); 165 } 166 167 if (!acc) 168 return std::optional<Value>(mul); 169 170 return makeArithReduction(rewriter, loc, kind, mul, acc, 171 /*fastmath=*/nullptr, mask); 172 } 173 174 /// Return the positions of the reductions in the given map. 175 static SmallVector<int64_t> getReductionIndex(AffineMap map, 176 ArrayAttr iteratorTypes) { 177 SmallVector<int64_t> dimsIdx; 178 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 179 if (isReductionIterator(iteratorTypes[map.getDimPosition(i)])) 180 dimsIdx.push_back(i); 181 } 182 return dimsIdx; 183 } 184 185 /// Look for a given dimension in an affine map and return its position. Return 186 /// std::nullopt if the dimension is not in the map results. 187 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) { 188 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 189 if (map.getDimPosition(i) == dim) 190 return i; 191 } 192 return std::nullopt; 193 } 194 195 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 196 /// operands `x` and `y`. 197 static Value createAdd(Location loc, Value x, Value y, bool isInt, 198 PatternRewriter &rewriter) { 199 if (isInt) 200 return rewriter.create<arith::AddIOp>(loc, x, y); 201 return rewriter.create<arith::AddFOp>(loc, x, y); 202 } 203 204 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 205 /// operands `x and `y`. 206 static Value createMul(Location loc, Value x, Value y, bool isInt, 207 PatternRewriter &rewriter) { 208 if (isInt) 209 return rewriter.create<arith::MulIOp>(loc, x, y); 210 return rewriter.create<arith::MulFOp>(loc, x, y); 211 } 212 213 namespace { 214 215 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 216 /// semantics to: 217 /// ``` 218 /// %flattened_a = vector.shape_cast %a 219 /// %flattened_b = vector.shape_cast %b 220 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 221 /// %d = vector.shape_cast %%flattened_d 222 /// %e = add %c, %d 223 /// ``` 224 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 225 // 226 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 227 /// the vector.contract op is a row-major matrix multiply. 228 class ContractionOpToMatmulOpLowering 229 : public OpRewritePattern<vector::ContractionOp> { 230 public: 231 using OpRewritePattern::OpRewritePattern; 232 233 using FilterConstraintType = 234 std::function<LogicalResult(vector::ContractionOp op)>; 235 236 static LogicalResult defaultFilter(vector::ContractionOp op) { 237 return success(); 238 } 239 240 ContractionOpToMatmulOpLowering( 241 vector::VectorTransformsOptions vectorTransformOptions, 242 MLIRContext *context, PatternBenefit benefit = 1, 243 FilterConstraintType constraint = defaultFilter) 244 : OpRewritePattern<vector::ContractionOp>(context, benefit), 245 vectorTransformOptions(vectorTransformOptions), 246 filter(std::move(constraint)) {} 247 248 LogicalResult matchAndRewrite(vector::ContractionOp op, 249 PatternRewriter &rewriter) const override; 250 251 private: 252 /// Options to control the vector patterns. 253 vector::VectorTransformsOptions vectorTransformOptions; 254 FilterConstraintType filter; 255 }; 256 257 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 258 /// semantics to a reduction_size-unrolled sequence: 259 /// ``` 260 /// %at = vector.transpose %a, [1, 0] 261 /// %bRow0 = vector.extract %b[0] 262 /// %atRow0 = vector.extract %at[0] 263 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 264 /// ... 265 /// %bRowK = vector.extract %b[K] 266 /// %atRowK = vector.extract %at[K] 267 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 268 /// ``` 269 /// 270 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 271 /// the vector.contract op is a row-major matrix multiply. 272 class ContractionOpToOuterProductOpLowering 273 : public OpRewritePattern<vector::ContractionOp> { 274 public: 275 using OpRewritePattern::OpRewritePattern; 276 277 using FilterConstraintType = 278 std::function<LogicalResult(vector::ContractionOp op)>; 279 280 static LogicalResult defaultFilter(vector::ContractionOp op) { 281 return success(); 282 } 283 284 ContractionOpToOuterProductOpLowering( 285 vector::VectorTransformsOptions vectorTransformOptions, 286 MLIRContext *context, PatternBenefit benefit = 1, 287 FilterConstraintType constraint = defaultFilter) 288 : OpRewritePattern<vector::ContractionOp>(context, benefit), 289 vectorTransformOptions(vectorTransformOptions), 290 filter(std::move(constraint)) {} 291 292 LogicalResult matchAndRewrite(vector::ContractionOp op, 293 PatternRewriter &rewriter) const override; 294 295 private: 296 /// Options to control the vector patterns. 297 vector::VectorTransformsOptions vectorTransformOptions; 298 FilterConstraintType filter; 299 }; 300 301 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 302 /// semantics to an output-size-unrolled sequence: 303 /// ``` 304 /// %out = arith.constant ... : vector<MxNxelt_type> 305 /// %bt = vector.transpose %b, [1, 0] 306 /// %aRow0 = vector.extract %a[0] 307 /// %btRow0 = vector.extract %bt[0] 308 /// %c00 = vector.reduce %atRow0, %bRow0 309 /// %out00 = vector.insert %c00, %out[0, 0] 310 /// ... 311 /// %aRowLast = vector.extract %at[M-1] 312 /// %btRowLast = vector.extract %b[N-1] 313 /// %cLastLast = vector.reduce %atRowLast, %bRowLast 314 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] 315 /// ``` 316 /// 317 /// This only kicks in when VectorTransformsOptions is set to Dot and 318 /// the vector.contract op is a row-major matmul or matvec. 319 class ContractionOpToDotLowering 320 : public OpRewritePattern<vector::ContractionOp> { 321 public: 322 using OpRewritePattern::OpRewritePattern; 323 324 using FilterConstraintType = 325 std::function<LogicalResult(vector::ContractionOp op)>; 326 327 static LogicalResult defaultFilter(vector::ContractionOp op) { 328 return success(); 329 } 330 331 ContractionOpToDotLowering( 332 vector::VectorTransformsOptions vectorTransformOptions, 333 MLIRContext *context, PatternBenefit benefit = 1, 334 const FilterConstraintType &constraint = defaultFilter) 335 : OpRewritePattern<vector::ContractionOp>(context, benefit), 336 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 337 338 LogicalResult matchAndRewrite(vector::ContractionOp op, 339 PatternRewriter &rewriter) const override; 340 341 private: 342 /// Options to control the vector patterns. 343 vector::VectorTransformsOptions vectorTransformOptions; 344 FilterConstraintType filter; 345 }; 346 347 /// Progressive lowering of ContractionOp. 348 /// 349 /// One: 350 /// %x = vector.contract with at least one free/batch dimension 351 /// is replaced by: 352 /// %a = vector.contract with one less free/batch dimension 353 /// %b = vector.contract with one less free/batch dimension 354 /// .. 355 /// %x = combine %a %b .. 356 /// until a pure contraction is reached (no free/batch dimensions), 357 /// which is replaced by a dot-product. 358 /// 359 /// This only kicks in when either VectorTransformsOptions is set 360 /// to Dot or when other contraction patterns fail. 361 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> { 362 public: 363 using OpRewritePattern::OpRewritePattern; 364 using FilterConstraintType = 365 std::function<LogicalResult(vector::ContractionOp op)>; 366 367 static LogicalResult defaultFilter(vector::ContractionOp op) { 368 return success(); 369 } 370 371 ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 372 MLIRContext *context, PatternBenefit benefit = 1, 373 FilterConstraintType constraint = defaultFilter) 374 : OpRewritePattern<vector::ContractionOp>(context, benefit), 375 vectorTransformOptions(vectorTransformOptions), 376 filter(std::move(constraint)) {} 377 378 LogicalResult matchAndRewrite(vector::ContractionOp op, 379 PatternRewriter &rewriter) const override; 380 381 private: 382 /// Options to control the vector patterns. 383 vector::VectorTransformsOptions vectorTransformOptions; 384 FilterConstraintType filter; 385 // Lower one parallel dimension. 386 FailureOr<Value> lowerParallel(PatternRewriter &rewriter, 387 vector::ContractionOp op, int64_t lhsIndex, 388 int64_t rhsIndex, Value mask) const; 389 // Lower one reduction dimension. 390 FailureOr<Value> lowerReduction(PatternRewriter &rewriter, 391 vector::ContractionOp op, Value mask) const; 392 }; 393 394 /// Generate a vector implementation for matmat, matvec and tmatvec. 395 /// This unrolls outer-products along the reduction dimension. 396 struct UnrolledOuterProductGenerator 397 : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> { 398 UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) 399 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op), 400 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), 401 res(op.getAcc()), lhsType(op.getLhsType()) { 402 auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 403 if (maskableOp.isMasked()) 404 mask = maskableOp.getMaskingOp().getMask(); 405 } 406 407 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) { 408 if (!v) 409 return v; 410 return rewriter.create<vector::TransposeOp>(loc, v, perm); 411 } 412 413 Value promote(Value v, Type dstElementType) { 414 Type elementType = v.getType(); 415 auto vecType = dyn_cast<VectorType>(elementType); 416 if (vecType) 417 elementType = vecType.getElementType(); 418 if (elementType == dstElementType) 419 return v; 420 Type promotedType = dstElementType; 421 if (vecType) 422 promotedType = vecType.clone(promotedType); 423 if (isa<FloatType>(dstElementType)) 424 return rewriter.create<arith::ExtFOp>(loc, promotedType, v); 425 return rewriter.create<arith::ExtSIOp>(loc, promotedType, v); 426 } 427 428 FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, 429 VectorType lhsType, int reductionSize, 430 std::optional<Value> maybeMask = std::nullopt) { 431 // Incremental support for masking. 432 if (mask && !maybeMask.has_value()) 433 return failure(); 434 435 Type resElementType = cast<VectorType>(res.getType()).getElementType(); 436 for (int64_t k = 0; k < reductionSize; ++k) { 437 Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k); 438 Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k); 439 extractA = promote(extractA, resElementType); 440 extractB = promote(extractB, resElementType); 441 Value extractMask; 442 if (maybeMask.has_value() && maybeMask.value()) 443 extractMask = 444 rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k); 445 446 Operation *outerProdOp = rewriter.create<vector::OuterProductOp>( 447 loc, res.getType(), extractA, extractB, res, kind); 448 res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); 449 } 450 return res; 451 } 452 453 /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of 454 /// dimension `reductionDim`. If the dimension is a scalable dimension, 455 /// returns "nullopt". 456 std::optional<int64_t> getReductionSize(VectorType vecType, 457 int64_t reductionDim) { 458 // Cannot unroll scalable dimension. 459 if (vecType.getScalableDims()[reductionDim]) 460 return std::nullopt; 461 int64_t reductionSize = vecType.getDimSize(reductionDim); 462 assert(reductionSize > 0 && 463 "Reduction dim must be a known static size to allow unrolling"); 464 return reductionSize; 465 } 466 467 /// Two outer parallel, one inner reduction (matmat flavor). 468 FailureOr<Value> matmat() { 469 if (!iters({Par(), Par(), Red()})) 470 return failure(); 471 // Set up the parallel/reduction structure in the right form. 472 AffineExpr m, n, k; 473 bindDims(rewriter.getContext(), m, n, k); 474 475 // Classical row-major matmul: Just permute the lhs. 476 if (layout({{m, k}, {k, n}, {m, n}})) { 477 if (auto reductionSize = getReductionSize(lhsType, 1)) { 478 // Note: `t` creates new IR. It must be nested within this `if` check 479 // so that no IR is created when then pattern returns "failure". 480 Value tLhs = t(lhs); 481 Value tMask = t(mask, {2, 0, 1}); 482 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask); 483 } 484 } 485 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 486 if (layout({{m, k}, {n, k}, {m, n}})) { 487 if (auto reductionSize = getReductionSize(lhsType, 1)) { 488 Value tLhs = t(lhs); 489 Value tRhs = t(rhs); 490 Value tMask = t(mask, {2, 0, 1}); 491 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask); 492 } 493 } 494 // No need to permute anything. 495 if (layout({{k, m}, {k, n}, {m, n}})) { 496 if (auto reductionSize = getReductionSize(lhsType, 0)) { 497 Value tMask = t(mask, {2, 0, 1}); 498 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask); 499 } 500 } 501 // Just permute the rhs. 502 if (layout({{k, m}, {n, k}, {m, n}})) { 503 if (auto reductionSize = getReductionSize(lhsType, 0)) { 504 Value tRhs = t(rhs); 505 Value tMask = t(mask, {2, 0, 1}); 506 return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask); 507 } 508 } 509 // Transposed output: swap RHS and LHS. 510 // Classical row-major matmul: permute the lhs. 511 if (layout({{m, k}, {k, n}, {n, m}})) { 512 if (auto reductionSize = getReductionSize(lhsType, 1)) { 513 Value tLhs = t(lhs); 514 Value tMask = t(mask, {2, 0, 1}); 515 return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask); 516 } 517 } 518 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 519 if (layout({{m, k}, {n, k}, {n, m}})) { 520 if (auto reductionSize = getReductionSize(lhsType, 1)) { 521 Value tRhs = t(rhs); 522 Value tLhs = t(lhs); 523 Value tMask = t(mask, {2, 0, 1}); 524 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask); 525 } 526 } 527 if (layout({{k, m}, {k, n}, {n, m}})) { 528 if (auto reductionSize = getReductionSize(lhsType, 0)) { 529 Value tMask = t(mask, {2, 0, 1}); 530 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask); 531 } 532 } 533 if (layout({{k, m}, {n, k}, {n, m}})) { 534 if (auto reductionSize = getReductionSize(lhsType, 0)) { 535 Value tRhs = t(rhs); 536 Value tMask = t(mask, {2, 0, 1}); 537 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask); 538 } 539 } 540 return failure(); 541 } 542 543 // 544 // One outer parallel, one inner reduction (matvec flavor). 545 // Mask needs to be transposed everywhere to turn the reduction dimension 546 // outermost as required by outerproduct. 547 // 548 FailureOr<Value> matvec() { 549 if (!iters({Par(), Red()})) 550 return failure(); 551 AffineExpr m, k; 552 bindDims(rewriter.getContext(), m, k); 553 554 // Case mat-vec: transpose. 555 if (layout({{m, k}, {k}, {m}})) { 556 if (auto reductionSize = getReductionSize(lhsType, 1)) { 557 Value tLhs = t(lhs); 558 Value tMask = t(mask); 559 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask); 560 } 561 } 562 // Case mat-trans-vec: ready to go. 563 if (layout({{k, m}, {k}, {m}})) { 564 if (auto reductionSize = getReductionSize(lhsType, 0)) { 565 Value tMask = t(mask); 566 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask); 567 } 568 } 569 // Case vec-mat: swap and transpose. 570 if (layout({{k}, {m, k}, {m}})) { 571 if (auto reductionSize = getReductionSize(lhsType, 0)) { 572 Value tRhs = t(rhs); 573 Value tMask = t(mask); 574 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask); 575 } 576 } 577 // Case vec-mat-trans: swap and ready to go. 578 if (layout({{k}, {k, m}, {m}})) { 579 if (auto reductionSize = getReductionSize(lhsType, 0)) { 580 Value tMask = t(mask); 581 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask); 582 } 583 } 584 return failure(); 585 } 586 587 // 588 // One outer reduction, one inner parallel (tmatvec flavor). 589 // Mask already has the shape of the outer product. 590 // 591 FailureOr<Value> tmatvec() { 592 if (!iters({Red(), Par()})) 593 return failure(); 594 AffineExpr k, m; 595 bindDims(rewriter.getContext(), k, m); 596 597 // Case mat-vec: transpose. 598 if (layout({{m, k}, {k}, {m}})) 599 if (auto reductionSize = getReductionSize(lhsType, 1)) 600 return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask); 601 // Case mat-trans-vec: ready to go. 602 if (layout({{k, m}, {k}, {m}})) 603 if (auto reductionSize = getReductionSize(lhsType, 0)) 604 return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask); 605 // Case vec-mat: swap and transpose. 606 if (layout({{k}, {m, k}, {m}})) 607 if (auto reductionSize = getReductionSize(lhsType, 0)) 608 return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask); 609 // Case vec-mat-trans: swap and ready to go. 610 if (layout({{k}, {k, m}, {m}})) 611 if (auto reductionSize = getReductionSize(lhsType, 0)) 612 return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask); 613 return failure(); 614 } 615 616 private: 617 vector::CombiningKind kind; 618 Value lhs, rhs, res, mask; 619 VectorType lhsType; 620 }; 621 622 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 623 /// semantics to a reduction_size-unrolled sequence: 624 /// ``` 625 /// %at = vector.transpose %a, [1, 0] 626 /// %bRow0 = vector.extract %b[0] 627 /// %atRow0 = vector.extract %at[0] 628 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 629 /// ... 630 /// %bRowK = vector.extract %b[K] 631 /// %atRowK = vector.extract %at[K] 632 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 633 /// ``` 634 /// 635 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 636 /// otherwise supports any layout permutation of the matrix-multiply. 637 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 638 vector::ContractionOp op, PatternRewriter &rewriter) const { 639 if (vectorTransformOptions.vectorContractLowering != 640 vector::VectorContractLowering::OuterProduct) 641 return failure(); 642 643 if (failed(filter(op))) 644 return failure(); 645 646 // Vector mask setup. 647 OpBuilder::InsertionGuard guard(rewriter); 648 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation()); 649 Operation *rootOp; 650 if (maskableOp.isMasked()) { 651 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 652 rootOp = maskableOp.getMaskingOp(); 653 } else { 654 rootOp = op; 655 } 656 657 UnrolledOuterProductGenerator e(rewriter, op); 658 FailureOr<Value> matmatRes = e.matmat(); 659 if (succeeded(matmatRes)) { 660 rewriter.replaceOp(rootOp, *matmatRes); 661 return success(); 662 } 663 FailureOr<Value> matvecRes = e.matvec(); 664 if (succeeded(matvecRes)) { 665 rewriter.replaceOp(rootOp, *matvecRes); 666 return success(); 667 } 668 FailureOr<Value> tmatvecRes = e.tmatvec(); 669 if (succeeded(tmatvecRes)) { 670 rewriter.replaceOp(rootOp, *tmatvecRes); 671 return success(); 672 } 673 674 return failure(); 675 } 676 677 LogicalResult 678 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 679 PatternRewriter &rewriter) const { 680 // TODO: Support vector.mask. 681 auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 682 if (maskableOp.isMasked()) 683 return failure(); 684 685 if (failed(filter(op))) 686 return failure(); 687 688 if (vectorTransformOptions.vectorContractLowering != 689 vector::VectorContractLowering::Dot) 690 return failure(); 691 692 auto iteratorTypes = op.getIteratorTypes().getValue(); 693 static constexpr std::array<int64_t, 2> perm = {1, 0}; 694 Location loc = op.getLoc(); 695 Value lhs = op.getLhs(), rhs = op.getRhs(); 696 697 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 698 auto infer = [&](MapList m) { 699 return AffineMap::inferFromExprList(m, op.getContext()); 700 }; 701 AffineExpr m, n, k; 702 bindDims(rewriter.getContext(), m, n, k); 703 SmallVector<AffineMap> maps = op.getIndexingMapsArray(); 704 // 705 // In the following we wish to make the reduction dimension innermost so we 706 // can load vectors and just fmul + reduce into a scalar. 707 // 708 if (isParallelIterator(iteratorTypes[0]) && 709 isParallelIterator(iteratorTypes[1]) && 710 isReductionIterator(iteratorTypes[2])) { 711 // 712 // Two outer parallel, one inner reduction (matmat flavor). 713 // 714 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 715 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 716 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 717 // No need to permute anything. 718 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 719 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 720 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 721 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 722 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 723 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 724 // This is the classical row-major matmul. Just permute the lhs. 725 Value tmp = lhs; 726 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 727 rhs = tmp; 728 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 729 std::swap(lhs, rhs); 730 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 731 Value tmp = lhs; 732 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 733 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 734 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 735 Value tmp = rhs; 736 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 737 lhs = tmp; 738 } else { 739 return failure(); 740 } 741 } else if (isParallelIterator(iteratorTypes[0]) && 742 isReductionIterator(iteratorTypes[1])) { 743 // 744 // One outer parallel, one inner reduction (matvec flavor) 745 // 746 if (maps == infer({{m, n}, {n}, {m}})) { 747 // No need to permute anything. 748 } else if (maps == infer({{n, m}, {n}, {m}})) { 749 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 750 } else if (maps == infer({{n}, {m, n}, {m}})) { 751 std::swap(lhs, rhs); 752 } else if (maps == infer({{n}, {n, m}, {m}})) { 753 std::swap(lhs, rhs); 754 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 755 } else { 756 return failure(); 757 } 758 } else { 759 return failure(); 760 } 761 762 VectorType dstType = cast<VectorType>(op.getResultType()); 763 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 764 "Expected dst type of rank 1 or 2"); 765 766 unsigned rank = dstType.getRank(); 767 unsigned dstRows = dstType.getShape()[0]; 768 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 769 770 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 771 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 772 rewriter.getZeroAttr(dstType)); 773 bool isInt = isa<IntegerType>(dstType.getElementType()); 774 for (unsigned r = 0; r < dstRows; ++r) { 775 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 776 for (unsigned c = 0; c < dstColumns; ++c) { 777 Value b = rank == 1 778 ? rhs 779 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 780 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 781 Value reduced = rewriter.create<vector::ReductionOp>( 782 op.getLoc(), vector::CombiningKind::ADD, m); 783 784 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 785 : SmallVector<int64_t, 2>{r, c}; 786 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 787 } 788 } 789 if (auto acc = op.getAcc()) 790 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 791 rewriter.replaceOp(op, res); 792 return success(); 793 } 794 795 /// Lower vector.contract with all size one reduction dimensions to 796 /// elementwise ops when possible. 797 struct ContractOpToElementwise 798 : public OpRewritePattern<vector::ContractionOp> { 799 using OpRewritePattern::OpRewritePattern; 800 using FilterConstraintType = 801 std::function<LogicalResult(vector::ContractionOp op)>; 802 static LogicalResult defaultFilter(vector::ContractionOp op) { 803 return success(); 804 } 805 ContractOpToElementwise( 806 vector::VectorTransformsOptions vectorTransformOptions, 807 MLIRContext *context, PatternBenefit benefit = 1, 808 const FilterConstraintType &constraint = defaultFilter) 809 : OpRewritePattern<vector::ContractionOp>(context, benefit), 810 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 811 812 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 813 PatternRewriter &rewriter) const override { 814 // TODO: Support vector.mask. 815 auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation()); 816 if (maskableOp.isMasked()) 817 return failure(); 818 819 if (failed(filter(contractOp))) 820 return failure(); 821 822 if (vectorTransformOptions.vectorContractLowering != 823 vector::VectorContractLowering::ParallelArith) 824 return failure(); 825 826 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape(); 827 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape(); 828 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0]; 829 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1]; 830 SmallVector<int64_t> lhsReductionDims = 831 getReductionIndex(lhsMap, contractOp.getIteratorTypes()); 832 SmallVector<int64_t> rhsReductionDims = 833 getReductionIndex(rhsMap, contractOp.getIteratorTypes()); 834 // All the reduction dimensions must be a size 1. 835 for (int64_t dim : lhsReductionDims) { 836 if (lhsShape[dim] != 1) 837 return failure(); 838 } 839 for (int64_t dim : rhsReductionDims) { 840 if (rhsShape[dim] != 1) 841 return failure(); 842 } 843 AffineMap accMap = contractOp.getIndexingMapsArray()[2]; 844 unsigned numParallelDims = accMap.getNumResults(); 845 unsigned numLhsDimToBroadcast = 846 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size()); 847 unsigned numRhsDimToBroadcast = 848 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size()); 849 SmallVector<int64_t> lhsDims; 850 SmallVector<int64_t> lhsTranspose; 851 SmallVector<int64_t> rhsDims; 852 SmallVector<int64_t> rhsTranspose; 853 for (int64_t dim : lhsReductionDims) 854 lhsTranspose.push_back(numLhsDimToBroadcast + dim); 855 for (int64_t dim : rhsReductionDims) 856 rhsTranspose.push_back(numRhsDimToBroadcast + dim); 857 // Loop through the parallel dimensions to calculate the dimensions to 858 // broadcast and to permute in order to extract only parallel dimensions. 859 for (unsigned i = 0; i < numParallelDims; i++) { 860 std::optional<unsigned> lhsDim = 861 getDimPosition(lhsMap, accMap.getDimPosition(i)); 862 if (lhsDim) { 863 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim); 864 } else { 865 // If the parallel dimension doesn't exist we will have to broadcast it. 866 lhsDims.push_back( 867 cast<VectorType>(contractOp.getResultType()).getDimSize(i)); 868 lhsTranspose.push_back(lhsDims.size() - 1); 869 } 870 std::optional<unsigned> rhsDim = 871 getDimPosition(rhsMap, accMap.getDimPosition(i)); 872 if (rhsDim) { 873 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim); 874 } else { 875 // If the parallel dimension doesn't exist we will have to broadcast it. 876 rhsDims.push_back( 877 cast<VectorType>(contractOp.getResultType()).getDimSize(i)); 878 rhsTranspose.push_back(rhsDims.size() - 1); 879 } 880 } 881 Value newLhs = contractOp.getLhs(); 882 Value newRhs = contractOp.getRhs(); 883 Location loc = contractOp.getLoc(); 884 if (!lhsDims.empty()) { 885 lhsDims.append(lhsShape.begin(), lhsShape.end()); 886 auto expandedType = 887 VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); 888 newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs); 889 } 890 if (!rhsDims.empty()) { 891 rhsDims.append(rhsShape.begin(), rhsShape.end()); 892 auto expandedType = 893 VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); 894 newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs); 895 } 896 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); 897 newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose); 898 newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose); 899 SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0); 900 SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0); 901 newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets); 902 newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets); 903 std::optional<Value> result = 904 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), 905 contractOp.getKind(), rewriter, isInt); 906 rewriter.replaceOp(contractOp, {*result}); 907 return success(); 908 } 909 910 private: 911 /// Options to control the vector patterns. 912 vector::VectorTransformsOptions vectorTransformOptions; 913 FilterConstraintType filter; 914 }; 915 916 /// Progressive lowering of ContractionOp. 917 /// One: 918 /// %x = vector.contract with at least one free/batch dimension 919 /// is replaced by: 920 /// %a = vector.contract with one less free/batch dimension 921 /// %b = vector.contract with one less free/batch dimension 922 /// .. 923 /// %x = combine %a %b .. 924 /// until a pure contraction is reached (no free/batch dimensions), 925 /// which is replaced by a dot-product. 926 /// 927 /// This only kicks in when either VectorTransformsOptions is set 928 /// to DOT or when other contraction patterns fail. 929 // 930 // TODO: break down into transpose/reshape/cast ops 931 // when they become available to avoid code dup 932 // TODO: investigate lowering order impact on performance 933 LogicalResult 934 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 935 PatternRewriter &rewriter) const { 936 if (failed(filter(op))) 937 return failure(); 938 939 // TODO: support mixed mode contract lowering. 940 if (op.getLhsType().getElementType() != 941 getElementTypeOrSelf(op.getAccType()) || 942 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 943 return failure(); 944 945 // TODO: the code below assumes the default contraction, make sure it supports 946 // other kinds before enabling this lowering. 947 if (op.getKind() != vector::CombiningKind::ADD) { 948 return rewriter.notifyMatchFailure( 949 op, "contractions other than 'add' not supported"); 950 } 951 952 // TODO: implement benefits, cost models. 953 MLIRContext *ctx = op.getContext(); 954 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 955 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 956 return success(); 957 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 958 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 959 return success(); 960 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 961 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 962 return success(); 963 ContractOpToElementwise pat4(vectorTransformOptions, ctx); 964 if (succeeded(pat4.matchAndRewrite(op, rewriter))) 965 return success(); 966 967 // Vector mask setup. 968 OpBuilder::InsertionGuard guard(rewriter); 969 Operation *rootOp = op; 970 Value mask; 971 if (op.isMasked()) { 972 rewriter.setInsertionPoint(op.getMaskingOp()); 973 rootOp = op.getMaskingOp(); 974 mask = op.getMaskingOp().getMask(); 975 } 976 977 // Find first batch dimension in LHS/RHS, and lower when found. 978 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 979 if (!batchDimMap.empty()) { 980 int64_t lhsIndex = batchDimMap[0].first; 981 int64_t rhsIndex = batchDimMap[0].second; 982 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask); 983 if (failed(newOp)) 984 return failure(); 985 rewriter.replaceOp(rootOp, *newOp); 986 return success(); 987 } 988 989 // Collect contracting dimensions. 990 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 991 op.getContractingDimMap(); 992 DenseSet<int64_t> lhsContractingDimSet; 993 DenseSet<int64_t> rhsContractingDimSet; 994 for (auto &dimPair : contractingDimMap) { 995 lhsContractingDimSet.insert(dimPair.first); 996 rhsContractingDimSet.insert(dimPair.second); 997 } 998 999 // Find first free dimension in LHS, and lower when found. 1000 VectorType lhsType = op.getLhsType(); 1001 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1002 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1003 auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask); 1004 if (failed(newOp)) 1005 return failure(); 1006 rewriter.replaceOp(rootOp, *newOp); 1007 return success(); 1008 } 1009 } 1010 1011 // Find first free dimension in RHS, and lower when found. 1012 VectorType rhsType = op.getRhsType(); 1013 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1014 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1015 auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask); 1016 if (failed(newOp)) 1017 return failure(); 1018 rewriter.replaceOp(rootOp, *newOp); 1019 return success(); 1020 } 1021 } 1022 1023 // Lower the first remaining reduction dimension. 1024 if (!contractingDimMap.empty()) { 1025 auto newOp = lowerReduction(rewriter, op, mask); 1026 if (failed(newOp)) 1027 return failure(); 1028 rewriter.replaceOp(rootOp, *newOp); 1029 return success(); 1030 } 1031 1032 return failure(); 1033 } 1034 1035 // Lower one parallel dimension. 1036 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions. 1037 // TODO: consider reusing existing contract unrolling 1038 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, 1039 vector::ContractionOp op, 1040 int64_t lhsIndex, 1041 int64_t rhsIndex, 1042 Value mask) const { 1043 VectorType lhsType = op.getLhsType(); 1044 VectorType rhsType = op.getRhsType(); 1045 VectorType resType = cast<VectorType>(op.getResultType()); 1046 // Find the iterator type index and result index. 1047 SmallVector<AffineMap> iMap = op.getIndexingMapsArray(); 1048 int64_t iterIndex = -1; 1049 int64_t dimSize = -1; 1050 if (lhsIndex >= 0) { 1051 iterIndex = iMap[0].getDimPosition(lhsIndex); 1052 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex)) 1053 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1054 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex 1055 << " to map to the same dimension"; 1056 }); 1057 if (lhsType.getScalableDims()[lhsIndex]) 1058 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1059 diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex 1060 << ") is not supported yet"; 1061 }); 1062 dimSize = lhsType.getDimSize(lhsIndex); 1063 } else if (rhsIndex >= 0) { 1064 iterIndex = iMap[1].getDimPosition(rhsIndex); 1065 if (rhsType.getScalableDims()[rhsIndex]) 1066 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1067 diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex 1068 << ") is not supported yet"; 1069 }); 1070 dimSize = rhsType.getDimSize(rhsIndex); 1071 } 1072 if (iterIndex < 0) 1073 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1074 diag << "expected either lhsIndex=" << lhsIndex 1075 << " or rhsIndex=" << rhsIndex << " to be nonnegative"; 1076 }); 1077 // value_or(-1) means that we tolerate a dimension not appearing 1078 // in the result map. That can't happen for actual parallel iterators, but 1079 // the caller ContractionOpLowering::matchAndRewrite is currently calling 1080 // lowerParallel also for the case of unit-size reduction dims appearing only 1081 // on one of LHS or RHS, not both. At the moment, such cases are created by 1082 // CastAwayContractionLeadingOneDim, so we need to either support that or 1083 // modify that pattern. 1084 int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1); 1085 if (resIndex == -1 && dimSize != 1) 1086 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1087 diag << "expected the dimension for iterIndex=" << iterIndex 1088 << " to either appear in the result map, or to be a unit dimension"; 1089 }); 1090 1091 // Construct new iterator types and affine map array attribute. 1092 std::array<AffineMap, 3> lowIndexingMaps = { 1093 adjustMap(iMap[0], iterIndex, rewriter), 1094 adjustMap(iMap[1], iterIndex, rewriter), 1095 adjustMap(iMap[2], iterIndex, rewriter)}; 1096 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1097 auto lowIter = 1098 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1099 // Unroll into a series of lower dimensional vector.contract ops. 1100 Location loc = op.getLoc(); 1101 Value result = rewriter.create<arith::ConstantOp>( 1102 loc, resType, rewriter.getZeroAttr(resType)); 1103 1104 for (int64_t d = 0; d < dimSize; ++d) { 1105 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1106 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1107 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); 1108 1109 Value lowMask; 1110 if (mask) 1111 lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()), 1112 iterIndex, d, rewriter); 1113 1114 Operation *lowContract = rewriter.create<vector::ContractionOp>( 1115 loc, lhs, rhs, acc, lowAffine, lowIter); 1116 lowContract = maskOperation(rewriter, lowContract, lowMask); 1117 result = reshapeStore(loc, lowContract->getResult(0), result, resType, 1118 resIndex, d, rewriter); 1119 } 1120 return result; 1121 } 1122 1123 // Lower one reduction dimension. 1124 FailureOr<Value> ContractionOpLowering::lowerReduction( 1125 PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const { 1126 auto loc = op.getLoc(); 1127 VectorType lhsType = op.getLhsType(); 1128 VectorType rhsType = op.getRhsType(); 1129 Type resType = op.getResultType(); 1130 if (isa<VectorType>(resType)) 1131 return rewriter.notifyMatchFailure(op, 1132 "did not expect a VectorType result"); 1133 bool isInt = isa<IntegerType>(resType); 1134 // Use iterator index 0. 1135 int64_t iterIndex = 0; 1136 SmallVector<AffineMap> iMap = op.getIndexingMapsArray(); 1137 std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1138 std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1139 if (!lookupLhs.has_value()) 1140 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1141 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension"; 1142 }); 1143 if (!lookupRhs.has_value()) 1144 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1145 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension"; 1146 }); 1147 int64_t lhsIndex = *lookupLhs; 1148 int64_t rhsIndex = *lookupRhs; 1149 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1150 if (dimSize != rhsType.getDimSize(rhsIndex)) 1151 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1152 diag << "expect LHS dimension " << lhsIndex 1153 << " to have the same size as RHS dimension " << rhsIndex; 1154 }); 1155 // Base case. 1156 if (lhsType.getRank() == 1) { 1157 if (rhsType.getRank() != 1) 1158 return rewriter.notifyMatchFailure( 1159 op, "When LHS has rank 1, expected also RHS to have rank 1"); 1160 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); 1161 auto kind = vector::CombiningKind::ADD; 1162 1163 Value acc = op.getAcc(); 1164 Operation *reductionOp = 1165 acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc) 1166 : rewriter.create<vector::ReductionOp>(loc, kind, m); 1167 return maskOperation(rewriter, reductionOp, mask)->getResult(0); 1168 } 1169 // Construct new iterator types and affine map array attribute. 1170 std::array<AffineMap, 3> lowIndexingMaps = { 1171 adjustMap(iMap[0], iterIndex, rewriter), 1172 adjustMap(iMap[1], iterIndex, rewriter), 1173 adjustMap(iMap[2], iterIndex, rewriter)}; 1174 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1175 auto lowIter = 1176 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1177 // Unroll into a series of lower dimensional vector.contract ops. 1178 // By feeding the initial accumulator into the first contraction, 1179 // and the result of each contraction into the next, eventually 1180 // the sum of all reductions is computed. 1181 Value result = op.getAcc(); 1182 for (int64_t d = 0; d < dimSize; ++d) { 1183 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1184 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1185 Value newMask; 1186 if (mask) 1187 newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()), 1188 iterIndex, d, rewriter); 1189 1190 Operation *newContract = rewriter.create<vector::ContractionOp>( 1191 loc, lhs, rhs, result, lowAffine, lowIter); 1192 result = maskOperation(rewriter, newContract, newMask)->getResult(0); 1193 } 1194 return result; 1195 } 1196 1197 /// Progressive lowering of OuterProductOp. 1198 /// One: 1199 /// %x = vector.outerproduct %lhs, %rhs, %acc 1200 /// is replaced by: 1201 /// %z = zero-result 1202 /// %0 = vector.extract %lhs[0] 1203 /// %1 = vector.broadcast %0 1204 /// %2 = vector.extract %acc[0] 1205 /// %3 = vector.fma %1, %rhs, %2 1206 /// %4 = vector.insert %3, %z[0] 1207 /// .. 1208 /// %x = vector.insert %.., %..[N-1] 1209 /// 1210 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 1211 public: 1212 using OpRewritePattern::OpRewritePattern; 1213 1214 LogicalResult matchAndRewrite(vector::OuterProductOp op, 1215 PatternRewriter &rewriter) const override { 1216 VectorType resType = op.getResultVectorType(); 1217 if ((resType.getShape().size() >= 2) && resType.allDimsScalable()) 1218 return failure(); 1219 1220 auto loc = op.getLoc(); 1221 1222 VectorType lhsType = op.getOperandVectorTypeLHS(); 1223 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS()); 1224 Type eltType = resType.getElementType(); 1225 bool isInt = isa<IntegerType, IndexType>(eltType); 1226 Value acc = op.getAcc(); 1227 vector::CombiningKind kind = op.getKind(); 1228 1229 // Vector mask setup. 1230 OpBuilder::InsertionGuard guard(rewriter); 1231 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation()); 1232 Operation *rootOp; 1233 Value mask; 1234 if (maskableOp.isMasked()) { 1235 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 1236 rootOp = maskableOp.getMaskingOp(); 1237 mask = maskableOp.getMaskingOp().getMask(); 1238 } else { 1239 rootOp = op; 1240 } 1241 1242 if (!rhsType) { 1243 // Special case: AXPY operation. 1244 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs()); 1245 std::optional<Value> mult = createContractArithOp( 1246 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask); 1247 if (!mult.has_value()) 1248 return failure(); 1249 rewriter.replaceOp(rootOp, *mult); 1250 return success(); 1251 } 1252 1253 Value result = rewriter.create<arith::ConstantOp>( 1254 loc, resType, rewriter.getZeroAttr(resType)); 1255 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 1256 Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d); 1257 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 1258 Value r = nullptr; 1259 if (acc) 1260 r = rewriter.create<vector::ExtractOp>(loc, acc, d); 1261 Value extrMask; 1262 if (mask) 1263 extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d); 1264 1265 std::optional<Value> m = createContractArithOp( 1266 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); 1267 if (!m.has_value()) 1268 return failure(); 1269 result = rewriter.create<vector::InsertOp>(loc, *m, result, d); 1270 } 1271 1272 rewriter.replaceOp(rootOp, result); 1273 return success(); 1274 } 1275 }; 1276 1277 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1278 /// semantics to: 1279 /// ``` 1280 /// %mta = maybe_transpose 1281 /// %mtb = maybe_transpose 1282 /// %flattened_a = vector.shape_cast %mta 1283 /// %flattened_b = vector.shape_cast %mtb 1284 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1285 /// %mtd = vector.shape_cast %flattened_d 1286 /// %d = maybe_untranspose %mtd 1287 /// %e = add %c, %d 1288 /// ``` 1289 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1290 // 1291 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1292 /// vector.transpose operations are inserted if the vector.contract op is not a 1293 /// row-major matrix multiply. 1294 LogicalResult 1295 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1296 PatternRewriter &rew) const { 1297 // TODO: Support vector.mask. 1298 auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 1299 if (maskableOp.isMasked()) 1300 return failure(); 1301 1302 if (vectorTransformOptions.vectorContractLowering != 1303 vector::VectorContractLowering::Matmul) 1304 return failure(); 1305 if (failed(filter(op))) 1306 return failure(); 1307 1308 auto iteratorTypes = op.getIteratorTypes().getValue(); 1309 if (!isParallelIterator(iteratorTypes[0]) || 1310 !isParallelIterator(iteratorTypes[1]) || 1311 !isReductionIterator(iteratorTypes[2])) 1312 return failure(); 1313 1314 Type elementType = op.getLhsType().getElementType(); 1315 if (!elementType.isIntOrFloat()) 1316 return failure(); 1317 1318 Type dstElementType = op.getType(); 1319 if (auto vecType = dyn_cast<VectorType>(dstElementType)) 1320 dstElementType = vecType.getElementType(); 1321 if (elementType != dstElementType) 1322 return failure(); 1323 1324 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1325 // Bail out if the contraction cannot be put in this form. 1326 MLIRContext *ctx = op.getContext(); 1327 Location loc = op.getLoc(); 1328 AffineExpr m, n, k; 1329 bindDims(rew.getContext(), m, n, k); 1330 // LHS must be A(m, k) or A(k, m). 1331 Value lhs = op.getLhs(); 1332 auto lhsMap = op.getIndexingMapsArray()[0]; 1333 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1334 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1335 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1336 return failure(); 1337 1338 // RHS must be B(k, n) or B(n, k). 1339 Value rhs = op.getRhs(); 1340 auto rhsMap = op.getIndexingMapsArray()[1]; 1341 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1342 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1343 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1344 return failure(); 1345 1346 // At this point lhs and rhs are in row-major. 1347 VectorType lhsType = cast<VectorType>(lhs.getType()); 1348 VectorType rhsType = cast<VectorType>(rhs.getType()); 1349 int64_t lhsRows = lhsType.getDimSize(0); 1350 int64_t lhsColumns = lhsType.getDimSize(1); 1351 int64_t rhsColumns = rhsType.getDimSize(1); 1352 1353 Type flattenedLHSType = 1354 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1355 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1356 1357 Type flattenedRHSType = 1358 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1359 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1360 1361 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1362 rhsColumns); 1363 mul = rew.create<vector::ShapeCastOp>( 1364 loc, 1365 VectorType::get({lhsRows, rhsColumns}, 1366 getElementTypeOrSelf(op.getAcc().getType())), 1367 mul); 1368 1369 // ACC must be C(m, n) or C(n, m). 1370 auto accMap = op.getIndexingMapsArray()[2]; 1371 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1372 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1373 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1374 llvm_unreachable("invalid contraction semantics"); 1375 1376 Value res = 1377 isa<IntegerType>(elementType) 1378 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul)) 1379 : static_cast<Value>( 1380 rew.create<arith::AddFOp>(loc, op.getAcc(), mul)); 1381 1382 rew.replaceOp(op, res); 1383 return success(); 1384 } 1385 } // namespace 1386 1387 void mlir::vector::populateVectorContractLoweringPatterns( 1388 RewritePatternSet &patterns, VectorTransformsOptions options, 1389 PatternBenefit benefit, bool disableOuterProductLowering) { 1390 if (!disableOuterProductLowering) 1391 patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit); 1392 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 1393 ContractionOpToOuterProductOpLowering>( 1394 options, patterns.getContext(), benefit); 1395 } 1396 1397 void mlir::vector::populateVectorOuterProductLoweringPatterns( 1398 RewritePatternSet &patterns, PatternBenefit benefit) { 1399 patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit); 1400 } 1401