xref: /llvm-project/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp (revision 884221eddb9d395830704fac79fd04008e02e368)
1 //===- DebugExtensionOps.cpp - Debug extension for the Transform dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.h"
10 
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
13 #include "mlir/IR/OpImplementation.h"
14 
15 using namespace mlir;
16 
17 #define GET_OP_CLASSES
18 #include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc"
19 
20 DiagnosedSilenceableFailure
21 transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
22                                       transform::TransformResults &results,
23                                       transform::TransformState &state) {
24   if (isa<TransformHandleTypeInterface>(getAt().getType())) {
25     auto payload = state.getPayloadOps(getAt());
26     for (Operation *op : payload)
27       op->emitRemark() << getMessage();
28     return DiagnosedSilenceableFailure::success();
29   }
30 
31   assert(isa<transform::TransformValueHandleTypeInterface>(getAt().getType()) &&
32          "unhandled kind of transform type");
33 
34   auto describeValue = [](Diagnostic &os, Value value) {
35     os << "value handle points to ";
36     if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
37       os << "a block argument #" << arg.getArgNumber() << " in block #"
38          << std::distance(arg.getOwner()->getParent()->begin(),
39                           arg.getOwner()->getIterator())
40          << " in region #" << arg.getOwner()->getParent()->getRegionNumber();
41     } else {
42       os << "an op result #" << llvm::cast<OpResult>(value).getResultNumber();
43     }
44   };
45 
46   for (Value value : state.getPayloadValues(getAt())) {
47     InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage();
48     describeValue(diag.attachNote(), value);
49   }
50 
51   return DiagnosedSilenceableFailure::success();
52 }
53 
54 DiagnosedSilenceableFailure transform::DebugEmitParamAsRemarkOp::apply(
55     transform::TransformRewriter &rewriter,
56     transform::TransformResults &results, transform::TransformState &state) {
57   std::string str;
58   llvm::raw_string_ostream os(str);
59   if (getMessage())
60     os << *getMessage() << " ";
61   llvm::interleaveComma(state.getParams(getParam()), os);
62   if (!getAnchor()) {
63     emitRemark() << str;
64     return DiagnosedSilenceableFailure::success();
65   }
66   for (Operation *payload : state.getPayloadOps(getAnchor()))
67     ::mlir::emitRemark(payload->getLoc()) << str;
68   return DiagnosedSilenceableFailure::success();
69 }
70