1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/StringSwitch.h"
18 #include "llvm/BinaryFormat/MsgPackDocument.h"
19
20 #include <utility>
21
22 namespace llvm {
23 namespace AMDGPU {
24 namespace HSAMD {
25 namespace V3 {
26
verifyScalar(msgpack::DocNode & Node,msgpack::Type SKind,function_ref<bool (msgpack::DocNode &)> verifyValue)27 bool MetadataVerifier::verifyScalar(
28 msgpack::DocNode &Node, msgpack::Type SKind,
29 function_ref<bool(msgpack::DocNode &)> verifyValue) {
30 if (!Node.isScalar())
31 return false;
32 if (Node.getKind() != SKind) {
33 if (Strict)
34 return false;
35 // If we are not strict, we interpret string values as "implicitly typed"
36 // and attempt to coerce them to the expected type here.
37 if (Node.getKind() != msgpack::Type::String)
38 return false;
39 StringRef StringValue = Node.getString();
40 Node.fromString(StringValue);
41 if (Node.getKind() != SKind)
42 return false;
43 }
44 if (verifyValue)
45 return verifyValue(Node);
46 return true;
47 }
48
verifyInteger(msgpack::DocNode & Node)49 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
50 if (!verifyScalar(Node, msgpack::Type::UInt))
51 if (!verifyScalar(Node, msgpack::Type::Int))
52 return false;
53 return true;
54 }
55
verifyArray(msgpack::DocNode & Node,function_ref<bool (msgpack::DocNode &)> verifyNode,std::optional<size_t> Size)56 bool MetadataVerifier::verifyArray(
57 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
58 std::optional<size_t> Size) {
59 if (!Node.isArray())
60 return false;
61 auto &Array = Node.getArray();
62 if (Size && Array.size() != *Size)
63 return false;
64 return llvm::all_of(Array, verifyNode);
65 }
66
verifyEntry(msgpack::MapDocNode & MapNode,StringRef Key,bool Required,function_ref<bool (msgpack::DocNode &)> verifyNode)67 bool MetadataVerifier::verifyEntry(
68 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
69 function_ref<bool(msgpack::DocNode &)> verifyNode) {
70 auto Entry = MapNode.find(Key);
71 if (Entry == MapNode.end())
72 return !Required;
73 return verifyNode(Entry->second);
74 }
75
verifyScalarEntry(msgpack::MapDocNode & MapNode,StringRef Key,bool Required,msgpack::Type SKind,function_ref<bool (msgpack::DocNode &)> verifyValue)76 bool MetadataVerifier::verifyScalarEntry(
77 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
78 msgpack::Type SKind,
79 function_ref<bool(msgpack::DocNode &)> verifyValue) {
80 return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
81 return verifyScalar(Node, SKind, verifyValue);
82 });
83 }
84
verifyIntegerEntry(msgpack::MapDocNode & MapNode,StringRef Key,bool Required)85 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
86 StringRef Key, bool Required) {
87 return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
88 return verifyInteger(Node);
89 });
90 }
91
verifyKernelArgs(msgpack::DocNode & Node)92 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
93 if (!Node.isMap())
94 return false;
95 auto &ArgsMap = Node.getMap();
96
97 if (!verifyScalarEntry(ArgsMap, ".name", false,
98 msgpack::Type::String))
99 return false;
100 if (!verifyScalarEntry(ArgsMap, ".type_name", false,
101 msgpack::Type::String))
102 return false;
103 if (!verifyIntegerEntry(ArgsMap, ".size", true))
104 return false;
105 if (!verifyIntegerEntry(ArgsMap, ".offset", true))
106 return false;
107 if (!verifyScalarEntry(ArgsMap, ".value_kind", true, msgpack::Type::String,
108 [](msgpack::DocNode &SNode) {
109 return StringSwitch<bool>(SNode.getString())
110 .Case("by_value", true)
111 .Case("global_buffer", true)
112 .Case("dynamic_shared_pointer", true)
113 .Case("sampler", true)
114 .Case("image", true)
115 .Case("pipe", true)
116 .Case("queue", true)
117 .Case("hidden_block_count_x", true)
118 .Case("hidden_block_count_y", true)
119 .Case("hidden_block_count_z", true)
120 .Case("hidden_group_size_x", true)
121 .Case("hidden_group_size_y", true)
122 .Case("hidden_group_size_z", true)
123 .Case("hidden_remainder_x", true)
124 .Case("hidden_remainder_y", true)
125 .Case("hidden_remainder_z", true)
126 .Case("hidden_global_offset_x", true)
127 .Case("hidden_global_offset_y", true)
128 .Case("hidden_global_offset_z", true)
129 .Case("hidden_grid_dims", true)
130 .Case("hidden_none", true)
131 .Case("hidden_printf_buffer", true)
132 .Case("hidden_hostcall_buffer", true)
133 .Case("hidden_heap_v1", true)
134 .Case("hidden_default_queue", true)
135 .Case("hidden_completion_action", true)
136 .Case("hidden_multigrid_sync_arg", true)
137 .Case("hidden_dynamic_lds_size", true)
138 .Case("hidden_private_base", true)
139 .Case("hidden_shared_base", true)
140 .Case("hidden_queue_ptr", true)
141 .Default(false);
142 }))
143 return false;
144 if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
145 return false;
146 if (!verifyScalarEntry(ArgsMap, ".address_space", false,
147 msgpack::Type::String,
148 [](msgpack::DocNode &SNode) {
149 return StringSwitch<bool>(SNode.getString())
150 .Case("private", true)
151 .Case("global", true)
152 .Case("constant", true)
153 .Case("local", true)
154 .Case("generic", true)
155 .Case("region", true)
156 .Default(false);
157 }))
158 return false;
159 if (!verifyScalarEntry(ArgsMap, ".access", false,
160 msgpack::Type::String,
161 [](msgpack::DocNode &SNode) {
162 return StringSwitch<bool>(SNode.getString())
163 .Case("read_only", true)
164 .Case("write_only", true)
165 .Case("read_write", true)
166 .Default(false);
167 }))
168 return false;
169 if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
170 msgpack::Type::String,
171 [](msgpack::DocNode &SNode) {
172 return StringSwitch<bool>(SNode.getString())
173 .Case("read_only", true)
174 .Case("write_only", true)
175 .Case("read_write", true)
176 .Default(false);
177 }))
178 return false;
179 if (!verifyScalarEntry(ArgsMap, ".is_const", false,
180 msgpack::Type::Boolean))
181 return false;
182 if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
183 msgpack::Type::Boolean))
184 return false;
185 if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
186 msgpack::Type::Boolean))
187 return false;
188 if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
189 msgpack::Type::Boolean))
190 return false;
191
192 return true;
193 }
194
verifyKernel(msgpack::DocNode & Node)195 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
196 if (!Node.isMap())
197 return false;
198 auto &KernelMap = Node.getMap();
199
200 if (!verifyScalarEntry(KernelMap, ".name", true,
201 msgpack::Type::String))
202 return false;
203 if (!verifyScalarEntry(KernelMap, ".symbol", true,
204 msgpack::Type::String))
205 return false;
206 if (!verifyScalarEntry(KernelMap, ".language", false,
207 msgpack::Type::String,
208 [](msgpack::DocNode &SNode) {
209 return StringSwitch<bool>(SNode.getString())
210 .Case("OpenCL C", true)
211 .Case("OpenCL C++", true)
212 .Case("HCC", true)
213 .Case("HIP", true)
214 .Case("OpenMP", true)
215 .Case("Assembler", true)
216 .Default(false);
217 }))
218 return false;
219 if (!verifyEntry(
220 KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
221 return verifyArray(
222 Node,
223 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
224 }))
225 return false;
226 if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
227 return verifyArray(Node, [this](msgpack::DocNode &Node) {
228 return verifyKernelArgs(Node);
229 });
230 }))
231 return false;
232 if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
233 [this](msgpack::DocNode &Node) {
234 return verifyArray(Node,
235 [this](msgpack::DocNode &Node) {
236 return verifyInteger(Node);
237 },
238 3);
239 }))
240 return false;
241 if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
242 [this](msgpack::DocNode &Node) {
243 return verifyArray(Node,
244 [this](msgpack::DocNode &Node) {
245 return verifyInteger(Node);
246 },
247 3);
248 }))
249 return false;
250 if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
251 msgpack::Type::String))
252 return false;
253 if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
254 msgpack::Type::String))
255 return false;
256 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
257 return false;
258 if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
259 return false;
260 if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
261 return false;
262 if (!verifyScalarEntry(KernelMap, ".uses_dynamic_stack", false,
263 msgpack::Type::Boolean))
264 return false;
265 if (!verifyIntegerEntry(KernelMap, ".workgroup_processor_mode", false))
266 return false;
267 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
268 return false;
269 if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
270 return false;
271 if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
272 return false;
273 if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
274 return false;
275 if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
276 return false;
277 if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
278 return false;
279 if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
280 return false;
281 if (!verifyIntegerEntry(KernelMap, ".uniform_work_group_size", false))
282 return false;
283
284
285 return true;
286 }
287
verify(msgpack::DocNode & HSAMetadataRoot)288 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
289 if (!HSAMetadataRoot.isMap())
290 return false;
291 auto &RootMap = HSAMetadataRoot.getMap();
292
293 if (!verifyEntry(
294 RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
295 return verifyArray(
296 Node,
297 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
298 }))
299 return false;
300 if (!verifyEntry(
301 RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
302 return verifyArray(Node, [this](msgpack::DocNode &Node) {
303 return verifyScalar(Node, msgpack::Type::String);
304 });
305 }))
306 return false;
307 if (!verifyEntry(RootMap, "amdhsa.kernels", true,
308 [this](msgpack::DocNode &Node) {
309 return verifyArray(Node, [this](msgpack::DocNode &Node) {
310 return verifyKernel(Node);
311 });
312 }))
313 return false;
314
315 return true;
316 }
317
318 } // end namespace V3
319 } // end namespace HSAMD
320 } // end namespace AMDGPU
321 } // end namespace llvm
322