1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "packet_def.h"
18
19 #include <list>
20 #include <set>
21
22 #include "fields/all_fields.h"
23 #include "util.h"
24
PacketDef(std::string name,FieldList fields)25 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)26 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {}
27
GetNewField(const std::string &,ParseLocation) const28 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
29 return nullptr; // Packets can't be fields
30 }
31
GenParserDefinition(std::ostream & s) const32 void PacketDef::GenParserDefinition(std::ostream& s) const {
33 s << "class " << name_ << "View";
34 if (parent_ != nullptr) {
35 s << " : public " << parent_->name_ << "View {";
36 } else {
37 s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
38 }
39 s << " public:";
40
41 // Specialize function
42 if (parent_ != nullptr) {
43 s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
44 s << "{ return " << name_ << "View(std::move(parent)); }";
45 } else {
46 s << "static " << name_ << "View Create(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
47 s << "{ return " << name_ << "View(std::move(packet)); }";
48 }
49
50 GenTestingParserFromBytes(s);
51
52 std::set<std::string> fixed_types = {
53 FixedScalarField::kFieldType,
54 FixedEnumField::kFieldType,
55 };
56
57 // Print all of the public fields which are all the fields minus the fixed fields.
58 const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
59 bool has_fixed_fields = public_fields.size() != fields_.size();
60 for (const auto& field : public_fields) {
61 GenParserFieldGetter(s, field);
62 s << "\n";
63 }
64 GenValidator(s);
65 s << "\n";
66
67 s << " public:";
68 GenParserToString(s);
69 s << "\n";
70
71 s << " protected:\n";
72 // Constructor from a View
73 if (parent_ != nullptr) {
74 s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
75 s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
76 } else {
77 s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
78 s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}";
79 }
80
81 // Print the private fields which are the fixed fields.
82 if (has_fixed_fields) {
83 const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
84 s << " private:\n";
85 for (const auto& field : private_fields) {
86 GenParserFieldGetter(s, field);
87 s << "\n";
88 }
89 }
90 s << "};\n";
91 }
92
GenTestingParserFromBytes(std::ostream & s) const93 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
94 s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
95
96 s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
97 s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
98 s << "return " << name_ << "View::Create(";
99 auto ancestor_ptr = parent_;
100 size_t parent_parens = 0;
101 while (ancestor_ptr != nullptr) {
102 s << ancestor_ptr->name_ << "View::Create(";
103 parent_parens++;
104 ancestor_ptr = ancestor_ptr->parent_;
105 }
106 s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
107 for (size_t i = 0; i < parent_parens; i++) {
108 s << ")";
109 }
110 s << ");";
111 s << "}";
112
113 s << "\n#endif\n";
114 }
115
GenParserDefinitionPybind11(std::ostream & s) const116 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
117 s << "py::class_<" << name_ << "View";
118 if (parent_ != nullptr) {
119 s << ", " << parent_->name_ << "View";
120 } else {
121 s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
122 }
123 s << ">(m, \"" << name_ << "View\")";
124 if (parent_ != nullptr) {
125 s << ".def(py::init([](" << parent_->name_ << "View parent) {";
126 } else {
127 s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {";
128 }
129 s << "auto view =" << name_ << "View::Create(std::move(parent));";
130 s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
131 s << "return view; }))";
132
133 s << ".def(py::init(&" << name_ << "View::Create))";
134 std::set<std::string> protected_field_types = {
135 FixedScalarField::kFieldType,
136 FixedEnumField::kFieldType,
137 SizeField::kFieldType,
138 CountField::kFieldType,
139 };
140 const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
141 for (const auto& field : public_fields) {
142 auto getter_func_name = field->GetGetterFunctionName();
143 if (getter_func_name.empty()) {
144 continue;
145 }
146 s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
147 }
148 s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
149 s << ";\n";
150 }
151
GenParserFieldGetter(std::ostream & s,const PacketField * field) const152 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
153 // Start field offset
154 auto start_field_offset = GetOffsetForField(field->GetName(), false);
155 auto end_field_offset = GetOffsetForField(field->GetName(), true);
156
157 if (start_field_offset.empty() && end_field_offset.empty()) {
158 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
159 << "no method exists to determine field location from begin() or end().\n";
160 }
161
162 field->GenGetter(s, start_field_offset, end_field_offset);
163 }
164
GetDefinitionType() const165 TypeDef::Type PacketDef::GetDefinitionType() const {
166 return TypeDef::Type::PACKET;
167 }
168
GenValidator(std::ostream & s) const169 void PacketDef::GenValidator(std::ostream& s) const {
170 // Get the static offset for all of our fields.
171 int bits_size = 0;
172 for (const auto& field : fields_) {
173 if (field->GetFieldType() != PaddingField::kFieldType) {
174 bits_size += field->GetSize().bits();
175 }
176 }
177
178 // Write the function declaration.
179 s << "virtual bool IsValid() " << (parent_ != nullptr ? " override" : "") << " {";
180 s << "if (was_validated_) { return true; } ";
181 s << "else { was_validated_ = true; was_validated_ = IsValid_(); return was_validated_; }";
182 s << "}";
183
184 s << "protected:";
185 s << "virtual bool IsValid_() const {";
186
187 // Offset by the parents known size. We know that any dynamic fields can
188 // already be called since the parent must have already been validated by
189 // this point.
190 auto parent_size = Size(0);
191 if (parent_ != nullptr) {
192 parent_size = parent_->GetSize(true);
193 }
194
195 s << "auto it = begin() + (" << parent_size << ") / 8;";
196
197 // Check if you can extract the static fields.
198 // At this point you know you can use the size getters without crashing
199 // as long as they follow the instruction that size fields cant come before
200 // their corrisponding variable length field.
201 s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
202 s << "if (it > end()) return false;";
203
204 // For any variable length fields, use their size check.
205 for (const auto& field : fields_) {
206 if (field->GetFieldType() == ChecksumStartField::kFieldType) {
207 auto offset = GetOffsetForField(field->GetName(), false);
208 if (!offset.empty()) {
209 s << "size_t sum_index = (" << offset << ") / 8;";
210 } else {
211 offset = GetOffsetForField(field->GetName(), true);
212 if (offset.empty()) {
213 ERROR(field) << "Checksum Start Field offset can not be determined.";
214 }
215 s << "size_t sum_index = size() - (" << offset << ") / 8;";
216 }
217
218 const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
219 const auto& started_field = fields_.GetField(field_name);
220 if (started_field == nullptr) {
221 ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
222 << ")";
223 }
224 auto end_offset = GetOffsetForField(started_field->GetName(), false);
225 if (!end_offset.empty()) {
226 s << "size_t end_sum_index = (" << end_offset << ") / 8;";
227 } else {
228 end_offset = GetOffsetForField(started_field->GetName(), true);
229 if (end_offset.empty()) {
230 ERROR(started_field) << "Checksum Field end_offset can not be determined.";
231 }
232 s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;";
233 }
234 if (is_little_endian_) {
235 s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
236 } else {
237 s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
238 }
239 s << started_field->GetDataType() << " checksum;";
240 s << "checksum.Initialize();";
241 s << "for (uint8_t byte : checksum_view) { ";
242 s << "checksum.AddByte(byte);}";
243 s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
244 << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
245
246 continue;
247 }
248
249 auto field_size = field->GetSize();
250 // Fixed size fields have already been handled.
251 if (!field_size.has_dynamic()) {
252 continue;
253 }
254
255 // Custom fields with dynamic size must have the offset for the field passed in as well
256 // as the end iterator so that they may ensure that they don't try to read past the end.
257 // Custom fields with fixed sizes will be handled in the static offset checking.
258 if (field->GetFieldType() == CustomField::kFieldType) {
259 // Check if we can determine offset from begin(), otherwise error because by this point,
260 // the size of the custom field is unknown and can't be subtracted from end() to get the
261 // offset.
262 auto offset = GetOffsetForField(field->GetName(), false);
263 if (offset.empty()) {
264 ERROR(field) << "Custom Field offset can not be determined from begin().";
265 }
266
267 if (offset.bits() % 8 != 0) {
268 ERROR(field) << "Custom fields must be byte aligned.";
269 }
270
271 // Custom fields are special as their size field takes an argument.
272 const auto& custom_size_var = field->GetName() + "_size";
273 s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
274 s << "(begin() + (" << offset << ") / 8);";
275
276 s << "if (!" << custom_size_var << ".has_value()) { return false; }";
277 s << "it += *" << custom_size_var << ";";
278 s << "if (it > end()) return false;";
279 continue;
280 } else {
281 s << "it += (" << field_size.dynamic_string() << ") / 8;";
282 s << "if (it > end()) return false;";
283 }
284 }
285
286 // Validate constraints after validating the size
287 if (parent_constraints_.size() > 0 && parent_ == nullptr) {
288 ERROR() << "Can't have a constraint on a NULL parent";
289 }
290
291 for (const auto& constraint : parent_constraints_) {
292 s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
293 const auto& field = parent_->GetParamList().GetField(constraint.first);
294 if (field->GetFieldType() == ScalarField::kFieldType) {
295 s << std::get<int64_t>(constraint.second);
296 } else {
297 s << std::get<std::string>(constraint.second);
298 }
299 s << ") return false;";
300 }
301
302 // Validate the packets fields last
303 for (const auto& field : fields_) {
304 field->GenValidator(s);
305 s << "\n";
306 }
307
308 s << "return true;";
309 s << "}\n";
310 if (parent_ == nullptr) {
311 s << "bool was_validated_{false};\n";
312 }
313 }
314
GenParserToString(std::ostream & s) const315 void PacketDef::GenParserToString(std::ostream& s) const {
316 s << "virtual std::string ToString() " << (parent_ != nullptr ? " override" : "") << " {";
317 s << "std::stringstream ss;";
318 s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
319
320 if (fields_.size() > 0) {
321 s << "ss << \"\" ";
322 bool firstfield = true;
323 for (const auto& field : fields_) {
324 if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType ||
325 field->GetFieldType() == ChecksumStartField::kFieldType)
326 continue;
327
328 s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
329
330 field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
331
332 if (firstfield) {
333 firstfield = false;
334 }
335 }
336 s << ";";
337 }
338
339 s << "ss << \" }\";";
340 s << "return ss.str();";
341 s << "}\n";
342 }
343
GenBuilderDefinition(std::ostream & s) const344 void PacketDef::GenBuilderDefinition(std::ostream& s) const {
345 s << "class " << name_ << "Builder";
346 if (parent_ != nullptr) {
347 s << " : public " << parent_->name_ << "Builder";
348 } else {
349 if (is_little_endian_) {
350 s << " : public PacketBuilder<kLittleEndian>";
351 } else {
352 s << " : public PacketBuilder<!kLittleEndian>";
353 }
354 }
355 s << " {";
356 s << " public:";
357 s << " virtual ~" << name_ << "Builder()" << (parent_ != nullptr ? " override" : "") << " = default;";
358
359 if (!fields_.HasBody()) {
360 GenBuilderCreate(s);
361 s << "\n";
362
363 GenTestingFromView(s);
364 s << "\n";
365 }
366
367 GenSerialize(s);
368 s << "\n";
369
370 GenSize(s);
371 s << "\n";
372
373 s << " protected:\n";
374 GenBuilderConstructor(s);
375 s << "\n";
376
377 GenBuilderParameterChecker(s);
378 s << "\n";
379
380 GenMembers(s);
381 s << "};\n";
382
383 GenTestDefine(s);
384 s << "\n";
385
386 GenFuzzTestDefine(s);
387 s << "\n";
388 }
389
GenTestingFromView(std::ostream & s) const390 void PacketDef::GenTestingFromView(std::ostream& s) const {
391 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
392
393 s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
394 s << "return " << name_ << "Builder::Create(";
395 FieldList params = GetParamList().GetFieldsWithoutTypes({
396 BodyField::kFieldType,
397 });
398 for (int i = 0; i < params.size(); i++) {
399 params[i]->GenBuilderParameterFromView(s);
400 if (i != params.size() - 1) {
401 s << ", ";
402 }
403 }
404 s << ");";
405 s << "}";
406
407 s << "\n#endif\n";
408 }
409
GenBuilderDefinitionPybind11(std::ostream & s) const410 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
411 s << "py::class_<" << name_ << "Builder";
412 if (parent_ != nullptr) {
413 s << ", " << parent_->name_ << "Builder";
414 } else {
415 if (is_little_endian_) {
416 s << ", PacketBuilder<kLittleEndian>";
417 } else {
418 s << ", PacketBuilder<!kLittleEndian>";
419 }
420 }
421 s << ", std::shared_ptr<" << name_ << "Builder>";
422 s << ">(m, \"" << name_ << "Builder\")";
423 if (!fields_.HasBody()) {
424 GenBuilderCreatePybind11(s);
425 }
426 s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
427 s << "std::vector<uint8_t> bytes;";
428 s << "BitInserter bi(bytes);";
429 s << "builder.Serialize(bi);";
430 s << "return bytes;})";
431 s << ";\n";
432 }
433
GenTestDefine(std::ostream & s) const434 void PacketDef::GenTestDefine(std::ostream& s) const {
435 s << "#ifdef PACKET_TESTING\n";
436 s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
437 s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
438 s << "public: ";
439 s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
440 s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
441 s << "if (!view.IsValid()) { LOG_INFO(\"Invalid Packet Bytes (size = %zu)\", view.size());";
442 s << "for (size_t i = 0; i < view.size(); i++) { LOG_DEBUG(\"%5zd:%02X\", i, *(view.begin() + i)); }}";
443 s << "ASSERT_TRUE(view.IsValid());";
444 s << "auto packet = " << name_ << "Builder::FromView(view);";
445 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
446 s << "packet_bytes->reserve(packet->size());";
447 s << "BitInserter it(*packet_bytes);";
448 s << "packet->Serialize(it);";
449 s << "ASSERT_EQ(*packet_bytes, captured_packet);";
450 s << "}";
451 s << "};";
452 s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
453 s << "CompareBytes(GetParam());";
454 s << "}";
455 s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
456 s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
457 s << "\n#endif";
458 }
459
GenFuzzTestDefine(std::ostream & s) const460 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
461 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
462 s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
463 s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
464 s << "auto vec = std::vector<uint8_t>(data, data + size);";
465 s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
466 s << "if (!view.IsValid()) { return; }";
467 s << "auto packet = " << name_ << "Builder::FromView(view);";
468 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
469 s << "packet_bytes->reserve(packet->size());";
470 s << "BitInserter it(*packet_bytes);";
471 s << "packet->Serialize(it);";
472 s << "}";
473 s << "\n#endif\n";
474 s << "#ifdef PACKET_FUZZ_TESTING\n";
475 s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
476 s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
477 s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
478 s << "public: ";
479 s << "explicit " << name_
480 << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {";
481 s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
482 s << "}}; ";
483 s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);";
484 s << "\n#endif";
485 }
486
GetParametersToValidate() const487 FieldList PacketDef::GetParametersToValidate() const {
488 FieldList params_to_validate;
489 for (const auto& field : GetParamList()) {
490 if (field->HasParameterValidator()) {
491 params_to_validate.AppendField(field);
492 }
493 }
494 return params_to_validate;
495 }
496
GenBuilderCreate(std::ostream & s) const497 void PacketDef::GenBuilderCreate(std::ostream& s) const {
498 s << "static std::unique_ptr<" << name_ << "Builder> Create(";
499
500 auto params = GetParamList();
501 for (int i = 0; i < params.size(); i++) {
502 params[i]->GenBuilderParameter(s);
503 if (i != params.size() - 1) {
504 s << ", ";
505 }
506 }
507 s << ") {";
508
509 // Call the constructor
510 s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
511
512 params = params.GetFieldsWithoutTypes({
513 PayloadField::kFieldType,
514 BodyField::kFieldType,
515 });
516 // Add the parameters.
517 for (int i = 0; i < params.size(); i++) {
518 if (params[i]->BuilderParameterMustBeMoved()) {
519 s << "std::move(" << params[i]->GetName() << ")";
520 } else {
521 s << params[i]->GetName();
522 }
523 if (i != params.size() - 1) {
524 s << ", ";
525 }
526 }
527
528 s << "));";
529 if (fields_.HasPayload()) {
530 s << "builder->payload_ = std::move(payload);";
531 }
532 s << "return builder;";
533 s << "}\n";
534 }
535
GenBuilderCreatePybind11(std::ostream & s) const536 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
537 s << ".def(py::init([](";
538 auto params = GetParamList();
539 std::vector<std::string> constructor_args;
540 int i = 1;
541 for (const auto& param : params) {
542 i++;
543 std::stringstream ss;
544 auto param_type = param->GetBuilderParameterType();
545 if (param_type.empty()) {
546 continue;
547 }
548 // Use shared_ptr instead of unique_ptr for the Python interface
549 if (param->BuilderParameterMustBeMoved()) {
550 param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
551 }
552 ss << param_type << " " << param->GetName();
553 constructor_args.push_back(ss.str());
554 }
555 s << util::StringJoin(",", constructor_args) << "){";
556
557 // Deal with move only args
558 for (const auto& param : params) {
559 std::stringstream ss;
560 auto param_type = param->GetBuilderParameterType();
561 if (param_type.empty()) {
562 continue;
563 }
564 if (!param->BuilderParameterMustBeMoved()) {
565 continue;
566 }
567 auto move_only_param_name = param->GetName() + "_move_only";
568 s << param_type << " " << move_only_param_name << ";";
569 if (param->IsContainerField()) {
570 // Assume single layer container and copy it
571 auto struct_type = param->GetElementField()->GetDataType();
572 struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
573 struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
574 s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
575 // Serialize each struct
576 s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
577 s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
578 s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
579 s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
580 // Parse it again
581 s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);";
582 s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
583 s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
584 // Push it into a new container
585 if (param->GetFieldType() == VectorField::kFieldType) {
586 s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
587 } else if (param->GetFieldType() == ArrayField::kFieldType) {
588 s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
589 } else {
590 ERROR() << param << " is not supported by Pybind11";
591 }
592 s << "}";
593 } else {
594 // Serialize the parameter and pass the bytes in a RawBuilder
595 s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
596 s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
597 s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
598 s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
599 s << move_only_param_name << " = ";
600 s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
601 }
602 }
603 s << "return " << name_ << "Builder::Create(";
604 std::vector<std::string> builder_vars;
605 for (const auto& param : params) {
606 std::stringstream ss;
607 auto param_type = param->GetBuilderParameterType();
608 if (param_type.empty()) {
609 continue;
610 }
611 auto param_name = param->GetName();
612 if (param->BuilderParameterMustBeMoved()) {
613 ss << "std::move(" << param_name << "_move_only)";
614 } else {
615 ss << param_name;
616 }
617 builder_vars.push_back(ss.str());
618 }
619 s << util::StringJoin(",", builder_vars) << ");}";
620 s << "))";
621 }
622
GenBuilderParameterChecker(std::ostream & s) const623 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
624 FieldList params_to_validate = GetParametersToValidate();
625
626 // Skip writing this function if there is nothing to validate.
627 if (params_to_validate.size() == 0) {
628 return;
629 }
630
631 // Generate function arguments.
632 s << "void CheckParameterValues(";
633 for (int i = 0; i < params_to_validate.size(); i++) {
634 params_to_validate[i]->GenBuilderParameter(s);
635 if (i != params_to_validate.size() - 1) {
636 s << ", ";
637 }
638 }
639 s << ") {";
640
641 // Check the parameters.
642 for (const auto& field : params_to_validate) {
643 field->GenParameterValidator(s);
644 }
645 s << "}\n";
646 }
647
GenBuilderConstructor(std::ostream & s) const648 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
649 s << "explicit " << name_ << "Builder(";
650
651 // Generate the constructor parameters.
652 auto params = GetParamList().GetFieldsWithoutTypes({
653 PayloadField::kFieldType,
654 BodyField::kFieldType,
655 });
656 for (int i = 0; i < params.size(); i++) {
657 params[i]->GenBuilderParameter(s);
658 if (i != params.size() - 1) {
659 s << ", ";
660 }
661 }
662 if (params.size() > 0 || parent_constraints_.size() > 0) {
663 s << ") :";
664 } else {
665 s << ")";
666 }
667
668 // Get the list of parent params to call the parent constructor with.
669 FieldList parent_params;
670 if (parent_ != nullptr) {
671 // Pass parameters to the parent constructor
672 s << parent_->name_ << "Builder(";
673 parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
674 PayloadField::kFieldType,
675 BodyField::kFieldType,
676 });
677
678 // Go through all the fields and replace constrained fields with fixed values
679 // when calling the parent constructor.
680 for (int i = 0; i < parent_params.size(); i++) {
681 const auto& field = parent_params[i];
682 const auto& constraint = parent_constraints_.find(field->GetName());
683 if (constraint != parent_constraints_.end()) {
684 if (field->GetFieldType() == ScalarField::kFieldType) {
685 s << std::get<int64_t>(constraint->second);
686 } else if (field->GetFieldType() == EnumField::kFieldType) {
687 s << std::get<std::string>(constraint->second);
688 } else {
689 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
690 }
691
692 s << "/* " << field->GetName() << "_ */";
693 } else {
694 s << field->GetName();
695 }
696
697 if (i != parent_params.size() - 1) {
698 s << ", ";
699 }
700 }
701 s << ") ";
702 }
703
704 // Build a list of parameters that excludes all parent parameters.
705 FieldList saved_params;
706 for (const auto& field : params) {
707 if (parent_params.GetField(field->GetName()) == nullptr) {
708 saved_params.AppendField(field);
709 }
710 }
711 if (parent_ != nullptr && saved_params.size() > 0) {
712 s << ",";
713 }
714 for (int i = 0; i < saved_params.size(); i++) {
715 const auto& saved_param_name = saved_params[i]->GetName();
716 if (saved_params[i]->BuilderParameterMustBeMoved()) {
717 s << saved_param_name << "_(std::move(" << saved_param_name << "))";
718 } else {
719 s << saved_param_name << "_(" << saved_param_name << ")";
720 }
721 if (i != saved_params.size() - 1) {
722 s << ",";
723 }
724 }
725 s << " {";
726
727 FieldList params_to_validate = GetParametersToValidate();
728
729 if (params_to_validate.size() > 0) {
730 s << "CheckParameterValues(";
731 for (int i = 0; i < params_to_validate.size(); i++) {
732 s << params_to_validate[i]->GetName() << "_";
733 if (i != params_to_validate.size() - 1) {
734 s << ", ";
735 }
736 }
737 s << ");";
738 }
739
740 s << "}\n";
741 }
742