1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Arith/Utils/Utils.h" 14 #include "mlir/Dialect/Complex/IR/Complex.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/IR/AffineExpr.h" 19 #include "mlir/IR/AffineExprVisitor.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/BuiltinTypeInterfaces.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/TypeUtilities.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/SetOperations.h" 26 #include "llvm/ADT/SmallBitVector.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "llvm/Support/Casting.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include <algorithm> 31 #include <numeric> 32 #include <optional> 33 34 using namespace mlir; 35 using namespace mlir::linalg; 36 37 /// Include the definitions of the copy operation interface. 38 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" 39 40 //===----------------------------------------------------------------------===// 41 // Interface utility functions 42 //===----------------------------------------------------------------------===// 43 44 bool linalg::detail::canOpOperandsBeDroppedImpl( 45 linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) { 46 SmallVector<AffineMap> indexingMaps; 47 for (auto &opOperand : linalgOp->getOpOperands()) { 48 if (llvm::is_contained(droppedOperands, &opOperand)) 49 continue; 50 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand)); 51 } 52 if (indexingMaps.empty()) { 53 // If there are no indexing maps, the operand can only be dropped 54 // if the op has no loops. 55 return linalgOp.getNumLoops() == 0; 56 } 57 return inversePermutation(concatAffineMaps( 58 indexingMaps, linalgOp.getContext())) != AffineMap(); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // CopyOpInterface implementation 63 //===----------------------------------------------------------------------===// 64 65 bool linalg::isaCopyOpInterface(LinalgOp op) { 66 // Check all loops are parallel and linalgOp is single input and output. 67 if (!op.isAllParallelLoops() || !op.isSingleInputOutput()) 68 return false; 69 70 auto mapRange = op.getIndexingMapsArray(); 71 if (mapRange.size() != 2 || !mapRange.front().isIdentity() || 72 !mapRange.back().isIdentity()) { 73 return false; 74 } 75 // Region. 76 return llvm::hasSingleElement(op.getBlock()->getOperations()); 77 } 78 79 //===----------------------------------------------------------------------===// 80 // FillOpInterface implementation 81 //===----------------------------------------------------------------------===// 82 std::optional<Value> linalg::isaFillOpInterface(GenericOp op) { 83 // Structural. 84 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || 85 !op.isSingleYieldOp()) 86 return std::nullopt; 87 88 // Input should be referenced and init should not. 89 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) || 90 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0))) 91 return std::nullopt; 92 93 OpOperand *value = op.getDpsInputOperand(0); 94 if (!op.isScalar(value)) 95 return std::nullopt; 96 return value->get(); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // BroadcastOpInterface implementation 101 //===----------------------------------------------------------------------===// 102 std::optional<SmallVector<int64_t>> 103 linalg::isaBroadcastOpInterface(GenericOp op) { 104 // Structural. 105 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || 106 !op.isSingleYieldOp()) 107 return std::nullopt; 108 109 auto srcTy = op.getDpsInputOperand(0)->get().getType(); 110 auto dstTy = op.getDpsInitOperand(0)->get().getType(); 111 if (!isa<MemRefType, RankedTensorType>(srcTy) || 112 !isa<MemRefType, RankedTensorType>(dstTy)) 113 return std::nullopt; 114 115 // Check output is identity map. Broadcast could additionally be 116 // employing permutation of indices and that would be expressible 117 // in linalg.generic but is not expressible for named broadcast op. 118 auto dstMap = op.getIndexingMapsArray()[1]; 119 if (!dstMap.isIdentity()) 120 return std::nullopt; 121 122 SmallVector<int64_t> position; 123 auto srcMap = op.getIndexingMapsArray()[0]; 124 125 if (srcMap.getResults().size() >= dstMap.getResults().size()) 126 return std::nullopt; 127 128 // Check input map is monotonically increasing DimIds. 129 for (unsigned i = 0; i < srcMap.getNumResults(); ++i) { 130 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]); 131 if (!expr) 132 return std::nullopt; 133 int64_t pos = expr.getPosition(); 134 if (i > 0 && pos <= position[i - 1]) 135 return std::nullopt; 136 position.push_back(expr.getPosition()); 137 } 138 139 SmallVector<int64_t> broadcastedDims; 140 auto numDims = srcMap.getNumDims(); 141 // This is quadratic but number of items is generally small. 142 for (auto dim : llvm::seq<int64_t>(0, numDims)) { 143 if (!llvm::is_contained(position, dim)) 144 broadcastedDims.push_back(dim); 145 } 146 return broadcastedDims; 147 } 148 149 //===----------------------------------------------------------------------===// 150 // TransposeOpInterface implementation 151 //===----------------------------------------------------------------------===// 152 std::optional<SmallVector<int64_t>> 153 linalg::isaTransposeOpInterface(GenericOp op) { 154 // To specialize as a transpose op, the genericOp must be 155 // all parallel loops, single input, single output, and its body 156 // should be just a yield op, yielding input as output as is (no compute). 157 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || 158 !op.isSingleYieldOp()) 159 return std::nullopt; 160 161 auto mapRange = op.getIndexingMapsArray(); 162 if (mapRange.size() != 2) 163 return std::nullopt; 164 165 auto mapOfInput = mapRange.front(); 166 auto mapOfResult = mapRange.back(); 167 168 // linalg.transpose permutes the dimensions of input using this 169 // rule: dim(result, i) = dim(input, permutation[i]) 170 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation()) 171 return std::nullopt; 172 173 SmallVector<int64_t> permutation(mapOfInput.getNumDims()); 174 for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) { 175 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]); 176 permutation[expr.getPosition()] = i; 177 } 178 return permutation; 179 } 180 181 //===----------------------------------------------------------------------===// 182 // Elementwise Single Unary/Binary-OpInterface implementation 183 //===----------------------------------------------------------------------===// 184 static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, 185 unsigned arity) { 186 // Check all loops are parallel. 187 if (!op.isAllParallelLoops() || op.getNumLoops() < 1) 188 return false; 189 190 // Check there are arity-inputs, 1-output and all are identity-maps. 191 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 || 192 !llvm::all_of(op.getIndexingMapsArray(), 193 [](AffineMap map) { return map.isIdentity(); })) 194 return false; 195 196 // Init should not be referenced for elementwise operations. 197 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0))) 198 return false; 199 200 // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such 201 // as resulting from producer-consumer fusion. Here, we restrict to two ops in 202 // the body, where the first is the elementwise single op and the second a 203 // yield. 204 Block *body = op.getBody(); 205 if (body->getOperations().size() != 2) 206 return false; 207 208 Operation *oper = &body->front(); 209 if (oper->getNumOperands() != arity || oper->getNumResults() != 1) 210 return false; 211 212 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back()); 213 if (!yieldOp || yieldOp.getNumOperands() != 1 || 214 yieldOp->getOperand(0).getDefiningOp() != oper) 215 return false; 216 return true; 217 } 218 219 bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) { 220 // All basic elemwise checks. 221 if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1)) 222 return false; 223 224 // Check input is actully used. 225 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0))) 226 return false; 227 return true; 228 } 229 230 bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) { 231 if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2)) 232 return false; 233 234 // Check both inputs are used (elementwise). 235 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0); 236 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1); 237 if (!op.payloadUsesValueFromOperand(inputOpOperand0) || 238 !op.payloadUsesValueFromOperand(inputOpOperand1)) 239 return false; 240 return true; 241 } 242 243 //===----------------------------------------------------------------------===// 244 // ContractionOpInterface implementation 245 //===----------------------------------------------------------------------===// 246 247 /// If the value is defined by a chain of unary side effect-free, go up the 248 /// use-def chain until the first value that isn't defined by such an op. 249 // TODO: relax to multi-operands with constants, which are technically unary ops 250 // as needed (e.g. add5). 251 static Value getSourceSkipUnary(Value value) { 252 Operation *op = value.getDefiningOp(); 253 while (op && op->getNumOperands() == 1) { 254 auto iface = dyn_cast<MemoryEffectOpInterface>(op); 255 if (!iface || !iface.hasNoEffect()) 256 break; 257 value = op->getOperand(0); 258 op = value.getDefiningOp(); 259 } 260 return value; 261 } 262 263 bool mlir::linalg::detail::isContractionBody( 264 Block &block, function_ref<bool(Operation *, Operation *)> isaPair, 265 llvm::raw_ostream &errs) { 266 if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) { 267 errs << "no terminator in the block"; 268 return false; 269 } 270 271 if (block.getNumArguments() != 3) { 272 errs << "expected block with 3 arguments"; 273 return false; 274 } 275 276 Operation *terminator = block.getTerminator(); 277 if (terminator->getNumOperands() != 1) { 278 errs << "expected terminator with 1 operand"; 279 return false; 280 } 281 282 Value yielded = getSourceSkipUnary(terminator->getOperand(0)); 283 Operation *reductionOp = yielded.getDefiningOp(); 284 if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { 285 errs << "expected reduction op to be binary"; 286 return false; 287 } 288 289 Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0)); 290 Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1)); 291 292 if (reductionLHS != block.getArgument(2) && 293 reductionRHS != block.getArgument(2)) { 294 errs << "expected reduction to take block argument #2 as one of the " 295 "operands (modulo unary casts)"; 296 return false; 297 } 298 299 Value contributed = getSourceSkipUnary( 300 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS); 301 Operation *elementwiseOp = contributed.getDefiningOp(); 302 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 || 303 elementwiseOp->getNumOperands() != 2) { 304 errs << "expected elementwise op to be binary"; 305 return false; 306 } 307 308 if (!isaPair(elementwiseOp, reductionOp)) { 309 errs << "expected reduction/elementwise op kind not satisfied"; 310 return false; 311 } 312 313 Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0)); 314 Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1)); 315 if ((elementwiseLHS == block.getArgument(0) && 316 elementwiseRHS == block.getArgument(1)) || 317 (elementwiseLHS == block.getArgument(1) && 318 elementwiseRHS == block.getArgument(0))) { 319 return true; 320 } 321 322 errs << "expected elementwise op to apply to block arguments (modulo unary " 323 "casts)"; 324 return false; 325 } 326 327 /// Returns true if the two operations are of the kinds specified by a pair of 328 /// consecutive template arguments. 329 template <typename AddOpTy, typename MulOpTy, typename... Args> 330 static bool isPairTemplateImpl(Operation *add, Operation *mul) { 331 static_assert(sizeof...(Args) % 2 == 0, 332 "expected an even number of template arguments"); 333 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul)) 334 return true; 335 336 if constexpr (sizeof...(Args) > 0) 337 return isPairTemplateImpl<Args...>(add, mul); 338 else 339 return false; 340 } 341 342 /// Returns true if the block is a body of a contraction with the kinds of 343 /// operations given pairwise by template arguments. 344 template <typename... Args> 345 static bool isContractionBody(Block &block) { 346 return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>); 347 } 348 349 /// Given an `indexingMap` and its corresponding `iterators`, returns 350 /// the positions of the iterators of type `iter` that are indexed by 351 /// the `indexingMap` as a permutation. This is useful to infer various 352 /// subcomputations on a `LinalgOp`. This is performed by looking up 353 /// each result in the `indexingMap` and determining whether: 354 /// - It is a single AffineDimExpr. 355 /// - It is the only result involving this AffineDimExpr. 356 static llvm::SmallDenseSet<int64_t> 357 findPermutationsIndexingOperand(AffineMap indexingMap, 358 ArrayRef<utils::IteratorType> iterators, 359 utils::IteratorType iter) { 360 assert(iterators.size() == indexingMap.getNumDims()); 361 llvm::SmallDenseSet<int64_t> res; 362 for (AffineExpr e : indexingMap.getResults()) { 363 if (auto d = dyn_cast<AffineDimExpr>(e)) { 364 if (iterators[d.getPosition()] == iter && 365 llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { 366 return e.isFunctionOfDim(d.getPosition()); 367 }) == 1) 368 res.insert(d.getPosition()); 369 } 370 } 371 return res; 372 } 373 374 namespace { 375 auto par = utils::IteratorType::parallel; 376 auto red = utils::IteratorType::reduction; 377 } // namespace 378 379 /// Infer the iterator types from the init affine map. This looks at which dims 380 /// are present in the map results, and returns an iterator types array with 381 /// parallel types for dims that are present, and reduction types for dims that 382 /// are not present. 383 static FailureOr<SmallVector<utils::IteratorType>> 384 inferIteratorsFromOutMap(AffineMap map) { 385 if (!map.isProjectedPermutation()) 386 return failure(); 387 SmallVector<utils::IteratorType> iterators(map.getNumDims(), red); 388 for (auto expr : map.getResults()) 389 if (auto dim = dyn_cast<AffineDimExpr>(expr)) 390 iterators[dim.getPosition()] = par; 391 return iterators; 392 } 393 394 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form 395 /// a matmul subcomputation within `linalgOp`. These dimensions are such that: 396 /// 1. The m dimension is involved in an outer-product along LHS 397 /// (i.e. it is a permutation on RES and LHS and does not appear in RHS). 398 /// 2. The n dimension is involved in an outer-product along RHS 399 /// (i.e. it is a permutation on RES and RHS and does not appear in LHS). 400 /// 3. The k dimension appears as a permutation on LHS and RHS. 401 /// 4. m, n and k appear only once in any given indexing. 402 /// 5. Optional batch dimensions that appear in all operands are captured. 403 /// This allows e.g. detecting that some contraction is embedded within 404 /// `linalgOp` with some orthogonal heuristic. 405 static FailureOr<ContractionDimensions> 406 inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps, 407 ArrayRef<utils::IteratorType> iterators) { 408 llvm::SmallDenseSet<int64_t> a = 409 findPermutationsIndexingOperand(indexingMaps[0], iterators, par); 410 llvm::SmallDenseSet<int64_t> b = 411 findPermutationsIndexingOperand(indexingMaps[1], iterators, par); 412 llvm::SmallDenseSet<int64_t> c = 413 findPermutationsIndexingOperand(indexingMaps[2], iterators, par); 414 415 // A & C - B are the iterators involved in an outer-product along A (the LHS). 416 llvm::SmallDenseSet<int64_t> ac = a; 417 llvm::set_intersect(ac, c); 418 llvm::set_subtract(ac, b); 419 // B & C - A are the iterators involved in an outer-product along B (the RHS). 420 llvm::SmallDenseSet<int64_t> bc = b; 421 llvm::set_intersect(bc, c); 422 llvm::set_subtract(bc, a); 423 // A & B & C are the "batch" dimensions. 424 llvm::SmallDenseSet<int64_t> batches = a; 425 llvm::set_intersect(batches, b); 426 llvm::set_intersect(batches, c); 427 428 // A & B red are the reduction dimensions. 429 llvm::SmallDenseSet<int64_t> ra = 430 findPermutationsIndexingOperand(indexingMaps[0], iterators, red); 431 llvm::SmallDenseSet<int64_t> rb = 432 findPermutationsIndexingOperand(indexingMaps[1], iterators, red); 433 llvm::set_intersect(ra, rb); 434 435 // Return each set in sorted order. 436 ContractionDimensions dimensions{ 437 SmallVector<unsigned, 2>(batches.begin(), batches.end()), 438 SmallVector<unsigned, 2>(ac.begin(), ac.end()), 439 SmallVector<unsigned, 2>(bc.begin(), bc.end()), 440 SmallVector<unsigned, 2>(ra.begin(), ra.end())}; 441 llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); 442 llvm::sort(dimensions.m.begin(), dimensions.m.end()); 443 llvm::sort(dimensions.n.begin(), dimensions.n.end()); 444 llvm::sort(dimensions.k.begin(), dimensions.k.end()); 445 return dimensions; 446 } 447 448 FailureOr<ContractionDimensions> 449 mlir::linalg::inferContractionDims(LinalgOp linalgOp) { 450 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) 451 return failure(); 452 return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(), 453 linalgOp.getIteratorTypesArray()); 454 } 455 456 FailureOr<ContractionDimensions> 457 mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) { 458 if (indexingMaps.size() != 3) 459 return failure(); 460 auto iterators = inferIteratorsFromOutMap(indexingMaps[2]); 461 if (failed(iterators)) 462 return failure(); 463 return inferContractionDimsImpl(indexingMaps, iterators.value()); 464 } 465 466 namespace mlir::linalg::detail { 467 enum class MatchContractionResult { 468 Success = 0, 469 NotLinalgOp, 470 WrongNumOperands, 471 NoReduction, 472 NotProjectedPermutations, 473 NotAddMul 474 }; 475 } // namespace mlir::linalg::detail 476 477 mlir::linalg::detail::MatchContractionResult 478 mlir::linalg::detail::isContractionInterfaceImpl( 479 Operation *op, mlir::linalg::ContractionDimensions *dimensions) { 480 auto linalgOp = dyn_cast<linalg::LinalgOp>(op); 481 if (!linalgOp) 482 return MatchContractionResult::NotLinalgOp; 483 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) 484 return MatchContractionResult::WrongNumOperands; 485 auto mapRange = linalgOp.getIndexingMapsArray(); 486 if (linalgOp.getNumReductionLoops() == 0) 487 return MatchContractionResult::NoReduction; 488 if (llvm::any_of(mapRange, 489 [](AffineMap m) { return !m.isProjectedPermutation(); })) 490 return MatchContractionResult::NotProjectedPermutations; 491 // TODO: more fields than add/mul. 492 // clang-format off 493 if (!::isContractionBody< 494 arith::MulFOp, arith::AddFOp, 495 arith::MulIOp, arith::AddIOp, 496 complex::MulOp, complex::AddOp, 497 arith::AndIOp, arith::OrIOp>( 498 *linalgOp.getBlock())) { 499 return MatchContractionResult::NotAddMul; 500 } 501 // clang-format on 502 503 if (dimensions) { 504 FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp); 505 assert(succeeded(res) && "unexpected failure to infer contraction dims"); 506 *dimensions = *res; 507 } 508 return MatchContractionResult::Success; 509 } 510 511 StringRef 512 mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) { 513 switch (res) { 514 case MatchContractionResult::NotLinalgOp: 515 return "expected a LinalgOp"; 516 case MatchContractionResult::WrongNumOperands: 517 return "expected op with 2 inputs and 1 output"; 518 case MatchContractionResult::NoReduction: 519 return "expected at least 1 reduction"; 520 case MatchContractionResult::NotProjectedPermutations: 521 return "expected indexing maps to be projected permutations"; 522 case MatchContractionResult::NotAddMul: 523 return "expected add/mul op in the body"; 524 case MatchContractionResult::Success: 525 return ""; 526 } 527 llvm_unreachable("unhandled MatchContractionResult case"); 528 } 529 530 bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) { 531 if (!linalgOp) 532 return false; 533 Operation *op = linalgOp.getOperation(); 534 return isa<ContractionOpInterface>(op) || 535 (mlir::linalg::detail::isContractionInterfaceImpl(op) == 536 mlir::linalg::detail::MatchContractionResult::Success); 537 } 538 539 /// Verify that a LinalgOp `op` is a contraction. 540 /// A Linalg contraction is defined in general terms: 541 /// 1. Has 2 input and 1 output shapes. 542 /// 2. Has at least one reduction dimension. 543 /// 3. Has only projected permutation indexing maps. 544 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field 545 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary 546 /// operations that may change the type (e.g. for mixed-precision). 547 /// As a consequence, when vectorization of such an op occurs, the only special 548 /// behavior is that the (unique) MulOpType is vectorized into a 549 /// `vector.contract`. All other ops are handled in a generic fashion. 550 /// In the future, we may wish to allow more input arguments and elementwise and 551 /// constant operations that do not involve the reduction dimension(s). 552 LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) { 553 auto res = isContractionInterfaceImpl(op); 554 if (res != MatchContractionResult::Success) 555 return op->emitError(getMatchContractionMessage(res)); 556 return success(); 557 } 558 559 //===----------------------------------------------------------------------===// 560 // ConvolutionOpInterface implementation 561 //===----------------------------------------------------------------------===// 562 563 /// Of the given two expressions returns one that is of type T (`lhs` gets 564 /// preference over `rhs`) 565 template <typename T> 566 static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) { 567 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr); 568 } 569 570 namespace { 571 /// Walk the indexing expressions for input of a convolution operation to verify 572 /// its of the right form, either 573 /// - AffineDimExpr 574 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))? 575 /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)* 576 /// 577 /// classifies the AffineDimExpr as convolved dimensions or unconvolved 578 /// dimensions and verifies each dimension occurs only once. 579 struct ConvAccessExprWalker 580 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> { 581 // Stores dimensions used in expressions of the above form. 582 llvm::SmallDenseSet<int64_t> convolvedDims; 583 // Stores the dual mapping between LHS and RHS of convolution exprs. 584 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping; 585 // Stores single use dimensions used by an AffineDimExpr. 586 llvm::SmallDenseSet<int64_t> unConvolvedDims; 587 // Stores a mapping from convolved dims to their coefficient. 588 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping; 589 590 // Removes dims with multiple uses in the source input map from dimension 591 // sets tracked by this walker. 592 void clearMultiUseDims(AffineMap map) { 593 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) { 594 if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) { 595 return e.isFunctionOfDim(dimPos); 596 }) > 1) { 597 convolvedDims.erase(dimPos); 598 unConvolvedDims.erase(dimPos); 599 // If a duplicate dim is marked as convolved, the pair of the duplicate 600 // dim must be removed from the map as well. 601 auto it = convolvedDimMapping.find(dimPos); 602 if (it != convolvedDimMapping.end()) { 603 int64_t pairedDim = it->second; 604 convolvedDims.erase(pairedDim); 605 unConvolvedDims.erase(pairedDim); 606 strideAndDilationMapping.erase(pairedDim); 607 convolvedDimMapping.erase(dimPos); 608 convolvedDimMapping.erase(pairedDim); 609 } 610 } 611 } 612 } 613 614 LogicalResult visitDimExpr(AffineDimExpr dimExpr) { 615 unsigned position = dimExpr.getPosition(); 616 if (unConvolvedDims.count(position) || convolvedDims.count(position)) { 617 return failure(); 618 } 619 unConvolvedDims.insert(position); 620 return success(); 621 } 622 623 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); } 624 625 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); } 626 627 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) { 628 // In pre-order visit, top level op has to be an add op. 629 if (binaryExpr.getKind() != AffineExprKind::Add) 630 return failure(); 631 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS()); 632 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS()); 633 if (failed(lhsDimPos) || failed(rhsDimPos)) 634 return failure(); 635 convolvedDimMapping[*lhsDimPos] = *rhsDimPos; 636 convolvedDimMapping[*rhsDimPos] = *lhsDimPos; 637 return success(); 638 } 639 640 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) { 641 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { 642 int64_t dim = dimExpr.getPosition(); 643 if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) 644 return failure(); 645 // Stride/dilation for this dim is implicitly 1. 646 strideAndDilationMapping[dim] = 647 getAffineConstantExpr(1, expr.getContext()); 648 convolvedDims.insert(dim); 649 return dim; 650 } 651 if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) { 652 if (symbolMulExpr.getKind() != AffineExprKind::Mul) 653 return failure(); 654 auto lhsExpr = symbolMulExpr.getLHS(); 655 auto rhsExpr = symbolMulExpr.getRHS(); 656 // Check for symbol expression. 657 AffineExpr mulExpr = 658 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr); 659 // If there was no symbol expr, check for constant expression. 660 if (!mulExpr) { 661 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr); 662 } 663 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr); 664 if (!mulExpr || !dimExpr) 665 return failure(); 666 int64_t dim = dimExpr.getPosition(); 667 if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) 668 return failure(); 669 strideAndDilationMapping[dim] = mulExpr; 670 convolvedDims.insert(dim); 671 return dim; 672 } 673 return failure(); 674 } 675 }; 676 } // namespace 677 678 static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) { 679 assert(map.isProjectedPermutation() && 680 "expected map to have projected permutations"); 681 llvm::SmallDenseSet<int64_t> preservedDims; 682 for (auto expr : map.getResults()) 683 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition()); 684 return preservedDims; 685 } 686 687 static SmallVector<int64_t, 2> 688 getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) { 689 SmallVector<int64_t, 2> vals; 690 for (auto e : exprs) { 691 auto constantExpr = dyn_cast<AffineConstantExpr>(e); 692 assert(constantExpr && "Found non-constant stride/dilation"); 693 vals.push_back(constantExpr.getValue()); 694 } 695 return vals; 696 } 697 698 /// Classifies dimensions in the `linalgOp` used by a convolution 699 /// subcomputation, as captured by `inputExprWalker`. If 700 /// `allowEmptyConvolvedDims` is not set this this will fail if there is not 701 /// at least convolved dimension pair (output image + filter loop). Convolution 702 /// dimensions are specified in sorted order, and strides match the order of 703 /// the filter loop dimensions, while the dilations match the order of the 704 /// output image dimensions. 705 static FailureOr<ConvolutionDimensions> 706 inferConvolutionDimsImpl(LinalgOp linalgOp, 707 ConvAccessExprWalker &inputExprWalker, 708 bool allowEmptyConvolvedDims) { 709 auto filterMap = 710 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1)); 711 auto outputMap = 712 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)); 713 llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand( 714 filterMap, linalgOp.getIteratorTypesArray(), par); 715 llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand( 716 outputMap, linalgOp.getIteratorTypesArray(), par); 717 718 // unConvolvedDims & outputDims - filterDims are the batch iterators. 719 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims; 720 llvm::set_intersect(batch, outputDims); 721 llvm::set_subtract(batch, filterDims); 722 723 // convolvedDims & outputDims are the output image iterators. 724 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims; 725 llvm::set_intersect(oi, outputDims); 726 727 // filterDims & outputDims - unConvolvedDims are the output channel iterators. 728 llvm::SmallDenseSet<int64_t> oc = filterDims; 729 llvm::set_intersect(oc, outputDims); 730 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims); 731 732 // filterDims & outputDims & unConvolvedDims are the depth iterators. 733 llvm::SmallDenseSet<int64_t> depth = filterDims; 734 llvm::set_intersect(depth, outputDims); 735 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims); 736 737 llvm::SmallDenseSet<int64_t> filterReducedDims = 738 findPermutationsIndexingOperand(filterMap, 739 linalgOp.getIteratorTypesArray(), red); 740 741 // convolvedDims & filterReducedDims are the filter loop iterators. 742 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims; 743 llvm::set_intersect(fl, filterReducedDims); 744 745 // unConvolvedDims & filterReducedDims are the input channel iterators. 746 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims; 747 llvm::set_intersect(ic, filterReducedDims); 748 749 if (oi.empty() && !allowEmptyConvolvedDims) 750 return failure(); 751 752 // Return each set in sorted order. 753 ConvolutionDimensions dimensions{ 754 SmallVector<unsigned, 2>(batch.begin(), batch.end()), 755 SmallVector<unsigned, 2>(oi.begin(), oi.end()), 756 SmallVector<unsigned, 2>(oc.begin(), oc.end()), 757 SmallVector<unsigned, 2>(fl.begin(), fl.end()), 758 SmallVector<unsigned, 2>(ic.begin(), ic.end()), 759 SmallVector<unsigned, 2>(depth.begin(), depth.end()), 760 /*strides=*/SmallVector<int64_t, 2>{}, 761 /*dilations=*/SmallVector<int64_t, 2>{}}; 762 llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); 763 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end()); 764 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end()); 765 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end()); 766 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end()); 767 llvm::sort(dimensions.depth.begin(), dimensions.depth.end()); 768 769 // Use the op carried strides/dilations attribute if present. 770 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides"); 771 if (!nativeStrides) { 772 SmallVector<AffineExpr, 2> strideExprs; 773 for (unsigned oiDim : dimensions.outputImage) 774 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]); 775 dimensions.strides = getConstantsFromExprList(strideExprs); 776 } else { 777 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>()); 778 } 779 auto nativeDilations = 780 linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations"); 781 if (!nativeDilations) { 782 SmallVector<AffineExpr, 2> dilationExprs; 783 for (unsigned flDim : dimensions.filterLoop) 784 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]); 785 dimensions.dilations = getConstantsFromExprList(dilationExprs); 786 } else { 787 dimensions.dilations = 788 llvm::to_vector<2>(nativeDilations.getValues<int64_t>()); 789 } 790 return dimensions; 791 } 792 793 /// Find at least 1 parallel (output_image) and reduction (filter_loop) 794 /// dimension candidates that form a convolution subcomputation within 795 /// `linalgOp`. The LHS is assumed to be the convolution input while the 796 /// RHS is assumed as the filter. 797 /// These dimensions are such that: 798 /// 1. Optional batch dimensions that appear in the input and filter. 799 /// 2. The output_image dimension is involved in a cross-correlation along LHS 800 /// (i.e. it is a permutation on RES and LHS and has an associated 801 /// filter_loop in RHS). 802 /// 3. Optional output_channel dimension is involved in an outer-product along 803 /// RHS (i.e. it is a permutation on RES and RHS and does not appear in 804 /// LHS). 805 /// 4. Optional input_channel dimension appears as a permutation on LHS and 806 /// RHS. 807 /// 5. The filter_loop dimension appears as a permutation on the RHS and 808 /// represents the shape of the kernel cross-correlated along a 809 /// corresponding output_image dim. 810 /// 6. The input_channel dimension appears as a permutation on LHS and RHS. 811 /// 7. All dimensions appear only once in any given indexing map. 812 /// This allows e.g. detecting that some convolution is embedded within 813 /// `linalgOp` with some orthogonal heuristic. 814 /// When multiple dimension occurrences exist that match any classification 815 /// indices are returned in sorted order. 816 /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty. 817 FailureOr<ConvolutionDimensions> 818 mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) { 819 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) 820 return failure(); 821 822 auto indexingMaps = linalgOp.getIndexingMapsArray(); 823 824 // Check the input indexing map has the right form. 825 ConvAccessExprWalker inputExprWalker; 826 for (AffineExpr expr : indexingMaps[0].getResults()) 827 (void)inputExprWalker.visit(expr); 828 inputExprWalker.clearMultiUseDims(indexingMaps[0]); 829 830 return inferConvolutionDimsImpl(linalgOp, inputExprWalker, 831 /*allowEmptyConvolvedDims=*/false); 832 } 833 834 namespace mlir::linalg::detail { 835 enum class MatchConvolutionResult { 836 Success = 0, 837 NotLinalgOp, 838 WrongNumOperands, 839 WrongInputIndexingMap, 840 NotProjectedPermutations, 841 NonConvolutionLoop, 842 OutputDimsNotParallel, 843 NonOutputDimNotReduction, 844 EmptyConvolvedDims 845 }; 846 } // namespace mlir::linalg::detail 847 848 mlir::linalg::detail::MatchConvolutionResult 849 mlir::linalg::detail::isConvolutionInterfaceImpl( 850 Operation *op, ConvolutionDimensions *dimensions, 851 bool allowEmptyConvolvedDims) { 852 auto linalgOp = dyn_cast<linalg::LinalgOp>(op); 853 if (!linalgOp) 854 return MatchConvolutionResult::NotLinalgOp; 855 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1) 856 return MatchConvolutionResult::WrongNumOperands; 857 858 auto indexingMaps = linalgOp.getIndexingMapsArray(); 859 860 // Check the input indexing map has the right form. 861 ConvAccessExprWalker inputExprWalker; 862 if (llvm::any_of(indexingMaps[0].getResults(), 863 [&inputExprWalker](AffineExpr expr) { 864 return failed(inputExprWalker.visit(expr)); 865 })) { 866 return MatchConvolutionResult::WrongInputIndexingMap; 867 } 868 869 // Filter and output maps must be projected permutation. 870 if (!indexingMaps[1].isProjectedPermutation() || 871 !indexingMaps.back().isProjectedPermutation()) 872 return MatchConvolutionResult::NotProjectedPermutations; 873 874 auto iteratorTypes = linalgOp.getIteratorTypesArray(); 875 876 llvm::SmallDenseSet<int64_t> outputDims = 877 getPreservedDims(indexingMaps.back()); 878 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]); 879 // Make sure all loops are characterized as one of: 880 // - Batch loop : present in output, as non-convolved in input, not present in 881 // filter. 882 // - Output image dimension : present in output, convolved dims in input, not 883 // present in filter. 884 // - Output channel dimension : present in output, not present in input, 885 // present in filter. 886 // - Filter loop dimension : present in filter, convolved in input, not 887 // present in output. 888 // - Input channel dimension : unconvolved in input, not present in output, 889 // present in filter. 890 // - Depth multiplier : unconvolved in input, present in output, present in 891 // filter. 892 llvm::SmallDenseSet<int64_t> allLoopDims; 893 for (auto outputExpr : indexingMaps.back().getResults()) { 894 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition(); 895 if (inputExprWalker.unConvolvedDims.count(outputDim) && 896 !filterDims.count(outputDim)) { 897 // Batch dimension. 898 if (iteratorTypes[outputDim] != utils::IteratorType::parallel) 899 return MatchConvolutionResult::OutputDimsNotParallel; 900 allLoopDims.insert(outputDim); 901 continue; 902 } 903 if (inputExprWalker.convolvedDims.count(outputDim) && 904 !filterDims.count(outputDim)) { 905 // Output image Loop dimension. 906 if (iteratorTypes[outputDim] != utils::IteratorType::parallel) 907 return MatchConvolutionResult::OutputDimsNotParallel; 908 allLoopDims.insert(outputDim); 909 continue; 910 } 911 if (!inputExprWalker.convolvedDims.count(outputDim) && 912 !inputExprWalker.unConvolvedDims.count(outputDim) && 913 filterDims.count(outputDim)) { 914 // Output channel dimension. 915 if (iteratorTypes[outputDim] != utils::IteratorType::parallel) 916 return MatchConvolutionResult::OutputDimsNotParallel; 917 allLoopDims.insert(outputDim); 918 continue; 919 } 920 if (inputExprWalker.unConvolvedDims.count(outputDim) && 921 filterDims.count(outputDim)) { 922 // Depth multiplier. 923 if (iteratorTypes[outputDim] != utils::IteratorType::parallel) 924 return MatchConvolutionResult::OutputDimsNotParallel; 925 allLoopDims.insert(outputDim); 926 continue; 927 } 928 return MatchConvolutionResult::NonConvolutionLoop; 929 } 930 for (auto filterExpr : indexingMaps[1].getResults()) { 931 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition(); 932 if (outputDims.count(filterDim) && 933 !inputExprWalker.unConvolvedDims.count(filterDim) && 934 !inputExprWalker.convolvedDims.count(filterDim)) { 935 // Output channel dimension. This is already seen, continue; 936 continue; 937 } 938 if (inputExprWalker.convolvedDims.count(filterDim) && 939 !outputDims.count(filterDim)) { 940 // Filter loop dimension. 941 if (iteratorTypes[filterDim] != utils::IteratorType::reduction) 942 return MatchConvolutionResult::NonOutputDimNotReduction; 943 if (allLoopDims.count(filterDim)) 944 return MatchConvolutionResult::NonConvolutionLoop; 945 allLoopDims.insert(filterDim); 946 continue; 947 } 948 if (inputExprWalker.unConvolvedDims.count(filterDim) && 949 !outputDims.count(filterDim)) { 950 // Input channel dimension. 951 if (iteratorTypes[filterDim] != utils::IteratorType::reduction) 952 return MatchConvolutionResult::NonOutputDimNotReduction; 953 if (allLoopDims.count(filterDim)) 954 return MatchConvolutionResult::NonConvolutionLoop; 955 allLoopDims.insert(filterDim); 956 continue; 957 } 958 if (inputExprWalker.unConvolvedDims.count(filterDim) && 959 outputDims.count(filterDim)) { 960 // Depthwise loop. Already seen. 961 continue; 962 } 963 return MatchConvolutionResult::NonConvolutionLoop; 964 } 965 // All loops must be covered now. 966 if (allLoopDims.size() != linalgOp.getNumLoops()) 967 return MatchConvolutionResult::NonConvolutionLoop; 968 969 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty()) 970 return MatchConvolutionResult::EmptyConvolvedDims; 971 972 if (dimensions) { 973 FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl( 974 linalgOp, inputExprWalker, allowEmptyConvolvedDims); 975 assert(succeeded(res) && "unexpected failure to infer convolution dims"); 976 *dimensions = *res; 977 } 978 979 return MatchConvolutionResult::Success; 980 } 981 982 StringRef 983 mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) { 984 switch (res) { 985 case MatchConvolutionResult::NotLinalgOp: 986 return "expected a LinalgOp"; 987 case MatchConvolutionResult::WrongNumOperands: 988 return "expected op with 2 inputs and 1 output"; 989 case MatchConvolutionResult::WrongInputIndexingMap: 990 return "unexpected input index map for convolutions"; 991 case MatchConvolutionResult::NotProjectedPermutations: 992 return "expected output/filter indexing maps to be projected permutations"; 993 case MatchConvolutionResult::NonConvolutionLoop: 994 return "unexpected loop dimension for convolution op"; 995 case MatchConvolutionResult::OutputDimsNotParallel: 996 return "expected all iterators used to access outputs to be parallel"; 997 case MatchConvolutionResult::NonOutputDimNotReduction: 998 return "expected all iterators not used to access outputs to be reduction"; 999 case MatchConvolutionResult::EmptyConvolvedDims: 1000 return "expected convolved dim to be non-empty"; 1001 case MatchConvolutionResult::Success: 1002 return ""; 1003 } 1004 llvm_unreachable("unhandled MatchConvolutionResult case"); 1005 } 1006 1007 bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp, 1008 bool allowEmptyConvolvedDims) { 1009 return linalg::detail::isConvolutionInterfaceImpl( 1010 linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) == 1011 linalg::detail::MatchConvolutionResult::Success; 1012 } 1013 1014 LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { 1015 MatchConvolutionResult res = isConvolutionInterfaceImpl(op); 1016 if (res != MatchConvolutionResult::Success) 1017 return op->emitError(getMatchConvolutionMessage(res)); 1018 return success(); 1019 } 1020 1021 //===----------------------------------------------------------------------===// 1022 // FillOpInterface implementation 1023 //===----------------------------------------------------------------------===// 1024 1025 enum class MatchFillResult { 1026 Success = 0, 1027 NotLinalgOp, 1028 WrongNumOperands, 1029 NotScalarInput 1030 }; 1031 1032 static MatchFillResult isFillInterfaceImpl(Operation *op) { 1033 auto linalgOp = dyn_cast<linalg::LinalgOp>(op); 1034 if (!linalgOp) 1035 return MatchFillResult::NotLinalgOp; 1036 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) 1037 return MatchFillResult::WrongNumOperands; 1038 1039 OpOperand *value = linalgOp.getDpsInputOperand(0); 1040 if (!linalgOp.isScalar(value)) 1041 return MatchFillResult::NotScalarInput; 1042 1043 return MatchFillResult::Success; 1044 } 1045 1046 LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { 1047 auto res = isFillInterfaceImpl(op); 1048 if (res == MatchFillResult::NotLinalgOp) 1049 return op->emitError("expected a LinalgOp"); 1050 if (res == MatchFillResult::WrongNumOperands) 1051 return op->emitError("expected op with 1 input and 1 output"); 1052 if (res == MatchFillResult::NotScalarInput) 1053 return op->emitError("expected op with scalar input"); 1054 1055 return success(); 1056 } 1057 1058 //===----------------------------------------------------------------------===// 1059 // StructuredOpInterface implementation 1060 //===----------------------------------------------------------------------===// 1061 1062 SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b, 1063 Location loc) { 1064 SmallVector<OpFoldResult> res; 1065 for (OpOperand &opOperand : getOperation()->getOpOperands()) { 1066 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i) 1067 res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i)); 1068 } 1069 return res; 1070 } 1071 1072 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() { 1073 SmallVector<int64_t, 4> res; 1074 assert(!hasDynamicShape() && "expected operands to have static shapes"); 1075 for (OpOperand &opOperand : getOperation()->getOpOperands()) 1076 llvm::append_range(res, getShape(&opOperand)); 1077 return res; 1078 } 1079 1080 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { 1081 AffineMap map = getLoopsToShapesMap(); 1082 unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); 1083 auto viewSizes = createFlatListOfOperandDims(b, loc); 1084 SmallVector<Range, 4> res(numDims); 1085 for (unsigned idx = 0; idx < numRes; ++idx) { 1086 auto result = map.getResult(idx); 1087 if (auto d = dyn_cast<AffineDimExpr>(result)) { 1088 if (res[d.getPosition()].offset) 1089 continue; 1090 res[d.getPosition()] = 1091 Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)}; 1092 } 1093 } 1094 return res; 1095 } 1096 1097 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() { 1098 AffineMap map = getLoopsToShapesMap(); 1099 unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); 1100 SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims(); 1101 SmallVector<int64_t, 4> res(numDims, 0); 1102 for (unsigned idx = 0; idx < numRes; ++idx) { 1103 auto result = map.getResult(idx); 1104 if (auto d = dyn_cast<AffineDimExpr>(result)) 1105 res[d.getPosition()] = allShapeSizes[idx]; 1106 } 1107 return res; 1108 } 1109 1110 /// Visitor to check if any of the given set of positions from AffineDimExprs 1111 /// are used within an AffineExpr. 1112 struct HasAffineDimExprVisitor 1113 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> { 1114 HasAffineDimExprVisitor(llvm::SmallBitVector positions) 1115 : positions(std::move(positions)) {} 1116 1117 bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { 1118 return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS()); 1119 } 1120 1121 bool visitDimExpr(AffineDimExpr dimExpr) { 1122 return positions.test(dimExpr.getPosition()); 1123 } 1124 1125 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } 1126 1127 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } 1128 1129 private: 1130 llvm::SmallBitVector positions; 1131 }; 1132 1133 static std::pair<int64_t, int64_t> 1134 getResultsPositionInLoopsToShapeMap(LinalgOp &op) { 1135 int64_t inputRankSum = 0; 1136 int64_t outputRankSum = 0; 1137 for (OpOperand *input : op.getDpsInputOperands()) 1138 inputRankSum += op.getRank(input); 1139 for (OpOperand &output : op.getDpsInitsMutable()) 1140 outputRankSum += op.getRank(&output); 1141 return {inputRankSum, inputRankSum + outputRankSum}; 1142 } 1143 1144 LogicalResult 1145 LinalgOp::reifyResultShapes(OpBuilder &b, 1146 ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 1147 // An example that helps understand the logic below. 1148 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) 1149 // We want to express the shape of dim 0 of O in terms of shape of the inputs. 1150 // This is achieved as follows. 1151 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) 1152 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1) 1153 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) 1154 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap) 1155 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1) 1156 AffineMap loopsToShapesMap = getLoopsToShapesMap(); 1157 1158 // Find the position in the above map that represents the shape of the 1159 // result:dim being inferred. 1160 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this); 1161 1162 /// From loopsToShapesMap extract the submap that represents the shape of the 1163 /// (resultIdx, dim) needed. 1164 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap( 1165 resultShapesSubMapPos.first, 1166 resultShapesSubMapPos.second - resultShapesSubMapPos.first); 1167 AffineMap resultShapesFromInputShapesMap = 1168 loopToResultsShapeMap.compose(getShapesToLoopsMap()); 1169 1170 // Check that the result dim map does not contain the positions corresponding 1171 // to the outputs. 1172 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims()); 1173 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second); 1174 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims)); 1175 Location loc = getOperation()->getLoc(); 1176 IRRewriter rewriter(b); 1177 SmallVector<OpFoldResult> allResultDimValues = 1178 affine::makeComposedFoldedMultiResultAffineApply( 1179 rewriter, loc, resultShapesFromInputShapesMap, 1180 createFlatListOfOperandDims(b, loc)); 1181 int64_t pos = 0; 1182 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults(); 1183 for (OpOperand &opOperand : getDpsInitsMutable()) { 1184 SmallVector<OpFoldResult> shapes; 1185 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) { 1186 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType()); 1187 if (!shapedType.isDynamicDim(dim)) { 1188 // Static dim: Return IntegerAttr. 1189 shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim))); 1190 } else { 1191 // Dynamic dim: Return Value. 1192 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos]) 1193 ? createOrFoldDimOp(b, loc, opOperand.get(), dim) 1194 : allResultDimValues[pos]; 1195 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr)); 1196 } 1197 pos++; 1198 } 1199 reifiedReturnShapes.emplace_back(std::move(shapes)); 1200 } 1201 return success(); 1202 } 1203 1204 /// Return the index in the indexingMaps vector that corresponds to this 1205 /// `opOperand`. 1206 int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) { 1207 auto operandNumber = opOperand->getOperandNumber(); 1208 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation()); 1209 if (!dpsIface.isDpsInput(opOperand)) 1210 return operandNumber; 1211 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex(); 1212 assert(!dpsIface.isDpsInit(opOperand)); 1213 // Account for potential inputs that are not DPS and may not appear in 1214 // `indexingMaps`. 1215 return cast<DestinationStyleOpInterface>(*this->getOperation()) 1216 .getNumDpsInputs() + 1217 operandNumber - start; 1218 } 1219 1220 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { 1221 LinalgOp linalgOp = cast<LinalgOp>(op); 1222 // Mixed tensor/buffer operands are not allowed. 1223 if (!linalgOp.hasPureTensorSemantics() && 1224 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0) 1225 return op->emitOpError("expected to have pure tensor or buffer semantics"); 1226 1227 // Before checking indexing maps, we need to make sure the attributes 1228 // referenced by it are valid. 1229 if (linalgOp.hasDynamicIndexingMaps()) 1230 if (failed(linalgOp.verifyIndexingMapRequiredAttributes())) 1231 return failure(); 1232 1233 // All input/output operands must be indexed. 1234 if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) != 1235 linalgOp->getNumOperands()) 1236 return op->emitOpError("expected the number of indexing_map (") 1237 << linalgOp.getIndexingMapsArray().size() 1238 << ") to be equal to the number of input/output operands (" 1239 << linalgOp->getNumOperands() << ")"; 1240 1241 // Set this flag if this op has user defined maps. This is required to guard 1242 // the below error condition which assume default indexing maps. 1243 for (OpOperand &opOperand : linalgOp->getOpOperands()) { 1244 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); 1245 1246 // Symbols disallowed. 1247 if (indexingMap.getNumSymbols() != 0) 1248 return op->emitOpError("unexpected symbols in indexing_map #") 1249 << opOperand.getOperandNumber(); 1250 1251 // Domain must be consistent. 1252 unsigned numLoops = linalgOp.getNumLoops(); 1253 if (indexingMap.getNumDims() != numLoops) 1254 return op->emitOpError("expected indexing_map #") 1255 << opOperand.getOperandNumber() << " to have " << numLoops 1256 << " dim(s) to match the number of loops"; 1257 1258 int64_t rank = linalgOp.getRank(&opOperand); 1259 1260 if (indexingMap.getNumResults() != rank) 1261 return op->emitOpError("expected operand rank (") 1262 << rank << ") to match the result rank of indexing_map #" 1263 << opOperand.getOperandNumber() << " (" 1264 << indexingMap.getNumResults() << ")"; 1265 } 1266 SmallVector<unsigned> redDims; 1267 linalgOp.getReductionDims(redDims); 1268 1269 if (!linalgOp.getShapesToLoopsMap()) 1270 return op->emitOpError("expected the shape-to-loops map to be non-null"); 1271 1272 // Check if given shapes match to inferred shapes. 1273 SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges(); 1274 SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0); 1275 // Verify only static cases since we can't get exact dimension sizes and 1276 // loop ranges for dynamic cases in this stage. 1277 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { 1278 for (int64_t &range : endLoopRangeValues) 1279 range -= 1; 1280 for (OpOperand &opOperand : linalgOp->getOpOperands()) { 1281 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); 1282 SmallVector<int64_t, 4> startIndices = 1283 indexingMap.compose(startLoopRangeValues); 1284 SmallVector<int64_t, 4> endIndices = 1285 indexingMap.compose(endLoopRangeValues); 1286 ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand); 1287 for (auto dim : llvm::seq<int64_t>(0, shape.size())) { 1288 // Ignore dynamic dimension or the case that the dimension size is 0 1289 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) 1290 continue; 1291 1292 // The first index or last index should be the maximum or the minimum in 1293 // the inferred index ranges since the range is increasing or 1294 // decreasing. The size of dimensions of input/output operands and the 1295 // maximum value + 1 in the inferred range should be the same. But, for 1296 // now we check if the inferred ranges are in boundary of input/output 1297 // operands' size or not in case that Affine Expressions are complicated 1298 // such as d0 * 3 1299 // + d1 since it is not easy to handle the issues. 1300 // Found the case that this solution can't check, for example, (d0, d1) 1301 // -> (d1 - d0) 1302 int64_t inferredDimSize = 1303 std::max(startIndices[dim], endIndices[dim]) + 1; 1304 if (std::min(startIndices[dim], endIndices[dim]) < 0) { 1305 std::string mapStr; 1306 { 1307 llvm::raw_string_ostream os(mapStr); 1308 os << indexingMap; 1309 } 1310 return op->emitOpError( 1311 "unexpected result less than 0 at expression #") 1312 << dim << " in " << mapStr; 1313 } 1314 if (dyn_cast<AffineDimExpr>(indexingMap.getResult(dim))) { 1315 if (inferredDimSize != shape[dim]) { 1316 return op->emitOpError("inferred input/output operand #") 1317 << opOperand.getOperandNumber() << " has shape's dimension #" 1318 << dim << " to be " << inferredDimSize << ", but found " 1319 << shape[dim]; 1320 } 1321 } else { 1322 if (inferredDimSize > shape[dim]) { 1323 return op->emitOpError("inferred input/output operand #") 1324 << opOperand.getOperandNumber() << " has shape's dimension #" 1325 << dim << " to be greater than or equal to " 1326 << inferredDimSize << ", but found " << shape[dim]; 1327 } 1328 } 1329 } 1330 } 1331 } 1332 1333 // Check the region has exactly one block. 1334 if (linalgOp->getNumRegions() != 1 || 1335 !llvm::hasSingleElement(linalgOp->getRegion(0))) 1336 return op->emitOpError("expects to have 1 region with 1 block"); 1337 1338 // Simplifying assumption: bbargs match 1-1 with shape operands elemental 1339 // types. 1340 // TODO: once ranked shape types are plugged in, we may want to drop the 1341 // corresponding bbargs, that can never be read from. This will be subject to 1342 // consistency discussions (i.e. what to do with output tensors whose bbarg is 1343 // not used). 1344 Block &block = linalgOp->getRegion(0).front(); 1345 1346 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments()) 1347 return op->emitOpError("expected as many non-induction variable region " 1348 "arguments as the number of input/output operands"); 1349 1350 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { 1351 Type elementType = opOperand->get().getType(); 1352 if (isa<MemRefType, RankedTensorType>(elementType)) 1353 elementType = getElementTypeOrSelf(opOperand->get().getType()); 1354 Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); 1355 if (elementType != argType) 1356 return op->emitOpError("expected type of bb argument #") 1357 << opOperand->getOperandNumber() << " (" << argType << ")" 1358 << " to match element or self type of the corresponding operand (" 1359 << elementType << ")"; 1360 } 1361 1362 return success(); 1363 } 1364