xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h (revision 15c06bc4af6e11af3dfac465cdb4e772e70505e7)
1 //===- IterationGraphSorter.h -----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines the iteration graph sorter (top-sort scheduling).
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
15 
16 #include "mlir/IR/AffineMap.h"
17 
18 namespace mlir {
19 
20 // Forward declarations.
21 class Value;
22 namespace utils {
23 enum class IteratorType : uint32_t;
24 } // namespace utils
25 namespace linalg {
26 class GenericOp;
27 } // namespace linalg
28 
29 namespace sparse_tensor {
30 
31 /// Iteration graph sorting mask,
32 enum class SortMask : unsigned {
33   // The individual mask bits.
34   kIncludeDenseOutput = 0x1, // b001
35   kIncludeDenseInput = 0x2,  // b010
36   // The subsets of mask bits.
37   kIncludeAll = 0x7,   // b111
38   kIncludeDense = 0x3, // b011
39   kSparseOnly = 0x0,   // b000
40 };
41 
42 class IterationGraphSorter {
43 public:
44   /// Factory method that construct an iteration graph sorter
45   /// for the given linalg.generic operation.
46   static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp);
47 
48   /// Returns a permutation that represents the scheduled loop order.
49   /// Note that the returned AffineMap could be null if the kernel
50   /// cannot be scheduled due to cyclic iteration graph.
51   [[nodiscard]] AffineMap sort(SortMask mask, Value ignored = nullptr);
52 
53   /// Returns the number of loops in the iteration graph.
getNumLoops()54   unsigned getNumLoops() const { return loop2OutLvl.getNumDims(); }
55 
56 private:
57   // Private constructor.
58   IterationGraphSorter(SmallVector<Value> &&ins,
59                        SmallVector<AffineMap> &&loop2InsLvl, Value out,
60                        AffineMap loop2OutLvl,
61                        SmallVector<utils::IteratorType> &&iterTypes);
62 
63   // Adds all the constraints in the given loop to level map.
64   void addConstraints(Value t, AffineMap loop2LvlMap);
65 
66   /// A helper to compute a topological sort. The method has an
67   /// O(n^2) time complexity since we use an adjacency matrix
68   /// representation for the iteration graph.
69   AffineMap topoSort();
70 
71   // Input tensors and associated loop to level maps.
72   SmallVector<Value> ins;
73   SmallVector<AffineMap> loop2InsLvl;
74 
75   // Output tensor and associated loop to level map.
76   Value out;
77   AffineMap loop2OutLvl;
78 
79   // Loop itation types;
80   SmallVector<utils::IteratorType> iterTypes;
81 
82   // Adjacency matrix that represents the iteration graph.
83   std::vector<std::vector<bool>> itGraph;
84 
85   // InDegree used for topo sort.
86   std::vector<unsigned> inDegree;
87 };
88 
89 } // namespace sparse_tensor
90 } // namespace mlir
91 
92 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
93