xref: /llvm-project/llvm/lib/BinaryFormat/AMDGPUMetadataVerifier.cpp (revision ed0b9af9973e9f714a6e35d858a55bca5c7529b6)
1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15 #include "llvm/Support/AMDGPUMetadata.h"
16 
17 namespace llvm {
18 namespace AMDGPU {
19 namespace HSAMD {
20 namespace V3 {
21 
22 bool MetadataVerifier::verifyScalar(
23     msgpack::DocNode &Node, msgpack::Type SKind,
24     function_ref<bool(msgpack::DocNode &)> verifyValue) {
25   if (!Node.isScalar())
26     return false;
27   if (Node.getKind() != SKind) {
28     if (Strict)
29       return false;
30     // If we are not strict, we interpret string values as "implicitly typed"
31     // and attempt to coerce them to the expected type here.
32     if (Node.getKind() != msgpack::Type::String)
33       return false;
34     StringRef StringValue = Node.getString();
35     Node.fromString(StringValue);
36     if (Node.getKind() != SKind)
37       return false;
38   }
39   if (verifyValue)
40     return verifyValue(Node);
41   return true;
42 }
43 
44 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
45   if (!verifyScalar(Node, msgpack::Type::UInt))
46     if (!verifyScalar(Node, msgpack::Type::Int))
47       return false;
48   return true;
49 }
50 
51 bool MetadataVerifier::verifyArray(
52     msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
53     Optional<size_t> Size) {
54   if (!Node.isArray())
55     return false;
56   auto &Array = Node.getArray();
57   if (Size && Array.size() != *Size)
58     return false;
59   for (auto &Item : Array)
60     if (!verifyNode(Item))
61       return false;
62 
63   return true;
64 }
65 
66 bool MetadataVerifier::verifyEntry(
67     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
68     function_ref<bool(msgpack::DocNode &)> verifyNode) {
69   auto Entry = MapNode.find(Key);
70   if (Entry == MapNode.end())
71     return !Required;
72   return verifyNode(Entry->second);
73 }
74 
75 bool MetadataVerifier::verifyScalarEntry(
76     msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
77     msgpack::Type SKind,
78     function_ref<bool(msgpack::DocNode &)> verifyValue) {
79   return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
80     return verifyScalar(Node, SKind, verifyValue);
81   });
82 }
83 
84 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
85                                           StringRef Key, bool Required) {
86   return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
87     return verifyInteger(Node);
88   });
89 }
90 
91 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
92   if (!Node.isMap())
93     return false;
94   auto &ArgsMap = Node.getMap();
95 
96   if (!verifyScalarEntry(ArgsMap, ".name", false,
97                          msgpack::Type::String))
98     return false;
99   if (!verifyScalarEntry(ArgsMap, ".type_name", false,
100                          msgpack::Type::String))
101     return false;
102   if (!verifyIntegerEntry(ArgsMap, ".size", true))
103     return false;
104   if (!verifyIntegerEntry(ArgsMap, ".offset", true))
105     return false;
106   if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
107                          msgpack::Type::String,
108                          [](msgpack::DocNode &SNode) {
109                            return StringSwitch<bool>(SNode.getString())
110                                .Case("by_value", true)
111                                .Case("global_buffer", true)
112                                .Case("dynamic_shared_pointer", true)
113                                .Case("sampler", true)
114                                .Case("image", true)
115                                .Case("pipe", true)
116                                .Case("queue", true)
117                                .Case("hidden_global_offset_x", true)
118                                .Case("hidden_global_offset_y", true)
119                                .Case("hidden_global_offset_z", true)
120                                .Case("hidden_none", true)
121                                .Case("hidden_printf_buffer", true)
122                                .Case("hidden_default_queue", true)
123                                .Case("hidden_completion_action", true)
124                                .Default(false);
125                          }))
126     return false;
127   if (!verifyScalarEntry(ArgsMap, ".value_type", true,
128                          msgpack::Type::String,
129                          [](msgpack::DocNode &SNode) {
130                            return StringSwitch<bool>(SNode.getString())
131                                .Case("struct", true)
132                                .Case("i8", true)
133                                .Case("u8", true)
134                                .Case("i16", true)
135                                .Case("u16", true)
136                                .Case("f16", true)
137                                .Case("i32", true)
138                                .Case("u32", true)
139                                .Case("f32", true)
140                                .Case("i64", true)
141                                .Case("u64", true)
142                                .Case("f64", true)
143                                .Default(false);
144                          }))
145     return false;
146   if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
147     return false;
148   if (!verifyScalarEntry(ArgsMap, ".address_space", false,
149                          msgpack::Type::String,
150                          [](msgpack::DocNode &SNode) {
151                            return StringSwitch<bool>(SNode.getString())
152                                .Case("private", true)
153                                .Case("global", true)
154                                .Case("constant", true)
155                                .Case("local", true)
156                                .Case("generic", true)
157                                .Case("region", true)
158                                .Default(false);
159                          }))
160     return false;
161   if (!verifyScalarEntry(ArgsMap, ".access", false,
162                          msgpack::Type::String,
163                          [](msgpack::DocNode &SNode) {
164                            return StringSwitch<bool>(SNode.getString())
165                                .Case("read_only", true)
166                                .Case("write_only", true)
167                                .Case("read_write", true)
168                                .Default(false);
169                          }))
170     return false;
171   if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
172                          msgpack::Type::String,
173                          [](msgpack::DocNode &SNode) {
174                            return StringSwitch<bool>(SNode.getString())
175                                .Case("read_only", true)
176                                .Case("write_only", true)
177                                .Case("read_write", true)
178                                .Default(false);
179                          }))
180     return false;
181   if (!verifyScalarEntry(ArgsMap, ".is_const", false,
182                          msgpack::Type::Boolean))
183     return false;
184   if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
185                          msgpack::Type::Boolean))
186     return false;
187   if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
188                          msgpack::Type::Boolean))
189     return false;
190   if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
191                          msgpack::Type::Boolean))
192     return false;
193 
194   return true;
195 }
196 
197 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
198   if (!Node.isMap())
199     return false;
200   auto &KernelMap = Node.getMap();
201 
202   if (!verifyScalarEntry(KernelMap, ".name", true,
203                          msgpack::Type::String))
204     return false;
205   if (!verifyScalarEntry(KernelMap, ".symbol", true,
206                          msgpack::Type::String))
207     return false;
208   if (!verifyScalarEntry(KernelMap, ".language", false,
209                          msgpack::Type::String,
210                          [](msgpack::DocNode &SNode) {
211                            return StringSwitch<bool>(SNode.getString())
212                                .Case("OpenCL C", true)
213                                .Case("OpenCL C++", true)
214                                .Case("HCC", true)
215                                .Case("HIP", true)
216                                .Case("OpenMP", true)
217                                .Case("Assembler", true)
218                                .Default(false);
219                          }))
220     return false;
221   if (!verifyEntry(
222           KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
223             return verifyArray(
224                 Node,
225                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
226           }))
227     return false;
228   if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
229         return verifyArray(Node, [this](msgpack::DocNode &Node) {
230           return verifyKernelArgs(Node);
231         });
232       }))
233     return false;
234   if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
235                    [this](msgpack::DocNode &Node) {
236                      return verifyArray(Node,
237                                         [this](msgpack::DocNode &Node) {
238                                           return verifyInteger(Node);
239                                         },
240                                         3);
241                    }))
242     return false;
243   if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
244                    [this](msgpack::DocNode &Node) {
245                      return verifyArray(Node,
246                                         [this](msgpack::DocNode &Node) {
247                                           return verifyInteger(Node);
248                                         },
249                                         3);
250                    }))
251     return false;
252   if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
253                          msgpack::Type::String))
254     return false;
255   if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
256                          msgpack::Type::String))
257     return false;
258   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
259     return false;
260   if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
261     return false;
262   if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
263     return false;
264   if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
265     return false;
266   if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
267     return false;
268   if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
269     return false;
270   if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
271     return false;
272   if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
273     return false;
274   if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
275     return false;
276   if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
277     return false;
278 
279   return true;
280 }
281 
282 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
283   if (!HSAMetadataRoot.isMap())
284     return false;
285   auto &RootMap = HSAMetadataRoot.getMap();
286 
287   if (!verifyEntry(
288           RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
289             return verifyArray(
290                 Node,
291                 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
292           }))
293     return false;
294   if (!verifyEntry(
295           RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
296             return verifyArray(Node, [this](msgpack::DocNode &Node) {
297               return verifyScalar(Node, msgpack::Type::String);
298             });
299           }))
300     return false;
301   if (!verifyEntry(RootMap, "amdhsa.kernels", true,
302                    [this](msgpack::DocNode &Node) {
303                      return verifyArray(Node, [this](msgpack::DocNode &Node) {
304                        return verifyKernel(Node);
305                      });
306                    }))
307     return false;
308 
309   return true;
310 }
311 
312 } // end namespace V3
313 } // end namespace HSAMD
314 } // end namespace AMDGPU
315 } // end namespace llvm
316