1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// Implements a verifier for AMDGPU HSA metadata. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h" 15 16 #include "llvm/ADT/STLExtras.h" 17 #include "llvm/ADT/STLForwardCompat.h" 18 #include "llvm/ADT/StringSwitch.h" 19 #include "llvm/BinaryFormat/MsgPackDocument.h" 20 21 #include <map> 22 #include <utility> 23 24 namespace llvm { 25 namespace AMDGPU { 26 namespace HSAMD { 27 namespace V3 { 28 29 bool MetadataVerifier::verifyScalar( 30 msgpack::DocNode &Node, msgpack::Type SKind, 31 function_ref<bool(msgpack::DocNode &)> verifyValue) { 32 if (!Node.isScalar()) 33 return false; 34 if (Node.getKind() != SKind) { 35 if (Strict) 36 return false; 37 // If we are not strict, we interpret string values as "implicitly typed" 38 // and attempt to coerce them to the expected type here. 39 if (Node.getKind() != msgpack::Type::String) 40 return false; 41 StringRef StringValue = Node.getString(); 42 Node.fromString(StringValue); 43 if (Node.getKind() != SKind) 44 return false; 45 } 46 if (verifyValue) 47 return verifyValue(Node); 48 return true; 49 } 50 51 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) { 52 if (!verifyScalar(Node, msgpack::Type::UInt)) 53 if (!verifyScalar(Node, msgpack::Type::Int)) 54 return false; 55 return true; 56 } 57 58 bool MetadataVerifier::verifyArray( 59 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode, 60 Optional<size_t> Size) { 61 if (!Node.isArray()) 62 return false; 63 auto &Array = Node.getArray(); 64 if (Size && Array.size() != *Size) 65 return false; 66 return llvm::all_of(Array, verifyNode); 67 } 68 69 bool MetadataVerifier::verifyEntry( 70 msgpack::MapDocNode &MapNode, StringRef Key, bool Required, 71 function_ref<bool(msgpack::DocNode &)> verifyNode) { 72 auto Entry = MapNode.find(Key); 73 if (Entry == MapNode.end()) 74 return !Required; 75 return verifyNode(Entry->second); 76 } 77 78 bool MetadataVerifier::verifyScalarEntry( 79 msgpack::MapDocNode &MapNode, StringRef Key, bool Required, 80 msgpack::Type SKind, 81 function_ref<bool(msgpack::DocNode &)> verifyValue) { 82 return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) { 83 return verifyScalar(Node, SKind, verifyValue); 84 }); 85 } 86 87 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode, 88 StringRef Key, bool Required) { 89 return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) { 90 return verifyInteger(Node); 91 }); 92 } 93 94 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) { 95 if (!Node.isMap()) 96 return false; 97 auto &ArgsMap = Node.getMap(); 98 99 if (!verifyScalarEntry(ArgsMap, ".name", false, 100 msgpack::Type::String)) 101 return false; 102 if (!verifyScalarEntry(ArgsMap, ".type_name", false, 103 msgpack::Type::String)) 104 return false; 105 if (!verifyIntegerEntry(ArgsMap, ".size", true)) 106 return false; 107 if (!verifyIntegerEntry(ArgsMap, ".offset", true)) 108 return false; 109 if (!verifyScalarEntry(ArgsMap, ".value_kind", true, 110 msgpack::Type::String, 111 [](msgpack::DocNode &SNode) { 112 return StringSwitch<bool>(SNode.getString()) 113 .Case("by_value", true) 114 .Case("global_buffer", true) 115 .Case("dynamic_shared_pointer", true) 116 .Case("sampler", true) 117 .Case("image", true) 118 .Case("pipe", true) 119 .Case("queue", true) 120 .Case("hidden_global_offset_x", true) 121 .Case("hidden_global_offset_y", true) 122 .Case("hidden_global_offset_z", true) 123 .Case("hidden_none", true) 124 .Case("hidden_printf_buffer", true) 125 .Case("hidden_hostcall_buffer", true) 126 .Case("hidden_default_queue", true) 127 .Case("hidden_completion_action", true) 128 .Case("hidden_multigrid_sync_arg", true) 129 .Default(false); 130 })) 131 return false; 132 if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false)) 133 return false; 134 if (!verifyScalarEntry(ArgsMap, ".address_space", false, 135 msgpack::Type::String, 136 [](msgpack::DocNode &SNode) { 137 return StringSwitch<bool>(SNode.getString()) 138 .Case("private", true) 139 .Case("global", true) 140 .Case("constant", true) 141 .Case("local", true) 142 .Case("generic", true) 143 .Case("region", true) 144 .Default(false); 145 })) 146 return false; 147 if (!verifyScalarEntry(ArgsMap, ".access", false, 148 msgpack::Type::String, 149 [](msgpack::DocNode &SNode) { 150 return StringSwitch<bool>(SNode.getString()) 151 .Case("read_only", true) 152 .Case("write_only", true) 153 .Case("read_write", true) 154 .Default(false); 155 })) 156 return false; 157 if (!verifyScalarEntry(ArgsMap, ".actual_access", false, 158 msgpack::Type::String, 159 [](msgpack::DocNode &SNode) { 160 return StringSwitch<bool>(SNode.getString()) 161 .Case("read_only", true) 162 .Case("write_only", true) 163 .Case("read_write", true) 164 .Default(false); 165 })) 166 return false; 167 if (!verifyScalarEntry(ArgsMap, ".is_const", false, 168 msgpack::Type::Boolean)) 169 return false; 170 if (!verifyScalarEntry(ArgsMap, ".is_restrict", false, 171 msgpack::Type::Boolean)) 172 return false; 173 if (!verifyScalarEntry(ArgsMap, ".is_volatile", false, 174 msgpack::Type::Boolean)) 175 return false; 176 if (!verifyScalarEntry(ArgsMap, ".is_pipe", false, 177 msgpack::Type::Boolean)) 178 return false; 179 180 return true; 181 } 182 183 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) { 184 if (!Node.isMap()) 185 return false; 186 auto &KernelMap = Node.getMap(); 187 188 if (!verifyScalarEntry(KernelMap, ".name", true, 189 msgpack::Type::String)) 190 return false; 191 if (!verifyScalarEntry(KernelMap, ".symbol", true, 192 msgpack::Type::String)) 193 return false; 194 if (!verifyScalarEntry(KernelMap, ".language", false, 195 msgpack::Type::String, 196 [](msgpack::DocNode &SNode) { 197 return StringSwitch<bool>(SNode.getString()) 198 .Case("OpenCL C", true) 199 .Case("OpenCL C++", true) 200 .Case("HCC", true) 201 .Case("HIP", true) 202 .Case("OpenMP", true) 203 .Case("Assembler", true) 204 .Default(false); 205 })) 206 return false; 207 if (!verifyEntry( 208 KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) { 209 return verifyArray( 210 Node, 211 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2); 212 })) 213 return false; 214 if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) { 215 return verifyArray(Node, [this](msgpack::DocNode &Node) { 216 return verifyKernelArgs(Node); 217 }); 218 })) 219 return false; 220 if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false, 221 [this](msgpack::DocNode &Node) { 222 return verifyArray(Node, 223 [this](msgpack::DocNode &Node) { 224 return verifyInteger(Node); 225 }, 226 3); 227 })) 228 return false; 229 if (!verifyEntry(KernelMap, ".workgroup_size_hint", false, 230 [this](msgpack::DocNode &Node) { 231 return verifyArray(Node, 232 [this](msgpack::DocNode &Node) { 233 return verifyInteger(Node); 234 }, 235 3); 236 })) 237 return false; 238 if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false, 239 msgpack::Type::String)) 240 return false; 241 if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false, 242 msgpack::Type::String)) 243 return false; 244 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true)) 245 return false; 246 if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true)) 247 return false; 248 if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true)) 249 return false; 250 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true)) 251 return false; 252 if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true)) 253 return false; 254 if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true)) 255 return false; 256 if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true)) 257 return false; 258 if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true)) 259 return false; 260 if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false)) 261 return false; 262 if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false)) 263 return false; 264 265 return true; 266 } 267 268 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) { 269 if (!HSAMetadataRoot.isMap()) 270 return false; 271 auto &RootMap = HSAMetadataRoot.getMap(); 272 273 if (!verifyEntry( 274 RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) { 275 return verifyArray( 276 Node, 277 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2); 278 })) 279 return false; 280 if (!verifyEntry( 281 RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) { 282 return verifyArray(Node, [this](msgpack::DocNode &Node) { 283 return verifyScalar(Node, msgpack::Type::String); 284 }); 285 })) 286 return false; 287 if (!verifyEntry(RootMap, "amdhsa.kernels", true, 288 [this](msgpack::DocNode &Node) { 289 return verifyArray(Node, [this](msgpack::DocNode &Node) { 290 return verifyKernel(Node); 291 }); 292 })) 293 return false; 294 295 return true; 296 } 297 298 } // end namespace V3 299 } // end namespace HSAMD 300 } // end namespace AMDGPU 301 } // end namespace llvm 302