xref: /llvm-project/mlir/lib/Query/Query.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===---- Query.cpp - -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Query/Query.h"
10 #include "QueryParser.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/Query/Matcher/MatchFinder.h"
14 #include "mlir/Query/QuerySession.h"
15 #include "llvm/Support/SourceMgr.h"
16 #include "llvm/Support/raw_ostream.h"
17 
18 namespace mlir::query {
19 
parse(llvm::StringRef line,const QuerySession & qs)20 QueryRef parse(llvm::StringRef line, const QuerySession &qs) {
21   return QueryParser::parse(line, qs);
22 }
23 
24 std::vector<llvm::LineEditor::Completion>
complete(llvm::StringRef line,size_t pos,const QuerySession & qs)25 complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
26   return QueryParser::complete(line, pos, qs);
27 }
28 
printMatch(llvm::raw_ostream & os,QuerySession & qs,Operation * op,const std::string & binding)29 static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
30                        const std::string &binding) {
31   auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
32   auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
33       qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
34   qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
35                                      "\"" + binding + "\" binds here");
36 }
37 
38 // TODO: Extract into a helper function that can be reused outside query
39 // context.
extractFunction(std::vector<Operation * > & ops,MLIRContext * context,llvm::StringRef functionName)40 static Operation *extractFunction(std::vector<Operation *> &ops,
41                                   MLIRContext *context,
42                                   llvm::StringRef functionName) {
43   context->loadDialect<func::FuncDialect>();
44   OpBuilder builder(context);
45 
46   // Collect data for function creation
47   std::vector<Operation *> slice;
48   std::vector<Value> values;
49   std::vector<Type> outputTypes;
50 
51   for (auto *op : ops) {
52     // Return op's operands are propagated, but the op itself isn't needed.
53     if (!isa<func::ReturnOp>(op))
54       slice.push_back(op);
55 
56     // All results are returned by the extracted function.
57     outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
58                        op->getResults().getTypes().end());
59 
60     // Track all values that need to be taken as input to function.
61     values.insert(values.end(), op->getOperands().begin(),
62                   op->getOperands().end());
63   }
64 
65   // Create the function
66   FunctionType funcType =
67       builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes);
68   auto loc = builder.getUnknownLoc();
69   func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
70 
71   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
72 
73   // Map original values to function arguments
74   IRMapping mapper;
75   for (const auto &arg : llvm::enumerate(values))
76     mapper.map(arg.value(), funcOp.getArgument(arg.index()));
77 
78   // Clone operations and build function body
79   std::vector<Operation *> clonedOps;
80   std::vector<Value> clonedVals;
81   for (Operation *slicedOp : slice) {
82     Operation *clonedOp =
83         clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
84     clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
85                       clonedOp->result_end());
86   }
87   // Add return operation
88   builder.create<func::ReturnOp>(loc, clonedVals);
89 
90   // Remove unused function arguments
91   size_t currentIndex = 0;
92   while (currentIndex < funcOp.getNumArguments()) {
93     if (funcOp.getArgument(currentIndex).use_empty())
94       funcOp.eraseArgument(currentIndex);
95     else
96       ++currentIndex;
97   }
98 
99   return funcOp;
100 }
101 
102 Query::~Query() = default;
103 
run(llvm::raw_ostream & os,QuerySession & qs) const104 LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
105   os << errStr << "\n";
106   return mlir::failure();
107 }
108 
run(llvm::raw_ostream & os,QuerySession & qs) const109 LogicalResult NoOpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
110   return mlir::success();
111 }
112 
run(llvm::raw_ostream & os,QuerySession & qs) const113 LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
114   os << "Available commands:\n\n"
115         "  match MATCHER, m MATCHER      "
116         "Match the mlir against the given matcher.\n"
117         "  quit                              "
118         "Terminates the query session.\n\n";
119   return mlir::success();
120 }
121 
run(llvm::raw_ostream & os,QuerySession & qs) const122 LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
123   qs.terminate = true;
124   return mlir::success();
125 }
126 
run(llvm::raw_ostream & os,QuerySession & qs) const127 LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
128   Operation *rootOp = qs.getRootOp();
129   int matchCount = 0;
130   std::vector<Operation *> matches =
131       matcher::MatchFinder().getMatches(rootOp, matcher);
132 
133   // An extract call is recognized by considering if the matcher has a name.
134   // TODO: Consider making the extract more explicit.
135   if (matcher.hasFunctionName()) {
136     auto functionName = matcher.getFunctionName();
137     Operation *function =
138         extractFunction(matches, rootOp->getContext(), functionName);
139     os << "\n" << *function << "\n\n";
140     function->erase();
141     return mlir::success();
142   }
143 
144   os << "\n";
145   for (Operation *op : matches) {
146     os << "Match #" << ++matchCount << ":\n\n";
147     // Placeholder "root" binding for the initial draft.
148     printMatch(os, qs, op, "root");
149   }
150   os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
151 
152   return mlir::success();
153 }
154 
155 } // namespace mlir::query
156