1 //===---- Query.cpp - -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Query/Query.h"
10 #include "QueryParser.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/Query/Matcher/MatchFinder.h"
14 #include "mlir/Query/QuerySession.h"
15 #include "llvm/Support/SourceMgr.h"
16 #include "llvm/Support/raw_ostream.h"
17
18 namespace mlir::query {
19
parse(llvm::StringRef line,const QuerySession & qs)20 QueryRef parse(llvm::StringRef line, const QuerySession &qs) {
21 return QueryParser::parse(line, qs);
22 }
23
24 std::vector<llvm::LineEditor::Completion>
complete(llvm::StringRef line,size_t pos,const QuerySession & qs)25 complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
26 return QueryParser::complete(line, pos, qs);
27 }
28
printMatch(llvm::raw_ostream & os,QuerySession & qs,Operation * op,const std::string & binding)29 static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
30 const std::string &binding) {
31 auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
32 auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
33 qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
34 qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
35 "\"" + binding + "\" binds here");
36 }
37
38 // TODO: Extract into a helper function that can be reused outside query
39 // context.
extractFunction(std::vector<Operation * > & ops,MLIRContext * context,llvm::StringRef functionName)40 static Operation *extractFunction(std::vector<Operation *> &ops,
41 MLIRContext *context,
42 llvm::StringRef functionName) {
43 context->loadDialect<func::FuncDialect>();
44 OpBuilder builder(context);
45
46 // Collect data for function creation
47 std::vector<Operation *> slice;
48 std::vector<Value> values;
49 std::vector<Type> outputTypes;
50
51 for (auto *op : ops) {
52 // Return op's operands are propagated, but the op itself isn't needed.
53 if (!isa<func::ReturnOp>(op))
54 slice.push_back(op);
55
56 // All results are returned by the extracted function.
57 outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
58 op->getResults().getTypes().end());
59
60 // Track all values that need to be taken as input to function.
61 values.insert(values.end(), op->getOperands().begin(),
62 op->getOperands().end());
63 }
64
65 // Create the function
66 FunctionType funcType =
67 builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes);
68 auto loc = builder.getUnknownLoc();
69 func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
70
71 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
72
73 // Map original values to function arguments
74 IRMapping mapper;
75 for (const auto &arg : llvm::enumerate(values))
76 mapper.map(arg.value(), funcOp.getArgument(arg.index()));
77
78 // Clone operations and build function body
79 std::vector<Operation *> clonedOps;
80 std::vector<Value> clonedVals;
81 for (Operation *slicedOp : slice) {
82 Operation *clonedOp =
83 clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
84 clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
85 clonedOp->result_end());
86 }
87 // Add return operation
88 builder.create<func::ReturnOp>(loc, clonedVals);
89
90 // Remove unused function arguments
91 size_t currentIndex = 0;
92 while (currentIndex < funcOp.getNumArguments()) {
93 if (funcOp.getArgument(currentIndex).use_empty())
94 funcOp.eraseArgument(currentIndex);
95 else
96 ++currentIndex;
97 }
98
99 return funcOp;
100 }
101
102 Query::~Query() = default;
103
run(llvm::raw_ostream & os,QuerySession & qs) const104 LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
105 os << errStr << "\n";
106 return mlir::failure();
107 }
108
run(llvm::raw_ostream & os,QuerySession & qs) const109 LogicalResult NoOpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
110 return mlir::success();
111 }
112
run(llvm::raw_ostream & os,QuerySession & qs) const113 LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
114 os << "Available commands:\n\n"
115 " match MATCHER, m MATCHER "
116 "Match the mlir against the given matcher.\n"
117 " quit "
118 "Terminates the query session.\n\n";
119 return mlir::success();
120 }
121
run(llvm::raw_ostream & os,QuerySession & qs) const122 LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
123 qs.terminate = true;
124 return mlir::success();
125 }
126
run(llvm::raw_ostream & os,QuerySession & qs) const127 LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
128 Operation *rootOp = qs.getRootOp();
129 int matchCount = 0;
130 std::vector<Operation *> matches =
131 matcher::MatchFinder().getMatches(rootOp, matcher);
132
133 // An extract call is recognized by considering if the matcher has a name.
134 // TODO: Consider making the extract more explicit.
135 if (matcher.hasFunctionName()) {
136 auto functionName = matcher.getFunctionName();
137 Operation *function =
138 extractFunction(matches, rootOp->getContext(), functionName);
139 os << "\n" << *function << "\n\n";
140 function->erase();
141 return mlir::success();
142 }
143
144 os << "\n";
145 for (Operation *op : matches) {
146 os << "Match #" << ++matchCount << ":\n\n";
147 // Placeholder "root" binding for the initial draft.
148 printMatch(os, qs, op, "root");
149 }
150 os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
151
152 return mlir::success();
153 }
154
155 } // namespace mlir::query
156