xref: /llvm-project/mlir/lib/Analysis/DataLayoutAnalysis.cpp (revision 285a229f205ae67dca48c8eac8206a115320c677)
1 //===- DataLayoutAnalysis.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataLayoutAnalysis.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/Interfaces/DataLayoutInterfaces.h"
13 #include "mlir/Support/LLVM.h"
14 #include <memory>
15 
16 using namespace mlir;
17 
DataLayoutAnalysis(Operation * root)18 DataLayoutAnalysis::DataLayoutAnalysis(Operation *root)
19     : defaultLayout(std::make_unique<DataLayout>(DataLayoutOpInterface())) {
20   // Construct a DataLayout if possible from the op.
21   auto computeLayout = [this](Operation *op) {
22     if (auto iface = dyn_cast<DataLayoutOpInterface>(op))
23       layouts[op] = std::make_unique<DataLayout>(iface);
24     if (auto module = dyn_cast<ModuleOp>(op))
25       layouts[op] = std::make_unique<DataLayout>(module);
26   };
27 
28   // Compute layouts for both ancestors and descendants.
29   root->walk(computeLayout);
30   for (Operation *ancestor = root->getParentOp(); ancestor != nullptr;
31        ancestor = ancestor->getParentOp()) {
32     computeLayout(ancestor);
33   }
34 }
35 
getAbove(Operation * operation) const36 const DataLayout &DataLayoutAnalysis::getAbove(Operation *operation) const {
37   for (Operation *ancestor = operation->getParentOp(); ancestor != nullptr;
38        ancestor = ancestor->getParentOp()) {
39     auto it = layouts.find(ancestor);
40     if (it != layouts.end())
41       return *it->getSecond();
42   }
43 
44   // Fallback to the default layout.
45   return *defaultLayout;
46 }
47 
getAtOrAbove(Operation * operation) const48 const DataLayout &DataLayoutAnalysis::getAtOrAbove(Operation *operation) const {
49   auto it = layouts.find(operation);
50   if (it != layouts.end())
51     return *it->getSecond();
52   return getAbove(operation);
53 }
54