xref: /llvm-project/mlir/lib/Analysis/DataLayoutAnalysis.cpp (revision 285a229f205ae67dca48c8eac8206a115320c677)
1c59ce1f6SAlex Zinenko //===- DataLayoutAnalysis.cpp ---------------------------------------------===//
2c59ce1f6SAlex Zinenko //
3c59ce1f6SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c59ce1f6SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5c59ce1f6SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c59ce1f6SAlex Zinenko //
7c59ce1f6SAlex Zinenko //===----------------------------------------------------------------------===//
8c59ce1f6SAlex Zinenko 
9c59ce1f6SAlex Zinenko #include "mlir/Analysis/DataLayoutAnalysis.h"
10c59ce1f6SAlex Zinenko #include "mlir/IR/BuiltinOps.h"
11c59ce1f6SAlex Zinenko #include "mlir/IR/Operation.h"
12c59ce1f6SAlex Zinenko #include "mlir/Interfaces/DataLayoutInterfaces.h"
13*285a229fSMehdi Amini #include "mlir/Support/LLVM.h"
14*285a229fSMehdi Amini #include <memory>
15c59ce1f6SAlex Zinenko 
16c59ce1f6SAlex Zinenko using namespace mlir;
17c59ce1f6SAlex Zinenko 
DataLayoutAnalysis(Operation * root)18c59ce1f6SAlex Zinenko DataLayoutAnalysis::DataLayoutAnalysis(Operation *root)
19c59ce1f6SAlex Zinenko     : defaultLayout(std::make_unique<DataLayout>(DataLayoutOpInterface())) {
20c59ce1f6SAlex Zinenko   // Construct a DataLayout if possible from the op.
21c59ce1f6SAlex Zinenko   auto computeLayout = [this](Operation *op) {
22c59ce1f6SAlex Zinenko     if (auto iface = dyn_cast<DataLayoutOpInterface>(op))
23c59ce1f6SAlex Zinenko       layouts[op] = std::make_unique<DataLayout>(iface);
24c59ce1f6SAlex Zinenko     if (auto module = dyn_cast<ModuleOp>(op))
25c59ce1f6SAlex Zinenko       layouts[op] = std::make_unique<DataLayout>(module);
26c59ce1f6SAlex Zinenko   };
27c59ce1f6SAlex Zinenko 
28c59ce1f6SAlex Zinenko   // Compute layouts for both ancestors and descendants.
29c59ce1f6SAlex Zinenko   root->walk(computeLayout);
30c59ce1f6SAlex Zinenko   for (Operation *ancestor = root->getParentOp(); ancestor != nullptr;
31c59ce1f6SAlex Zinenko        ancestor = ancestor->getParentOp()) {
32c59ce1f6SAlex Zinenko     computeLayout(ancestor);
33c59ce1f6SAlex Zinenko   }
34c59ce1f6SAlex Zinenko }
35c59ce1f6SAlex Zinenko 
getAbove(Operation * operation) const36c59ce1f6SAlex Zinenko const DataLayout &DataLayoutAnalysis::getAbove(Operation *operation) const {
37c59ce1f6SAlex Zinenko   for (Operation *ancestor = operation->getParentOp(); ancestor != nullptr;
38c59ce1f6SAlex Zinenko        ancestor = ancestor->getParentOp()) {
39c59ce1f6SAlex Zinenko     auto it = layouts.find(ancestor);
40c59ce1f6SAlex Zinenko     if (it != layouts.end())
41c59ce1f6SAlex Zinenko       return *it->getSecond();
42c59ce1f6SAlex Zinenko   }
43c59ce1f6SAlex Zinenko 
44c59ce1f6SAlex Zinenko   // Fallback to the default layout.
45c59ce1f6SAlex Zinenko   return *defaultLayout;
46c59ce1f6SAlex Zinenko }
47c59ce1f6SAlex Zinenko 
getAtOrAbove(Operation * operation) const48c59ce1f6SAlex Zinenko const DataLayout &DataLayoutAnalysis::getAtOrAbove(Operation *operation) const {
49c59ce1f6SAlex Zinenko   auto it = layouts.find(operation);
50c59ce1f6SAlex Zinenko   if (it != layouts.end())
51c59ce1f6SAlex Zinenko     return *it->getSecond();
52c59ce1f6SAlex Zinenko   return getAbove(operation);
53c59ce1f6SAlex Zinenko }
54