xref: /llvm-project/mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h (revision 9f77909a5e07b7973fe13d4ea6391c29ff1b46b5)
1 //===- ShapeMappingAnalysis.h - Preserve shape Info  ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_
10 #define MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_
11 
12 #include "mlir/IR/BuiltinAttributes.h"
13 #include "mlir/IR/Value.h"
14 #include "llvm/ADT/DenseMap.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/SmallVector.h"
17 
18 namespace mlir {
19 
20 namespace shape {
21 
22 /// ShapeMappingValue works as the value of ShapeMappingAnalysis table, where
23 /// `funcSymbol` is the symbol of mapping function, and `inputs` are the actual
24 /// parameters for the function.
25 struct ShapeMappingValue {
26   ShapeMappingValue() = default;
ShapeMappingValueShapeMappingValue27   ShapeMappingValue(FlatSymbolRefAttr symbol, llvm::SmallVector<Value> &&inps)
28       : funcSymbol(symbol), inputs(inps) {}
29 
30   FlatSymbolRefAttr funcSymbol;
31   llvm::SmallVector<Value> inputs;
32 };
33 
34 /// ShapeMappingAnalysis is used together with OutlineShapeComputationPass to
35 /// preserve Value and corresponding shape function / arguments mapping
36 /// information
37 struct ShapeMappingAnalysis {
ShapeMappingAnalysisShapeMappingAnalysis38   ShapeMappingAnalysis(Operation *op) : operation(op) { (void)operation; }
39 
40   /// Dumps the shape mapping information to the given stream.
printShapeMappingAnalysis41   void print(raw_ostream &os) const {
42     os << "// ---- Shape Mapping Information -----\n";
43     for (const auto &it : shapeMapping) {
44       const ShapeMappingValue &mappingValue = it.second;
45       os << "// Shape for " << it.first << " :: " << mappingValue.funcSymbol;
46       llvm::interleaveComma(mappingValue.inputs, os << "(");
47       os << ")\n";
48     }
49   }
50 
51   llvm::DenseMap<Value, ShapeMappingValue> shapeMapping;
52 
53 private:
54   Operation *operation;
55 };
56 
57 } // namespace shape
58 } // namespace mlir
59 
60 #endif // MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_
61