xref: /llvm-project/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (revision db115ba3efee9c940539667842a1092d8d956850)
1 //===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml  ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an ODS (and C++) generator from a YAML form
10 // derived from the mathematical expression of linalg named ops. Typically a
11 // math oriented DSL will be used to export the essential representation to
12 // this form, and maintaining the SOT at the math level (versus recreating it
13 // in MLIR) is deemed to have systemic value.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/AsmParser/AsmParser.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/ToolOutputFile.h"
28 #include "llvm/Support/YAMLTraits.h"
29 #include <optional>
30 
31 using namespace mlir;
32 
33 using llvm::yaml::Input;
34 
35 #define DEBUG_TYPE "linalg-ods-gen"
36 
37 //===----------------------------------------------------------------------===//
38 // Mapping structs (correspond to data types in the YAML description).
39 // TODO: Since this is a schema/part of the contract, it should be moved to
40 // a real header.
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 
45 struct LinalgYAMLContext {
46   MLIRContext *mlirContext;
47 };
48 
49 struct LinalgOpMetadata {
50   std::string name;
51   std::string cppClassName;
52   std::optional<std::string> doc;
53   SmallVector<std::string> implements;
54   SmallVector<std::string> defines;
55 };
56 
57 struct SerializedAffineMap {
58   AffineMapAttr affineMapAttr;
59 
60   AffineMap affineMap() { return affineMapAttr.getValue(); }
61 };
62 
63 enum class LinalgOperandDefKind {
64   InputTensor,
65   Scalar,
66   OutputTensor,
67   IndexAttr,
68   UnaryFnAttr,
69   BinaryFnAttr,
70   TernaryFnAttr,
71   TypeFnAttr
72 };
73 
74 struct LinalgOperandDef {
75   std::string name;
76   LinalgOperandDefKind kind;
77   std::optional<std::string> typeVar;
78   std::optional<SerializedAffineMap> shapeMap;
79   std::optional<SerializedAffineMap> indexAttrMap;
80   std::optional<SmallVector<int64_t>> defaultIndices;
81   std::optional<std::string> defaultFn;
82 };
83 
84 enum class LinalgIteratorTypeDef {
85   parallel,
86   reduction,
87 };
88 
89 struct LinalgIndexingMapsConfig {
90   std::optional<SmallVector<SerializedAffineMap>> staticIndexingMaps;
91 };
92 
93 struct ScalarExpression;
94 
95 enum class ScalarFnKind { Unary, Binary, Ternary, Type };
96 
97 struct ScalarFn {
98   ScalarFnKind kind;
99   std::optional<std::string> fnName;
100   std::optional<std::string> attrName;
101   std::optional<std::string> typeVar;
102   // NOTE: This must be of arity 1, but to break the self-referential cycle,
103   // we use a heap allocated vector.
104   std::vector<ScalarExpression> operands;
105 };
106 
107 struct ScalarExpression {
108   std::optional<std::string> arg;
109   std::optional<std::string> constant;
110   std::optional<int64_t> index;
111   std::optional<ScalarFn> scalarFn;
112 };
113 
114 struct ScalarAssign {
115   std::string arg;
116   ScalarExpression value;
117 };
118 
119 struct LinalgStructuredOpConfig {
120   SmallVector<LinalgOperandDef> args;
121   LinalgIndexingMapsConfig indexingMaps;
122   SmallVector<LinalgIteratorTypeDef> iteratorTypes;
123   std::vector<ScalarAssign> assignments;
124 };
125 
126 struct LinalgOpConfig {
127   std::optional<LinalgOpMetadata> metadata;
128   std::optional<LinalgStructuredOpConfig> structuredOp;
129 };
130 
131 } // namespace
132 
133 //===----------------------------------------------------------------------===//
134 // Mapping traits.
135 //===----------------------------------------------------------------------===//
136 
137 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef)
138 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap)
139 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef)
140 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign)
141 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression)
142 LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig)
143 
144 namespace llvm {
145 namespace yaml {
146 
147 /// Top-level type containing op metadata and one of a concrete op type.
148 /// Currently, the only defined op type is `structured_op` (maps to
149 /// `LinalgStructuredOpConfig`).
150 template <>
151 struct MappingTraits<LinalgOpConfig> {
152   static void mapping(IO &io, LinalgOpConfig &info) {
153     io.mapOptional("metadata", info.metadata);
154     io.mapOptional("structured_op", info.structuredOp);
155   }
156 };
157 
158 /// A structured op models (at most) a single contraction by modeling
159 ///   - A list of named arguments (`LinalgOperandDef`), which can be inputs,
160 ///     outputs, or index attributes.
161 ///   - List of indexing maps (see `LinalgIndexingMaps`).
162 ///   - Iterator types (see `LinalgIteratorTypeDef`).
163 ///   - List of scalar level assignment (see `ScalarAssign`).
164 template <>
165 struct MappingTraits<LinalgStructuredOpConfig> {
166   static void mapping(IO &io, LinalgStructuredOpConfig &info) {
167     io.mapRequired("args", info.args);
168     io.mapRequired("indexing_maps", info.indexingMaps);
169     io.mapRequired("iterator_types", info.iteratorTypes);
170     io.mapRequired("assignments", info.assignments);
171   }
172 };
173 
174 /// Maps a named tensor, scalar or attribute argument to an operation,
175 /// consisting of:
176 ///   - `name`: Must be unique within the operation.
177 ///   - `usage`: How the argument is used (input, output, attribute, etc).
178 ///   - `type_var`: The symbolic type variable that binds to the element or self
179 ///     type of the tensor or scalar argument, respectively.
180 ///   - `shape_map`: An optional AffineMap from all op symbols to the shape of
181 ///     the argument. Only tensor arguments have a `shape_map`. Each shape must
182 ///     be normalized over the same list of symbols and have no dimension
183 ///     inputs.
184 ///   - `index_attr_map`: An optional AffineMap from all op symbols to the
185 ///     index attribute symbols. During op creation these symbols are replaced
186 ///     by the corresponding `name` index attribue values. Only index attribute
187 ///     arguments have an `index_attr_map`.
188 ///   - `default_indices`: An optional default initialization for index
189 ///     attribute arguments.
190 ///   - `default_fn`: An optional default initialization for function attribute
191 ///     arguments.
192 template <>
193 struct MappingTraits<LinalgOperandDef> {
194   static void mapping(IO &io, LinalgOperandDef &info) {
195     io.mapRequired("name", info.name);
196     io.mapRequired("kind", info.kind);
197     io.mapOptional("type_var", info.typeVar);
198     io.mapOptional("shape_map", info.shapeMap);
199     io.mapOptional("index_attr_map", info.indexAttrMap);
200     io.mapOptional("default_indices", info.defaultIndices);
201     io.mapOptional("default_fn", info.defaultFn);
202   }
203 };
204 
205 /// Usage enum for a named argument.
206 template <>
207 struct ScalarEnumerationTraits<LinalgOperandDefKind> {
208   static void enumeration(IO &io, LinalgOperandDefKind &value) {
209     io.enumCase(value, "input_tensor", LinalgOperandDefKind::InputTensor);
210     io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar);
211     io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor);
212     io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
213     io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
214     io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
215     io.enumCase(value, "ternary_fn_attr", LinalgOperandDefKind::TernaryFnAttr);
216     io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
217   }
218 };
219 
220 /// Iterator type enum.
221 template <>
222 struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
223   static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
224     io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel);
225     io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction);
226   }
227 };
228 
229 /// Metadata about the op (name, C++ name, and documentation).
230 template <>
231 struct MappingTraits<LinalgOpMetadata> {
232   static void mapping(IO &io, LinalgOpMetadata &info) {
233     io.mapRequired("name", info.name);
234     io.mapRequired("cpp_class_name", info.cppClassName);
235     io.mapOptional("doc", info.doc);
236     io.mapOptional("implements", info.implements);
237     io.mapOptional("defines", info.defines);
238   }
239 };
240 
241 /// How the ops indexing maps are produced. Must be one of:
242 ///   - static_indexing_maps: A static list of AffineMaps, possibly with
243 ///     some symbols that bind to attributes of the op. Each indexing map must
244 ///     be normalized over the same list of dimensions, and its symbols must
245 ///     match the symbols for argument shapes.
246 template <>
247 struct MappingTraits<LinalgIndexingMapsConfig> {
248   static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
249     io.mapOptional("static_indexing_maps", info.staticIndexingMaps);
250   }
251 };
252 
253 /// Models an assignment to a named output.
254 ///   - The `arg` name must match a named output.
255 ///   - The `value` is a scalar expression for computing the value to
256 ///     assign (see `ScalarExpression`).
257 template <>
258 struct MappingTraits<ScalarAssign> {
259   static void mapping(IO &io, ScalarAssign &info) {
260     io.mapRequired("arg", info.arg);
261     io.mapRequired("value", info.value);
262   }
263 };
264 
265 /// A scalar expression (RHS of an assignment). Must be one of:
266 ///   - `scalar_arg`: An operation argument.
267 ///   - `scalar_const`: A constant definition.
268 ///   - `scalar_index`: An iteration index.
269 ///   - `scalar_fn`: A named function (see `ScalarFn`).
270 template <>
271 struct MappingTraits<ScalarExpression> {
272   static void mapping(IO &io, ScalarExpression &info) {
273     io.mapOptional("scalar_arg", info.arg);
274     io.mapOptional("scalar_const", info.constant);
275     io.mapOptional("scalar_index", info.index);
276     io.mapOptional("scalar_fn", info.scalarFn);
277   }
278 };
279 
280 /// Scalar function kind enum.
281 template <>
282 struct ScalarEnumerationTraits<ScalarFnKind> {
283   static void enumeration(IO &io, ScalarFnKind &value) {
284     io.enumCase(value, "unary", ScalarFnKind::Unary);
285     io.enumCase(value, "binary", ScalarFnKind::Binary);
286     io.enumCase(value, "ternary", ScalarFnKind::Ternary);
287     io.enumCase(value, "type", ScalarFnKind::Type);
288   }
289 };
290 
291 /// A scalar expression that evaluates a named function.
292 /// Functions are generally "math" level and type polymorphic. Builtin
293 /// functions include:
294 ///   - `add(lhs, rhs)`
295 ///   - `mul(lhs, rhs)`
296 template <>
297 struct MappingTraits<ScalarFn> {
298   static void mapping(IO &io, ScalarFn &info) {
299     io.mapRequired("kind", info.kind);
300     io.mapOptional("fn_name", info.fnName);
301     io.mapOptional("attr_name", info.attrName);
302     io.mapOptional("type_var", info.typeVar);
303     io.mapRequired("operands", info.operands);
304   }
305 };
306 
307 /// Helper mapping which accesses an AffineMapAttr as a serialized string of
308 /// the same.
309 template <>
310 struct ScalarTraits<SerializedAffineMap> {
311   static void output(const SerializedAffineMap &value, void *rawYamlContext,
312                      raw_ostream &out) {
313     assert(value.affineMapAttr);
314     value.affineMapAttr.print(out);
315   }
316   static StringRef input(StringRef scalar, void *rawYamlContext,
317                          SerializedAffineMap &value) {
318     assert(rawYamlContext);
319     auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
320     if (auto attr = dyn_cast_or_null<AffineMapAttr>(
321             mlir::parseAttribute(scalar, yamlContext->mlirContext)))
322       value.affineMapAttr = attr;
323     else if (!value.affineMapAttr || !isa<AffineMapAttr>(value.affineMapAttr))
324       return "could not parse as an affine map attribute";
325     return StringRef();
326   }
327   static QuotingType mustQuote(StringRef) { return QuotingType::None; }
328 };
329 
330 } // namespace yaml
331 } // namespace llvm
332 
333 namespace {
334 
335 //===----------------------------------------------------------------------===//
336 // Generation utilities
337 //===----------------------------------------------------------------------===//
338 
339 class GenerationContext {
340 public:
341   GenerationContext(MLIRContext *context, raw_ostream *odsOut,
342                     raw_ostream *defnOut)
343       : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut),
344         defnOut(defnOut) {}
345 
346   MLIRContext *getContext() { return context; }
347 
348   void setLoc(Location loc) { this->loc = loc; }
349   Location getLoc() { return loc; }
350 
351   bool shouldGenerateOds() { return odsOut; }
352   bool shouldGenerateDefns() { return defnOut; }
353 
354   raw_ostream &odss() {
355     assert(odsOut && "ODS stream not defined");
356     return *odsOut;
357   }
358 
359   raw_ostream &defns() {
360     assert(defnOut && "Definition stream not defined");
361     return *defnOut;
362   }
363 
364 private:
365   MLIRContext *context;
366   Location loc;
367   raw_ostream *odsOut;
368   raw_ostream *defnOut;
369 };
370 
371 } // namespace
372 
373 static std::string generateCppExpression(SerializedAffineMap self,
374                                          StringRef contextName) {
375   std::string printedStr;
376   llvm::raw_string_ostream printedSs(printedStr);
377   self.affineMapAttr.print(printedSs);
378 
379   static const char exprFormat[] =
380       R"FMT(llvm::cast<AffineMapAttr>(mlir::parseAttribute("{0}", {1})).getValue())FMT";
381   return llvm::formatv(exprFormat, printedStr, contextName);
382 }
383 
384 template <typename Container>
385 static std::string interleaveToString(Container &container,
386                                       StringRef separator) {
387   std::string result;
388   llvm::raw_string_ostream ss(result);
389   llvm::interleave(container, ss, separator);
390   return result;
391 }
392 
393 static std::optional<int>
394 findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) {
395   for (const auto &it : llvm::enumerate(args)) {
396     if (it.value().name == name)
397       return it.index();
398   }
399   return std::nullopt;
400 }
401 
402 // Try to map the TypeVar to a predefined or an argument type.
403 static std::optional<std::string>
404 findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
405   // Handle all predefined types.
406   if (typeVar == "I32")
407     return std::string("helper.getIntegerType(32)");
408   if (typeVar == "I64")
409     return std::string("helper.getIntegerType(64)");
410   if (typeVar == "F32")
411     return std::string("helper.getFloat32Type()");
412   if (typeVar == "F64")
413     return std::string("helper.getFloat64Type()");
414 
415   // Search all argument types.
416   for (const auto &it : llvm::enumerate(args)) {
417     if (it.value().kind != LinalgOperandDefKind::InputTensor &&
418         it.value().kind != LinalgOperandDefKind::Scalar &&
419         it.value().kind != LinalgOperandDefKind::OutputTensor)
420       continue;
421     if (*it.value().typeVar == typeVar)
422       return llvm::formatv("block.getArgument({0}).getType()", it.index())
423           .str();
424   }
425 
426   return std::nullopt;
427 }
428 
429 static ScalarAssign *findAssignment(StringRef name,
430                                     std::vector<ScalarAssign> &assignments) {
431   for (auto &assign : assignments) {
432     if (assign.arg == name)
433       return &assign;
434   }
435   return nullptr;
436 }
437 
438 // Return true if the operand is a function attribute.
439 static bool isFunctionAttribute(LinalgOperandDefKind kind) {
440   return kind == LinalgOperandDefKind::UnaryFnAttr ||
441          kind == LinalgOperandDefKind::BinaryFnAttr ||
442          kind == LinalgOperandDefKind::TernaryFnAttr ||
443          kind == LinalgOperandDefKind::TypeFnAttr;
444 }
445 
446 // Return true if the operand is an attribute.
447 static bool isAttribute(LinalgOperandDefKind kind) {
448   return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind);
449 }
450 
451 // Get the enum name for the given operand kind.
452 std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
453   switch (kind) {
454   case LinalgOperandDefKind::UnaryFnAttr:
455     return std::string("UnaryFn");
456   case LinalgOperandDefKind::BinaryFnAttr:
457     return std::string("BinaryFn");
458   case LinalgOperandDefKind::TernaryFnAttr:
459     return std::string("TernaryFn");
460   case LinalgOperandDefKind::TypeFnAttr:
461     return std::string("TypeFn");
462   default:
463     break;
464   }
465   llvm_unreachable("unsupported function attribute kind");
466 }
467 
468 // Get the enum name for the given function kind.
469 std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
470   switch (kind) {
471   case ScalarFnKind::Unary:
472     return std::string("UnaryFn");
473   case ScalarFnKind::Binary:
474     return std::string("BinaryFn");
475   case ScalarFnKind::Ternary:
476     return std::string("TernaryFn");
477   case ScalarFnKind::Type:
478     return std::string("TypeFn");
479   }
480   llvm_unreachable("unsupported function kind");
481 }
482 
483 //===----------------------------------------------------------------------===//
484 // Templates
485 //===----------------------------------------------------------------------===//
486 
487 // A single line banner format. Parameters:
488 // {0}: Single line comment
489 static const char bannerFormat[] = R"FMT(
490 //===----------------------------------------------------------------------===//
491 // {0}
492 //===----------------------------------------------------------------------===//
493 )FMT";
494 
495 //===----------------------------------------------------------------------===//
496 // Named generic op generation.
497 // These ops map at most a single contraction that complies with the limitations
498 // of a linalg.generic.
499 //===----------------------------------------------------------------------===//
500 
501 // Template for Linalg named ops' ODS definitions. Parameters:
502 // {0}: ODS/C++ op name
503 // {1}: assembly op mnemonic
504 // {2}: op interface list
505 // {3}: documentation (summary + description)
506 // {4}: op attribute list
507 // {5}: builder methods taking standalone attribute parameters
508 // {6}: additional method defintions
509 // {7}: additional methods for attributes used by indexing maps
510 static const char structuredOpOdsHeaderFormat[] = R"FMT(
511 //===----------------------------------------------------------------------===//
512 // Op definition for {0}
513 //===----------------------------------------------------------------------===//
514 
515 def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
516   /*extraInterfaces=*/[{2}])> {
517     {3}
518     let arguments = (ins
519       Variadic<AnyType>:$inputs,
520       Variadic<AnyShaped>:$outputs{4}
521     );
522     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
523     let regions = (region AnyRegion:$region);
524 
525     let skipDefaultBuilders = 1;
526     let builders = [
527       OpBuilder<
528       (ins "ValueRange":$inputs, "ValueRange":$outputs,
529             CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
530       [{{
531         buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
532           attributes, {0}::getRegionBuilder());
533       }]>,
534       OpBuilder<
535       (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
536             "ValueRange":$outputs,
537             CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
538       [{{
539         buildStructuredOp($_builder, $_state, resultTensorTypes,
540           inputs, outputs, attributes, {0}::getRegionBuilder());
541       }]>,
542       OpBuilder<
543       (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
544             CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
545       [{{
546         $_state.addOperands(operands);
547         $_state.addAttributes(attributes);
548         $_state.addTypes(resultTensorTypes);
549         (void)$_state.addRegion();
550       }]>
551       {5}
552     ];
553     let hasCustomAssemblyFormat = 1;
554     let hasFolder = 1;
555     {6}
556 
557     let extraClassDeclaration = structuredOpsBaseDecls # [{{
558       // Auto-generated.
559       SmallVector<utils::IteratorType> getIteratorTypesArray();
560       ArrayAttr getIndexingMaps();
561       static void regionBuilder(ImplicitLocOpBuilder &b,
562                                 Block &block, ArrayRef<NamedAttribute> attrs);
563       static std::function<void(ImplicitLocOpBuilder &,
564                                 Block &, ArrayRef<NamedAttribute>)>
565       getRegionBuilder() {{
566         return regionBuilder;
567       }
568 
569       ::mlir::MutableOperandRange getDpsInitsMutable() {{
570         return getOutputsMutable();
571       }
572 
573       // Generic methods.
574       static unsigned getNumRegionArgs();
575       std::string getLibraryCallName();
576       {7}
577     }];
578 }
579 )FMT";
580 
581 // Builder method taking attribute parameters. Parameters:
582 // {0}: Class name
583 // {1}: Comma interleaved attribute parameters
584 // {2}: Attribute initialization
585 static const char structuredOpBuilderFormat[] = R"FMT(
586   , OpBuilder<
587   (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
588        "ValueRange":$outputs, {1},
589        CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
590   [{{
591     {2}
592     buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
593       attributes, {0}::getRegionBuilder());
594   }]>
595 )FMT";
596 
597 // The getIteratorTypesArray() method for structured ops. Parameters:
598 // {0}: Class name
599 // {1}: Comma interleaved iterator type names.
600 static const char structuredOpIteratorTypesFormat[] =
601     R"FMT(
602 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
603   return SmallVector<utils::IteratorType>{{ {1} };
604 }
605 )FMT";
606 
607 // The getIteratorTypesArray() method for rank polymorphic structured ops.
608 // Parameters:
609 // {0}: Class name
610 static const char rankPolyStructuredOpIteratorTypesFormat[] =
611     R"FMT(
612 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
613   int64_t rank = getRank(getDpsInitOperand(0));
614   return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
615 }
616 )FMT";
617 
618 // The indexing_maps() method for structured ops. Parameters:
619 // {0}: Class name
620 // {1}: Comma-separated list of dimension variable names.
621 // {2}: Statements
622 static const char structuredOpIndexingMapsFormat[] = R"FMT(
623 ArrayAttr {0}::getIndexingMaps() {{
624   static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
625   ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
626   if (cached)
627     return cached;
628 
629   MLIRContext *context = getContext();
630   auto symbolBindings = getSymbolBindings(*this);
631   SmallVector<AffineMap> maps;
632   {1}
633   cached = Builder(context).getAffineMapArrayAttr(maps);
634   getOperation()->setAttr(memoizeAttr, cached);
635   return cached;
636 }
637 )FMT";
638 
639 // The indexing_maps() method for rank polymorphic structured ops. Parameters:
640 // {0}: Class name
641 static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT(
642 ArrayAttr {0}::getIndexingMaps() {{
643   MLIRContext *context = getContext();
644   AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
645   AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
646     getNumParallelLoops(), context);
647   SmallVector<AffineMap> indexingMaps;
648   for (OpOperand &opOperand : getOperation()->getOpOperands())
649     indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap);
650   return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
651 }
652 )FMT";
653 
654 // Implementations of fold, getEffects and getSpeculatability.
655 // Parameters:
656 // {0}: Class name
657 const char structuredOpFoldersFormat[] = R"FMT(
658 LogicalResult {0}::fold(FoldAdaptor,
659                         SmallVectorImpl<OpFoldResult> &) {{
660   return memref::foldMemRefCast(*this);
661 }
662 void {0}::getEffects(SmallVectorImpl<
663     SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
664       if (hasPureTensorSemantics()) return;
665       getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
666 }
667 Speculation::Speculatability {0}::getSpeculatability() {{
668   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
669 }
670 )FMT";
671 
672 // Implementation of parse/print.
673 // Parameters:
674 // {0}: Class name
675 static const char structuredOpParserFormat[] = R"FMT(
676 ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
677   return ::parseNamedStructuredOp(parser, result,
678     {0}::getNumRegionArgs(), {0}::getRegionBuilder());
679 }
680 void {0}::print(OpAsmPrinter &p) {{
681   SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes",
682                                            "linalg.memoized_indexing_maps"};
683   ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
684                            elidedAttrs);
685 }
686 )FMT";
687 
688 static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
689                                                GenerationContext &genContext) {
690   if (!genContext.shouldGenerateOds())
691     return success();
692 
693   raw_ostream &os = genContext.odss();
694 
695   std::string interfaceNameList;
696   std::string attrList;
697   std::string attrMethods;
698   std::string attrBuilder;
699 
700   std::string doc;
701   if (opConfig.metadata->doc) {
702     static const char structuredOpDocFmt[] = R"FMT(
703   let summary = [{{{0}}];
704   let description = [{{{1}}];
705 )FMT";
706     StringRef summary, description;
707     std::tie(summary, description) =
708         StringRef(*opConfig.metadata->doc).trim().split("\n\n");
709 
710     doc = llvm::formatv(structuredOpDocFmt, summary.trim(), description.trim());
711   }
712 
713   interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
714 
715   std::string definitionList;
716   for (const std::string &definition : opConfig.metadata->defines) {
717     static const char definitionFmt[] = "let {0} = 1;\n";
718     definitionList.append(llvm::formatv(definitionFmt, definition));
719   }
720 
721   if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
722         return isAttribute(arg.kind);
723       })) {
724     SmallVector<std::string> attrDefs;
725     SmallVector<std::string> attrParams;
726     SmallVector<std::string> attrStmts;
727     for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
728       static const char paramFmt[] = "\"Attribute\":${0}";
729       static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
730       // Add the type conversion attributes to the op definition and builders.
731       if (isFunctionAttribute(arg.kind)) {
732         assert(arg.defaultFn);
733         std::string enumName = convertOperandKindToEnumName(arg.kind);
734         static const char typeFmt[] = "{0}::{1}";
735         static const char defFmt[] =
736             "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}";
737         attrDefs.push_back(llvm::formatv(
738             defFmt, llvm::formatv("{0}Attr", enumName),
739             llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name));
740         attrParams.push_back(llvm::formatv(paramFmt, arg.name));
741         attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
742       }
743       // Add the index attributes to the op definition and builders.
744       if (arg.kind == LinalgOperandDefKind::IndexAttr) {
745         assert(arg.indexAttrMap.has_value());
746         assert(arg.defaultIndices.has_value());
747         size_t size = arg.indexAttrMap->affineMap().getNumResults();
748         assert(arg.defaultIndices->size() == size);
749         static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
750         static const char defFmt[] =
751             "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}";
752         std::string defaultVals;
753         llvm::raw_string_ostream ss(defaultVals);
754         llvm::interleave(
755             *arg.defaultIndices, ss,
756             [&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
757             ", ");
758         attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
759                                          ss.str(), arg.name));
760         attrParams.push_back(llvm::formatv(paramFmt, arg.name));
761         attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
762       }
763     }
764     if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
765           return arg.kind == LinalgOperandDefKind::IndexAttr;
766         })) {
767       attrMethods = R"(
768         bool hasDynamicIndexingMaps();
769         LogicalResult verifyIndexingMapRequiredAttributes();
770       )";
771     }
772     attrList = ",\n" + llvm::join(attrDefs, ",\n");
773     attrBuilder = llvm::formatv(
774         structuredOpBuilderFormat, opConfig.metadata->cppClassName,
775         llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n"));
776   }
777 
778   os << llvm::formatv(structuredOpOdsHeaderFormat,
779                       opConfig.metadata->cppClassName, opConfig.metadata->name,
780                       interfaceNameList, doc, attrList, attrBuilder,
781                       definitionList, attrMethods);
782 
783   return success();
784 }
785 
786 static LogicalResult
787 generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
788                             GenerationContext &genContext) {
789   if (!genContext.shouldGenerateDefns())
790     return success();
791 
792   raw_ostream &os = genContext.defns();
793   StringRef className = opConfig.metadata->cppClassName;
794 
795   // Implementation banner.
796   std::string bannerComment = llvm::formatv("Implementation of {0}", className);
797   os << llvm::formatv(bannerFormat, bannerComment);
798 
799   // Compute the number of scalar and tensor arguments.
800   int64_t numOfArgs =
801       llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
802         return arg.kind == LinalgOperandDefKind::InputTensor ||
803                arg.kind == LinalgOperandDefKind::Scalar ||
804                arg.kind == LinalgOperandDefKind::OutputTensor;
805       });
806 
807   // An operation that accesses only scalars and scalar/rank zero tensors is
808   // rank polymorhpic. We implement rank polymorphism by generating different
809   // indexing maps and iterators that match the rank of the first output tensor.
810   // An operation is rank polymorphic if the iteration domain has rank zero.
811   bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty();
812 
813   // Generate the iterator_types() method.
814   if (!isRankPolymorphic) {
815     std::string iteratorsStr;
816     llvm::raw_string_ostream ss(iteratorsStr);
817     llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss,
818                           [&](LinalgIteratorTypeDef it) {
819                             switch (it) {
820                             case LinalgIteratorTypeDef::parallel:
821                               ss << "utils::IteratorType::parallel";
822                               break;
823                             case LinalgIteratorTypeDef::reduction:
824                               ss << "utils::IteratorType::reduction";
825                               break;
826                             }
827                           });
828     os << llvm::formatv(structuredOpIteratorTypesFormat, className,
829                         iteratorsStr);
830   } else {
831     os << llvm::formatv(rankPolyStructuredOpIteratorTypesFormat, className);
832   }
833 
834   // Generating the getIndexingMaps() method.
835   if (auto &staticMaps =
836           opConfig.structuredOp->indexingMaps.staticIndexingMaps) {
837     if (staticMaps->empty())
838       return emitError(genContext.getLoc()) << "op has no indexing maps";
839     if (!isRankPolymorphic) {
840       AffineMap firstMap = staticMaps->front().affineMap();
841 
842       // Symbol bindings.
843       {
844         // For each symbol, generate a declaration for it, either with an
845         // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
846         // an attribute).
847         // TODO: Possibly lift into a top-level method.
848         static const char structuredOpSymbolBindingsFormat[] = R"FMT(
849 static SmallVector<AffineExpr> getSymbolBindings({0} self) {
850   MLIRContext *context = self.getContext();
851   SmallVector<AffineExpr> exprs;
852 {1}
853   return exprs;
854 }
855 )FMT";
856 
857         unsigned symbolCount = firstMap.getNumSymbols();
858         SmallVector<std::string> symbolBindings;
859         for (unsigned i = 0; i < symbolCount; ++i) {
860           symbolBindings.push_back(llvm::formatv(
861               "  exprs.push_back(getAffineSymbolExpr({0}, context));", i));
862         }
863 
864         // Access an index attribute. Parameters:
865         // {0}: Attribute name
866         // {1}: Symbol position
867         // {2}: Attribute index
868         static const char structuredOpAccessAttrFormat[] = R"FMT(
869 int64_t cst{1} = self.get{0}().getValues<int64_t>()[{2}];
870 exprs.push_back(getAffineConstantExpr(cst{1}, context));
871 )FMT";
872         // Update all symbol bindings mapped to an attribute.
873         for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
874           if (arg.kind != LinalgOperandDefKind::IndexAttr)
875             continue;
876           assert(arg.indexAttrMap);
877           for (auto [idx, result] :
878                llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) {
879             if (auto symbol = dyn_cast<AffineSymbolExpr>(result)) {
880               std::string argName = arg.name;
881               argName[0] = toupper(argName[0]);
882               symbolBindings[symbol.getPosition()] =
883                   llvm::formatv(structuredOpAccessAttrFormat, argName,
884                                 symbol.getPosition(), idx);
885             }
886           }
887         }
888 
889         std::string symbolBindingsStr;
890         llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
891         llvm::interleave(symbolBindings, symbolBindingsSs, "\n");
892 
893         os << llvm::formatv(structuredOpSymbolBindingsFormat, className,
894                             symbolBindingsStr);
895       }
896 
897       // Indexing maps.
898       {
899         unsigned dimCount = firstMap.getNumDims();
900 
901         // Generate a comma-separated list of dim identifiers to be passed to
902         // bindDims, ensuring tht AffineExpr identifiers are bound in the right
903         // order to the proper AffineDimExpr.
904         // This results in vars in scope like: d0, d1, d2...
905         SmallVector<unsigned> dimIndices;
906         for (unsigned i = 0; i < dimCount; ++i)
907           dimIndices.push_back(i);
908         std::string dimIdentsStr;
909         llvm::raw_string_ostream dimIdentsSs(dimIdentsStr);
910         llvm::interleaveComma(dimIndices, dimIdentsSs,
911                               [&](unsigned i) { dimIdentsSs << "d" << i; });
912 
913         // Statements to add and simplify each affine map.
914         SmallVector<std::string> stmts;
915         for (auto &indexingMap : *staticMaps) {
916           // TODO: Assert that dim and symbol count match the first.
917           stmts.push_back(
918               llvm::formatv("maps.push_back({0});",
919                             generateCppExpression(indexingMap, "context")));
920           stmts.push_back(llvm::formatv(
921               "maps.back() = "
922               "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
923               "symbolBindings, {0}, 0));",
924               dimCount));
925         }
926 
927         // TODO: This needs to be memoized and/or converted to non-parser based
928         // C++ codegen prior to real use.
929         os << llvm::formatv(structuredOpIndexingMapsFormat, className,
930                             interleaveToString(stmts, "\n  "));
931       }
932     } else {
933       os << llvm::formatv(rankPolyStructuredOpIndexingMapsFormat, className);
934     }
935   } else {
936     return emitError(genContext.getLoc())
937            << "generating code for non static indexing maps not currently "
938               "supported";
939   }
940 
941   // getNumRegionArgs()
942   {
943     // Generates a getNumRegionArgs() method. Parameters:
944     // {0}: Class name
945     // {1}: Number of region args
946     static const char structuredOpGetNumRegionArgsFormat[] = R"FMT(
947 unsigned {0}::getNumRegionArgs() {{ return {1}; }
948 )FMT";
949     os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className,
950                         numOfArgs);
951   }
952 
953   // getLibraryCallName()
954   {
955     // Generates a getLibraryCallName method. Parameters:
956     // {0}: Class name
957     static const char structuredOpGetLibraryCallFormat[] = R"FMT(
958 std::string {0}::getLibraryCallName() {{
959   return generateLibraryCallName(getOperation());
960 }
961 )FMT";
962     os << llvm::formatv(structuredOpGetLibraryCallFormat, className);
963   }
964 
965   // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
966   if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
967         return arg.kind == LinalgOperandDefKind::IndexAttr;
968       })) {
969     std::vector<std::string> attrVerifications;
970     for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
971       if (arg.kind != LinalgOperandDefKind::IndexAttr)
972         continue;
973       assert(arg.indexAttrMap);
974       // Verify index attribute. Paramters:
975       // {0}: Attribute name
976       // {1}: Attribute size
977       static const char attrFmt[] = R"FMT(
978 if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
979   if (!attr.getType().getElementType().isInteger(64))
980     return op->emitError("incorrect element type for index attribute '{0}'");
981   if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
982     return op->emitError("incorrect shape for index attribute '{0}'");
983 }
984 )FMT";
985       attrVerifications.push_back(llvm::formatv(
986           attrFmt, arg.name, arg.indexAttrMap->affineMap().getNumResults()));
987     }
988 
989     // Generates the verifyIndexingMapRequiredAttributes method. Parameters:
990     // {0}: Class name
991     // {1}: Attribute verification
992     static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT(
993 bool {0}::hasDynamicIndexingMaps() {{ return true; }
994 LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
995   Operation *op = getOperation();
996   {1}
997   return success();
998 }
999 )FMT";
1000     os << llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes,
1001                         className, llvm::join(attrVerifications, "\n"));
1002   }
1003 
1004   // regionBuilder()
1005   {
1006     // Generates a regionBuilder method. Parameters.
1007     // {0}: Class name
1008     // {1}: Number of args
1009     // {2}: Attributes
1010     // {3}: Statements
1011     static const char structuredOpRegionBuilderFormat[] = R"FMT(
1012 void {0}::regionBuilder(ImplicitLocOpBuilder &b,
1013                         Block &block, ArrayRef<NamedAttribute> attrs) {{
1014   assert({1} > 0 && block.getNumArguments() == {1} &&
1015          "{0} regionBuilder expects {1} (>=0) args");
1016   RegionBuilderHelper helper(b, block);
1017   SmallVector<Value> yields;
1018   {2}
1019   {3}
1020   helper.yieldOutputs(yields);
1021 }
1022 )FMT";
1023     auto &args = opConfig.structuredOp->args;
1024     auto &assignments = opConfig.structuredOp->assignments;
1025     size_t generatedAssignmentCount = 0;
1026     int localCounter = 0;
1027     SmallVector<std::string> attrs;
1028     SmallVector<std::string> stmts;
1029     for (LinalgOperandDef &arg : args) {
1030       if (!isFunctionAttribute(arg.kind))
1031         continue;
1032       // Obtain the type function attribute values. Parameters.
1033       // {0}: enum name
1034       // {1}: attribute name
1035       // {2}: default type function name
1036       static const char attrDef[] = R"FMT(
1037   {0} {1}Val = {0}::{2};
1038   auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
1039                                 return attr.getName() == "{1}"; });
1040   if ({1}Iter != attrs.end()) {{
1041     if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
1042       {1}Val = attr.getValue();
1043   }
1044 )FMT";
1045       std::string enumName = convertOperandKindToEnumName(arg.kind);
1046       attrs.push_back(
1047           llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn));
1048     }
1049     for (LinalgOperandDef &arg : args) {
1050       if (arg.kind != LinalgOperandDefKind::OutputTensor)
1051         continue;
1052 
1053       // Find the assignment that correlates with the argument.
1054       ScalarAssign *assignment = findAssignment(arg.name, assignments);
1055       if (!assignment)
1056         return emitError(genContext.getLoc())
1057                << "no assignment found for output argument " << arg.name;
1058       ++generatedAssignmentCount;
1059 
1060       // Recursively generate the expression.
1061       std::function<std::optional<std::string>(ScalarExpression &)>
1062           generateExpression =
1063               [&](ScalarExpression &expression) -> std::optional<std::string> {
1064         if (expression.arg) {
1065           // Argument reference.
1066           std::optional<int> argIndex =
1067               findTensorDefArgIndex(*expression.arg, args);
1068           if (!argIndex) {
1069             emitError(genContext.getLoc())
1070                 << "scalar argument not defined on the op: " << *expression.arg;
1071             return std::nullopt;
1072           }
1073           return std::string(
1074               llvm::formatv("block.getArgument({0})", *argIndex));
1075         }
1076         if (expression.constant) {
1077           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
1078           stmts.push_back(
1079               llvm::formatv(R"FMT(Value {0} = helper.constant("{1}");)FMT",
1080                             cppIdent, expression.constant));
1081           return cppIdent;
1082         }
1083         if (expression.index) {
1084           // Access an iteration index.
1085           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
1086           stmts.push_back(llvm::formatv("Value {0} = helper.index({1});",
1087                                         cppIdent, *expression.index));
1088           return cppIdent;
1089         }
1090         if (expression.scalarFn) {
1091           std::string enumName =
1092               convertFunctionKindToEnumName(expression.scalarFn->kind);
1093 
1094           // Get the function or attribute name.
1095           assert(expression.scalarFn->fnName || expression.scalarFn->attrName);
1096           std::string funcType;
1097           if (expression.scalarFn->fnName) {
1098             funcType = llvm::formatv("{0}::{1}", enumName,
1099                                      *expression.scalarFn->fnName);
1100           }
1101           if (expression.scalarFn->attrName) {
1102             if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
1103                   return isFunctionAttribute(arg.kind) &&
1104                          arg.name == *expression.scalarFn->attrName;
1105                 })) {
1106               emitError(genContext.getLoc()) << "missing function attribute "
1107                                              << *expression.scalarFn->attrName;
1108             }
1109             funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
1110           }
1111           assert(!funcType.empty());
1112 
1113           // Add the optional type parameter to the operands.
1114           SmallVector<std::string> operandCppValues;
1115           if (expression.scalarFn->kind == ScalarFnKind::Type) {
1116             assert(expression.scalarFn->typeVar.has_value());
1117             std::optional<std::string> typeCppValue =
1118                 findTypeValue(*expression.scalarFn->typeVar, args);
1119             if (!typeCppValue) {
1120               emitError(genContext.getLoc())
1121                   << "type variable " << *expression.scalarFn->typeVar
1122                   << ", used in a type conversion, must map to a predefined or "
1123                   << "an argument type but it does not";
1124               return std::nullopt;
1125             }
1126             operandCppValues.push_back(*typeCppValue);
1127           }
1128 
1129           // Collect the scalar operands.
1130           for (ScalarExpression &operand : expression.scalarFn->operands) {
1131             auto operandCppValue = generateExpression(operand);
1132             if (!operandCppValue)
1133               return std::nullopt;
1134             operandCppValues.push_back(*operandCppValue);
1135           }
1136 
1137           // Call the function builder.
1138           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
1139           stmts.push_back(llvm::formatv(
1140               "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName,
1141               funcType, interleaveToString(operandCppValues, ", ")));
1142           return cppIdent;
1143         }
1144         emitError(genContext.getLoc()) << "unknown ScalarExpression type";
1145         return std::nullopt;
1146       };
1147       std::optional<std::string> cppValue =
1148           generateExpression(assignment->value);
1149       if (!cppValue)
1150         return failure();
1151       stmts.push_back(llvm::formatv("yields.push_back({0});", *cppValue));
1152     }
1153 
1154     if (generatedAssignmentCount != assignments.size())
1155       return emitError(genContext.getLoc())
1156              << "mismatched number of assignments vs output arguments";
1157 
1158     os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
1159                         interleaveToString(attrs, "\n  "),
1160                         interleaveToString(stmts, "\n  "));
1161   }
1162 
1163   // Parser and printer.
1164   os << llvm::formatv(structuredOpParserFormat, className);
1165 
1166   // Canonicalizers and folders.
1167   os << llvm::formatv(structuredOpFoldersFormat, className);
1168 
1169   return success();
1170 }
1171 
1172 static LogicalResult generateOp(LinalgOpConfig &opConfig,
1173                                 GenerationContext &genContext) {
1174   // Switch on op type being generated.
1175   if (opConfig.structuredOp) {
1176     return success(
1177         succeeded(generateNamedGenericOpOds(opConfig, genContext)) &&
1178         succeeded(generateNamedGenericOpDefns(opConfig, genContext)));
1179   }
1180   return emitError(genContext.getLoc()) << "unsupported operation type";
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // Command line options and main
1185 //===----------------------------------------------------------------------===//
1186 
1187 static llvm::cl::opt<std::string>
1188     inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
1189                   llvm::cl::init("-"), llvm::cl::value_desc("YAML filename"));
1190 
1191 static llvm::cl::opt<std::string>
1192     outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
1193                           llvm::cl::value_desc("filename"), llvm::cl::init(""));
1194 
1195 static llvm::cl::opt<std::string>
1196     outputCppImplFilename("o-impl",
1197                           llvm::cl::desc("C++ implementation file name"),
1198                           llvm::cl::value_desc("filename"), llvm::cl::init(""));
1199 
1200 int main(int argc, char **argv) {
1201   llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen from YAML");
1202 
1203   // Set up the input file.
1204   std::string errorMessage;
1205   std::unique_ptr<llvm::MemoryBuffer> file =
1206       mlir::openInputFile(inputFilename, &errorMessage);
1207   if (!file) {
1208     llvm::errs() << errorMessage << "\n";
1209     return 1;
1210   }
1211 
1212   MLIRContext mlirContext;
1213   LinalgYAMLContext yamlContext{&mlirContext};
1214 
1215   std::vector<LinalgOpConfig> opConfigs;
1216 
1217   // Parse input.
1218   Input yin(file->getBuffer(), &yamlContext);
1219   yin >> opConfigs;
1220 
1221   if (yin.error())
1222     return 1;
1223 
1224   // Open output files.
1225   std::unique_ptr<llvm::ToolOutputFile> outputOdsDecl;
1226   if (!outputOdsDeclFilename.empty()) {
1227     outputOdsDecl = openOutputFile(outputOdsDeclFilename, &errorMessage);
1228     if (!outputOdsDecl) {
1229       llvm::errs() << errorMessage << "\n";
1230       return 1;
1231     }
1232   }
1233 
1234   std::unique_ptr<llvm::ToolOutputFile> outputCppImpl;
1235   if (!outputCppImplFilename.empty()) {
1236     outputCppImpl = openOutputFile(outputCppImplFilename, &errorMessage);
1237     if (!outputCppImpl) {
1238       llvm::errs() << errorMessage << "\n";
1239       return 1;
1240     }
1241   }
1242 
1243   if (!outputOdsDecl && !outputCppImpl) {
1244     llvm::errs() << "error: No output files specified\n";
1245     return 1;
1246   }
1247 
1248   // Generate.
1249   GenerationContext genContext(&mlirContext,
1250                                outputOdsDecl ? &outputOdsDecl->os() : nullptr,
1251                                outputCppImpl ? &outputCppImpl->os() : nullptr);
1252 
1253   for (auto &opConfig : opConfigs) {
1254     if (!opConfig.metadata) {
1255       emitError(genContext.getLoc())
1256           << "missing operation metadata on subsequent op";
1257       return 1;
1258     }
1259 
1260     genContext.setLoc(NameLoc::get(
1261         StringAttr::get(&mlirContext, opConfig.metadata->cppClassName)));
1262     if (failed(generateOp(opConfig, genContext))) {
1263       return 1;
1264     }
1265   }
1266 
1267   if (outputOdsDecl)
1268     outputOdsDecl->keep();
1269   if (outputCppImpl)
1270     outputCppImpl->keep();
1271 
1272   return 0;
1273 }
1274