xref: /llvm-project/llvm/lib/BinaryFormat/AMDGPUMetadataVerifier.cpp (revision 2946cd701067404b99c39fb29dc9c74bd7193eb3)
1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15 #include "llvm/Support/AMDGPUMetadata.h"
16 
17 namespace llvm {
18 namespace AMDGPU {
19 namespace HSAMD {
20 namespace V3 {
21 
22 bool MetadataVerifier::verifyScalar(
23     msgpack::Node &Node, msgpack::ScalarNode::ScalarKind SKind,
24     function_ref<bool(msgpack::ScalarNode &)> verifyValue) {
25   auto ScalarPtr = dyn_cast<msgpack::ScalarNode>(&Node);
26   if (!ScalarPtr)
27     return false;
28   auto &Scalar = *ScalarPtr;
29   // Do not output extraneous tags for types we know from the spec.
30   Scalar.IgnoreTag = true;
31   if (Scalar.getScalarKind() != SKind) {
32     if (Strict)
33       return false;
34     // If we are not strict, we interpret string values as "implicitly typed"
35     // and attempt to coerce them to the expected type here.
36     if (Scalar.getScalarKind() != msgpack::ScalarNode::SK_String)
37       return false;
38     std::string StringValue = Scalar.getString();
39     Scalar.setScalarKind(SKind);
40     if (Scalar.inputYAML(StringValue) != StringRef())
41       return false;
42   }
43   if (verifyValue)
44     return verifyValue(Scalar);
45   return true;
46 }
47 
48 bool MetadataVerifier::verifyInteger(msgpack::Node &Node) {
49   if (!verifyScalar(Node, msgpack::ScalarNode::SK_UInt))
50     if (!verifyScalar(Node, msgpack::ScalarNode::SK_Int))
51       return false;
52   return true;
53 }
54 
55 bool MetadataVerifier::verifyArray(
56     msgpack::Node &Node, function_ref<bool(msgpack::Node &)> verifyNode,
57     Optional<size_t> Size) {
58   auto ArrayPtr = dyn_cast<msgpack::ArrayNode>(&Node);
59   if (!ArrayPtr)
60     return false;
61   auto &Array = *ArrayPtr;
62   if (Size && Array.size() != *Size)
63     return false;
64   for (auto &Item : Array)
65     if (!verifyNode(*Item.get()))
66       return false;
67 
68   return true;
69 }
70 
71 bool MetadataVerifier::verifyEntry(
72     msgpack::MapNode &MapNode, StringRef Key, bool Required,
73     function_ref<bool(msgpack::Node &)> verifyNode) {
74   auto Entry = MapNode.find(Key);
75   if (Entry == MapNode.end())
76     return !Required;
77   return verifyNode(*Entry->second.get());
78 }
79 
80 bool MetadataVerifier::verifyScalarEntry(
81     msgpack::MapNode &MapNode, StringRef Key, bool Required,
82     msgpack::ScalarNode::ScalarKind SKind,
83     function_ref<bool(msgpack::ScalarNode &)> verifyValue) {
84   return verifyEntry(MapNode, Key, Required, [=](msgpack::Node &Node) {
85     return verifyScalar(Node, SKind, verifyValue);
86   });
87 }
88 
89 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapNode &MapNode,
90                                           StringRef Key, bool Required) {
91   return verifyEntry(MapNode, Key, Required, [this](msgpack::Node &Node) {
92     return verifyInteger(Node);
93   });
94 }
95 
96 bool MetadataVerifier::verifyKernelArgs(msgpack::Node &Node) {
97   auto ArgsMapPtr = dyn_cast<msgpack::MapNode>(&Node);
98   if (!ArgsMapPtr)
99     return false;
100   auto &ArgsMap = *ArgsMapPtr;
101 
102   if (!verifyScalarEntry(ArgsMap, ".name", false,
103                          msgpack::ScalarNode::SK_String))
104     return false;
105   if (!verifyScalarEntry(ArgsMap, ".type_name", false,
106                          msgpack::ScalarNode::SK_String))
107     return false;
108   if (!verifyIntegerEntry(ArgsMap, ".size", true))
109     return false;
110   if (!verifyIntegerEntry(ArgsMap, ".offset", true))
111     return false;
112   if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
113                          msgpack::ScalarNode::SK_String,
114                          [](msgpack::ScalarNode &SNode) {
115                            return StringSwitch<bool>(SNode.getString())
116                                .Case("by_value", true)
117                                .Case("global_buffer", true)
118                                .Case("dynamic_shared_pointer", true)
119                                .Case("sampler", true)
120                                .Case("image", true)
121                                .Case("pipe", true)
122                                .Case("queue", true)
123                                .Case("hidden_global_offset_x", true)
124                                .Case("hidden_global_offset_y", true)
125                                .Case("hidden_global_offset_z", true)
126                                .Case("hidden_none", true)
127                                .Case("hidden_printf_buffer", true)
128                                .Case("hidden_default_queue", true)
129                                .Case("hidden_completion_action", true)
130                                .Default(false);
131                          }))
132     return false;
133   if (!verifyScalarEntry(ArgsMap, ".value_type", true,
134                          msgpack::ScalarNode::SK_String,
135                          [](msgpack::ScalarNode &SNode) {
136                            return StringSwitch<bool>(SNode.getString())
137                                .Case("struct", true)
138                                .Case("i8", true)
139                                .Case("u8", true)
140                                .Case("i16", true)
141                                .Case("u16", true)
142                                .Case("f16", true)
143                                .Case("i32", true)
144                                .Case("u32", true)
145                                .Case("f32", true)
146                                .Case("i64", true)
147                                .Case("u64", true)
148                                .Case("f64", true)
149                                .Default(false);
150                          }))
151     return false;
152   if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
153     return false;
154   if (!verifyScalarEntry(ArgsMap, ".address_space", false,
155                          msgpack::ScalarNode::SK_String,
156                          [](msgpack::ScalarNode &SNode) {
157                            return StringSwitch<bool>(SNode.getString())
158                                .Case("private", true)
159                                .Case("global", true)
160                                .Case("constant", true)
161                                .Case("local", true)
162                                .Case("generic", true)
163                                .Case("region", true)
164                                .Default(false);
165                          }))
166     return false;
167   if (!verifyScalarEntry(ArgsMap, ".access", false,
168                          msgpack::ScalarNode::SK_String,
169                          [](msgpack::ScalarNode &SNode) {
170                            return StringSwitch<bool>(SNode.getString())
171                                .Case("read_only", true)
172                                .Case("write_only", true)
173                                .Case("read_write", true)
174                                .Default(false);
175                          }))
176     return false;
177   if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
178                          msgpack::ScalarNode::SK_String,
179                          [](msgpack::ScalarNode &SNode) {
180                            return StringSwitch<bool>(SNode.getString())
181                                .Case("read_only", true)
182                                .Case("write_only", true)
183                                .Case("read_write", true)
184                                .Default(false);
185                          }))
186     return false;
187   if (!verifyScalarEntry(ArgsMap, ".is_const", false,
188                          msgpack::ScalarNode::SK_Boolean))
189     return false;
190   if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
191                          msgpack::ScalarNode::SK_Boolean))
192     return false;
193   if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
194                          msgpack::ScalarNode::SK_Boolean))
195     return false;
196   if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
197                          msgpack::ScalarNode::SK_Boolean))
198     return false;
199 
200   return true;
201 }
202 
203 bool MetadataVerifier::verifyKernel(msgpack::Node &Node) {
204   auto KernelMapPtr = dyn_cast<msgpack::MapNode>(&Node);
205   if (!KernelMapPtr)
206     return false;
207   auto &KernelMap = *KernelMapPtr;
208 
209   if (!verifyScalarEntry(KernelMap, ".name", true,
210                          msgpack::ScalarNode::SK_String))
211     return false;
212   if (!verifyScalarEntry(KernelMap, ".symbol", true,
213                          msgpack::ScalarNode::SK_String))
214     return false;
215   if (!verifyScalarEntry(KernelMap, ".language", false,
216                          msgpack::ScalarNode::SK_String,
217                          [](msgpack::ScalarNode &SNode) {
218                            return StringSwitch<bool>(SNode.getString())
219                                .Case("OpenCL C", true)
220                                .Case("OpenCL C++", true)
221                                .Case("HCC", true)
222                                .Case("HIP", true)
223                                .Case("OpenMP", true)
224                                .Case("Assembler", true)
225                                .Default(false);
226                          }))
227     return false;
228   if (!verifyEntry(
229           KernelMap, ".language_version", false, [this](msgpack::Node &Node) {
230             return verifyArray(
231                 Node,
232                 [this](msgpack::Node &Node) { return verifyInteger(Node); }, 2);
233           }))
234     return false;
235   if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::Node &Node) {
236         return verifyArray(Node, [this](msgpack::Node &Node) {
237           return verifyKernelArgs(Node);
238         });
239       }))
240     return false;
241   if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
242                    [this](msgpack::Node &Node) {
243                      return verifyArray(Node,
244                                         [this](msgpack::Node &Node) {
245                                           return verifyInteger(Node);
246                                         },
247                                         3);
248                    }))
249     return false;
250   if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
251                    [this](msgpack::Node &Node) {
252                      return verifyArray(Node,
253                                         [this](msgpack::Node &Node) {
254                                           return verifyInteger(Node);
255                                         },
256                                         3);
257                    }))
258     return false;
259   if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
260                          msgpack::ScalarNode::SK_String))
261     return false;
262   if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
263                          msgpack::ScalarNode::SK_String))
264     return false;
265   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
266     return false;
267   if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
268     return false;
269   if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
270     return false;
271   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
272     return false;
273   if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
274     return false;
275   if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
276     return false;
277   if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
278     return false;
279   if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
280     return false;
281   if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
282     return false;
283   if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
284     return false;
285 
286   return true;
287 }
288 
289 bool MetadataVerifier::verify(msgpack::Node &HSAMetadataRoot) {
290   auto RootMapPtr = dyn_cast<msgpack::MapNode>(&HSAMetadataRoot);
291   if (!RootMapPtr)
292     return false;
293   auto &RootMap = *RootMapPtr;
294 
295   if (!verifyEntry(
296           RootMap, "amdhsa.version", true, [this](msgpack::Node &Node) {
297             return verifyArray(
298                 Node,
299                 [this](msgpack::Node &Node) { return verifyInteger(Node); }, 2);
300           }))
301     return false;
302   if (!verifyEntry(
303           RootMap, "amdhsa.printf", false, [this](msgpack::Node &Node) {
304             return verifyArray(Node, [this](msgpack::Node &Node) {
305               return verifyScalar(Node, msgpack::ScalarNode::SK_String);
306             });
307           }))
308     return false;
309   if (!verifyEntry(RootMap, "amdhsa.kernels", true,
310                    [this](msgpack::Node &Node) {
311                      return verifyArray(Node, [this](msgpack::Node &Node) {
312                        return verifyKernel(Node);
313                      });
314                    }))
315     return false;
316 
317   return true;
318 }
319 
320 } // end namespace V3
321 } // end namespace HSAMD
322 } // end namespace AMDGPU
323 } // end namespace llvm
324