1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// Implements a verifier for AMDGPU HSA metadata. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h" 15 #include "llvm/Support/AMDGPUMetadata.h" 16 17 namespace llvm { 18 namespace AMDGPU { 19 namespace HSAMD { 20 namespace V3 { 21 22 bool MetadataVerifier::verifyScalar( 23 msgpack::Node &Node, msgpack::ScalarNode::ScalarKind SKind, 24 function_ref<bool(msgpack::ScalarNode &)> verifyValue) { 25 auto ScalarPtr = dyn_cast<msgpack::ScalarNode>(&Node); 26 if (!ScalarPtr) 27 return false; 28 auto &Scalar = *ScalarPtr; 29 // Do not output extraneous tags for types we know from the spec. 30 Scalar.IgnoreTag = true; 31 if (Scalar.getScalarKind() != SKind) { 32 if (Strict) 33 return false; 34 // If we are not strict, we interpret string values as "implicitly typed" 35 // and attempt to coerce them to the expected type here. 36 if (Scalar.getScalarKind() != msgpack::ScalarNode::SK_String) 37 return false; 38 std::string StringValue = Scalar.getString(); 39 Scalar.setScalarKind(SKind); 40 if (Scalar.inputYAML(StringValue) != StringRef()) 41 return false; 42 } 43 if (verifyValue) 44 return verifyValue(Scalar); 45 return true; 46 } 47 48 bool MetadataVerifier::verifyInteger(msgpack::Node &Node) { 49 if (!verifyScalar(Node, msgpack::ScalarNode::SK_UInt)) 50 if (!verifyScalar(Node, msgpack::ScalarNode::SK_Int)) 51 return false; 52 return true; 53 } 54 55 bool MetadataVerifier::verifyArray( 56 msgpack::Node &Node, function_ref<bool(msgpack::Node &)> verifyNode, 57 Optional<size_t> Size) { 58 auto ArrayPtr = dyn_cast<msgpack::ArrayNode>(&Node); 59 if (!ArrayPtr) 60 return false; 61 auto &Array = *ArrayPtr; 62 if (Size && Array.size() != *Size) 63 return false; 64 for (auto &Item : Array) 65 if (!verifyNode(*Item.get())) 66 return false; 67 68 return true; 69 } 70 71 bool MetadataVerifier::verifyEntry( 72 msgpack::MapNode &MapNode, StringRef Key, bool Required, 73 function_ref<bool(msgpack::Node &)> verifyNode) { 74 auto Entry = MapNode.find(Key); 75 if (Entry == MapNode.end()) 76 return !Required; 77 return verifyNode(*Entry->second.get()); 78 } 79 80 bool MetadataVerifier::verifyScalarEntry( 81 msgpack::MapNode &MapNode, StringRef Key, bool Required, 82 msgpack::ScalarNode::ScalarKind SKind, 83 function_ref<bool(msgpack::ScalarNode &)> verifyValue) { 84 return verifyEntry(MapNode, Key, Required, [=](msgpack::Node &Node) { 85 return verifyScalar(Node, SKind, verifyValue); 86 }); 87 } 88 89 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapNode &MapNode, 90 StringRef Key, bool Required) { 91 return verifyEntry(MapNode, Key, Required, [this](msgpack::Node &Node) { 92 return verifyInteger(Node); 93 }); 94 } 95 96 bool MetadataVerifier::verifyKernelArgs(msgpack::Node &Node) { 97 auto ArgsMapPtr = dyn_cast<msgpack::MapNode>(&Node); 98 if (!ArgsMapPtr) 99 return false; 100 auto &ArgsMap = *ArgsMapPtr; 101 102 if (!verifyScalarEntry(ArgsMap, ".name", false, 103 msgpack::ScalarNode::SK_String)) 104 return false; 105 if (!verifyScalarEntry(ArgsMap, ".type_name", false, 106 msgpack::ScalarNode::SK_String)) 107 return false; 108 if (!verifyIntegerEntry(ArgsMap, ".size", true)) 109 return false; 110 if (!verifyIntegerEntry(ArgsMap, ".offset", true)) 111 return false; 112 if (!verifyScalarEntry(ArgsMap, ".value_kind", true, 113 msgpack::ScalarNode::SK_String, 114 [](msgpack::ScalarNode &SNode) { 115 return StringSwitch<bool>(SNode.getString()) 116 .Case("by_value", true) 117 .Case("global_buffer", true) 118 .Case("dynamic_shared_pointer", true) 119 .Case("sampler", true) 120 .Case("image", true) 121 .Case("pipe", true) 122 .Case("queue", true) 123 .Case("hidden_global_offset_x", true) 124 .Case("hidden_global_offset_y", true) 125 .Case("hidden_global_offset_z", true) 126 .Case("hidden_none", true) 127 .Case("hidden_printf_buffer", true) 128 .Case("hidden_default_queue", true) 129 .Case("hidden_completion_action", true) 130 .Default(false); 131 })) 132 return false; 133 if (!verifyScalarEntry(ArgsMap, ".value_type", true, 134 msgpack::ScalarNode::SK_String, 135 [](msgpack::ScalarNode &SNode) { 136 return StringSwitch<bool>(SNode.getString()) 137 .Case("struct", true) 138 .Case("i8", true) 139 .Case("u8", true) 140 .Case("i16", true) 141 .Case("u16", true) 142 .Case("f16", true) 143 .Case("i32", true) 144 .Case("u32", true) 145 .Case("f32", true) 146 .Case("i64", true) 147 .Case("u64", true) 148 .Case("f64", true) 149 .Default(false); 150 })) 151 return false; 152 if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false)) 153 return false; 154 if (!verifyScalarEntry(ArgsMap, ".address_space", false, 155 msgpack::ScalarNode::SK_String, 156 [](msgpack::ScalarNode &SNode) { 157 return StringSwitch<bool>(SNode.getString()) 158 .Case("private", true) 159 .Case("global", true) 160 .Case("constant", true) 161 .Case("local", true) 162 .Case("generic", true) 163 .Case("region", true) 164 .Default(false); 165 })) 166 return false; 167 if (!verifyScalarEntry(ArgsMap, ".access", false, 168 msgpack::ScalarNode::SK_String, 169 [](msgpack::ScalarNode &SNode) { 170 return StringSwitch<bool>(SNode.getString()) 171 .Case("read_only", true) 172 .Case("write_only", true) 173 .Case("read_write", true) 174 .Default(false); 175 })) 176 return false; 177 if (!verifyScalarEntry(ArgsMap, ".actual_access", false, 178 msgpack::ScalarNode::SK_String, 179 [](msgpack::ScalarNode &SNode) { 180 return StringSwitch<bool>(SNode.getString()) 181 .Case("read_only", true) 182 .Case("write_only", true) 183 .Case("read_write", true) 184 .Default(false); 185 })) 186 return false; 187 if (!verifyScalarEntry(ArgsMap, ".is_const", false, 188 msgpack::ScalarNode::SK_Boolean)) 189 return false; 190 if (!verifyScalarEntry(ArgsMap, ".is_restrict", false, 191 msgpack::ScalarNode::SK_Boolean)) 192 return false; 193 if (!verifyScalarEntry(ArgsMap, ".is_volatile", false, 194 msgpack::ScalarNode::SK_Boolean)) 195 return false; 196 if (!verifyScalarEntry(ArgsMap, ".is_pipe", false, 197 msgpack::ScalarNode::SK_Boolean)) 198 return false; 199 200 return true; 201 } 202 203 bool MetadataVerifier::verifyKernel(msgpack::Node &Node) { 204 auto KernelMapPtr = dyn_cast<msgpack::MapNode>(&Node); 205 if (!KernelMapPtr) 206 return false; 207 auto &KernelMap = *KernelMapPtr; 208 209 if (!verifyScalarEntry(KernelMap, ".name", true, 210 msgpack::ScalarNode::SK_String)) 211 return false; 212 if (!verifyScalarEntry(KernelMap, ".symbol", true, 213 msgpack::ScalarNode::SK_String)) 214 return false; 215 if (!verifyScalarEntry(KernelMap, ".language", false, 216 msgpack::ScalarNode::SK_String, 217 [](msgpack::ScalarNode &SNode) { 218 return StringSwitch<bool>(SNode.getString()) 219 .Case("OpenCL C", true) 220 .Case("OpenCL C++", true) 221 .Case("HCC", true) 222 .Case("HIP", true) 223 .Case("OpenMP", true) 224 .Case("Assembler", true) 225 .Default(false); 226 })) 227 return false; 228 if (!verifyEntry( 229 KernelMap, ".language_version", false, [this](msgpack::Node &Node) { 230 return verifyArray( 231 Node, 232 [this](msgpack::Node &Node) { return verifyInteger(Node); }, 2); 233 })) 234 return false; 235 if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::Node &Node) { 236 return verifyArray(Node, [this](msgpack::Node &Node) { 237 return verifyKernelArgs(Node); 238 }); 239 })) 240 return false; 241 if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false, 242 [this](msgpack::Node &Node) { 243 return verifyArray(Node, 244 [this](msgpack::Node &Node) { 245 return verifyInteger(Node); 246 }, 247 3); 248 })) 249 return false; 250 if (!verifyEntry(KernelMap, ".workgroup_size_hint", false, 251 [this](msgpack::Node &Node) { 252 return verifyArray(Node, 253 [this](msgpack::Node &Node) { 254 return verifyInteger(Node); 255 }, 256 3); 257 })) 258 return false; 259 if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false, 260 msgpack::ScalarNode::SK_String)) 261 return false; 262 if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false, 263 msgpack::ScalarNode::SK_String)) 264 return false; 265 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true)) 266 return false; 267 if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true)) 268 return false; 269 if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true)) 270 return false; 271 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true)) 272 return false; 273 if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true)) 274 return false; 275 if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true)) 276 return false; 277 if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true)) 278 return false; 279 if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true)) 280 return false; 281 if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false)) 282 return false; 283 if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false)) 284 return false; 285 286 return true; 287 } 288 289 bool MetadataVerifier::verify(msgpack::Node &HSAMetadataRoot) { 290 auto RootMapPtr = dyn_cast<msgpack::MapNode>(&HSAMetadataRoot); 291 if (!RootMapPtr) 292 return false; 293 auto &RootMap = *RootMapPtr; 294 295 if (!verifyEntry( 296 RootMap, "amdhsa.version", true, [this](msgpack::Node &Node) { 297 return verifyArray( 298 Node, 299 [this](msgpack::Node &Node) { return verifyInteger(Node); }, 2); 300 })) 301 return false; 302 if (!verifyEntry( 303 RootMap, "amdhsa.printf", false, [this](msgpack::Node &Node) { 304 return verifyArray(Node, [this](msgpack::Node &Node) { 305 return verifyScalar(Node, msgpack::ScalarNode::SK_String); 306 }); 307 })) 308 return false; 309 if (!verifyEntry(RootMap, "amdhsa.kernels", true, 310 [this](msgpack::Node &Node) { 311 return verifyArray(Node, [this](msgpack::Node &Node) { 312 return verifyKernel(Node); 313 }); 314 })) 315 return false; 316 317 return true; 318 } 319 320 } // end namespace V3 321 } // end namespace HSAMD 322 } // end namespace AMDGPU 323 } // end namespace llvm 324