xref: /llvm-project/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp (revision 9467645547f99ba8fa8152d514f06e76e0be8585)
1 //===- LLVMIntrinsicGen.cpp - TableGen utility for converting intrinsics --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a TableGen generator that converts TableGen definitions for LLVM
10 // intrinsics to TableGen definitions for MLIR operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/GenInfo.h"
15 
16 #include "llvm/ADT/SmallBitVector.h"
17 #include "llvm/CodeGenTypes/MachineValueType.h"
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/PrettyStackTrace.h"
20 #include "llvm/Support/Regex.h"
21 #include "llvm/Support/Signals.h"
22 #include "llvm/TableGen/Error.h"
23 #include "llvm/TableGen/Main.h"
24 #include "llvm/TableGen/Record.h"
25 #include "llvm/TableGen/TableGenBackend.h"
26 
27 using llvm::Record;
28 using llvm::RecordKeeper;
29 using llvm::Regex;
30 using namespace mlir;
31 
32 static llvm::cl::OptionCategory intrinsicGenCat("Intrinsics Generator Options");
33 
34 static llvm::cl::opt<std::string>
35     nameFilter("llvmir-intrinsics-filter",
36                llvm::cl::desc("Only keep the intrinsics with the specified "
37                               "substring in their record name"),
38                llvm::cl::cat(intrinsicGenCat));
39 
40 static llvm::cl::opt<std::string>
41     opBaseClass("dialect-opclass-base",
42                 llvm::cl::desc("The base class for the ops in the dialect we "
43                                "are planning to emit"),
44                 llvm::cl::init("LLVM_IntrOp"), llvm::cl::cat(intrinsicGenCat));
45 
46 static llvm::cl::opt<std::string> accessGroupRegexp(
47     "llvmir-intrinsics-access-group-regexp",
48     llvm::cl::desc("Mark intrinsics that match the specified "
49                    "regexp as taking an access group metadata"),
50     llvm::cl::cat(intrinsicGenCat));
51 
52 static llvm::cl::opt<std::string> aliasAnalysisRegexp(
53     "llvmir-intrinsics-alias-analysis-regexp",
54     llvm::cl::desc("Mark intrinsics that match the specified "
55                    "regexp as taking alias.scopes, noalias, and tbaa metadata"),
56     llvm::cl::cat(intrinsicGenCat));
57 
58 // Used to represent the indices of overloadable operands/results.
59 using IndicesTy = llvm::SmallBitVector;
60 
61 /// Return a CodeGen value type entry from a type record.
62 static llvm::MVT::SimpleValueType getValueType(const Record *rec) {
63   return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt(
64       "Value");
65 }
66 
67 /// Return the indices of the definitions in a list of definitions that
68 /// represent overloadable types
69 static IndicesTy getOverloadableTypeIdxs(const Record &record,
70                                          const char *listName) {
71   auto results = record.getValueAsListOfDefs(listName);
72   IndicesTy overloadedOps(results.size());
73   for (const auto &r : llvm::enumerate(results)) {
74     llvm::MVT::SimpleValueType vt = getValueType(r.value());
75     switch (vt) {
76     case llvm::MVT::iAny:
77     case llvm::MVT::fAny:
78     case llvm::MVT::Any:
79     case llvm::MVT::pAny:
80     case llvm::MVT::vAny:
81       overloadedOps.set(r.index());
82       break;
83     default:
84       continue;
85     }
86   }
87   return overloadedOps;
88 }
89 
90 namespace {
91 /// A wrapper for LLVM's Tablegen class `Intrinsic` that provides accessors to
92 /// the fields of the record.
93 class LLVMIntrinsic {
94 public:
95   LLVMIntrinsic(const Record &record) : record(record) {}
96 
97   /// Get the name of the operation to be used in MLIR.  Uses the appropriate
98   /// field if not empty, constructs a name by replacing underscores with dots
99   /// in the record name otherwise.
100   std::string getOperationName() const {
101     StringRef name = record.getValueAsString(fieldName);
102     if (!name.empty())
103       return name.str();
104 
105     name = record.getName();
106     assert(name.starts_with("int_") &&
107            "LLVM intrinsic names are expected to start with 'int_'");
108     name = name.drop_front(4);
109     SmallVector<StringRef, 8> chunks;
110     StringRef targetPrefix = record.getValueAsString("TargetPrefix");
111     name.split(chunks, '_');
112     auto *chunksBegin = chunks.begin();
113     // Remove the target prefix from target specific intrinsics.
114     if (!targetPrefix.empty()) {
115       assert(targetPrefix == *chunksBegin &&
116              "Intrinsic has TargetPrefix, but "
117              "record name doesn't begin with it");
118       assert(chunks.size() >= 2 &&
119              "Intrinsic has TargetPrefix, but "
120              "chunks has only one element meaning the intrinsic name is empty");
121       ++chunksBegin;
122     }
123     return llvm::join(chunksBegin, chunks.end(), ".");
124   }
125 
126   /// Get the name of the record without the "intrinsic" prefix.
127   StringRef getProperRecordName() const {
128     StringRef name = record.getName();
129     assert(name.starts_with("int_") &&
130            "LLVM intrinsic names are expected to start with 'int_'");
131     return name.drop_front(4);
132   }
133 
134   /// Get the number of operands.
135   unsigned getNumOperands() const {
136     auto operands = record.getValueAsListOfDefs(fieldOperands);
137     assert(llvm::all_of(
138                operands,
139                [](const Record *r) { return r->isSubClassOf("LLVMType"); }) &&
140            "expected operands to be of LLVM type");
141     return operands.size();
142   }
143 
144   /// Get the number of results.  Note that LLVM does not support multi-value
145   /// operations so, in fact, multiple results will be returned as a value of
146   /// structure type.
147   unsigned getNumResults() const {
148     auto results = record.getValueAsListOfDefs(fieldResults);
149     for (const Record *r : results) {
150       (void)r;
151       assert(r->isSubClassOf("LLVMType") &&
152              "expected operands to be of LLVM type");
153     }
154     return results.size();
155   }
156 
157   /// Return true if the intrinsic may have side effects, i.e. does not have the
158   /// `IntrNoMem` property.
159   bool hasSideEffects() const {
160     return llvm::none_of(
161         record.getValueAsListOfDefs(fieldTraits),
162         [](const Record *r) { return r->getName() == "IntrNoMem"; });
163   }
164 
165   /// Return true if the intrinsic is commutative, i.e. has the respective
166   /// property.
167   bool isCommutative() const {
168     return llvm::any_of(
169         record.getValueAsListOfDefs(fieldTraits),
170         [](const Record *r) { return r->getName() == "Commutative"; });
171   }
172 
173   IndicesTy getOverloadableOperandsIdxs() const {
174     return getOverloadableTypeIdxs(record, fieldOperands);
175   }
176 
177   IndicesTy getOverloadableResultsIdxs() const {
178     return getOverloadableTypeIdxs(record, fieldResults);
179   }
180 
181 private:
182   /// Names of the fields in the Intrinsic LLVM Tablegen class.
183   const char *fieldName = "LLVMName";
184   const char *fieldOperands = "ParamTypes";
185   const char *fieldResults = "RetTypes";
186   const char *fieldTraits = "IntrProperties";
187 
188   const Record &record;
189 };
190 } // namespace
191 
192 /// Prints the elements in "range" separated by commas and surrounded by "[]".
193 template <typename Range>
194 void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
195   os << '[';
196   llvm::interleaveComma(range, os);
197   os << ']';
198 }
199 
200 /// Emits ODS (TableGen-based) code for `record` representing an LLVM intrinsic.
201 /// Returns true on error, false on success.
202 static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
203   LLVMIntrinsic intr(record);
204 
205   Regex accessGroupMatcher(accessGroupRegexp);
206   bool requiresAccessGroup =
207       !accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName());
208 
209   Regex aliasAnalysisMatcher(aliasAnalysisRegexp);
210   bool requiresAliasAnalysis = !aliasAnalysisRegexp.empty() &&
211                                aliasAnalysisMatcher.match(record.getName());
212 
213   // Prepare strings for traits, if any.
214   SmallVector<StringRef, 2> traits;
215   if (intr.isCommutative())
216     traits.push_back("Commutative");
217   if (!intr.hasSideEffects())
218     traits.push_back("NoMemoryEffect");
219 
220   // Prepare strings for operands.
221   SmallVector<StringRef, 8> operands(intr.getNumOperands(), "LLVM_Type");
222   if (requiresAccessGroup)
223     operands.push_back(
224         "OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups");
225   if (requiresAliasAnalysis) {
226     operands.push_back("OptionalAttr<LLVM_AliasScopeArrayAttr>:$alias_scopes");
227     operands.push_back(
228         "OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes");
229     operands.push_back("OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa");
230   }
231 
232   // Emit the definition.
233   os << "def LLVM_" << intr.getProperRecordName() << " : " << opBaseClass
234      << "<\"" << intr.getOperationName() << "\", ";
235   printBracketedRange(intr.getOverloadableResultsIdxs().set_bits(), os);
236   os << ", ";
237   printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
238   os << ", ";
239   printBracketedRange(traits, os);
240   os << ", " << intr.getNumResults() << ", "
241      << (requiresAccessGroup ? "1" : "0") << ", "
242      << (requiresAliasAnalysis ? "1" : "0") << ">, Arguments<(ins"
243      << (operands.empty() ? "" : " ");
244   llvm::interleaveComma(operands, os);
245   os << ")>;\n\n";
246 
247   return false;
248 }
249 
250 /// Traverses the list of TableGen definitions derived from the "Intrinsic"
251 /// class and generates MLIR ODS definitions for those intrinsics that have
252 /// the name matching the filter.
253 static bool emitIntrinsics(const RecordKeeper &records, llvm::raw_ostream &os) {
254   llvm::emitSourceFileHeader("Operations for LLVM intrinsics", os, records);
255   os << "include \"mlir/Dialect/LLVMIR/LLVMOpBase.td\"\n";
256   os << "include \"mlir/Interfaces/SideEffectInterfaces.td\"\n\n";
257 
258   auto defs = records.getAllDerivedDefinitions("Intrinsic");
259   for (const Record *r : defs) {
260     if (!nameFilter.empty() && !r->getName().contains(nameFilter))
261       continue;
262     if (emitIntrinsic(*r, os))
263       return true;
264   }
265 
266   return false;
267 }
268 
269 static mlir::GenRegistration genLLVMIRIntrinsics("gen-llvmir-intrinsics",
270                                                  "Generate LLVM IR intrinsics",
271                                                  emitIntrinsics);
272