1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "parent_def.h"
18 
19 #include "fields/all_fields.h"
20 #include "util.h"
21 
ParentDef(std::string name,FieldList fields)22 ParentDef::ParentDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
ParentDef(std::string name,FieldList fields,ParentDef * parent)23 ParentDef::ParentDef(std::string name, FieldList fields, ParentDef* parent)
24     : TypeDef(name), fields_(fields), parent_(parent) {}
25 
AddParentConstraint(std::string field_name,std::variant<int64_t,std::string> value)26 void ParentDef::AddParentConstraint(std::string field_name, std::variant<int64_t, std::string> value) {
27   // NOTE: This could end up being very slow if there are a lot of constraints.
28   const auto& parent_params = parent_->GetParamList();
29   const auto& constrained_field = parent_params.GetField(field_name);
30   if (constrained_field == nullptr) {
31     ERROR() << "Attempting to constrain field " << field_name << " in parent " << parent_->name_
32             << ", but no such field exists.";
33   }
34 
35   if (constrained_field->GetFieldType() == ScalarField::kFieldType) {
36     if (!std::holds_alternative<int64_t>(value)) {
37       ERROR(constrained_field) << "Attempting to constrain a scalar field to an enum value in " << parent_->name_;
38     }
39   } else if (constrained_field->GetFieldType() == EnumField::kFieldType) {
40     if (!std::holds_alternative<std::string>(value)) {
41       ERROR(constrained_field) << "Attempting to constrain an enum field to a scalar value in " << parent_->name_;
42     }
43     const auto& enum_def = static_cast<EnumField*>(constrained_field)->GetEnumDef();
44     if (!enum_def.HasEntry(std::get<std::string>(value))) {
45       ERROR(constrained_field) << "No matching enumeration \"" << std::get<std::string>(value)
46                                << "\" for constraint on enum in parent " << parent_->name_ << ".";
47     }
48 
49     // For enums, we have to qualify the value using the enum type name.
50     value = enum_def.GetTypeName() + "::" + std::get<std::string>(value);
51   } else {
52     ERROR(constrained_field) << "Field in parent " << parent_->name_ << " is not viable for constraining.";
53   }
54 
55   parent_constraints_.insert(std::pair(field_name, value));
56 }
57 
58 // Assign all size fields to their corresponding variable length fields.
59 // Will crash if
60 //  - there aren't any fields that don't match up to a field.
61 //  - the size field points to a fixed size field.
62 //  - if the size field comes after the variable length field.
AssignSizeFields()63 void ParentDef::AssignSizeFields() {
64   for (const auto& field : fields_) {
65     DEBUG() << "field name: " << field->GetName();
66 
67     if (field->GetFieldType() != SizeField::kFieldType && field->GetFieldType() != CountField::kFieldType) {
68       continue;
69     }
70 
71     const SizeField* size_field = static_cast<SizeField*>(field);
72     // Check to see if a corresponding field can be found.
73     const auto& var_len_field = fields_.GetField(size_field->GetSizedFieldName());
74     if (var_len_field == nullptr) {
75       ERROR(field) << "Could not find corresponding field for size/count field.";
76     }
77 
78     // Do the ordering check to ensure the size field comes before the
79     // variable length field.
80     for (auto it = fields_.begin(); *it != size_field; it++) {
81       DEBUG() << "field name: " << (*it)->GetName();
82       if (*it == var_len_field) {
83         ERROR(var_len_field, size_field) << "Size/count field must come before the variable length field it describes.";
84       }
85     }
86 
87     if (var_len_field->GetFieldType() == PayloadField::kFieldType) {
88       const auto& payload_field = static_cast<PayloadField*>(var_len_field);
89       payload_field->SetSizeField(size_field);
90       continue;
91     }
92 
93     if (var_len_field->GetFieldType() == BodyField::kFieldType) {
94       const auto& body_field = static_cast<BodyField*>(var_len_field);
95       body_field->SetSizeField(size_field);
96       continue;
97     }
98 
99     if (var_len_field->GetFieldType() == VectorField::kFieldType) {
100       const auto& vector_field = static_cast<VectorField*>(var_len_field);
101       vector_field->SetSizeField(size_field);
102       continue;
103     }
104 
105     // If we've reached this point then the field wasn't a variable length field.
106     // Check to see if the field is a variable length field
107     ERROR(field, size_field) << "Can not use size/count in reference to a fixed size field.\n";
108   }
109 }
110 
SetEndianness(bool is_little_endian)111 void ParentDef::SetEndianness(bool is_little_endian) {
112   is_little_endian_ = is_little_endian;
113 }
114 
115 // Get the size. You scan specify without_payload in order to exclude payload fields as children will be overriding it.
GetSize(bool without_payload) const116 Size ParentDef::GetSize(bool without_payload) const {
117   auto size = Size(0);
118 
119   for (const auto& field : fields_) {
120     if (without_payload &&
121         (field->GetFieldType() == PayloadField::kFieldType || field->GetFieldType() == BodyField::kFieldType)) {
122       continue;
123     }
124 
125     // The offset to the field must be passed in as an argument for dynamically sized custom fields.
126     if (field->GetFieldType() == CustomField::kFieldType && field->GetSize().has_dynamic()) {
127       std::stringstream custom_field_size;
128 
129       // Custom fields are special as their size field takes an argument.
130       custom_field_size << field->GetSize().dynamic_string() << "(begin()";
131 
132       // Check if we can determine offset from begin(), otherwise error because by this point,
133       // the size of the custom field is unknown and can't be subtracted from end() to get the
134       // offset.
135       auto offset = GetOffsetForField(field->GetName(), false);
136       if (offset.empty()) {
137         ERROR(field) << "Custom Field offset can not be determined from begin().";
138       }
139 
140       if (offset.bits() % 8 != 0) {
141         ERROR(field) << "Custom fields must be byte aligned.";
142       }
143       if (offset.has_bits()) custom_field_size << " + " << offset.bits() / 8;
144       if (offset.has_dynamic()) custom_field_size << " + " << offset.dynamic_string();
145       custom_field_size << ")";
146 
147       size += custom_field_size.str();
148       continue;
149     }
150 
151     size += field->GetSize();
152   }
153 
154   if (parent_ != nullptr) {
155     size += parent_->GetSize(true);
156   }
157 
158   return size;
159 }
160 
161 // Get the offset until the field is reached, if there is no field
162 // returns an empty Size. from_end requests the offset to the field
163 // starting from the end() iterator. If there is a field with an unknown
164 // size along the traversal, then an empty size is returned.
GetOffsetForField(std::string field_name,bool from_end) const165 Size ParentDef::GetOffsetForField(std::string field_name, bool from_end) const {
166   // Check first if the field exists.
167   if (fields_.GetField(field_name) == nullptr) {
168     ERROR() << "Can't find a field offset for nonexistent field named: " << field_name << " in " << name_;
169   }
170 
171   // We have to use a generic lambda to conditionally change iteration direction
172   // due to iterator and reverse_iterator being different types.
173   auto size_lambda = [&](auto from, auto to) -> Size {
174     auto size = Size(0);
175     for (auto it = from; it != to; it++) {
176       // We've reached the field, end the loop.
177       if ((*it)->GetName() == field_name) break;
178       const auto& field = *it;
179       // If there is a field with an unknown size before the field, return an empty Size.
180       if (field->GetSize().empty()) {
181         return Size();
182       }
183       if (field->GetFieldType() != PaddingField::kFieldType || !from_end) {
184         size += field->GetSize();
185       }
186     }
187     return size;
188   };
189 
190   // Change iteration direction based on from_end.
191   auto size = Size();
192   if (from_end)
193     size = size_lambda(fields_.rbegin(), fields_.rend());
194   else
195     size = size_lambda(fields_.begin(), fields_.end());
196   if (size.empty()) return size;
197 
198   // We need the offset until a payload or body field.
199   if (parent_ != nullptr) {
200     if (parent_->fields_.HasPayload()) {
201       auto parent_payload_offset = parent_->GetOffsetForField("payload", from_end);
202       if (parent_payload_offset.empty()) {
203         ERROR() << "Empty offset for payload in " << parent_->name_ << " finding the offset for field: " << field_name;
204       }
205       size += parent_payload_offset;
206     } else {
207       auto parent_body_offset = parent_->GetOffsetForField("body", from_end);
208       if (parent_body_offset.empty()) {
209         ERROR() << "Empty offset for body in " << parent_->name_ << " finding the offset for field: " << field_name;
210       }
211       size += parent_body_offset;
212     }
213   }
214 
215   return size;
216 }
217 
GetParamList() const218 FieldList ParentDef::GetParamList() const {
219   FieldList params;
220 
221   std::set<std::string> param_types = {
222       ScalarField::kFieldType,
223       EnumField::kFieldType,
224       ArrayField::kFieldType,
225       VectorField::kFieldType,
226       CustomField::kFieldType,
227       StructField::kFieldType,
228       VariableLengthStructField::kFieldType,
229       PayloadField::kFieldType,
230   };
231 
232   if (parent_ != nullptr) {
233     auto parent_params = parent_->GetParamList().GetFieldsWithTypes(param_types);
234 
235     // Do not include constrained fields in the params
236     for (const auto& field : parent_params) {
237       if (parent_constraints_.find(field->GetName()) == parent_constraints_.end()) {
238         params.AppendField(field);
239       }
240     }
241   }
242   // Add our parameters.
243   return params.Merge(fields_.GetFieldsWithTypes(param_types));
244 }
245 
GenMembers(std::ostream & s) const246 void ParentDef::GenMembers(std::ostream& s) const {
247   // Add the parameter list.
248   for (int i = 0; i < fields_.size(); i++) {
249     if (fields_[i]->GenBuilderMember(s)) {
250       s << "_{};";
251     }
252   }
253 }
254 
GenSize(std::ostream & s) const255 void ParentDef::GenSize(std::ostream& s) const {
256   auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
257   auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
258 
259   s << "protected:";
260   s << "size_t BitsOfHeader() const {";
261   s << "return 0";
262 
263   if (parent_ != nullptr) {
264     if (parent_->GetDefinitionType() == Type::PACKET) {
265       s << " + " << parent_->name_ << "Builder::BitsOfHeader() ";
266     } else {
267       s << " + " << parent_->name_ << "::BitsOfHeader() ";
268     }
269   }
270 
271   for (const auto& field : header_fields) {
272     s << " + " << field->GetBuilderSize();
273   }
274   s << ";";
275 
276   s << "}\n\n";
277 
278   s << "size_t BitsOfFooter() const {";
279   s << "return 0";
280   for (const auto& field : footer_fields) {
281     s << " + " << field->GetBuilderSize();
282   }
283 
284   if (parent_ != nullptr) {
285     if (parent_->GetDefinitionType() == Type::PACKET) {
286       s << " + " << parent_->name_ << "Builder::BitsOfFooter() ";
287     } else {
288       s << " + " << parent_->name_ << "::BitsOfFooter() ";
289     }
290   }
291   s << ";";
292   s << "}\n\n";
293 
294   if (fields_.HasPayload()) {
295     s << "size_t GetPayloadSize() const {";
296     s << "if (payload_ != nullptr) {return payload_->size();}";
297     s << "else { return size() - (BitsOfHeader() + BitsOfFooter()) / 8;}";
298     s << ";}\n\n";
299   }
300 
301   Size padded_size;
302   for (const auto& field : header_fields) {
303     if (field->GetFieldType() == PaddingField::kFieldType) {
304       if (!padded_size.empty()) {
305         ERROR() << "Only one padding field is allowed.  Second field: " << field->GetName();
306       }
307       padded_size = field->GetSize();
308     }
309   }
310 
311   s << "public:";
312   s << "virtual size_t size() const override {";
313   if (!padded_size.empty()) {
314     s << "return " << padded_size.bytes() << ";}";
315     s << "size_t unpadded_size() const {";
316   }
317   s << "return (BitsOfHeader() / 8)";
318   if (fields_.HasPayload()) {
319     s << "+ payload_->size()";
320   }
321   if (fields_.HasBody()) {
322     for (const auto& field : header_fields) {
323       if (field->GetFieldType() == SizeField::kFieldType) {
324         const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
325         if (field_name == "body") {
326           s << "+ body_size_extracted_";
327         }
328       }
329     }
330   }
331   s << " + (BitsOfFooter() / 8);";
332   s << "}\n";
333 }
334 
GenSerialize(std::ostream & s) const335 void ParentDef::GenSerialize(std::ostream& s) const {
336   auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
337   auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
338 
339   s << "protected:";
340   s << "void SerializeHeader(BitInserter&";
341   if (parent_ != nullptr || header_fields.size() != 0) {
342     s << " i ";
343   }
344   s << ") const {";
345 
346   if (parent_ != nullptr) {
347     if (parent_->GetDefinitionType() == Type::PACKET) {
348       s << parent_->name_ << "Builder::SerializeHeader(i);";
349     } else {
350       s << parent_->name_ << "::SerializeHeader(i);";
351     }
352   }
353 
354   for (const auto& field : header_fields) {
355     if (field->GetFieldType() == SizeField::kFieldType) {
356       const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
357       const auto& sized_field = fields_.GetField(field_name);
358       if (sized_field == nullptr) {
359         ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
360       }
361       if (sized_field->GetFieldType() == PayloadField::kFieldType) {
362         s << "size_t payload_bytes = GetPayloadSize();";
363         std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
364         if (modifier != "") {
365           s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
366           s << "payload_bytes = payload_bytes + (" << modifier << ") / 8;";
367         }
368         s << "ASSERT(payload_bytes < (static_cast<size_t>(1) << " << field->GetSize().bits() << "));";
369         s << "insert(static_cast<" << field->GetDataType() << ">(payload_bytes), i," << field->GetSize().bits() << ");";
370       } else if (sized_field->GetFieldType() == BodyField::kFieldType) {
371         s << field->GetName() << "_extracted_ = 0;";
372         s << "size_t local_size = " << name_ << "::size();";
373 
374         s << "ASSERT((size() - local_size) < (static_cast<size_t>(1) << " << field->GetSize().bits() << "));";
375         s << "insert(static_cast<" << field->GetDataType() << ">(size() - local_size), i," << field->GetSize().bits()
376           << ");";
377       } else {
378         if (sized_field->GetFieldType() != VectorField::kFieldType) {
379           ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
380         }
381         const auto& vector_name = field_name + "_";
382         const VectorField* vector = (VectorField*)sized_field;
383         s << "size_t " << vector_name + "bytes = 0;";
384         if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
385           s << "for (auto elem : " << vector_name << ") {";
386           s << vector_name + "bytes += elem.size(); }";
387         } else {
388           s << vector_name + "bytes = ";
389           s << vector_name << ".size() * ((" << vector->element_size_ << ") / 8);";
390         }
391         std::string modifier = vector->GetSizeModifier();
392         if (modifier != "") {
393           s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
394           s << vector_name << "bytes = ";
395           s << vector_name << "bytes + (" << modifier << ") / 8;";
396         }
397         s << "ASSERT(" << vector_name + "bytes < (1 << " << field->GetSize().bits() << "));";
398         s << "insert(" << vector_name << "bytes, i, ";
399         s << field->GetSize().bits() << ");";
400       }
401     } else if (field->GetFieldType() == ChecksumStartField::kFieldType) {
402       const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
403       const auto& started_field = fields_.GetField(field_name);
404       if (started_field == nullptr) {
405         ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
406                      << ")";
407       }
408       s << "auto shared_checksum_ptr = std::make_shared<" << started_field->GetDataType() << ">();";
409       s << "shared_checksum_ptr->Initialize();";
410       s << "i.RegisterObserver(packet::ByteObserver(";
411       s << "[shared_checksum_ptr](uint8_t byte){ shared_checksum_ptr->AddByte(byte);},";
412       s << "[shared_checksum_ptr](){ return static_cast<uint64_t>(shared_checksum_ptr->GetChecksum());}));";
413     } else if (field->GetFieldType() == PaddingField::kFieldType) {
414       s << "ASSERT(unpadded_size() <= " << field->GetSize().bytes() << ");";
415       s << "size_t padding_bytes = ";
416       s << field->GetSize().bytes() << " - unpadded_size();";
417       s << "for (size_t padding = 0; padding < padding_bytes; padding++) {i.insert_byte(0);}";
418     } else if (field->GetFieldType() == CountField::kFieldType) {
419       const auto& vector_name = ((SizeField*)field)->GetSizedFieldName() + "_";
420       s << "insert(" << vector_name << ".size(), i, " << field->GetSize().bits() << ");";
421     } else {
422       field->GenInserter(s);
423     }
424   }
425   s << "}\n\n";
426 
427   s << "void SerializeFooter(BitInserter&";
428   if (parent_ != nullptr || footer_fields.size() != 0) {
429     s << " i ";
430   }
431   s << ") const {";
432 
433   for (const auto& field : footer_fields) {
434     field->GenInserter(s);
435   }
436   if (parent_ != nullptr) {
437     if (parent_->GetDefinitionType() == Type::PACKET) {
438       s << parent_->name_ << "Builder::SerializeFooter(i);";
439     } else {
440       s << parent_->name_ << "::SerializeFooter(i);";
441     }
442   }
443   s << "}\n\n";
444 
445   s << "public:";
446   s << "virtual void Serialize(BitInserter& i) const override {";
447   s << "SerializeHeader(i);";
448   if (fields_.HasPayload()) {
449     s << "payload_->Serialize(i);";
450   }
451   s << "SerializeFooter(i);";
452 
453   s << "}\n";
454 }
455 
GenInstanceOf(std::ostream & s) const456 void ParentDef::GenInstanceOf(std::ostream& s) const {
457   if (parent_ != nullptr && parent_constraints_.size() > 0) {
458     s << "static bool IsInstance(const " << parent_->name_ << "& parent) {";
459     // Get the list of parent params.
460     FieldList parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
461         PayloadField::kFieldType,
462         BodyField::kFieldType,
463     });
464 
465     // Check if constrained parent fields are set to their correct values.
466     for (int i = 0; i < parent_params.size(); i++) {
467       const auto& field = parent_params[i];
468       const auto& constraint = parent_constraints_.find(field->GetName());
469       if (constraint != parent_constraints_.end()) {
470         s << "if (parent." << field->GetName() << "_ != ";
471         if (field->GetFieldType() == ScalarField::kFieldType) {
472           s << std::get<int64_t>(constraint->second) << ")";
473           s << "{ return false;}";
474         } else if (field->GetFieldType() == EnumField::kFieldType) {
475           s << std::get<std::string>(constraint->second) << ")";
476           s << "{ return false;}";
477         } else {
478           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
479         }
480       }
481     }
482     s << "return true;}";
483   }
484 }
485