xref: /llvm-project/mlir/lib/Query/Query.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
102d9f4d1SDevajith //===---- Query.cpp - -----------------------------------------------------===//
202d9f4d1SDevajith //
302d9f4d1SDevajith // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
402d9f4d1SDevajith // See https://llvm.org/LICENSE.txt for license information.
502d9f4d1SDevajith // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
602d9f4d1SDevajith //
702d9f4d1SDevajith //===----------------------------------------------------------------------===//
802d9f4d1SDevajith 
902d9f4d1SDevajith #include "mlir/Query/Query.h"
1002d9f4d1SDevajith #include "QueryParser.h"
1158b44c81SJacques Pienaar #include "mlir/Dialect/Func/IR/FuncOps.h"
1258b44c81SJacques Pienaar #include "mlir/IR/IRMapping.h"
1302d9f4d1SDevajith #include "mlir/Query/Matcher/MatchFinder.h"
1402d9f4d1SDevajith #include "mlir/Query/QuerySession.h"
1502d9f4d1SDevajith #include "llvm/Support/SourceMgr.h"
1602d9f4d1SDevajith #include "llvm/Support/raw_ostream.h"
1702d9f4d1SDevajith 
1802d9f4d1SDevajith namespace mlir::query {
1902d9f4d1SDevajith 
parse(llvm::StringRef line,const QuerySession & qs)2002d9f4d1SDevajith QueryRef parse(llvm::StringRef line, const QuerySession &qs) {
2102d9f4d1SDevajith   return QueryParser::parse(line, qs);
2202d9f4d1SDevajith }
2302d9f4d1SDevajith 
2402d9f4d1SDevajith std::vector<llvm::LineEditor::Completion>
complete(llvm::StringRef line,size_t pos,const QuerySession & qs)2502d9f4d1SDevajith complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
2602d9f4d1SDevajith   return QueryParser::complete(line, pos, qs);
2702d9f4d1SDevajith }
2802d9f4d1SDevajith 
printMatch(llvm::raw_ostream & os,QuerySession & qs,Operation * op,const std::string & binding)2902d9f4d1SDevajith static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
3002d9f4d1SDevajith                        const std::string &binding) {
3102d9f4d1SDevajith   auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
3202d9f4d1SDevajith   auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
3302d9f4d1SDevajith       qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
3402d9f4d1SDevajith   qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
3502d9f4d1SDevajith                                      "\"" + binding + "\" binds here");
3602d9f4d1SDevajith }
3702d9f4d1SDevajith 
3858b44c81SJacques Pienaar // TODO: Extract into a helper function that can be reused outside query
3958b44c81SJacques Pienaar // context.
extractFunction(std::vector<Operation * > & ops,MLIRContext * context,llvm::StringRef functionName)4058b44c81SJacques Pienaar static Operation *extractFunction(std::vector<Operation *> &ops,
4158b44c81SJacques Pienaar                                   MLIRContext *context,
4258b44c81SJacques Pienaar                                   llvm::StringRef functionName) {
4358b44c81SJacques Pienaar   context->loadDialect<func::FuncDialect>();
4458b44c81SJacques Pienaar   OpBuilder builder(context);
4558b44c81SJacques Pienaar 
4658b44c81SJacques Pienaar   // Collect data for function creation
4758b44c81SJacques Pienaar   std::vector<Operation *> slice;
4858b44c81SJacques Pienaar   std::vector<Value> values;
4958b44c81SJacques Pienaar   std::vector<Type> outputTypes;
5058b44c81SJacques Pienaar 
5158b44c81SJacques Pienaar   for (auto *op : ops) {
5258b44c81SJacques Pienaar     // Return op's operands are propagated, but the op itself isn't needed.
5358b44c81SJacques Pienaar     if (!isa<func::ReturnOp>(op))
5458b44c81SJacques Pienaar       slice.push_back(op);
5558b44c81SJacques Pienaar 
5658b44c81SJacques Pienaar     // All results are returned by the extracted function.
5758b44c81SJacques Pienaar     outputTypes.insert(outputTypes.end(), op->getResults().getTypes().begin(),
5858b44c81SJacques Pienaar                        op->getResults().getTypes().end());
5958b44c81SJacques Pienaar 
6058b44c81SJacques Pienaar     // Track all values that need to be taken as input to function.
6158b44c81SJacques Pienaar     values.insert(values.end(), op->getOperands().begin(),
6258b44c81SJacques Pienaar                   op->getOperands().end());
6358b44c81SJacques Pienaar   }
6458b44c81SJacques Pienaar 
6558b44c81SJacques Pienaar   // Create the function
6658b44c81SJacques Pienaar   FunctionType funcType =
67db76af28SJacques Pienaar       builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes);
6858b44c81SJacques Pienaar   auto loc = builder.getUnknownLoc();
6958b44c81SJacques Pienaar   func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
7058b44c81SJacques Pienaar 
7158b44c81SJacques Pienaar   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
7258b44c81SJacques Pienaar 
7358b44c81SJacques Pienaar   // Map original values to function arguments
7458b44c81SJacques Pienaar   IRMapping mapper;
7558b44c81SJacques Pienaar   for (const auto &arg : llvm::enumerate(values))
7658b44c81SJacques Pienaar     mapper.map(arg.value(), funcOp.getArgument(arg.index()));
7758b44c81SJacques Pienaar 
7858b44c81SJacques Pienaar   // Clone operations and build function body
7958b44c81SJacques Pienaar   std::vector<Operation *> clonedOps;
8058b44c81SJacques Pienaar   std::vector<Value> clonedVals;
8158b44c81SJacques Pienaar   for (Operation *slicedOp : slice) {
8258b44c81SJacques Pienaar     Operation *clonedOp =
8358b44c81SJacques Pienaar         clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
8458b44c81SJacques Pienaar     clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
8558b44c81SJacques Pienaar                       clonedOp->result_end());
8658b44c81SJacques Pienaar   }
8758b44c81SJacques Pienaar   // Add return operation
8858b44c81SJacques Pienaar   builder.create<func::ReturnOp>(loc, clonedVals);
8958b44c81SJacques Pienaar 
9058b44c81SJacques Pienaar   // Remove unused function arguments
9158b44c81SJacques Pienaar   size_t currentIndex = 0;
9258b44c81SJacques Pienaar   while (currentIndex < funcOp.getNumArguments()) {
9358b44c81SJacques Pienaar     if (funcOp.getArgument(currentIndex).use_empty())
9458b44c81SJacques Pienaar       funcOp.eraseArgument(currentIndex);
9558b44c81SJacques Pienaar     else
9658b44c81SJacques Pienaar       ++currentIndex;
9758b44c81SJacques Pienaar   }
9858b44c81SJacques Pienaar 
9958b44c81SJacques Pienaar   return funcOp;
10058b44c81SJacques Pienaar }
10158b44c81SJacques Pienaar 
10202d9f4d1SDevajith Query::~Query() = default;
10302d9f4d1SDevajith 
run(llvm::raw_ostream & os,QuerySession & qs) const104*db791b27SRamkumar Ramachandra LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
10502d9f4d1SDevajith   os << errStr << "\n";
10602d9f4d1SDevajith   return mlir::failure();
10702d9f4d1SDevajith }
10802d9f4d1SDevajith 
run(llvm::raw_ostream & os,QuerySession & qs) const109*db791b27SRamkumar Ramachandra LogicalResult NoOpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
11002d9f4d1SDevajith   return mlir::success();
11102d9f4d1SDevajith }
11202d9f4d1SDevajith 
run(llvm::raw_ostream & os,QuerySession & qs) const113*db791b27SRamkumar Ramachandra LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
11402d9f4d1SDevajith   os << "Available commands:\n\n"
11502d9f4d1SDevajith         "  match MATCHER, m MATCHER      "
11602d9f4d1SDevajith         "Match the mlir against the given matcher.\n"
11702d9f4d1SDevajith         "  quit                              "
11802d9f4d1SDevajith         "Terminates the query session.\n\n";
11902d9f4d1SDevajith   return mlir::success();
12002d9f4d1SDevajith }
12102d9f4d1SDevajith 
run(llvm::raw_ostream & os,QuerySession & qs) const122*db791b27SRamkumar Ramachandra LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
12302d9f4d1SDevajith   qs.terminate = true;
12402d9f4d1SDevajith   return mlir::success();
12502d9f4d1SDevajith }
12602d9f4d1SDevajith 
run(llvm::raw_ostream & os,QuerySession & qs) const127*db791b27SRamkumar Ramachandra LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
12858b44c81SJacques Pienaar   Operation *rootOp = qs.getRootOp();
12902d9f4d1SDevajith   int matchCount = 0;
13002d9f4d1SDevajith   std::vector<Operation *> matches =
13158b44c81SJacques Pienaar       matcher::MatchFinder().getMatches(rootOp, matcher);
13258b44c81SJacques Pienaar 
13358b44c81SJacques Pienaar   // An extract call is recognized by considering if the matcher has a name.
13458b44c81SJacques Pienaar   // TODO: Consider making the extract more explicit.
13558b44c81SJacques Pienaar   if (matcher.hasFunctionName()) {
13658b44c81SJacques Pienaar     auto functionName = matcher.getFunctionName();
13758b44c81SJacques Pienaar     Operation *function =
13858b44c81SJacques Pienaar         extractFunction(matches, rootOp->getContext(), functionName);
13958b44c81SJacques Pienaar     os << "\n" << *function << "\n\n";
14058b44c81SJacques Pienaar     function->erase();
14158b44c81SJacques Pienaar     return mlir::success();
14258b44c81SJacques Pienaar   }
14358b44c81SJacques Pienaar 
14402d9f4d1SDevajith   os << "\n";
14502d9f4d1SDevajith   for (Operation *op : matches) {
14602d9f4d1SDevajith     os << "Match #" << ++matchCount << ":\n\n";
14702d9f4d1SDevajith     // Placeholder "root" binding for the initial draft.
14802d9f4d1SDevajith     printMatch(os, qs, op, "root");
14902d9f4d1SDevajith   }
15002d9f4d1SDevajith   os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
15102d9f4d1SDevajith 
15202d9f4d1SDevajith   return mlir::success();
15302d9f4d1SDevajith }
15402d9f4d1SDevajith 
15502d9f4d1SDevajith } // namespace mlir::query
156