1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.contract' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Utils/Utils.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 26 #include "mlir/IR/BuiltinAttributeInterfaces.h" 27 #include "mlir/IR/BuiltinTypes.h" 28 #include "mlir/IR/ImplicitLocOpBuilder.h" 29 #include "mlir/IR/Location.h" 30 #include "mlir/IR/Matchers.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/IR/TypeUtilities.h" 33 #include "mlir/Interfaces/VectorInterfaces.h" 34 35 #define DEBUG_TYPE "vector-contract-lowering" 36 37 using namespace mlir; 38 using namespace mlir::vector; 39 40 //===----------------------------------------------------------------------===// 41 // Helper functions 42 //===----------------------------------------------------------------------===// 43 // Helper to find an index in an affine map. 44 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 45 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 46 int64_t idx = map.getDimPosition(i); 47 if (idx == index) 48 return i; 49 } 50 return std::nullopt; 51 } 52 53 // Helper to construct iterator types with one index removed. 54 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes, 55 int64_t index) { 56 SmallVector<Attribute> results; 57 for (const auto &it : llvm::enumerate(iteratorTypes)) { 58 int64_t idx = it.index(); 59 if (idx == index) 60 continue; 61 results.push_back(it.value()); 62 } 63 return results; 64 } 65 66 // Helper to construct an affine map with one index removed. 67 static AffineMap adjustMap(AffineMap map, int64_t index, 68 PatternRewriter &rewriter) { 69 auto *ctx = rewriter.getContext(); 70 SmallVector<AffineExpr> results; 71 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 72 int64_t idx = map.getDimPosition(i); 73 if (idx == index) 74 continue; 75 // Re-insert remaining indices, but renamed when occurring 76 // after the removed index. 77 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 78 results.push_back(targetExpr); 79 } 80 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 81 } 82 83 // Helper method to possibly drop a dimension in a load. 84 // TODO 85 static Value reshapeLoad(Location loc, Value val, VectorType type, 86 int64_t index, int64_t pos, 87 PatternRewriter &rewriter) { 88 if (index == -1) 89 return val; 90 91 // At extraction dimension? 92 if (index == 0) 93 return rewriter.create<vector::ExtractOp>(loc, val, pos); 94 95 // Unroll leading dimensions. 96 VectorType vType = VectorType::Builder(type).dropDim(0); 97 VectorType resType = VectorType::Builder(type).dropDim(index); 98 Value result = rewriter.create<arith::ConstantOp>( 99 loc, resType, rewriter.getZeroAttr(resType)); 100 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { 101 Value ext = rewriter.create<vector::ExtractOp>(loc, val, d); 102 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 103 result = rewriter.create<vector::InsertOp>(loc, load, result, d); 104 } 105 return result; 106 } 107 108 // Helper method to possibly drop a dimension in a store. 109 // TODO 110 static Value reshapeStore(Location loc, Value val, Value result, 111 VectorType type, int64_t index, int64_t pos, 112 PatternRewriter &rewriter) { 113 // Unmodified? 114 if (index == -1) 115 return val; 116 // At insertion dimension? 117 if (index == 0) 118 return rewriter.create<vector::InsertOp>(loc, val, result, pos); 119 120 // Unroll leading dimensions. 121 VectorType vType = VectorType::Builder(type).dropDim(0); 122 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 123 Value ext = rewriter.create<vector::ExtractOp>(loc, result, d); 124 Value ins = rewriter.create<vector::ExtractOp>(loc, val, d); 125 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 126 result = rewriter.create<vector::InsertOp>(loc, sto, result, d); 127 } 128 return result; 129 } 130 131 /// Helper to create arithmetic operation associated with a kind of contraction. 132 static std::optional<Value> 133 createContractArithOp(Location loc, Value x, Value y, Value acc, 134 vector::CombiningKind kind, PatternRewriter &rewriter, 135 bool isInt, Value mask = Value()) { 136 using vector::CombiningKind; 137 Value mul; 138 139 if (isInt) { 140 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF || 141 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF) 142 // Only valid for floating point types. 143 return std::nullopt; 144 mul = rewriter.create<arith::MulIOp>(loc, x, y); 145 } else { 146 // Float case. 147 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || 148 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI || 149 kind == CombiningKind::MAXSI || kind == CombiningKind::OR || 150 kind == CombiningKind::XOR) 151 // Only valid for integer types. 152 return std::nullopt; 153 // Special case for fused multiply-add. 154 if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) { 155 Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc); 156 if (mask) 157 // The fma op doesn't need explicit masking. However, fma ops used in 158 // reductions must preserve previous 'acc' values for masked-out lanes. 159 fma = selectPassthru(rewriter, mask, fma, acc); 160 return fma; 161 } 162 mul = rewriter.create<arith::MulFOp>(loc, x, y); 163 } 164 165 if (!acc) 166 return std::optional<Value>(mul); 167 168 return makeArithReduction(rewriter, loc, kind, mul, acc, 169 /*fastmath=*/nullptr, mask); 170 } 171 172 /// Return the positions of the reductions in the given map. 173 static SmallVector<int64_t> getReductionIndex(AffineMap map, 174 ArrayAttr iteratorTypes) { 175 SmallVector<int64_t> dimsIdx; 176 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 177 if (isReductionIterator(iteratorTypes[map.getDimPosition(i)])) 178 dimsIdx.push_back(i); 179 } 180 return dimsIdx; 181 } 182 183 /// Look for a given dimension in an affine map and return its position. Return 184 /// std::nullopt if the dimension is not in the map results. 185 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) { 186 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 187 if (map.getDimPosition(i) == dim) 188 return i; 189 } 190 return std::nullopt; 191 } 192 193 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 194 /// operands `x` and `y`. 195 static Value createAdd(Location loc, Value x, Value y, bool isInt, 196 PatternRewriter &rewriter) { 197 if (isInt) 198 return rewriter.create<arith::AddIOp>(loc, x, y); 199 return rewriter.create<arith::AddFOp>(loc, x, y); 200 } 201 202 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 203 /// operands `x and `y`. 204 static Value createMul(Location loc, Value x, Value y, bool isInt, 205 PatternRewriter &rewriter) { 206 if (isInt) 207 return rewriter.create<arith::MulIOp>(loc, x, y); 208 return rewriter.create<arith::MulFOp>(loc, x, y); 209 } 210 211 namespace { 212 213 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 214 /// semantics to: 215 /// ``` 216 /// %flattened_a = vector.shape_cast %a 217 /// %flattened_b = vector.shape_cast %b 218 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 219 /// %d = vector.shape_cast %%flattened_d 220 /// %e = add %c, %d 221 /// ``` 222 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 223 // 224 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 225 /// the vector.contract op is a row-major matrix multiply. 226 class ContractionOpToMatmulOpLowering 227 : public vector::MaskableOpRewritePattern<vector::ContractionOp> { 228 public: 229 using MaskableOpRewritePattern::MaskableOpRewritePattern; 230 231 using FilterConstraintType = 232 std::function<LogicalResult(vector::ContractionOp op)>; 233 234 static LogicalResult defaultFilter(vector::ContractionOp op) { 235 return success(); 236 } 237 238 ContractionOpToMatmulOpLowering( 239 vector::VectorTransformsOptions vectorTransformOptions, 240 MLIRContext *context, PatternBenefit benefit = 1, 241 FilterConstraintType constraint = defaultFilter) 242 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), 243 vectorTransformOptions(vectorTransformOptions), 244 filter(std::move(constraint)) {} 245 246 FailureOr<Value> 247 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, 248 PatternRewriter &rewriter) const override; 249 250 private: 251 /// Options to control the vector patterns. 252 vector::VectorTransformsOptions vectorTransformOptions; 253 FilterConstraintType filter; 254 }; 255 256 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 257 /// semantics to a reduction_size-unrolled sequence: 258 /// ``` 259 /// %at = vector.transpose %a, [1, 0] 260 /// %bRow0 = vector.extract %b[0] 261 /// %atRow0 = vector.extract %at[0] 262 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 263 /// ... 264 /// %bRowK = vector.extract %b[K] 265 /// %atRowK = vector.extract %at[K] 266 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 267 /// ``` 268 /// 269 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 270 /// the vector.contract op is a row-major matrix multiply. 271 class ContractionOpToOuterProductOpLowering 272 : public MaskableOpRewritePattern<vector::ContractionOp> { 273 public: 274 using MaskableOpRewritePattern::MaskableOpRewritePattern; 275 276 using FilterConstraintType = 277 std::function<LogicalResult(vector::ContractionOp op)>; 278 279 static LogicalResult defaultFilter(vector::ContractionOp op) { 280 return success(); 281 } 282 283 ContractionOpToOuterProductOpLowering( 284 vector::VectorTransformsOptions vectorTransformOptions, 285 MLIRContext *context, PatternBenefit benefit = 1, 286 FilterConstraintType constraint = defaultFilter) 287 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), 288 vectorTransformOptions(vectorTransformOptions), 289 filter(std::move(constraint)) {} 290 291 FailureOr<Value> 292 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, 293 PatternRewriter &rewriter) const override; 294 295 private: 296 /// Options to control the vector patterns. 297 vector::VectorTransformsOptions vectorTransformOptions; 298 FilterConstraintType filter; 299 }; 300 301 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 302 /// semantics to an output-size-unrolled sequence: 303 /// ``` 304 /// %out = arith.constant ... : vector<MxNxelt_type> 305 /// %bt = vector.transpose %b, [1, 0] 306 /// %aRow0 = vector.extract %a[0] 307 /// %btRow0 = vector.extract %bt[0] 308 /// %c00 = vector.reduce %atRow0, %bRow0 309 /// %out00 = vector.insert %c00, %out[0, 0] 310 /// ... 311 /// %aRowLast = vector.extract %at[M-1] 312 /// %btRowLast = vector.extract %b[N-1] 313 /// %cLastLast = vector.reduce %atRowLast, %bRowLast 314 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] 315 /// ``` 316 /// 317 /// This only kicks in when VectorTransformsOptions is set to Dot and 318 /// the vector.contract op is a row-major matmul or matvec. 319 class ContractionOpToDotLowering 320 : public MaskableOpRewritePattern<vector::ContractionOp> { 321 public: 322 using MaskableOpRewritePattern::MaskableOpRewritePattern; 323 324 using FilterConstraintType = 325 std::function<LogicalResult(vector::ContractionOp op)>; 326 327 static LogicalResult defaultFilter(vector::ContractionOp op) { 328 return success(); 329 } 330 331 ContractionOpToDotLowering( 332 vector::VectorTransformsOptions vectorTransformOptions, 333 MLIRContext *context, PatternBenefit benefit = 1, 334 const FilterConstraintType &constraint = defaultFilter) 335 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), 336 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 337 338 FailureOr<Value> 339 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, 340 PatternRewriter &rewriter) const override; 341 342 private: 343 /// Options to control the vector patterns. 344 vector::VectorTransformsOptions vectorTransformOptions; 345 FilterConstraintType filter; 346 }; 347 348 /// Progressive lowering of ContractionOp. 349 /// 350 /// One: 351 /// %x = vector.contract with at least one free/batch dimension 352 /// is replaced by: 353 /// %a = vector.contract with one less free/batch dimension 354 /// %b = vector.contract with one less free/batch dimension 355 /// .. 356 /// %x = combine %a %b .. 357 /// until a pure contraction is reached (no free/batch dimensions), 358 /// which is replaced by a dot-product. 359 /// 360 /// This only kicks in when either VectorTransformsOptions is set 361 /// to Dot or when other contraction patterns fail. 362 class ContractionOpLowering 363 : public MaskableOpRewritePattern<vector::ContractionOp> { 364 public: 365 using MaskableOpRewritePattern::MaskableOpRewritePattern; 366 using FilterConstraintType = 367 std::function<LogicalResult(vector::ContractionOp op)>; 368 369 static LogicalResult defaultFilter(vector::ContractionOp op) { 370 return success(); 371 } 372 373 ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 374 MLIRContext *context, PatternBenefit benefit = 1, 375 FilterConstraintType constraint = defaultFilter) 376 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), 377 vectorTransformOptions(vectorTransformOptions), 378 filter(std::move(constraint)) {} 379 380 FailureOr<Value> 381 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, 382 PatternRewriter &rewriter) const override; 383 384 private: 385 /// Options to control the vector patterns. 386 vector::VectorTransformsOptions vectorTransformOptions; 387 FilterConstraintType filter; 388 // Lower one parallel dimension. 389 FailureOr<Value> lowerParallel(PatternRewriter &rewriter, 390 vector::ContractionOp op, int64_t lhsIndex, 391 int64_t rhsIndex, Value mask) const; 392 // Lower one reduction dimension. 393 FailureOr<Value> lowerReduction(PatternRewriter &rewriter, 394 vector::ContractionOp op, Value mask) const; 395 }; 396 397 /// Generate a vector implementation for matmat, matvec and tmatvec. 398 /// This unrolls outer-products along the reduction dimension. 399 struct UnrolledOuterProductGenerator 400 : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> { 401 UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) 402 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op), 403 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), 404 res(op.getAcc()), lhsType(op.getLhsType()) { 405 auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 406 if (maskableOp.isMasked()) 407 mask = maskableOp.getMaskingOp().getMask(); 408 } 409 410 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) { 411 if (!v) 412 return v; 413 return rewriter.create<vector::TransposeOp>(loc, v, perm); 414 } 415 416 Value promote(Value v, Type dstElementType) { 417 Type elementType = v.getType(); 418 auto vecType = dyn_cast<VectorType>(elementType); 419 if (vecType) 420 elementType = vecType.getElementType(); 421 if (elementType == dstElementType) 422 return v; 423 Type promotedType = dstElementType; 424 if (vecType) 425 promotedType = vecType.clone(promotedType); 426 if (isa<FloatType>(dstElementType)) 427 return rewriter.create<arith::ExtFOp>(loc, promotedType, v); 428 return rewriter.create<arith::ExtSIOp>(loc, promotedType, v); 429 } 430 431 FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, 432 VectorType lhsType, int reductionSize, 433 std::optional<Value> maybeMask = std::nullopt) { 434 // Incremental support for masking. 435 if (mask && !maybeMask.has_value()) 436 return failure(); 437 438 Type resElementType = cast<VectorType>(res.getType()).getElementType(); 439 for (int64_t k = 0; k < reductionSize; ++k) { 440 Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k); 441 Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k); 442 extractA = promote(extractA, resElementType); 443 extractB = promote(extractB, resElementType); 444 Value extractMask; 445 if (maybeMask.has_value() && maybeMask.value()) 446 extractMask = 447 rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k); 448 449 Operation *outerProdOp = rewriter.create<vector::OuterProductOp>( 450 loc, res.getType(), extractA, extractB, res, kind); 451 res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); 452 } 453 return res; 454 } 455 456 /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of 457 /// dimension `reductionDim`. If the dimension is a scalable dimension, 458 /// returns "nullopt". 459 std::optional<int64_t> getReductionSize(VectorType vecType, 460 int64_t reductionDim) { 461 // Cannot unroll scalable dimension. 462 if (vecType.getScalableDims()[reductionDim]) 463 return std::nullopt; 464 int64_t reductionSize = vecType.getDimSize(reductionDim); 465 assert(reductionSize > 0 && 466 "Reduction dim must be a known static size to allow unrolling"); 467 return reductionSize; 468 } 469 470 /// Two outer parallel, one inner reduction (matmat flavor). 471 FailureOr<Value> matmat() { 472 if (!iters({Par(), Par(), Red()})) 473 return failure(); 474 // Set up the parallel/reduction structure in the right form. 475 AffineExpr m, n, k; 476 bindDims(rewriter.getContext(), m, n, k); 477 478 // Classical row-major matmul: Just permute the lhs. 479 if (layout({{m, k}, {k, n}, {m, n}})) { 480 if (auto reductionSize = getReductionSize(lhsType, 1)) { 481 // Note: `t` creates new IR. It must be nested within this `if` check 482 // so that no IR is created when then pattern returns "failure". 483 Value tLhs = t(lhs); 484 Value tMask = t(mask, {2, 0, 1}); 485 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask); 486 } 487 } 488 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 489 if (layout({{m, k}, {n, k}, {m, n}})) { 490 if (auto reductionSize = getReductionSize(lhsType, 1)) { 491 Value tLhs = t(lhs); 492 Value tRhs = t(rhs); 493 Value tMask = t(mask, {2, 0, 1}); 494 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask); 495 } 496 } 497 // No need to permute anything. 498 if (layout({{k, m}, {k, n}, {m, n}})) { 499 if (auto reductionSize = getReductionSize(lhsType, 0)) { 500 Value tMask = t(mask, {2, 0, 1}); 501 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask); 502 } 503 } 504 // Just permute the rhs. 505 if (layout({{k, m}, {n, k}, {m, n}})) { 506 if (auto reductionSize = getReductionSize(lhsType, 0)) { 507 Value tRhs = t(rhs); 508 Value tMask = t(mask, {2, 0, 1}); 509 return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask); 510 } 511 } 512 // Transposed output: swap RHS and LHS. 513 // Classical row-major matmul: permute the lhs. 514 if (layout({{m, k}, {k, n}, {n, m}})) { 515 if (auto reductionSize = getReductionSize(lhsType, 1)) { 516 Value tLhs = t(lhs); 517 Value tMask = t(mask, {2, 0, 1}); 518 return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask); 519 } 520 } 521 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 522 if (layout({{m, k}, {n, k}, {n, m}})) { 523 if (auto reductionSize = getReductionSize(lhsType, 1)) { 524 Value tRhs = t(rhs); 525 Value tLhs = t(lhs); 526 Value tMask = t(mask, {2, 0, 1}); 527 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask); 528 } 529 } 530 if (layout({{k, m}, {k, n}, {n, m}})) { 531 if (auto reductionSize = getReductionSize(lhsType, 0)) { 532 Value tMask = t(mask, {2, 0, 1}); 533 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask); 534 } 535 } 536 if (layout({{k, m}, {n, k}, {n, m}})) { 537 if (auto reductionSize = getReductionSize(lhsType, 0)) { 538 Value tRhs = t(rhs); 539 Value tMask = t(mask, {2, 0, 1}); 540 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask); 541 } 542 } 543 return failure(); 544 } 545 546 // 547 // One outer parallel, one inner reduction (matvec flavor). 548 // Mask needs to be transposed everywhere to turn the reduction dimension 549 // outermost as required by outerproduct. 550 // 551 FailureOr<Value> matvec() { 552 if (!iters({Par(), Red()})) 553 return failure(); 554 AffineExpr m, k; 555 bindDims(rewriter.getContext(), m, k); 556 557 // Case mat-vec: transpose. 558 if (layout({{m, k}, {k}, {m}})) { 559 if (auto reductionSize = getReductionSize(lhsType, 1)) { 560 Value tLhs = t(lhs); 561 Value tMask = t(mask); 562 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask); 563 } 564 } 565 // Case mat-trans-vec: ready to go. 566 if (layout({{k, m}, {k}, {m}})) { 567 if (auto reductionSize = getReductionSize(lhsType, 0)) { 568 Value tMask = t(mask); 569 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask); 570 } 571 } 572 // Case vec-mat: swap and transpose. 573 if (layout({{k}, {m, k}, {m}})) { 574 if (auto reductionSize = getReductionSize(lhsType, 0)) { 575 Value tRhs = t(rhs); 576 Value tMask = t(mask); 577 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask); 578 } 579 } 580 // Case vec-mat-trans: swap and ready to go. 581 if (layout({{k}, {k, m}, {m}})) { 582 if (auto reductionSize = getReductionSize(lhsType, 0)) { 583 Value tMask = t(mask); 584 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask); 585 } 586 } 587 return failure(); 588 } 589 590 // 591 // One outer reduction, one inner parallel (tmatvec flavor). 592 // Mask already has the shape of the outer product. 593 // 594 FailureOr<Value> tmatvec() { 595 if (!iters({Red(), Par()})) 596 return failure(); 597 AffineExpr k, m; 598 bindDims(rewriter.getContext(), k, m); 599 600 // Case mat-vec: transpose. 601 if (layout({{m, k}, {k}, {m}})) 602 if (auto reductionSize = getReductionSize(lhsType, 1)) 603 return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask); 604 // Case mat-trans-vec: ready to go. 605 if (layout({{k, m}, {k}, {m}})) 606 if (auto reductionSize = getReductionSize(lhsType, 0)) 607 return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask); 608 // Case vec-mat: swap and transpose. 609 if (layout({{k}, {m, k}, {m}})) 610 if (auto reductionSize = getReductionSize(lhsType, 0)) 611 return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask); 612 // Case vec-mat-trans: swap and ready to go. 613 if (layout({{k}, {k, m}, {m}})) 614 if (auto reductionSize = getReductionSize(lhsType, 0)) 615 return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask); 616 return failure(); 617 } 618 619 private: 620 vector::CombiningKind kind; 621 Value lhs, rhs, res, mask; 622 VectorType lhsType; 623 }; 624 625 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 626 /// semantics to a reduction_size-unrolled sequence: 627 /// ``` 628 /// %at = vector.transpose %a, [1, 0] 629 /// %bRow0 = vector.extract %b[0] 630 /// %atRow0 = vector.extract %at[0] 631 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 632 /// ... 633 /// %bRowK = vector.extract %b[K] 634 /// %atRowK = vector.extract %at[K] 635 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 636 /// ``` 637 /// 638 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 639 /// otherwise supports any layout permutation of the matrix-multiply. 640 FailureOr<Value> 641 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp( 642 vector::ContractionOp op, MaskingOpInterface maskOp, 643 PatternRewriter &rewriter) const { 644 if (vectorTransformOptions.vectorContractLowering != 645 vector::VectorContractLowering::OuterProduct) 646 return failure(); 647 648 if (failed(filter(op))) 649 return failure(); 650 651 UnrolledOuterProductGenerator e(rewriter, op); 652 FailureOr<Value> matmatRes = e.matmat(); 653 if (succeeded(matmatRes)) { 654 return matmatRes; 655 } 656 FailureOr<Value> matvecRes = e.matvec(); 657 if (succeeded(matvecRes)) { 658 return matvecRes; 659 } 660 661 FailureOr<Value> tmatvecRes = e.tmatvec(); 662 return tmatvecRes; 663 } 664 665 FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp( 666 vector::ContractionOp op, MaskingOpInterface maskOp, 667 PatternRewriter &rewriter) const { 668 // TODO: Support vector.mask. 669 if (maskOp) 670 return failure(); 671 672 if (failed(filter(op))) 673 return failure(); 674 675 if (vectorTransformOptions.vectorContractLowering != 676 vector::VectorContractLowering::Dot) 677 return failure(); 678 679 auto iteratorTypes = op.getIteratorTypes().getValue(); 680 static constexpr std::array<int64_t, 2> perm = {1, 0}; 681 Location loc = op.getLoc(); 682 Value lhs = op.getLhs(), rhs = op.getRhs(); 683 684 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 685 auto infer = [&](MapList m) { 686 return AffineMap::inferFromExprList(m, op.getContext()); 687 }; 688 AffineExpr m, n, k; 689 bindDims(rewriter.getContext(), m, n, k); 690 SmallVector<AffineMap> maps = op.getIndexingMapsArray(); 691 // 692 // In the following we wish to make the reduction dimension innermost so we 693 // can load vectors and just fmul + reduce into a scalar. 694 // 695 if (isParallelIterator(iteratorTypes[0]) && 696 isParallelIterator(iteratorTypes[1]) && 697 isReductionIterator(iteratorTypes[2])) { 698 // 699 // Two outer parallel, one inner reduction (matmat flavor). 700 // 701 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 702 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 703 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 704 // No need to permute anything. 705 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 706 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 707 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 708 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 709 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 710 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 711 // This is the classical row-major matmul. Just permute the lhs. 712 Value tmp = lhs; 713 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 714 rhs = tmp; 715 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 716 std::swap(lhs, rhs); 717 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 718 Value tmp = lhs; 719 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 720 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 721 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 722 Value tmp = rhs; 723 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 724 lhs = tmp; 725 } else { 726 return failure(); 727 } 728 } else if (isParallelIterator(iteratorTypes[0]) && 729 isReductionIterator(iteratorTypes[1])) { 730 // 731 // One outer parallel, one inner reduction (matvec flavor) 732 // 733 if (maps == infer({{m, n}, {n}, {m}})) { 734 // No need to permute anything. 735 } else if (maps == infer({{n, m}, {n}, {m}})) { 736 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 737 } else if (maps == infer({{n}, {m, n}, {m}})) { 738 std::swap(lhs, rhs); 739 } else if (maps == infer({{n}, {n, m}, {m}})) { 740 std::swap(lhs, rhs); 741 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 742 } else { 743 return failure(); 744 } 745 } else { 746 return failure(); 747 } 748 749 VectorType dstType = cast<VectorType>(op.getResultType()); 750 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 751 "Expected dst type of rank 1 or 2"); 752 753 unsigned rank = dstType.getRank(); 754 unsigned dstRows = dstType.getShape()[0]; 755 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 756 757 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 758 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 759 rewriter.getZeroAttr(dstType)); 760 bool isInt = isa<IntegerType>(dstType.getElementType()); 761 for (unsigned r = 0; r < dstRows; ++r) { 762 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 763 for (unsigned c = 0; c < dstColumns; ++c) { 764 Value b = rank == 1 765 ? rhs 766 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 767 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 768 Value reduced = rewriter.create<vector::ReductionOp>( 769 op.getLoc(), vector::CombiningKind::ADD, m); 770 771 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 772 : SmallVector<int64_t, 2>{r, c}; 773 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 774 } 775 } 776 if (auto acc = op.getAcc()) 777 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 778 return res; 779 } 780 781 /// Lower vector.contract with all size one reduction dimensions to 782 /// elementwise ops when possible. 783 struct ContractOpToElementwise 784 : public MaskableOpRewritePattern<vector::ContractionOp> { 785 using MaskableOpRewritePattern::MaskableOpRewritePattern; 786 using FilterConstraintType = 787 std::function<LogicalResult(vector::ContractionOp op)>; 788 static LogicalResult defaultFilter(vector::ContractionOp op) { 789 return success(); 790 } 791 ContractOpToElementwise( 792 vector::VectorTransformsOptions vectorTransformOptions, 793 MLIRContext *context, PatternBenefit benefit = 1, 794 const FilterConstraintType &constraint = defaultFilter) 795 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), 796 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 797 798 FailureOr<Value> 799 matchAndRewriteMaskableOp(vector::ContractionOp contractOp, 800 MaskingOpInterface maskOp, 801 PatternRewriter &rewriter) const override { 802 // TODO: Support vector.mask. 803 if (maskOp) 804 return failure(); 805 806 if (failed(filter(contractOp))) 807 return failure(); 808 809 if (vectorTransformOptions.vectorContractLowering != 810 vector::VectorContractLowering::ParallelArith) 811 return failure(); 812 813 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape(); 814 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape(); 815 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0]; 816 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1]; 817 SmallVector<int64_t> lhsReductionDims = 818 getReductionIndex(lhsMap, contractOp.getIteratorTypes()); 819 SmallVector<int64_t> rhsReductionDims = 820 getReductionIndex(rhsMap, contractOp.getIteratorTypes()); 821 // All the reduction dimensions must be a size 1. 822 for (int64_t dim : lhsReductionDims) { 823 if (lhsShape[dim] != 1) 824 return failure(); 825 } 826 for (int64_t dim : rhsReductionDims) { 827 if (rhsShape[dim] != 1) 828 return failure(); 829 } 830 AffineMap accMap = contractOp.getIndexingMapsArray()[2]; 831 unsigned numParallelDims = accMap.getNumResults(); 832 unsigned numLhsDimToBroadcast = 833 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size()); 834 unsigned numRhsDimToBroadcast = 835 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size()); 836 SmallVector<int64_t> lhsDims; 837 SmallVector<int64_t> lhsTranspose; 838 SmallVector<int64_t> rhsDims; 839 SmallVector<int64_t> rhsTranspose; 840 for (int64_t dim : lhsReductionDims) 841 lhsTranspose.push_back(numLhsDimToBroadcast + dim); 842 for (int64_t dim : rhsReductionDims) 843 rhsTranspose.push_back(numRhsDimToBroadcast + dim); 844 // Loop through the parallel dimensions to calculate the dimensions to 845 // broadcast and to permute in order to extract only parallel dimensions. 846 for (unsigned i = 0; i < numParallelDims; i++) { 847 std::optional<unsigned> lhsDim = 848 getDimPosition(lhsMap, accMap.getDimPosition(i)); 849 if (lhsDim) { 850 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim); 851 } else { 852 // If the parallel dimension doesn't exist we will have to broadcast it. 853 lhsDims.push_back( 854 cast<VectorType>(contractOp.getResultType()).getDimSize(i)); 855 lhsTranspose.push_back(lhsDims.size() - 1); 856 } 857 std::optional<unsigned> rhsDim = 858 getDimPosition(rhsMap, accMap.getDimPosition(i)); 859 if (rhsDim) { 860 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim); 861 } else { 862 // If the parallel dimension doesn't exist we will have to broadcast it. 863 rhsDims.push_back( 864 cast<VectorType>(contractOp.getResultType()).getDimSize(i)); 865 rhsTranspose.push_back(rhsDims.size() - 1); 866 } 867 } 868 Value newLhs = contractOp.getLhs(); 869 Value newRhs = contractOp.getRhs(); 870 Location loc = contractOp.getLoc(); 871 if (!lhsDims.empty()) { 872 lhsDims.append(lhsShape.begin(), lhsShape.end()); 873 auto expandedType = 874 VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); 875 newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs); 876 } 877 if (!rhsDims.empty()) { 878 rhsDims.append(rhsShape.begin(), rhsShape.end()); 879 auto expandedType = 880 VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); 881 newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs); 882 } 883 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); 884 newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose); 885 newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose); 886 SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0); 887 SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0); 888 newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets); 889 newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets); 890 std::optional<Value> result = 891 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), 892 contractOp.getKind(), rewriter, isInt); 893 if (result) 894 return *result; 895 896 return failure(); 897 } 898 899 private: 900 /// Options to control the vector patterns. 901 vector::VectorTransformsOptions vectorTransformOptions; 902 FilterConstraintType filter; 903 }; 904 905 /// Progressive lowering of ContractionOp. 906 /// One: 907 /// %x = vector.contract with at least one free/batch dimension 908 /// is replaced by: 909 /// %a = vector.contract with one less free/batch dimension 910 /// %b = vector.contract with one less free/batch dimension 911 /// .. 912 /// %x = combine %a %b .. 913 /// until a pure contraction is reached (no free/batch dimensions), 914 /// which is replaced by a dot-product. 915 /// 916 /// This only kicks in when either VectorTransformsOptions is set 917 /// to DOT or when other contraction patterns fail. 918 // 919 // TODO: break down into transpose/reshape/cast ops 920 // when they become available to avoid code dup 921 // TODO: investigate lowering order impact on performance 922 FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp( 923 vector::ContractionOp op, MaskingOpInterface maskOp, 924 PatternRewriter &rewriter) const { 925 if (failed(filter(op))) 926 return failure(); 927 928 // TODO: support mixed mode contract lowering. 929 if (op.getLhsType().getElementType() != 930 getElementTypeOrSelf(op.getAccType()) || 931 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 932 return failure(); 933 934 // TODO: the code below assumes the default contraction, make sure it supports 935 // other kinds before enabling this lowering. 936 if (op.getKind() != vector::CombiningKind::ADD) { 937 return rewriter.notifyMatchFailure( 938 op, "contractions other than 'add' not supported"); 939 } 940 941 // TODO: implement benefits, cost models. 942 MLIRContext *ctx = op.getContext(); 943 944 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 945 FailureOr<Value> newVal1 = 946 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter); 947 if (!failed(newVal1)) 948 return newVal1; 949 950 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 951 FailureOr<Value> newVal2 = 952 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter); 953 if (!failed(newVal2)) 954 return newVal2; 955 956 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 957 FailureOr<Value> newVal3 = 958 pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter); 959 if (!failed(newVal3)) 960 return newVal3; 961 962 ContractOpToElementwise pat4(vectorTransformOptions, ctx); 963 FailureOr<Value> newVal4 = 964 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter); 965 if (!failed(newVal4)) 966 return newVal4; 967 968 // Vector mask setup. 969 970 Value mask; 971 if (maskOp) 972 mask = maskOp.getMask(); 973 // Find first batch dimension in LHS/RHS, and lower when found. 974 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 975 if (!batchDimMap.empty()) { 976 int64_t lhsIndex = batchDimMap[0].first; 977 int64_t rhsIndex = batchDimMap[0].second; 978 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask); 979 if (failed(newOp)) 980 return failure(); 981 return newOp; 982 } 983 984 // Collect contracting dimensions. 985 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 986 op.getContractingDimMap(); 987 DenseSet<int64_t> lhsContractingDimSet; 988 DenseSet<int64_t> rhsContractingDimSet; 989 for (auto &dimPair : contractingDimMap) { 990 lhsContractingDimSet.insert(dimPair.first); 991 rhsContractingDimSet.insert(dimPair.second); 992 } 993 994 // Find first free dimension in LHS, and lower when found. 995 VectorType lhsType = op.getLhsType(); 996 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 997 if (lhsContractingDimSet.count(lhsIndex) == 0) { 998 auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask); 999 if (failed(newOp)) 1000 return failure(); 1001 return newOp; 1002 } 1003 } 1004 1005 // Find first free dimension in RHS, and lower when found. 1006 VectorType rhsType = op.getRhsType(); 1007 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1008 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1009 auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask); 1010 if (failed(newOp)) 1011 return failure(); 1012 return newOp; 1013 } 1014 } 1015 1016 // Lower the first remaining reduction dimension. 1017 if (!contractingDimMap.empty()) { 1018 auto newOp = lowerReduction(rewriter, op, mask); 1019 if (failed(newOp)) 1020 return failure(); 1021 return newOp; 1022 } 1023 1024 return failure(); 1025 } 1026 1027 // Lower one parallel dimension. 1028 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions. 1029 // TODO: consider reusing existing contract unrolling 1030 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, 1031 vector::ContractionOp op, 1032 int64_t lhsIndex, 1033 int64_t rhsIndex, 1034 Value mask) const { 1035 VectorType lhsType = op.getLhsType(); 1036 VectorType rhsType = op.getRhsType(); 1037 VectorType resType = cast<VectorType>(op.getResultType()); 1038 // Find the iterator type index and result index. 1039 SmallVector<AffineMap> iMap = op.getIndexingMapsArray(); 1040 int64_t iterIndex = -1; 1041 int64_t dimSize = -1; 1042 if (lhsIndex >= 0) { 1043 iterIndex = iMap[0].getDimPosition(lhsIndex); 1044 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex)) 1045 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1046 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex 1047 << " to map to the same dimension"; 1048 }); 1049 if (lhsType.getScalableDims()[lhsIndex]) 1050 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1051 diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex 1052 << ") is not supported yet"; 1053 }); 1054 dimSize = lhsType.getDimSize(lhsIndex); 1055 } else if (rhsIndex >= 0) { 1056 iterIndex = iMap[1].getDimPosition(rhsIndex); 1057 if (rhsType.getScalableDims()[rhsIndex]) 1058 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1059 diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex 1060 << ") is not supported yet"; 1061 }); 1062 dimSize = rhsType.getDimSize(rhsIndex); 1063 } 1064 if (iterIndex < 0) 1065 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1066 diag << "expected either lhsIndex=" << lhsIndex 1067 << " or rhsIndex=" << rhsIndex << " to be nonnegative"; 1068 }); 1069 // value_or(-1) means that we tolerate a dimension not appearing 1070 // in the result map. That can't happen for actual parallel iterators, but 1071 // the caller ContractionOpLowering::matchAndRewrite is currently calling 1072 // lowerParallel also for the case of unit-size reduction dims appearing only 1073 // on one of LHS or RHS, not both. At the moment, such cases are created by 1074 // CastAwayContractionLeadingOneDim, so we need to either support that or 1075 // modify that pattern. 1076 int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1); 1077 if (resIndex == -1 && dimSize != 1) 1078 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1079 diag << "expected the dimension for iterIndex=" << iterIndex 1080 << " to either appear in the result map, or to be a unit dimension"; 1081 }); 1082 1083 // Construct new iterator types and affine map array attribute. 1084 std::array<AffineMap, 3> lowIndexingMaps = { 1085 adjustMap(iMap[0], iterIndex, rewriter), 1086 adjustMap(iMap[1], iterIndex, rewriter), 1087 adjustMap(iMap[2], iterIndex, rewriter)}; 1088 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1089 auto lowIter = 1090 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1091 // Unroll into a series of lower dimensional vector.contract ops. 1092 Location loc = op.getLoc(); 1093 Value result = rewriter.create<arith::ConstantOp>( 1094 loc, resType, rewriter.getZeroAttr(resType)); 1095 1096 for (int64_t d = 0; d < dimSize; ++d) { 1097 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1098 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1099 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); 1100 1101 Value lowMask; 1102 if (mask) 1103 lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()), 1104 iterIndex, d, rewriter); 1105 1106 Operation *lowContract = rewriter.create<vector::ContractionOp>( 1107 loc, lhs, rhs, acc, lowAffine, lowIter); 1108 lowContract = maskOperation(rewriter, lowContract, lowMask); 1109 result = reshapeStore(loc, lowContract->getResult(0), result, resType, 1110 resIndex, d, rewriter); 1111 } 1112 return result; 1113 } 1114 1115 // Lower one reduction dimension. 1116 FailureOr<Value> ContractionOpLowering::lowerReduction( 1117 PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const { 1118 auto loc = op.getLoc(); 1119 VectorType lhsType = op.getLhsType(); 1120 VectorType rhsType = op.getRhsType(); 1121 Type resType = op.getResultType(); 1122 if (isa<VectorType>(resType)) 1123 return rewriter.notifyMatchFailure(op, 1124 "did not expect a VectorType result"); 1125 bool isInt = isa<IntegerType>(resType); 1126 // Use iterator index 0. 1127 int64_t iterIndex = 0; 1128 SmallVector<AffineMap> iMap = op.getIndexingMapsArray(); 1129 std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1130 std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1131 if (!lookupLhs.has_value()) 1132 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1133 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension"; 1134 }); 1135 if (!lookupRhs.has_value()) 1136 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1137 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension"; 1138 }); 1139 int64_t lhsIndex = *lookupLhs; 1140 int64_t rhsIndex = *lookupRhs; 1141 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1142 if (dimSize != rhsType.getDimSize(rhsIndex)) 1143 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1144 diag << "expect LHS dimension " << lhsIndex 1145 << " to have the same size as RHS dimension " << rhsIndex; 1146 }); 1147 // Base case. 1148 if (lhsType.getRank() == 1) { 1149 if (rhsType.getRank() != 1) 1150 return rewriter.notifyMatchFailure( 1151 op, "When LHS has rank 1, expected also RHS to have rank 1"); 1152 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); 1153 auto kind = vector::CombiningKind::ADD; 1154 1155 Value acc = op.getAcc(); 1156 Operation *reductionOp = 1157 acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc) 1158 : rewriter.create<vector::ReductionOp>(loc, kind, m); 1159 return maskOperation(rewriter, reductionOp, mask)->getResult(0); 1160 } 1161 // Construct new iterator types and affine map array attribute. 1162 std::array<AffineMap, 3> lowIndexingMaps = { 1163 adjustMap(iMap[0], iterIndex, rewriter), 1164 adjustMap(iMap[1], iterIndex, rewriter), 1165 adjustMap(iMap[2], iterIndex, rewriter)}; 1166 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1167 auto lowIter = 1168 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1169 // Unroll into a series of lower dimensional vector.contract ops. 1170 // By feeding the initial accumulator into the first contraction, 1171 // and the result of each contraction into the next, eventually 1172 // the sum of all reductions is computed. 1173 Value result = op.getAcc(); 1174 for (int64_t d = 0; d < dimSize; ++d) { 1175 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1176 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1177 Value newMask; 1178 if (mask) 1179 newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()), 1180 iterIndex, d, rewriter); 1181 1182 Operation *newContract = rewriter.create<vector::ContractionOp>( 1183 loc, lhs, rhs, result, lowAffine, lowIter); 1184 result = maskOperation(rewriter, newContract, newMask)->getResult(0); 1185 } 1186 return result; 1187 } 1188 1189 /// Progressive lowering of OuterProductOp. 1190 /// One: 1191 /// %x = vector.outerproduct %lhs, %rhs, %acc 1192 /// is replaced by: 1193 /// %z = zero-result 1194 /// %0 = vector.extract %lhs[0] 1195 /// %1 = vector.broadcast %0 1196 /// %2 = vector.extract %acc[0] 1197 /// %3 = vector.fma %1, %rhs, %2 1198 /// %4 = vector.insert %3, %z[0] 1199 /// .. 1200 /// %x = vector.insert %.., %..[N-1] 1201 /// 1202 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 1203 public: 1204 using OpRewritePattern::OpRewritePattern; 1205 1206 LogicalResult matchAndRewrite(vector::OuterProductOp op, 1207 PatternRewriter &rewriter) const override { 1208 VectorType resType = op.getResultVectorType(); 1209 if ((resType.getShape().size() >= 2) && resType.allDimsScalable()) 1210 return failure(); 1211 1212 auto loc = op.getLoc(); 1213 1214 VectorType lhsType = op.getOperandVectorTypeLHS(); 1215 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS()); 1216 Type eltType = resType.getElementType(); 1217 bool isInt = isa<IntegerType, IndexType>(eltType); 1218 Value acc = op.getAcc(); 1219 vector::CombiningKind kind = op.getKind(); 1220 1221 // Vector mask setup. 1222 OpBuilder::InsertionGuard guard(rewriter); 1223 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation()); 1224 Operation *rootOp; 1225 Value mask; 1226 if (maskableOp.isMasked()) { 1227 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 1228 rootOp = maskableOp.getMaskingOp(); 1229 mask = maskableOp.getMaskingOp().getMask(); 1230 } else { 1231 rootOp = op; 1232 } 1233 1234 if (!rhsType) { 1235 // Special case: AXPY operation. 1236 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs()); 1237 std::optional<Value> mult = createContractArithOp( 1238 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask); 1239 if (!mult.has_value()) 1240 return failure(); 1241 rewriter.replaceOp(rootOp, *mult); 1242 return success(); 1243 } 1244 1245 Value result = rewriter.create<arith::ConstantOp>( 1246 loc, resType, rewriter.getZeroAttr(resType)); 1247 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 1248 Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d); 1249 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 1250 Value r = nullptr; 1251 if (acc) 1252 r = rewriter.create<vector::ExtractOp>(loc, acc, d); 1253 Value extrMask; 1254 if (mask) 1255 extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d); 1256 1257 std::optional<Value> m = createContractArithOp( 1258 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); 1259 if (!m.has_value()) 1260 return failure(); 1261 result = rewriter.create<vector::InsertOp>(loc, *m, result, d); 1262 } 1263 1264 rewriter.replaceOp(rootOp, result); 1265 return success(); 1266 } 1267 }; 1268 1269 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1270 /// semantics to: 1271 /// ``` 1272 /// %mta = maybe_transpose 1273 /// %mtb = maybe_transpose 1274 /// %flattened_a = vector.shape_cast %mta 1275 /// %flattened_b = vector.shape_cast %mtb 1276 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1277 /// %mtd = vector.shape_cast %flattened_d 1278 /// %d = maybe_untranspose %mtd 1279 /// %e = add %c, %d 1280 /// ``` 1281 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1282 // 1283 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1284 /// vector.transpose operations are inserted if the vector.contract op is not a 1285 /// row-major matrix multiply. 1286 /// 1287 /// Scalable vectors are not supported. 1288 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( 1289 vector::ContractionOp op, MaskingOpInterface maskOp, 1290 PatternRewriter &rew) const { 1291 // TODO: Support vector.mask. 1292 if (maskOp) 1293 return failure(); 1294 1295 if (vectorTransformOptions.vectorContractLowering != 1296 vector::VectorContractLowering::Matmul) 1297 return failure(); 1298 if (failed(filter(op))) 1299 return failure(); 1300 1301 auto iteratorTypes = op.getIteratorTypes().getValue(); 1302 if (!isParallelIterator(iteratorTypes[0]) || 1303 !isParallelIterator(iteratorTypes[1]) || 1304 !isReductionIterator(iteratorTypes[2])) 1305 return failure(); 1306 1307 Type opResType = op.getType(); 1308 VectorType vecType = dyn_cast<VectorType>(opResType); 1309 if (vecType && vecType.isScalable()) { 1310 // Note - this is sufficient to reject all cases with scalable vectors. 1311 return failure(); 1312 } 1313 1314 Type elementType = op.getLhsType().getElementType(); 1315 if (!elementType.isIntOrFloat()) 1316 return failure(); 1317 1318 Type dstElementType = vecType ? vecType.getElementType() : opResType; 1319 if (elementType != dstElementType) 1320 return failure(); 1321 1322 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1323 // Bail out if the contraction cannot be put in this form. 1324 MLIRContext *ctx = op.getContext(); 1325 Location loc = op.getLoc(); 1326 AffineExpr m, n, k; 1327 bindDims(rew.getContext(), m, n, k); 1328 // LHS must be A(m, k) or A(k, m). 1329 Value lhs = op.getLhs(); 1330 auto lhsMap = op.getIndexingMapsArray()[0]; 1331 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1332 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1333 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1334 return failure(); 1335 1336 // RHS must be B(k, n) or B(n, k). 1337 Value rhs = op.getRhs(); 1338 auto rhsMap = op.getIndexingMapsArray()[1]; 1339 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1340 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1341 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1342 return failure(); 1343 1344 // At this point lhs and rhs are in row-major. 1345 VectorType lhsType = cast<VectorType>(lhs.getType()); 1346 VectorType rhsType = cast<VectorType>(rhs.getType()); 1347 int64_t lhsRows = lhsType.getDimSize(0); 1348 int64_t lhsColumns = lhsType.getDimSize(1); 1349 int64_t rhsColumns = rhsType.getDimSize(1); 1350 1351 Type flattenedLHSType = 1352 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1353 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1354 1355 Type flattenedRHSType = 1356 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1357 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1358 1359 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1360 rhsColumns); 1361 mul = rew.create<vector::ShapeCastOp>( 1362 loc, 1363 VectorType::get({lhsRows, rhsColumns}, 1364 getElementTypeOrSelf(op.getAcc().getType())), 1365 mul); 1366 1367 // ACC must be C(m, n) or C(n, m). 1368 auto accMap = op.getIndexingMapsArray()[2]; 1369 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1370 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1371 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1372 llvm_unreachable("invalid contraction semantics"); 1373 1374 Value res = 1375 isa<IntegerType>(elementType) 1376 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul)) 1377 : static_cast<Value>( 1378 rew.create<arith::AddFOp>(loc, op.getAcc(), mul)); 1379 1380 return res; 1381 } 1382 } // namespace 1383 1384 void mlir::vector::populateVectorContractLoweringPatterns( 1385 RewritePatternSet &patterns, VectorTransformsOptions options, 1386 PatternBenefit benefit, bool disableOuterProductLowering) { 1387 if (!disableOuterProductLowering) 1388 patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit); 1389 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 1390 ContractionOpToOuterProductOpLowering>( 1391 options, patterns.getContext(), benefit); 1392 } 1393 1394 void mlir::vector::populateVectorOuterProductLoweringPatterns( 1395 RewritePatternSet &patterns, PatternBenefit benefit) { 1396 patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit); 1397 } 1398