1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "packet_def.h"
18 
19 #include <list>
20 #include <set>
21 
22 #include "fields/all_fields.h"
23 #include "util.h"
24 
PacketDef(std::string name,FieldList fields)25 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)26 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {}
27 
GetNewField(const std::string &,ParseLocation) const28 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
29   return nullptr;  // Packets can't be fields
30 }
31 
GenParserDefinition(std::ostream & s) const32 void PacketDef::GenParserDefinition(std::ostream& s) const {
33   s << "class " << name_ << "View";
34   if (parent_ != nullptr) {
35     s << " : public " << parent_->name_ << "View {";
36   } else {
37     s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
38   }
39   s << " public:";
40 
41   // Specialize function
42   if (parent_ != nullptr) {
43     s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
44     s << "{ return " << name_ << "View(std::move(parent)); }";
45   } else {
46     s << "static " << name_ << "View Create(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
47     s << "{ return " << name_ << "View(std::move(packet)); }";
48   }
49 
50   GenTestingParserFromBytes(s);
51 
52   std::set<std::string> fixed_types = {
53       FixedScalarField::kFieldType,
54       FixedEnumField::kFieldType,
55   };
56 
57   // Print all of the public fields which are all the fields minus the fixed fields.
58   const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
59   bool has_fixed_fields = public_fields.size() != fields_.size();
60   for (const auto& field : public_fields) {
61     GenParserFieldGetter(s, field);
62     s << "\n";
63   }
64   GenValidator(s);
65   s << "\n";
66 
67   s << " public:";
68   GenParserToString(s);
69   s << "\n";
70 
71   s << " protected:\n";
72   // Constructor from a View
73   if (parent_ != nullptr) {
74     s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
75     s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
76   } else {
77     s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
78     s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}";
79   }
80 
81   // Print the private fields which are the fixed fields.
82   if (has_fixed_fields) {
83     const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
84     s << " private:\n";
85     for (const auto& field : private_fields) {
86       GenParserFieldGetter(s, field);
87       s << "\n";
88     }
89   }
90   s << "};\n";
91 }
92 
GenTestingParserFromBytes(std::ostream & s) const93 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
94   s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
95 
96   s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
97   s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
98   s << "return " << name_ << "View::Create(";
99   auto ancestor_ptr = parent_;
100   size_t parent_parens = 0;
101   while (ancestor_ptr != nullptr) {
102     s << ancestor_ptr->name_ << "View::Create(";
103     parent_parens++;
104     ancestor_ptr = ancestor_ptr->parent_;
105   }
106   s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
107   for (size_t i = 0; i < parent_parens; i++) {
108     s << ")";
109   }
110   s << ");";
111   s << "}";
112 
113   s << "\n#endif\n";
114 }
115 
GenParserDefinitionPybind11(std::ostream & s) const116 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
117   s << "py::class_<" << name_ << "View";
118   if (parent_ != nullptr) {
119     s << ", " << parent_->name_ << "View";
120   } else {
121     s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
122   }
123   s << ">(m, \"" << name_ << "View\")";
124   if (parent_ != nullptr) {
125     s << ".def(py::init([](" << parent_->name_ << "View parent) {";
126   } else {
127     s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {";
128   }
129   s << "auto view =" << name_ << "View::Create(std::move(parent));";
130   s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
131   s << "return view; }))";
132 
133   s << ".def(py::init(&" << name_ << "View::Create))";
134   std::set<std::string> protected_field_types = {
135       FixedScalarField::kFieldType,
136       FixedEnumField::kFieldType,
137       SizeField::kFieldType,
138       CountField::kFieldType,
139   };
140   const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
141   for (const auto& field : public_fields) {
142     auto getter_func_name = field->GetGetterFunctionName();
143     if (getter_func_name.empty()) {
144       continue;
145     }
146     s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
147   }
148   s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
149   s << ";\n";
150 }
151 
GenParserFieldGetter(std::ostream & s,const PacketField * field) const152 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
153   // Start field offset
154   auto start_field_offset = GetOffsetForField(field->GetName(), false);
155   auto end_field_offset = GetOffsetForField(field->GetName(), true);
156 
157   if (start_field_offset.empty() && end_field_offset.empty()) {
158     ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
159                  << "no method exists to determine field location from begin() or end().\n";
160   }
161 
162   field->GenGetter(s, start_field_offset, end_field_offset);
163 }
164 
GetDefinitionType() const165 TypeDef::Type PacketDef::GetDefinitionType() const {
166   return TypeDef::Type::PACKET;
167 }
168 
GenValidator(std::ostream & s) const169 void PacketDef::GenValidator(std::ostream& s) const {
170   // Get the static offset for all of our fields.
171   int bits_size = 0;
172   for (const auto& field : fields_) {
173     if (field->GetFieldType() != PaddingField::kFieldType) {
174       bits_size += field->GetSize().bits();
175     }
176   }
177 
178   // Write the function declaration.
179   s << "virtual bool IsValid() " << (parent_ != nullptr ? " override" : "") << " {";
180   s << "if (was_validated_) { return true; } ";
181   s << "else { was_validated_ = true; was_validated_ = IsValid_(); return was_validated_; }";
182   s << "}";
183 
184   s << "protected:";
185   s << "virtual bool IsValid_() const {";
186 
187   // Offset by the parents known size. We know that any dynamic fields can
188   // already be called since the parent must have already been validated by
189   // this point.
190   auto parent_size = Size(0);
191   if (parent_ != nullptr) {
192     parent_size = parent_->GetSize(true);
193   }
194 
195   s << "auto it = begin() + (" << parent_size << ") / 8;";
196 
197   // Check if you can extract the static fields.
198   // At this point you know you can use the size getters without crashing
199   // as long as they follow the instruction that size fields cant come before
200   // their corrisponding variable length field.
201   s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
202   s << "if (it > end()) return false;";
203 
204   // For any variable length fields, use their size check.
205   for (const auto& field : fields_) {
206     if (field->GetFieldType() == ChecksumStartField::kFieldType) {
207       auto offset = GetOffsetForField(field->GetName(), false);
208       if (!offset.empty()) {
209         s << "size_t sum_index = (" << offset << ") / 8;";
210       } else {
211         offset = GetOffsetForField(field->GetName(), true);
212         if (offset.empty()) {
213           ERROR(field) << "Checksum Start Field offset can not be determined.";
214         }
215         s << "size_t sum_index = size() - (" << offset << ") / 8;";
216       }
217 
218       const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
219       const auto& started_field = fields_.GetField(field_name);
220       if (started_field == nullptr) {
221         ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
222                      << ")";
223       }
224       auto end_offset = GetOffsetForField(started_field->GetName(), false);
225       if (!end_offset.empty()) {
226         s << "size_t end_sum_index = (" << end_offset << ") / 8;";
227       } else {
228         end_offset = GetOffsetForField(started_field->GetName(), true);
229         if (end_offset.empty()) {
230           ERROR(started_field) << "Checksum Field end_offset can not be determined.";
231         }
232         s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;";
233       }
234       if (is_little_endian_) {
235         s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
236       } else {
237         s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
238       }
239       s << started_field->GetDataType() << " checksum;";
240       s << "checksum.Initialize();";
241       s << "for (uint8_t byte : checksum_view) { ";
242       s << "checksum.AddByte(byte);}";
243       s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
244         << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
245 
246       continue;
247     }
248 
249     auto field_size = field->GetSize();
250     // Fixed size fields have already been handled.
251     if (!field_size.has_dynamic()) {
252       continue;
253     }
254 
255     // Custom fields with dynamic size must have the offset for the field passed in as well
256     // as the end iterator so that they may ensure that they don't try to read past the end.
257     // Custom fields with fixed sizes will be handled in the static offset checking.
258     if (field->GetFieldType() == CustomField::kFieldType) {
259       // Check if we can determine offset from begin(), otherwise error because by this point,
260       // the size of the custom field is unknown and can't be subtracted from end() to get the
261       // offset.
262       auto offset = GetOffsetForField(field->GetName(), false);
263       if (offset.empty()) {
264         ERROR(field) << "Custom Field offset can not be determined from begin().";
265       }
266 
267       if (offset.bits() % 8 != 0) {
268         ERROR(field) << "Custom fields must be byte aligned.";
269       }
270 
271       // Custom fields are special as their size field takes an argument.
272       const auto& custom_size_var = field->GetName() + "_size";
273       s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
274       s << "(begin() + (" << offset << ") / 8);";
275 
276       s << "if (!" << custom_size_var << ".has_value()) { return false; }";
277       s << "it += *" << custom_size_var << ";";
278       s << "if (it > end()) return false;";
279       continue;
280     } else {
281       s << "it += (" << field_size.dynamic_string() << ") / 8;";
282       s << "if (it > end()) return false;";
283     }
284   }
285 
286   // Validate constraints after validating the size
287   if (parent_constraints_.size() > 0 && parent_ == nullptr) {
288     ERROR() << "Can't have a constraint on a NULL parent";
289   }
290 
291   for (const auto& constraint : parent_constraints_) {
292     s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
293     const auto& field = parent_->GetParamList().GetField(constraint.first);
294     if (field->GetFieldType() == ScalarField::kFieldType) {
295       s << std::get<int64_t>(constraint.second);
296     } else {
297       s << std::get<std::string>(constraint.second);
298     }
299     s << ") return false;";
300   }
301 
302   // Validate the packets fields last
303   for (const auto& field : fields_) {
304     field->GenValidator(s);
305     s << "\n";
306   }
307 
308   s << "return true;";
309   s << "}\n";
310   if (parent_ == nullptr) {
311     s << "bool was_validated_{false};\n";
312   }
313 }
314 
GenParserToString(std::ostream & s) const315 void PacketDef::GenParserToString(std::ostream& s) const {
316   s << "virtual std::string ToString() " << (parent_ != nullptr ? " override" : "") << " {";
317   s << "std::stringstream ss;";
318   s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
319 
320   if (fields_.size() > 0) {
321     s << "ss << \"\" ";
322     bool firstfield = true;
323     for (const auto& field : fields_) {
324       if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType ||
325           field->GetFieldType() == ChecksumStartField::kFieldType)
326         continue;
327 
328       s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
329 
330       field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
331 
332       if (firstfield) {
333         firstfield = false;
334       }
335     }
336     s << ";";
337   }
338 
339   s << "ss << \" }\";";
340   s << "return ss.str();";
341   s << "}\n";
342 }
343 
GenBuilderDefinition(std::ostream & s) const344 void PacketDef::GenBuilderDefinition(std::ostream& s) const {
345   s << "class " << name_ << "Builder";
346   if (parent_ != nullptr) {
347     s << " : public " << parent_->name_ << "Builder";
348   } else {
349     if (is_little_endian_) {
350       s << " : public PacketBuilder<kLittleEndian>";
351     } else {
352       s << " : public PacketBuilder<!kLittleEndian>";
353     }
354   }
355   s << " {";
356   s << " public:";
357   s << "  virtual ~" << name_ << "Builder()" << (parent_ != nullptr ? " override" : "") << " = default;";
358 
359   if (!fields_.HasBody()) {
360     GenBuilderCreate(s);
361     s << "\n";
362 
363     GenTestingFromView(s);
364     s << "\n";
365   }
366 
367   GenSerialize(s);
368   s << "\n";
369 
370   GenSize(s);
371   s << "\n";
372 
373   s << " protected:\n";
374   GenBuilderConstructor(s);
375   s << "\n";
376 
377   GenBuilderParameterChecker(s);
378   s << "\n";
379 
380   GenMembers(s);
381   s << "};\n";
382 
383   GenTestDefine(s);
384   s << "\n";
385 
386   GenFuzzTestDefine(s);
387   s << "\n";
388 }
389 
GenTestingFromView(std::ostream & s) const390 void PacketDef::GenTestingFromView(std::ostream& s) const {
391   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
392 
393   s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
394   s << "return " << name_ << "Builder::Create(";
395   FieldList params = GetParamList().GetFieldsWithoutTypes({
396       BodyField::kFieldType,
397   });
398   for (int i = 0; i < params.size(); i++) {
399     params[i]->GenBuilderParameterFromView(s);
400     if (i != params.size() - 1) {
401       s << ", ";
402     }
403   }
404   s << ");";
405   s << "}";
406 
407   s << "\n#endif\n";
408 }
409 
GenBuilderDefinitionPybind11(std::ostream & s) const410 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
411   s << "py::class_<" << name_ << "Builder";
412   if (parent_ != nullptr) {
413     s << ", " << parent_->name_ << "Builder";
414   } else {
415     if (is_little_endian_) {
416       s << ", PacketBuilder<kLittleEndian>";
417     } else {
418       s << ", PacketBuilder<!kLittleEndian>";
419     }
420   }
421   s << ", std::shared_ptr<" << name_ << "Builder>";
422   s << ">(m, \"" << name_ << "Builder\")";
423   if (!fields_.HasBody()) {
424     GenBuilderCreatePybind11(s);
425   }
426   s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
427   s << "std::vector<uint8_t> bytes;";
428   s << "BitInserter bi(bytes);";
429   s << "builder.Serialize(bi);";
430   s << "return bytes;})";
431   s << ";\n";
432 }
433 
GenTestDefine(std::ostream & s) const434 void PacketDef::GenTestDefine(std::ostream& s) const {
435   s << "#ifdef PACKET_TESTING\n";
436   s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
437   s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
438   s << "public: ";
439   s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
440   s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
441   s << "if (!view.IsValid()) { LOG_INFO(\"Invalid Packet Bytes (size = %zu)\", view.size());";
442   s << "for (size_t i = 0; i < view.size(); i++) { LOG_DEBUG(\"%5zd:%02X\", i, *(view.begin() + i)); }}";
443   s << "ASSERT_TRUE(view.IsValid());";
444   s << "auto packet = " << name_ << "Builder::FromView(view);";
445   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
446   s << "packet_bytes->reserve(packet->size());";
447   s << "BitInserter it(*packet_bytes);";
448   s << "packet->Serialize(it);";
449   s << "ASSERT_EQ(*packet_bytes, captured_packet);";
450   s << "}";
451   s << "};";
452   s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
453   s << "CompareBytes(GetParam());";
454   s << "}";
455   s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
456   s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
457   s << "\n#endif";
458 }
459 
GenFuzzTestDefine(std::ostream & s) const460 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
461   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
462   s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
463   s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
464   s << "auto vec = std::vector<uint8_t>(data, data + size);";
465   s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
466   s << "if (!view.IsValid()) { return; }";
467   s << "auto packet = " << name_ << "Builder::FromView(view);";
468   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
469   s << "packet_bytes->reserve(packet->size());";
470   s << "BitInserter it(*packet_bytes);";
471   s << "packet->Serialize(it);";
472   s << "}";
473   s << "\n#endif\n";
474   s << "#ifdef PACKET_FUZZ_TESTING\n";
475   s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
476   s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
477   s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
478   s << "public: ";
479   s << "explicit " << name_
480     << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {";
481   s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
482   s << "}}; ";
483   s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);";
484   s << "\n#endif";
485 }
486 
GetParametersToValidate() const487 FieldList PacketDef::GetParametersToValidate() const {
488   FieldList params_to_validate;
489   for (const auto& field : GetParamList()) {
490     if (field->HasParameterValidator()) {
491       params_to_validate.AppendField(field);
492     }
493   }
494   return params_to_validate;
495 }
496 
GenBuilderCreate(std::ostream & s) const497 void PacketDef::GenBuilderCreate(std::ostream& s) const {
498   s << "static std::unique_ptr<" << name_ << "Builder> Create(";
499 
500   auto params = GetParamList();
501   for (int i = 0; i < params.size(); i++) {
502     params[i]->GenBuilderParameter(s);
503     if (i != params.size() - 1) {
504       s << ", ";
505     }
506   }
507   s << ") {";
508 
509   // Call the constructor
510   s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
511 
512   params = params.GetFieldsWithoutTypes({
513       PayloadField::kFieldType,
514       BodyField::kFieldType,
515   });
516   // Add the parameters.
517   for (int i = 0; i < params.size(); i++) {
518     if (params[i]->BuilderParameterMustBeMoved()) {
519       s << "std::move(" << params[i]->GetName() << ")";
520     } else {
521       s << params[i]->GetName();
522     }
523     if (i != params.size() - 1) {
524       s << ", ";
525     }
526   }
527 
528   s << "));";
529   if (fields_.HasPayload()) {
530     s << "builder->payload_ = std::move(payload);";
531   }
532   s << "return builder;";
533   s << "}\n";
534 }
535 
GenBuilderCreatePybind11(std::ostream & s) const536 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
537   s << ".def(py::init([](";
538   auto params = GetParamList();
539   std::vector<std::string> constructor_args;
540   int i = 1;
541   for (const auto& param : params) {
542     i++;
543     std::stringstream ss;
544     auto param_type = param->GetBuilderParameterType();
545     if (param_type.empty()) {
546       continue;
547     }
548     // Use shared_ptr instead of unique_ptr for the Python interface
549     if (param->BuilderParameterMustBeMoved()) {
550       param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
551     }
552     ss << param_type << " " << param->GetName();
553     constructor_args.push_back(ss.str());
554   }
555   s << util::StringJoin(",", constructor_args) << "){";
556 
557   // Deal with move only args
558   for (const auto& param : params) {
559     std::stringstream ss;
560     auto param_type = param->GetBuilderParameterType();
561     if (param_type.empty()) {
562       continue;
563     }
564     if (!param->BuilderParameterMustBeMoved()) {
565       continue;
566     }
567     auto move_only_param_name = param->GetName() + "_move_only";
568     s << param_type << " " << move_only_param_name << ";";
569     if (param->IsContainerField()) {
570       // Assume single layer container and copy it
571       auto struct_type = param->GetElementField()->GetDataType();
572       struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
573       struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
574       s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
575       // Serialize each struct
576       s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
577       s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
578       s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
579       s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
580       // Parse it again
581       s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);";
582       s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
583       s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
584       // Push it into a new container
585       if (param->GetFieldType() == VectorField::kFieldType) {
586         s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
587       } else if (param->GetFieldType() == ArrayField::kFieldType) {
588         s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
589       } else {
590         ERROR() << param << " is not supported by Pybind11";
591       }
592       s << "}";
593     } else {
594       // Serialize the parameter and pass the bytes in a RawBuilder
595       s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
596       s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
597       s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
598       s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
599       s << move_only_param_name << " = ";
600       s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
601     }
602   }
603   s << "return " << name_ << "Builder::Create(";
604   std::vector<std::string> builder_vars;
605   for (const auto& param : params) {
606     std::stringstream ss;
607     auto param_type = param->GetBuilderParameterType();
608     if (param_type.empty()) {
609       continue;
610     }
611     auto param_name = param->GetName();
612     if (param->BuilderParameterMustBeMoved()) {
613       ss << "std::move(" << param_name << "_move_only)";
614     } else {
615       ss << param_name;
616     }
617     builder_vars.push_back(ss.str());
618   }
619   s << util::StringJoin(",", builder_vars) << ");}";
620   s << "))";
621 }
622 
GenBuilderParameterChecker(std::ostream & s) const623 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
624   FieldList params_to_validate = GetParametersToValidate();
625 
626   // Skip writing this function if there is nothing to validate.
627   if (params_to_validate.size() == 0) {
628     return;
629   }
630 
631   // Generate function arguments.
632   s << "void CheckParameterValues(";
633   for (int i = 0; i < params_to_validate.size(); i++) {
634     params_to_validate[i]->GenBuilderParameter(s);
635     if (i != params_to_validate.size() - 1) {
636       s << ", ";
637     }
638   }
639   s << ") {";
640 
641   // Check the parameters.
642   for (const auto& field : params_to_validate) {
643     field->GenParameterValidator(s);
644   }
645   s << "}\n";
646 }
647 
GenBuilderConstructor(std::ostream & s) const648 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
649   s << "explicit " << name_ << "Builder(";
650 
651   // Generate the constructor parameters.
652   auto params = GetParamList().GetFieldsWithoutTypes({
653       PayloadField::kFieldType,
654       BodyField::kFieldType,
655   });
656   for (int i = 0; i < params.size(); i++) {
657     params[i]->GenBuilderParameter(s);
658     if (i != params.size() - 1) {
659       s << ", ";
660     }
661   }
662   if (params.size() > 0 || parent_constraints_.size() > 0) {
663     s << ") :";
664   } else {
665     s << ")";
666   }
667 
668   // Get the list of parent params to call the parent constructor with.
669   FieldList parent_params;
670   if (parent_ != nullptr) {
671     // Pass parameters to the parent constructor
672     s << parent_->name_ << "Builder(";
673     parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
674         PayloadField::kFieldType,
675         BodyField::kFieldType,
676     });
677 
678     // Go through all the fields and replace constrained fields with fixed values
679     // when calling the parent constructor.
680     for (int i = 0; i < parent_params.size(); i++) {
681       const auto& field = parent_params[i];
682       const auto& constraint = parent_constraints_.find(field->GetName());
683       if (constraint != parent_constraints_.end()) {
684         if (field->GetFieldType() == ScalarField::kFieldType) {
685           s << std::get<int64_t>(constraint->second);
686         } else if (field->GetFieldType() == EnumField::kFieldType) {
687           s << std::get<std::string>(constraint->second);
688         } else {
689           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
690         }
691 
692         s << "/* " << field->GetName() << "_ */";
693       } else {
694         s << field->GetName();
695       }
696 
697       if (i != parent_params.size() - 1) {
698         s << ", ";
699       }
700     }
701     s << ") ";
702   }
703 
704   // Build a list of parameters that excludes all parent parameters.
705   FieldList saved_params;
706   for (const auto& field : params) {
707     if (parent_params.GetField(field->GetName()) == nullptr) {
708       saved_params.AppendField(field);
709     }
710   }
711   if (parent_ != nullptr && saved_params.size() > 0) {
712     s << ",";
713   }
714   for (int i = 0; i < saved_params.size(); i++) {
715     const auto& saved_param_name = saved_params[i]->GetName();
716     if (saved_params[i]->BuilderParameterMustBeMoved()) {
717       s << saved_param_name << "_(std::move(" << saved_param_name << "))";
718     } else {
719       s << saved_param_name << "_(" << saved_param_name << ")";
720     }
721     if (i != saved_params.size() - 1) {
722       s << ",";
723     }
724   }
725   s << " {";
726 
727   FieldList params_to_validate = GetParametersToValidate();
728 
729   if (params_to_validate.size() > 0) {
730     s << "CheckParameterValues(";
731     for (int i = 0; i < params_to_validate.size(); i++) {
732       s << params_to_validate[i]->GetName() << "_";
733       if (i != params_to_validate.size() - 1) {
734         s << ", ";
735       }
736     }
737     s << ");";
738   }
739 
740   s << "}\n";
741 }
742