xref: /llvm-project/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Arith/Utils/Utils.h"
14 #include "mlir/Dialect/Complex/IR/Complex.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/AffineExprVisitor.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/BuiltinTypeInterfaces.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetOperations.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <numeric>
32 #include <optional>
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 
37 /// Include the definitions of the copy operation interface.
38 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
39 
40 //===----------------------------------------------------------------------===//
41 // Interface utility functions
42 //===----------------------------------------------------------------------===//
43 
44 bool linalg::detail::canOpOperandsBeDroppedImpl(
45     linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
46   SmallVector<AffineMap> indexingMaps;
47   for (auto &opOperand : linalgOp->getOpOperands()) {
48     if (llvm::is_contained(droppedOperands, &opOperand))
49       continue;
50     indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
51   }
52   if (indexingMaps.empty()) {
53     // If there are no indexing maps, the operand can only be dropped
54     // if the op has no loops.
55     return linalgOp.getNumLoops() == 0;
56   }
57   return inversePermutation(concatAffineMaps(
58              indexingMaps, linalgOp.getContext())) != AffineMap();
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // CopyOpInterface implementation
63 //===----------------------------------------------------------------------===//
64 
65 bool linalg::isaCopyOpInterface(LinalgOp op) {
66   // Check all loops are parallel and linalgOp is single input and output.
67   if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
68     return false;
69 
70   auto mapRange = op.getIndexingMapsArray();
71   if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
72       !mapRange.back().isIdentity()) {
73     return false;
74   }
75   // Region.
76   return llvm::hasSingleElement(op.getBlock()->getOperations());
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // FillOpInterface implementation
81 //===----------------------------------------------------------------------===//
82 std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
83   // Structural.
84   if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
85       !op.isSingleYieldOp())
86     return std::nullopt;
87 
88   // Input should be referenced and init should not.
89   if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
90       op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
91     return std::nullopt;
92 
93   OpOperand *value = op.getDpsInputOperand(0);
94   if (!op.isScalar(value))
95     return std::nullopt;
96   return value->get();
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // BroadcastOpInterface implementation
101 //===----------------------------------------------------------------------===//
102 std::optional<SmallVector<int64_t>>
103 linalg::isaBroadcastOpInterface(GenericOp op) {
104   // Structural.
105   if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
106       !op.isSingleYieldOp())
107     return std::nullopt;
108 
109   auto srcTy = op.getDpsInputOperand(0)->get().getType();
110   auto dstTy = op.getDpsInitOperand(0)->get().getType();
111   if (!isa<MemRefType, RankedTensorType>(srcTy) ||
112       !isa<MemRefType, RankedTensorType>(dstTy))
113     return std::nullopt;
114 
115   // Check output is identity map. Broadcast could additionally be
116   // employing permutation of indices and that would be expressible
117   // in linalg.generic but is not expressible for named broadcast op.
118   auto dstMap = op.getIndexingMapsArray()[1];
119   if (!dstMap.isIdentity())
120     return std::nullopt;
121 
122   SmallVector<int64_t> position;
123   auto srcMap = op.getIndexingMapsArray()[0];
124 
125   if (srcMap.getResults().size() >= dstMap.getResults().size())
126     return std::nullopt;
127 
128   // Check input map is monotonically increasing DimIds.
129   for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
130     auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
131     if (!expr)
132       return std::nullopt;
133     int64_t pos = expr.getPosition();
134     if (i > 0 && pos <= position[i - 1])
135       return std::nullopt;
136     position.push_back(expr.getPosition());
137   }
138 
139   SmallVector<int64_t> broadcastedDims;
140   auto numDims = srcMap.getNumDims();
141   // This is quadratic but number of items is generally small.
142   for (auto dim : llvm::seq<int64_t>(0, numDims)) {
143     if (!llvm::is_contained(position, dim))
144       broadcastedDims.push_back(dim);
145   }
146   return broadcastedDims;
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // TransposeOpInterface implementation
151 //===----------------------------------------------------------------------===//
152 std::optional<SmallVector<int64_t>>
153 linalg::isaTransposeOpInterface(GenericOp op) {
154   // To specialize as a transpose op, the genericOp must be
155   // all parallel loops, single input, single output, and its body
156   // should be just a yield op, yielding input as output as is (no compute).
157   if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
158       !op.isSingleYieldOp())
159     return std::nullopt;
160 
161   auto mapRange = op.getIndexingMapsArray();
162   if (mapRange.size() != 2)
163     return std::nullopt;
164 
165   auto mapOfInput = mapRange.front();
166   auto mapOfResult = mapRange.back();
167 
168   // linalg.transpose permutes the dimensions of input using this
169   // rule: dim(result, i) = dim(input, permutation[i])
170   if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
171     return std::nullopt;
172 
173   SmallVector<int64_t> permutation(mapOfInput.getNumDims());
174   for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
175     auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
176     permutation[expr.getPosition()] = i;
177   }
178   return permutation;
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // Elementwise Single Unary/Binary-OpInterface implementation
183 //===----------------------------------------------------------------------===//
184 static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
185                                                       unsigned arity) {
186   // Check all loops are parallel.
187   if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
188     return false;
189 
190   // Check there are arity-inputs, 1-output and all are identity-maps.
191   if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
192       !llvm::all_of(op.getIndexingMapsArray(),
193                     [](AffineMap map) { return map.isIdentity(); }))
194     return false;
195 
196   // Init should not be referenced for elementwise operations.
197   if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
198     return false;
199 
200   // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
201   // as resulting from producer-consumer fusion. Here, we restrict to two ops in
202   // the body, where the first is the elementwise single op and the second a
203   // yield.
204   Block *body = op.getBody();
205   if (body->getOperations().size() != 2)
206     return false;
207 
208   Operation *oper = &body->front();
209   if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
210     return false;
211 
212   auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
213   if (!yieldOp || yieldOp.getNumOperands() != 1 ||
214       yieldOp->getOperand(0).getDefiningOp() != oper)
215     return false;
216   return true;
217 }
218 
219 bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
220   // All basic elemwise checks.
221   if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1))
222     return false;
223 
224   // Check input is actully used.
225   if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
226     return false;
227   return true;
228 }
229 
230 bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
231   if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2))
232     return false;
233 
234   // Check both inputs are used (elementwise).
235   OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
236   OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
237   if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
238       !op.payloadUsesValueFromOperand(inputOpOperand1))
239     return false;
240   return true;
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // ContractionOpInterface implementation
245 //===----------------------------------------------------------------------===//
246 
247 /// If the value is defined by a chain of unary side effect-free, go up the
248 /// use-def chain until the first value that isn't defined by such an op.
249 // TODO: relax to multi-operands with constants, which are technically unary ops
250 // as needed (e.g. add5).
251 static Value getSourceSkipUnary(Value value) {
252   Operation *op = value.getDefiningOp();
253   while (op && op->getNumOperands() == 1) {
254     auto iface = dyn_cast<MemoryEffectOpInterface>(op);
255     if (!iface || !iface.hasNoEffect())
256       break;
257     value = op->getOperand(0);
258     op = value.getDefiningOp();
259   }
260   return value;
261 }
262 
263 bool mlir::linalg::detail::isContractionBody(
264     Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
265     llvm::raw_ostream &errs) {
266   if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
267     errs << "no terminator in the block";
268     return false;
269   }
270 
271   if (block.getNumArguments() != 3) {
272     errs << "expected block with 3 arguments";
273     return false;
274   }
275 
276   Operation *terminator = block.getTerminator();
277   if (terminator->getNumOperands() != 1) {
278     errs << "expected terminator with 1 operand";
279     return false;
280   }
281 
282   Value yielded = getSourceSkipUnary(terminator->getOperand(0));
283   Operation *reductionOp = yielded.getDefiningOp();
284   if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
285     errs << "expected reduction op to be binary";
286     return false;
287   }
288 
289   Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
290   Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
291 
292   if (reductionLHS != block.getArgument(2) &&
293       reductionRHS != block.getArgument(2)) {
294     errs << "expected reduction to take block argument #2 as one of the "
295             "operands (modulo unary casts)";
296     return false;
297   }
298 
299   Value contributed = getSourceSkipUnary(
300       isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
301   Operation *elementwiseOp = contributed.getDefiningOp();
302   if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
303       elementwiseOp->getNumOperands() != 2) {
304     errs << "expected elementwise op to be binary";
305     return false;
306   }
307 
308   if (!isaPair(elementwiseOp, reductionOp)) {
309     errs << "expected reduction/elementwise op kind not satisfied";
310     return false;
311   }
312 
313   Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
314   Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
315   if ((elementwiseLHS == block.getArgument(0) &&
316        elementwiseRHS == block.getArgument(1)) ||
317       (elementwiseLHS == block.getArgument(1) &&
318        elementwiseRHS == block.getArgument(0))) {
319     return true;
320   }
321 
322   errs << "expected elementwise op to apply to block arguments (modulo unary "
323           "casts)";
324   return false;
325 }
326 
327 /// Returns true if the two operations are of the kinds specified by a pair of
328 /// consecutive template arguments.
329 template <typename AddOpTy, typename MulOpTy, typename... Args>
330 static bool isPairTemplateImpl(Operation *add, Operation *mul) {
331   static_assert(sizeof...(Args) % 2 == 0,
332                 "expected an even number of template arguments");
333   if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
334     return true;
335 
336   if constexpr (sizeof...(Args) > 0)
337     return isPairTemplateImpl<Args...>(add, mul);
338   else
339     return false;
340 }
341 
342 /// Returns true if the block is a body of a contraction with the kinds of
343 /// operations given pairwise by template arguments.
344 template <typename... Args>
345 static bool isContractionBody(Block &block) {
346   return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
347 }
348 
349 /// Given an `indexingMap` and its corresponding `iterators`, returns
350 /// the positions of the iterators of type `iter` that are indexed by
351 /// the `indexingMap` as a permutation. This is useful to infer various
352 /// subcomputations on a `LinalgOp`. This is performed by looking up
353 /// each result in the `indexingMap` and determining whether:
354 ///   - It is a single AffineDimExpr.
355 ///   - It is the only result involving this AffineDimExpr.
356 static llvm::SmallDenseSet<int64_t>
357 findPermutationsIndexingOperand(AffineMap indexingMap,
358                                 ArrayRef<utils::IteratorType> iterators,
359                                 utils::IteratorType iter) {
360   assert(iterators.size() == indexingMap.getNumDims());
361   llvm::SmallDenseSet<int64_t> res;
362   for (AffineExpr e : indexingMap.getResults()) {
363     if (auto d = dyn_cast<AffineDimExpr>(e)) {
364       if (iterators[d.getPosition()] == iter &&
365           llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
366             return e.isFunctionOfDim(d.getPosition());
367           }) == 1)
368         res.insert(d.getPosition());
369     }
370   }
371   return res;
372 }
373 
374 namespace {
375 auto par = utils::IteratorType::parallel;
376 auto red = utils::IteratorType::reduction;
377 } // namespace
378 
379 /// Infer the iterator types from the init affine map. This looks at which dims
380 /// are present in the map results, and returns an iterator types array with
381 /// parallel types for dims that are present, and reduction types for dims that
382 /// are not present.
383 static FailureOr<SmallVector<utils::IteratorType>>
384 inferIteratorsFromOutMap(AffineMap map) {
385   if (!map.isProjectedPermutation())
386     return failure();
387   SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
388   for (auto expr : map.getResults())
389     if (auto dim = dyn_cast<AffineDimExpr>(expr))
390       iterators[dim.getPosition()] = par;
391   return iterators;
392 }
393 
394 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
395 /// a matmul subcomputation within `linalgOp`. These dimensions are such that:
396 ///   1. The m dimension is involved in an outer-product along LHS
397 ///      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
398 ///   2. The n dimension is involved in an outer-product along RHS
399 ///      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
400 ///   3. The k dimension appears as a permutation on LHS and RHS.
401 ///   4. m, n and k appear only once in any given indexing.
402 ///   5. Optional batch dimensions that appear in all operands are captured.
403 /// This allows e.g. detecting that some contraction is embedded within
404 /// `linalgOp` with some orthogonal heuristic.
405 static FailureOr<ContractionDimensions>
406 inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
407                          ArrayRef<utils::IteratorType> iterators) {
408   llvm::SmallDenseSet<int64_t> a =
409       findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
410   llvm::SmallDenseSet<int64_t> b =
411       findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
412   llvm::SmallDenseSet<int64_t> c =
413       findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
414 
415   // A & C - B are the iterators involved in an outer-product along A (the LHS).
416   llvm::SmallDenseSet<int64_t> ac = a;
417   llvm::set_intersect(ac, c);
418   llvm::set_subtract(ac, b);
419   // B & C - A are the iterators involved in an outer-product along B (the RHS).
420   llvm::SmallDenseSet<int64_t> bc = b;
421   llvm::set_intersect(bc, c);
422   llvm::set_subtract(bc, a);
423   // A & B & C are the "batch" dimensions.
424   llvm::SmallDenseSet<int64_t> batches = a;
425   llvm::set_intersect(batches, b);
426   llvm::set_intersect(batches, c);
427 
428   // A & B red are the reduction dimensions.
429   llvm::SmallDenseSet<int64_t> ra =
430       findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
431   llvm::SmallDenseSet<int64_t> rb =
432       findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
433   llvm::set_intersect(ra, rb);
434 
435   // Return each set in sorted order.
436   ContractionDimensions dimensions{
437       SmallVector<unsigned, 2>(batches.begin(), batches.end()),
438       SmallVector<unsigned, 2>(ac.begin(), ac.end()),
439       SmallVector<unsigned, 2>(bc.begin(), bc.end()),
440       SmallVector<unsigned, 2>(ra.begin(), ra.end())};
441   llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
442   llvm::sort(dimensions.m.begin(), dimensions.m.end());
443   llvm::sort(dimensions.n.begin(), dimensions.n.end());
444   llvm::sort(dimensions.k.begin(), dimensions.k.end());
445   return dimensions;
446 }
447 
448 FailureOr<ContractionDimensions>
449 mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
450   if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
451     return failure();
452   return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
453                                   linalgOp.getIteratorTypesArray());
454 }
455 
456 FailureOr<ContractionDimensions>
457 mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
458   if (indexingMaps.size() != 3)
459     return failure();
460   auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
461   if (failed(iterators))
462     return failure();
463   return inferContractionDimsImpl(indexingMaps, iterators.value());
464 }
465 
466 namespace mlir::linalg::detail {
467 enum class MatchContractionResult {
468   Success = 0,
469   NotLinalgOp,
470   WrongNumOperands,
471   NoReduction,
472   NotProjectedPermutations,
473   NotAddMul
474 };
475 } // namespace mlir::linalg::detail
476 
477 mlir::linalg::detail::MatchContractionResult
478 mlir::linalg::detail::isContractionInterfaceImpl(
479     Operation *op, mlir::linalg::ContractionDimensions *dimensions) {
480   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
481   if (!linalgOp)
482     return MatchContractionResult::NotLinalgOp;
483   if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
484     return MatchContractionResult::WrongNumOperands;
485   auto mapRange = linalgOp.getIndexingMapsArray();
486   if (linalgOp.getNumReductionLoops() == 0)
487     return MatchContractionResult::NoReduction;
488   if (llvm::any_of(mapRange,
489                    [](AffineMap m) { return !m.isProjectedPermutation(); }))
490     return MatchContractionResult::NotProjectedPermutations;
491   // TODO: more fields than add/mul.
492   // clang-format off
493   if (!::isContractionBody<
494         arith::MulFOp, arith::AddFOp,
495         arith::MulIOp, arith::AddIOp,
496         complex::MulOp, complex::AddOp,
497         arith::AndIOp, arith::OrIOp>(
498       *linalgOp.getBlock())) {
499     return MatchContractionResult::NotAddMul;
500   }
501   // clang-format on
502 
503   if (dimensions) {
504     FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
505     assert(succeeded(res) && "unexpected failure to infer contraction dims");
506     *dimensions = *res;
507   }
508   return MatchContractionResult::Success;
509 }
510 
511 StringRef
512 mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) {
513   switch (res) {
514   case MatchContractionResult::NotLinalgOp:
515     return "expected a LinalgOp";
516   case MatchContractionResult::WrongNumOperands:
517     return "expected op with 2 inputs and 1 output";
518   case MatchContractionResult::NoReduction:
519     return "expected at least 1 reduction";
520   case MatchContractionResult::NotProjectedPermutations:
521     return "expected indexing maps to be projected permutations";
522   case MatchContractionResult::NotAddMul:
523     return "expected add/mul op in the body";
524   case MatchContractionResult::Success:
525     return "";
526   }
527   llvm_unreachable("unhandled MatchContractionResult case");
528 }
529 
530 bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
531   if (!linalgOp)
532     return false;
533   Operation *op = linalgOp.getOperation();
534   return isa<ContractionOpInterface>(op) ||
535          (mlir::linalg::detail::isContractionInterfaceImpl(op) ==
536           mlir::linalg::detail::MatchContractionResult::Success);
537 }
538 
539 /// Verify that a LinalgOp `op` is a contraction.
540 /// A Linalg contraction is defined in general terms:
541 ///   1. Has 2 input and 1 output shapes.
542 ///   2. Has at least one reduction dimension.
543 ///   3. Has only projected permutation indexing maps.
544 ///   4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
545 ///   (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
546 ///   operations that may change the type (e.g. for mixed-precision).
547 /// As a consequence, when vectorization of such an op occurs, the only special
548 /// behavior is that the (unique) MulOpType is vectorized into a
549 /// `vector.contract`. All other ops are handled in a generic fashion.
550 /// In the future, we may wish to allow more input arguments and elementwise and
551 /// constant operations that do not involve the reduction dimension(s).
552 LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
553   auto res = isContractionInterfaceImpl(op);
554   if (res != MatchContractionResult::Success)
555     return op->emitError(getMatchContractionMessage(res));
556   return success();
557 }
558 
559 //===----------------------------------------------------------------------===//
560 // ConvolutionOpInterface implementation
561 //===----------------------------------------------------------------------===//
562 
563 /// Of the given two expressions returns one that is of type T (`lhs` gets
564 /// preference over `rhs`)
565 template <typename T>
566 static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) {
567   return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
568 }
569 
570 namespace {
571 /// Walk the indexing expressions for input of a convolution operation to verify
572 /// its of the right form, either
573 /// - AffineDimExpr
574 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
575 ///      (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
576 ///
577 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
578 /// dimensions and verifies each dimension occurs only once.
579 struct ConvAccessExprWalker
580     : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
581   // Stores dimensions used in expressions of the above form.
582   llvm::SmallDenseSet<int64_t> convolvedDims;
583   // Stores the dual mapping between LHS and RHS of convolution exprs.
584   llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
585   // Stores single use dimensions used by an AffineDimExpr.
586   llvm::SmallDenseSet<int64_t> unConvolvedDims;
587   // Stores a mapping from convolved dims to their coefficient.
588   llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
589 
590   // Removes dims with multiple uses in the source input map from dimension
591   // sets tracked by this walker.
592   void clearMultiUseDims(AffineMap map) {
593     for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
594       if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
595             return e.isFunctionOfDim(dimPos);
596           }) > 1) {
597         convolvedDims.erase(dimPos);
598         unConvolvedDims.erase(dimPos);
599         // If a duplicate dim is marked as convolved, the pair of the duplicate
600         // dim must be removed from the map as well.
601         auto it = convolvedDimMapping.find(dimPos);
602         if (it != convolvedDimMapping.end()) {
603           int64_t pairedDim = it->second;
604           convolvedDims.erase(pairedDim);
605           unConvolvedDims.erase(pairedDim);
606           strideAndDilationMapping.erase(pairedDim);
607           convolvedDimMapping.erase(dimPos);
608           convolvedDimMapping.erase(pairedDim);
609         }
610       }
611     }
612   }
613 
614   LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
615     unsigned position = dimExpr.getPosition();
616     if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
617       return failure();
618     }
619     unConvolvedDims.insert(position);
620     return success();
621   }
622 
623   LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
624 
625   LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
626 
627   LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
628     // In pre-order visit, top level op has to be an add op.
629     if (binaryExpr.getKind() != AffineExprKind::Add)
630       return failure();
631     auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
632     auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
633     if (failed(lhsDimPos) || failed(rhsDimPos))
634       return failure();
635     convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
636     convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
637     return success();
638   }
639 
640   FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
641     if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
642       int64_t dim = dimExpr.getPosition();
643       if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
644         return failure();
645       // Stride/dilation for this dim is implicitly 1.
646       strideAndDilationMapping[dim] =
647           getAffineConstantExpr(1, expr.getContext());
648       convolvedDims.insert(dim);
649       return dim;
650     }
651     if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
652       if (symbolMulExpr.getKind() != AffineExprKind::Mul)
653         return failure();
654       auto lhsExpr = symbolMulExpr.getLHS();
655       auto rhsExpr = symbolMulExpr.getRHS();
656       // Check for symbol expression.
657       AffineExpr mulExpr =
658           getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
659       // If there was no symbol expr, check for constant expression.
660       if (!mulExpr) {
661         mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
662       }
663       auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
664       if (!mulExpr || !dimExpr)
665         return failure();
666       int64_t dim = dimExpr.getPosition();
667       if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
668         return failure();
669       strideAndDilationMapping[dim] = mulExpr;
670       convolvedDims.insert(dim);
671       return dim;
672     }
673     return failure();
674   }
675 };
676 } // namespace
677 
678 static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
679   assert(map.isProjectedPermutation() &&
680          "expected map to have projected permutations");
681   llvm::SmallDenseSet<int64_t> preservedDims;
682   for (auto expr : map.getResults())
683     preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
684   return preservedDims;
685 }
686 
687 static SmallVector<int64_t, 2>
688 getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) {
689   SmallVector<int64_t, 2> vals;
690   for (auto e : exprs) {
691     auto constantExpr = dyn_cast<AffineConstantExpr>(e);
692     assert(constantExpr && "Found non-constant stride/dilation");
693     vals.push_back(constantExpr.getValue());
694   }
695   return vals;
696 }
697 
698 /// Classifies dimensions in the `linalgOp` used by a convolution
699 /// subcomputation, as captured by `inputExprWalker`. If
700 /// `allowEmptyConvolvedDims` is not set this this will fail if there is not
701 /// at least convolved dimension pair (output image + filter loop). Convolution
702 /// dimensions are specified in sorted order, and strides match the order of
703 /// the filter loop dimensions, while the dilations match the order of the
704 /// output image dimensions.
705 static FailureOr<ConvolutionDimensions>
706 inferConvolutionDimsImpl(LinalgOp linalgOp,
707                          ConvAccessExprWalker &inputExprWalker,
708                          bool allowEmptyConvolvedDims) {
709   auto filterMap =
710       linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
711   auto outputMap =
712       linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
713   llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
714       filterMap, linalgOp.getIteratorTypesArray(), par);
715   llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
716       outputMap, linalgOp.getIteratorTypesArray(), par);
717 
718   // unConvolvedDims & outputDims - filterDims are the batch iterators.
719   llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
720   llvm::set_intersect(batch, outputDims);
721   llvm::set_subtract(batch, filterDims);
722 
723   // convolvedDims & outputDims are the output image iterators.
724   llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
725   llvm::set_intersect(oi, outputDims);
726 
727   // filterDims & outputDims - unConvolvedDims are the output channel iterators.
728   llvm::SmallDenseSet<int64_t> oc = filterDims;
729   llvm::set_intersect(oc, outputDims);
730   llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
731 
732   // filterDims & outputDims & unConvolvedDims are the depth iterators.
733   llvm::SmallDenseSet<int64_t> depth = filterDims;
734   llvm::set_intersect(depth, outputDims);
735   llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
736 
737   llvm::SmallDenseSet<int64_t> filterReducedDims =
738       findPermutationsIndexingOperand(filterMap,
739                                       linalgOp.getIteratorTypesArray(), red);
740 
741   // convolvedDims & filterReducedDims are the filter loop iterators.
742   llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
743   llvm::set_intersect(fl, filterReducedDims);
744 
745   // unConvolvedDims & filterReducedDims are the input channel iterators.
746   llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
747   llvm::set_intersect(ic, filterReducedDims);
748 
749   if (oi.empty() && !allowEmptyConvolvedDims)
750     return failure();
751 
752   // Return each set in sorted order.
753   ConvolutionDimensions dimensions{
754       SmallVector<unsigned, 2>(batch.begin(), batch.end()),
755       SmallVector<unsigned, 2>(oi.begin(), oi.end()),
756       SmallVector<unsigned, 2>(oc.begin(), oc.end()),
757       SmallVector<unsigned, 2>(fl.begin(), fl.end()),
758       SmallVector<unsigned, 2>(ic.begin(), ic.end()),
759       SmallVector<unsigned, 2>(depth.begin(), depth.end()),
760       /*strides=*/SmallVector<int64_t, 2>{},
761       /*dilations=*/SmallVector<int64_t, 2>{}};
762   llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
763   llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
764   llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
765   llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
766   llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
767   llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
768 
769   // Use the op carried strides/dilations attribute if present.
770   auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
771   if (!nativeStrides) {
772     SmallVector<AffineExpr, 2> strideExprs;
773     for (unsigned oiDim : dimensions.outputImage)
774       strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
775     dimensions.strides = getConstantsFromExprList(strideExprs);
776   } else {
777     dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
778   }
779   auto nativeDilations =
780       linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
781   if (!nativeDilations) {
782     SmallVector<AffineExpr, 2> dilationExprs;
783     for (unsigned flDim : dimensions.filterLoop)
784       dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
785     dimensions.dilations = getConstantsFromExprList(dilationExprs);
786   } else {
787     dimensions.dilations =
788         llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
789   }
790   return dimensions;
791 }
792 
793 /// Find at least 1 parallel (output_image) and reduction (filter_loop)
794 /// dimension candidates that form a convolution subcomputation within
795 /// `linalgOp`. The LHS is assumed to be the convolution input while the
796 /// RHS is assumed as the filter.
797 /// These dimensions are such that:
798 ///   1. Optional batch dimensions that appear in the input and filter.
799 ///   2. The output_image dimension is involved in a cross-correlation along LHS
800 ///      (i.e. it is a permutation on RES and LHS and has an associated
801 ///      filter_loop in RHS).
802 ///   3. Optional output_channel dimension is involved in an outer-product along
803 ///      RHS (i.e. it is a permutation on RES and RHS and does not appear in
804 ///      LHS).
805 ///   4. Optional input_channel dimension appears as a permutation on LHS and
806 ///      RHS.
807 ///   5. The filter_loop dimension appears as a permutation on the RHS and
808 ///      represents the shape of the kernel cross-correlated along a
809 ///      corresponding output_image dim.
810 ///   6. The input_channel dimension appears as a permutation on LHS and RHS.
811 ///   7. All dimensions appear only once in any given indexing map.
812 /// This allows e.g. detecting that some convolution is embedded within
813 /// `linalgOp` with some orthogonal heuristic.
814 /// When multiple dimension occurrences exist that match any classification
815 /// indices are returned in sorted order.
816 /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
817 FailureOr<ConvolutionDimensions>
818 mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) {
819   if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
820     return failure();
821 
822   auto indexingMaps = linalgOp.getIndexingMapsArray();
823 
824   // Check the input indexing map has the right form.
825   ConvAccessExprWalker inputExprWalker;
826   for (AffineExpr expr : indexingMaps[0].getResults())
827     (void)inputExprWalker.visit(expr);
828   inputExprWalker.clearMultiUseDims(indexingMaps[0]);
829 
830   return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
831                                   /*allowEmptyConvolvedDims=*/false);
832 }
833 
834 namespace mlir::linalg::detail {
835 enum class MatchConvolutionResult {
836   Success = 0,
837   NotLinalgOp,
838   WrongNumOperands,
839   WrongInputIndexingMap,
840   NotProjectedPermutations,
841   NonConvolutionLoop,
842   OutputDimsNotParallel,
843   NonOutputDimNotReduction,
844   EmptyConvolvedDims
845 };
846 } // namespace mlir::linalg::detail
847 
848 mlir::linalg::detail::MatchConvolutionResult
849 mlir::linalg::detail::isConvolutionInterfaceImpl(
850     Operation *op, ConvolutionDimensions *dimensions,
851     bool allowEmptyConvolvedDims) {
852   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
853   if (!linalgOp)
854     return MatchConvolutionResult::NotLinalgOp;
855   if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
856     return MatchConvolutionResult::WrongNumOperands;
857 
858   auto indexingMaps = linalgOp.getIndexingMapsArray();
859 
860   // Check the input indexing map has the right form.
861   ConvAccessExprWalker inputExprWalker;
862   if (llvm::any_of(indexingMaps[0].getResults(),
863                    [&inputExprWalker](AffineExpr expr) {
864                      return failed(inputExprWalker.visit(expr));
865                    })) {
866     return MatchConvolutionResult::WrongInputIndexingMap;
867   }
868 
869   // Filter and output maps must be projected permutation.
870   if (!indexingMaps[1].isProjectedPermutation() ||
871       !indexingMaps.back().isProjectedPermutation())
872     return MatchConvolutionResult::NotProjectedPermutations;
873 
874   auto iteratorTypes = linalgOp.getIteratorTypesArray();
875 
876   llvm::SmallDenseSet<int64_t> outputDims =
877       getPreservedDims(indexingMaps.back());
878   llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
879   // Make sure all loops are characterized as one of:
880   // - Batch loop : present in output, as non-convolved in input, not present in
881   //   filter.
882   // - Output image dimension : present in output, convolved dims in input, not
883   //   present in filter.
884   // - Output channel dimension : present in output, not present in input,
885   //   present in filter.
886   // - Filter loop dimension : present in filter, convolved in input, not
887   //   present in output.
888   // - Input channel dimension : unconvolved in input, not present in output,
889   //   present in filter.
890   // - Depth multiplier : unconvolved in input, present in output, present in
891   //   filter.
892   llvm::SmallDenseSet<int64_t> allLoopDims;
893   for (auto outputExpr : indexingMaps.back().getResults()) {
894     int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
895     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
896         !filterDims.count(outputDim)) {
897       // Batch dimension.
898       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
899         return MatchConvolutionResult::OutputDimsNotParallel;
900       allLoopDims.insert(outputDim);
901       continue;
902     }
903     if (inputExprWalker.convolvedDims.count(outputDim) &&
904         !filterDims.count(outputDim)) {
905       // Output image Loop dimension.
906       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
907         return MatchConvolutionResult::OutputDimsNotParallel;
908       allLoopDims.insert(outputDim);
909       continue;
910     }
911     if (!inputExprWalker.convolvedDims.count(outputDim) &&
912         !inputExprWalker.unConvolvedDims.count(outputDim) &&
913         filterDims.count(outputDim)) {
914       // Output channel dimension.
915       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
916         return MatchConvolutionResult::OutputDimsNotParallel;
917       allLoopDims.insert(outputDim);
918       continue;
919     }
920     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
921         filterDims.count(outputDim)) {
922       // Depth multiplier.
923       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
924         return MatchConvolutionResult::OutputDimsNotParallel;
925       allLoopDims.insert(outputDim);
926       continue;
927     }
928     return MatchConvolutionResult::NonConvolutionLoop;
929   }
930   for (auto filterExpr : indexingMaps[1].getResults()) {
931     int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
932     if (outputDims.count(filterDim) &&
933         !inputExprWalker.unConvolvedDims.count(filterDim) &&
934         !inputExprWalker.convolvedDims.count(filterDim)) {
935       // Output channel dimension. This is already seen, continue;
936       continue;
937     }
938     if (inputExprWalker.convolvedDims.count(filterDim) &&
939         !outputDims.count(filterDim)) {
940       // Filter loop dimension.
941       if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
942         return MatchConvolutionResult::NonOutputDimNotReduction;
943       if (allLoopDims.count(filterDim))
944         return MatchConvolutionResult::NonConvolutionLoop;
945       allLoopDims.insert(filterDim);
946       continue;
947     }
948     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
949         !outputDims.count(filterDim)) {
950       // Input channel dimension.
951       if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
952         return MatchConvolutionResult::NonOutputDimNotReduction;
953       if (allLoopDims.count(filterDim))
954         return MatchConvolutionResult::NonConvolutionLoop;
955       allLoopDims.insert(filterDim);
956       continue;
957     }
958     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
959         outputDims.count(filterDim)) {
960       // Depthwise loop. Already seen.
961       continue;
962     }
963     return MatchConvolutionResult::NonConvolutionLoop;
964   }
965   // All loops must be covered now.
966   if (allLoopDims.size() != linalgOp.getNumLoops())
967     return MatchConvolutionResult::NonConvolutionLoop;
968 
969   if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
970     return MatchConvolutionResult::EmptyConvolvedDims;
971 
972   if (dimensions) {
973     FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
974         linalgOp, inputExprWalker, allowEmptyConvolvedDims);
975     assert(succeeded(res) && "unexpected failure to infer convolution dims");
976     *dimensions = *res;
977   }
978 
979   return MatchConvolutionResult::Success;
980 }
981 
982 StringRef
983 mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
984   switch (res) {
985   case MatchConvolutionResult::NotLinalgOp:
986     return "expected a LinalgOp";
987   case MatchConvolutionResult::WrongNumOperands:
988     return "expected op with 2 inputs and 1 output";
989   case MatchConvolutionResult::WrongInputIndexingMap:
990     return "unexpected input index map for convolutions";
991   case MatchConvolutionResult::NotProjectedPermutations:
992     return "expected output/filter indexing maps to be projected permutations";
993   case MatchConvolutionResult::NonConvolutionLoop:
994     return "unexpected loop dimension for convolution op";
995   case MatchConvolutionResult::OutputDimsNotParallel:
996     return "expected all iterators used to access outputs to be parallel";
997   case MatchConvolutionResult::NonOutputDimNotReduction:
998     return "expected all iterators not used to access outputs to be reduction";
999   case MatchConvolutionResult::EmptyConvolvedDims:
1000     return "expected convolved dim to be non-empty";
1001   case MatchConvolutionResult::Success:
1002     return "";
1003   }
1004   llvm_unreachable("unhandled MatchConvolutionResult case");
1005 }
1006 
1007 bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp,
1008                                              bool allowEmptyConvolvedDims) {
1009   return linalg::detail::isConvolutionInterfaceImpl(
1010              linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
1011          linalg::detail::MatchConvolutionResult::Success;
1012 }
1013 
1014 LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
1015   MatchConvolutionResult res = isConvolutionInterfaceImpl(op);
1016   if (res != MatchConvolutionResult::Success)
1017     return op->emitError(getMatchConvolutionMessage(res));
1018   return success();
1019 }
1020 
1021 //===----------------------------------------------------------------------===//
1022 // FillOpInterface implementation
1023 //===----------------------------------------------------------------------===//
1024 
1025 enum class MatchFillResult {
1026   Success = 0,
1027   NotLinalgOp,
1028   WrongNumOperands,
1029   NotScalarInput
1030 };
1031 
1032 static MatchFillResult isFillInterfaceImpl(Operation *op) {
1033   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1034   if (!linalgOp)
1035     return MatchFillResult::NotLinalgOp;
1036   if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1037     return MatchFillResult::WrongNumOperands;
1038 
1039   OpOperand *value = linalgOp.getDpsInputOperand(0);
1040   if (!linalgOp.isScalar(value))
1041     return MatchFillResult::NotScalarInput;
1042 
1043   return MatchFillResult::Success;
1044 }
1045 
1046 LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
1047   auto res = isFillInterfaceImpl(op);
1048   if (res == MatchFillResult::NotLinalgOp)
1049     return op->emitError("expected a LinalgOp");
1050   if (res == MatchFillResult::WrongNumOperands)
1051     return op->emitError("expected op with 1 input and 1 output");
1052   if (res == MatchFillResult::NotScalarInput)
1053     return op->emitError("expected op with scalar input");
1054 
1055   return success();
1056 }
1057 
1058 //===----------------------------------------------------------------------===//
1059 // StructuredOpInterface implementation
1060 //===----------------------------------------------------------------------===//
1061 
1062 SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1063                                                                 Location loc) {
1064   SmallVector<OpFoldResult> res;
1065   for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1066     for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1067       res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
1068   }
1069   return res;
1070 }
1071 
1072 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1073   SmallVector<int64_t, 4> res;
1074   assert(!hasDynamicShape() && "expected operands to have static shapes");
1075   for (OpOperand &opOperand : getOperation()->getOpOperands())
1076     llvm::append_range(res, getShape(&opOperand));
1077   return res;
1078 }
1079 
1080 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1081   AffineMap map = getLoopsToShapesMap();
1082   unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1083   auto viewSizes = createFlatListOfOperandDims(b, loc);
1084   SmallVector<Range, 4> res(numDims);
1085   for (unsigned idx = 0; idx < numRes; ++idx) {
1086     auto result = map.getResult(idx);
1087     if (auto d = dyn_cast<AffineDimExpr>(result)) {
1088       if (res[d.getPosition()].offset)
1089         continue;
1090       res[d.getPosition()] =
1091           Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1092     }
1093   }
1094   return res;
1095 }
1096 
1097 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
1098   AffineMap map = getLoopsToShapesMap();
1099   unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1100   SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
1101   SmallVector<int64_t, 4> res(numDims, 0);
1102   for (unsigned idx = 0; idx < numRes; ++idx) {
1103     auto result = map.getResult(idx);
1104     if (auto d = dyn_cast<AffineDimExpr>(result))
1105       res[d.getPosition()] = allShapeSizes[idx];
1106   }
1107   return res;
1108 }
1109 
1110 /// Visitor to check if any of the given set of positions from AffineDimExprs
1111 /// are used within an AffineExpr.
1112 struct HasAffineDimExprVisitor
1113     : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1114   HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1115       : positions(std::move(positions)) {}
1116 
1117   bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
1118     return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1119   }
1120 
1121   bool visitDimExpr(AffineDimExpr dimExpr) {
1122     return positions.test(dimExpr.getPosition());
1123   }
1124 
1125   bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1126 
1127   bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1128 
1129 private:
1130   llvm::SmallBitVector positions;
1131 };
1132 
1133 static std::pair<int64_t, int64_t>
1134 getResultsPositionInLoopsToShapeMap(LinalgOp &op) {
1135   int64_t inputRankSum = 0;
1136   int64_t outputRankSum = 0;
1137   for (OpOperand *input : op.getDpsInputOperands())
1138     inputRankSum += op.getRank(input);
1139   for (OpOperand &output : op.getDpsInitsMutable())
1140     outputRankSum += op.getRank(&output);
1141   return {inputRankSum, inputRankSum + outputRankSum};
1142 }
1143 
1144 LogicalResult
1145 LinalgOp::reifyResultShapes(OpBuilder &b,
1146                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1147   // An example that helps understand the logic below.
1148   // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1149   // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1150   // This is achieved as follows.
1151   //   loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1152   //   subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1153   //   shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1154   //   resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1155   //     = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1156   AffineMap loopsToShapesMap = getLoopsToShapesMap();
1157 
1158   // Find the position in the above map that represents the shape of the
1159   // result:dim being inferred.
1160   auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1161 
1162   /// From loopsToShapesMap extract the submap that represents the shape of the
1163   /// (resultIdx, dim) needed.
1164   AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1165       resultShapesSubMapPos.first,
1166       resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1167   AffineMap resultShapesFromInputShapesMap =
1168       loopToResultsShapeMap.compose(getShapesToLoopsMap());
1169 
1170   // Check that the result dim map does not contain the positions corresponding
1171   // to the outputs.
1172   llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1173   outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1174   HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1175   Location loc = getOperation()->getLoc();
1176   IRRewriter rewriter(b);
1177   SmallVector<OpFoldResult> allResultDimValues =
1178       affine::makeComposedFoldedMultiResultAffineApply(
1179           rewriter, loc, resultShapesFromInputShapesMap,
1180           createFlatListOfOperandDims(b, loc));
1181   int64_t pos = 0;
1182   ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1183   for (OpOperand &opOperand : getDpsInitsMutable()) {
1184     SmallVector<OpFoldResult> shapes;
1185     for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1186       auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1187       if (!shapedType.isDynamicDim(dim)) {
1188         // Static dim: Return IntegerAttr.
1189         shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1190       } else {
1191         // Dynamic dim: Return Value.
1192         OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1193                                ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1194                                : allResultDimValues[pos];
1195         shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1196       }
1197       pos++;
1198     }
1199     reifiedReturnShapes.emplace_back(std::move(shapes));
1200   }
1201   return success();
1202 }
1203 
1204 /// Return the index in the indexingMaps vector that corresponds to this
1205 /// `opOperand`.
1206 int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1207   auto operandNumber = opOperand->getOperandNumber();
1208   auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1209   if (!dpsIface.isDpsInput(opOperand))
1210     return operandNumber;
1211   unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1212   assert(!dpsIface.isDpsInit(opOperand));
1213   // Account for potential inputs that are not DPS and may not appear in
1214   // `indexingMaps`.
1215   return cast<DestinationStyleOpInterface>(*this->getOperation())
1216              .getNumDpsInputs() +
1217          operandNumber - start;
1218 }
1219 
1220 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
1221   LinalgOp linalgOp = cast<LinalgOp>(op);
1222   // Mixed tensor/buffer operands are not allowed.
1223   if (!linalgOp.hasPureTensorSemantics() &&
1224       !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1225     return op->emitOpError("expected to have pure tensor or buffer semantics");
1226 
1227   // Before checking indexing maps, we need to make sure the attributes
1228   // referenced by it are valid.
1229   if (linalgOp.hasDynamicIndexingMaps())
1230     if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1231       return failure();
1232 
1233   // All input/output operands must be indexed.
1234   if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
1235       linalgOp->getNumOperands())
1236     return op->emitOpError("expected the number of indexing_map (")
1237            << linalgOp.getIndexingMapsArray().size()
1238            << ") to be equal to the number of input/output operands ("
1239            << linalgOp->getNumOperands() << ")";
1240 
1241   // Set this flag if this op has user defined maps. This is required to guard
1242   // the below error condition which assume default indexing maps.
1243   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1244     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1245 
1246     // Symbols disallowed.
1247     if (indexingMap.getNumSymbols() != 0)
1248       return op->emitOpError("unexpected symbols in indexing_map #")
1249              << opOperand.getOperandNumber();
1250 
1251     // Domain must be consistent.
1252     unsigned numLoops = linalgOp.getNumLoops();
1253     if (indexingMap.getNumDims() != numLoops)
1254       return op->emitOpError("expected indexing_map #")
1255              << opOperand.getOperandNumber() << " to have " << numLoops
1256              << " dim(s) to match the number of loops";
1257 
1258     int64_t rank = linalgOp.getRank(&opOperand);
1259 
1260     if (indexingMap.getNumResults() != rank)
1261       return op->emitOpError("expected operand rank (")
1262              << rank << ") to match the result rank of indexing_map #"
1263              << opOperand.getOperandNumber() << " ("
1264              << indexingMap.getNumResults() << ")";
1265   }
1266   SmallVector<unsigned> redDims;
1267   linalgOp.getReductionDims(redDims);
1268 
1269   if (!linalgOp.getShapesToLoopsMap())
1270     return op->emitOpError("expected the shape-to-loops map to be non-null");
1271 
1272   // Check if given shapes match to inferred shapes.
1273   SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
1274   SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1275   // Verify only static cases since we can't get exact dimension sizes and
1276   // loop ranges for dynamic cases in this stage.
1277   if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1278     for (int64_t &range : endLoopRangeValues)
1279       range -= 1;
1280     for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1281       AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1282       SmallVector<int64_t, 4> startIndices =
1283           indexingMap.compose(startLoopRangeValues);
1284       SmallVector<int64_t, 4> endIndices =
1285           indexingMap.compose(endLoopRangeValues);
1286       ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
1287       for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
1288         // Ignore dynamic dimension or the case that the dimension size is 0
1289         if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1290           continue;
1291 
1292         // The first index or last index should be the maximum or the minimum in
1293         // the inferred index ranges since the range is increasing or
1294         // decreasing. The size of dimensions of input/output operands and the
1295         // maximum value + 1 in the inferred range should be the same. But, for
1296         // now we check if the inferred ranges are in boundary of input/output
1297         // operands' size or not in case that Affine Expressions are complicated
1298         // such as d0 * 3
1299         // + d1 since it is not easy to handle the issues.
1300         // Found the case that this solution can't check, for example, (d0, d1)
1301         // -> (d1 - d0)
1302         int64_t inferredDimSize =
1303             std::max(startIndices[dim], endIndices[dim]) + 1;
1304         if (std::min(startIndices[dim], endIndices[dim]) < 0) {
1305           std::string mapStr;
1306           {
1307             llvm::raw_string_ostream os(mapStr);
1308             os << indexingMap;
1309           }
1310           return op->emitOpError(
1311                      "unexpected result less than 0 at expression #")
1312                  << dim << " in " << mapStr;
1313         }
1314         if (dyn_cast<AffineDimExpr>(indexingMap.getResult(dim))) {
1315           if (inferredDimSize != shape[dim]) {
1316             return op->emitOpError("inferred input/output operand #")
1317                    << opOperand.getOperandNumber() << " has shape's dimension #"
1318                    << dim << " to be " << inferredDimSize << ", but found "
1319                    << shape[dim];
1320           }
1321         } else {
1322           if (inferredDimSize > shape[dim]) {
1323             return op->emitOpError("inferred input/output operand #")
1324                    << opOperand.getOperandNumber() << " has shape's dimension #"
1325                    << dim << " to be greater than or equal to "
1326                    << inferredDimSize << ", but found " << shape[dim];
1327           }
1328         }
1329       }
1330     }
1331   }
1332 
1333   // Check the region has exactly one block.
1334   if (linalgOp->getNumRegions() != 1 ||
1335       !llvm::hasSingleElement(linalgOp->getRegion(0)))
1336     return op->emitOpError("expects to have 1 region with 1 block");
1337 
1338   // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1339   // types.
1340   // TODO: once ranked shape types are plugged in, we may want to drop the
1341   // corresponding bbargs, that can never be read from. This will be subject to
1342   // consistency discussions (i.e. what to do with output tensors whose bbarg is
1343   // not used).
1344   Block &block = linalgOp->getRegion(0).front();
1345 
1346   if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1347     return op->emitOpError("expected as many non-induction variable region "
1348                            "arguments as the number of input/output operands");
1349 
1350   for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1351     Type elementType = opOperand->get().getType();
1352     if (isa<MemRefType, RankedTensorType>(elementType))
1353       elementType = getElementTypeOrSelf(opOperand->get().getType());
1354     Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1355     if (elementType != argType)
1356       return op->emitOpError("expected type of bb argument #")
1357              << opOperand->getOperandNumber() << " (" << argType << ")"
1358              << " to match element or self type of the corresponding operand ("
1359              << elementType << ")";
1360   }
1361 
1362   return success();
1363 }
1364