1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "parent_def.h"
18
19 #include "fields/all_fields.h"
20 #include "util.h"
21
ParentDef(std::string name,FieldList fields)22 ParentDef::ParentDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
ParentDef(std::string name,FieldList fields,ParentDef * parent)23 ParentDef::ParentDef(std::string name, FieldList fields, ParentDef* parent)
24 : TypeDef(name), fields_(fields), parent_(parent) {}
25
AddParentConstraint(std::string field_name,std::variant<int64_t,std::string> value)26 void ParentDef::AddParentConstraint(std::string field_name, std::variant<int64_t, std::string> value) {
27 // NOTE: This could end up being very slow if there are a lot of constraints.
28 const auto& parent_params = parent_->GetParamList();
29 const auto& constrained_field = parent_params.GetField(field_name);
30 if (constrained_field == nullptr) {
31 ERROR() << "Attempting to constrain field " << field_name << " in parent " << parent_->name_
32 << ", but no such field exists.";
33 }
34
35 if (constrained_field->GetFieldType() == ScalarField::kFieldType) {
36 if (!std::holds_alternative<int64_t>(value)) {
37 ERROR(constrained_field) << "Attempting to constrain a scalar field to an enum value in " << parent_->name_;
38 }
39 } else if (constrained_field->GetFieldType() == EnumField::kFieldType) {
40 if (!std::holds_alternative<std::string>(value)) {
41 ERROR(constrained_field) << "Attempting to constrain an enum field to a scalar value in " << parent_->name_;
42 }
43 const auto& enum_def = static_cast<EnumField*>(constrained_field)->GetEnumDef();
44 if (!enum_def.HasEntry(std::get<std::string>(value))) {
45 ERROR(constrained_field) << "No matching enumeration \"" << std::get<std::string>(value)
46 << "\" for constraint on enum in parent " << parent_->name_ << ".";
47 }
48
49 // For enums, we have to qualify the value using the enum type name.
50 value = enum_def.GetTypeName() + "::" + std::get<std::string>(value);
51 } else {
52 ERROR(constrained_field) << "Field in parent " << parent_->name_ << " is not viable for constraining.";
53 }
54
55 parent_constraints_.insert(std::pair(field_name, value));
56 }
57
58 // Assign all size fields to their corresponding variable length fields.
59 // Will crash if
60 // - there aren't any fields that don't match up to a field.
61 // - the size field points to a fixed size field.
62 // - if the size field comes after the variable length field.
AssignSizeFields()63 void ParentDef::AssignSizeFields() {
64 for (const auto& field : fields_) {
65 DEBUG() << "field name: " << field->GetName();
66
67 if (field->GetFieldType() != SizeField::kFieldType && field->GetFieldType() != CountField::kFieldType) {
68 continue;
69 }
70
71 const SizeField* size_field = static_cast<SizeField*>(field);
72 // Check to see if a corresponding field can be found.
73 const auto& var_len_field = fields_.GetField(size_field->GetSizedFieldName());
74 if (var_len_field == nullptr) {
75 ERROR(field) << "Could not find corresponding field for size/count field.";
76 }
77
78 // Do the ordering check to ensure the size field comes before the
79 // variable length field.
80 for (auto it = fields_.begin(); *it != size_field; it++) {
81 DEBUG() << "field name: " << (*it)->GetName();
82 if (*it == var_len_field) {
83 ERROR(var_len_field, size_field) << "Size/count field must come before the variable length field it describes.";
84 }
85 }
86
87 if (var_len_field->GetFieldType() == PayloadField::kFieldType) {
88 const auto& payload_field = static_cast<PayloadField*>(var_len_field);
89 payload_field->SetSizeField(size_field);
90 continue;
91 }
92
93 if (var_len_field->GetFieldType() == BodyField::kFieldType) {
94 const auto& body_field = static_cast<BodyField*>(var_len_field);
95 body_field->SetSizeField(size_field);
96 continue;
97 }
98
99 if (var_len_field->GetFieldType() == VectorField::kFieldType) {
100 const auto& vector_field = static_cast<VectorField*>(var_len_field);
101 vector_field->SetSizeField(size_field);
102 continue;
103 }
104
105 // If we've reached this point then the field wasn't a variable length field.
106 // Check to see if the field is a variable length field
107 ERROR(field, size_field) << "Can not use size/count in reference to a fixed size field.\n";
108 }
109 }
110
SetEndianness(bool is_little_endian)111 void ParentDef::SetEndianness(bool is_little_endian) {
112 is_little_endian_ = is_little_endian;
113 }
114
115 // Get the size. You scan specify without_payload in order to exclude payload fields as children will be overriding it.
GetSize(bool without_payload) const116 Size ParentDef::GetSize(bool without_payload) const {
117 auto size = Size(0);
118
119 for (const auto& field : fields_) {
120 if (without_payload &&
121 (field->GetFieldType() == PayloadField::kFieldType || field->GetFieldType() == BodyField::kFieldType)) {
122 continue;
123 }
124
125 // The offset to the field must be passed in as an argument for dynamically sized custom fields.
126 if (field->GetFieldType() == CustomField::kFieldType && field->GetSize().has_dynamic()) {
127 std::stringstream custom_field_size;
128
129 // Custom fields are special as their size field takes an argument.
130 custom_field_size << field->GetSize().dynamic_string() << "(begin()";
131
132 // Check if we can determine offset from begin(), otherwise error because by this point,
133 // the size of the custom field is unknown and can't be subtracted from end() to get the
134 // offset.
135 auto offset = GetOffsetForField(field->GetName(), false);
136 if (offset.empty()) {
137 ERROR(field) << "Custom Field offset can not be determined from begin().";
138 }
139
140 if (offset.bits() % 8 != 0) {
141 ERROR(field) << "Custom fields must be byte aligned.";
142 }
143 if (offset.has_bits()) custom_field_size << " + " << offset.bits() / 8;
144 if (offset.has_dynamic()) custom_field_size << " + " << offset.dynamic_string();
145 custom_field_size << ")";
146
147 size += custom_field_size.str();
148 continue;
149 }
150
151 size += field->GetSize();
152 }
153
154 if (parent_ != nullptr) {
155 size += parent_->GetSize(true);
156 }
157
158 return size;
159 }
160
161 // Get the offset until the field is reached, if there is no field
162 // returns an empty Size. from_end requests the offset to the field
163 // starting from the end() iterator. If there is a field with an unknown
164 // size along the traversal, then an empty size is returned.
GetOffsetForField(std::string field_name,bool from_end) const165 Size ParentDef::GetOffsetForField(std::string field_name, bool from_end) const {
166 // Check first if the field exists.
167 if (fields_.GetField(field_name) == nullptr) {
168 ERROR() << "Can't find a field offset for nonexistent field named: " << field_name << " in " << name_;
169 }
170
171 // We have to use a generic lambda to conditionally change iteration direction
172 // due to iterator and reverse_iterator being different types.
173 auto size_lambda = [&](auto from, auto to) -> Size {
174 auto size = Size(0);
175 for (auto it = from; it != to; it++) {
176 // We've reached the field, end the loop.
177 if ((*it)->GetName() == field_name) break;
178 const auto& field = *it;
179 // If there is a field with an unknown size before the field, return an empty Size.
180 if (field->GetSize().empty()) {
181 return Size();
182 }
183 if (field->GetFieldType() != PaddingField::kFieldType || !from_end) {
184 size += field->GetSize();
185 }
186 }
187 return size;
188 };
189
190 // Change iteration direction based on from_end.
191 auto size = Size();
192 if (from_end)
193 size = size_lambda(fields_.rbegin(), fields_.rend());
194 else
195 size = size_lambda(fields_.begin(), fields_.end());
196 if (size.empty()) return size;
197
198 // We need the offset until a payload or body field.
199 if (parent_ != nullptr) {
200 if (parent_->fields_.HasPayload()) {
201 auto parent_payload_offset = parent_->GetOffsetForField("payload", from_end);
202 if (parent_payload_offset.empty()) {
203 ERROR() << "Empty offset for payload in " << parent_->name_ << " finding the offset for field: " << field_name;
204 }
205 size += parent_payload_offset;
206 } else {
207 auto parent_body_offset = parent_->GetOffsetForField("body", from_end);
208 if (parent_body_offset.empty()) {
209 ERROR() << "Empty offset for body in " << parent_->name_ << " finding the offset for field: " << field_name;
210 }
211 size += parent_body_offset;
212 }
213 }
214
215 return size;
216 }
217
GetParamList() const218 FieldList ParentDef::GetParamList() const {
219 FieldList params;
220
221 std::set<std::string> param_types = {
222 ScalarField::kFieldType,
223 EnumField::kFieldType,
224 ArrayField::kFieldType,
225 VectorField::kFieldType,
226 CustomField::kFieldType,
227 StructField::kFieldType,
228 VariableLengthStructField::kFieldType,
229 PayloadField::kFieldType,
230 };
231
232 if (parent_ != nullptr) {
233 auto parent_params = parent_->GetParamList().GetFieldsWithTypes(param_types);
234
235 // Do not include constrained fields in the params
236 for (const auto& field : parent_params) {
237 if (parent_constraints_.find(field->GetName()) == parent_constraints_.end()) {
238 params.AppendField(field);
239 }
240 }
241 }
242 // Add our parameters.
243 return params.Merge(fields_.GetFieldsWithTypes(param_types));
244 }
245
GenMembers(std::ostream & s) const246 void ParentDef::GenMembers(std::ostream& s) const {
247 // Add the parameter list.
248 for (int i = 0; i < fields_.size(); i++) {
249 if (fields_[i]->GenBuilderMember(s)) {
250 s << "_{};";
251 }
252 }
253 }
254
GenSize(std::ostream & s) const255 void ParentDef::GenSize(std::ostream& s) const {
256 auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
257 auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
258
259 s << "protected:";
260 s << "size_t BitsOfHeader() const {";
261 s << "return 0";
262
263 if (parent_ != nullptr) {
264 if (parent_->GetDefinitionType() == Type::PACKET) {
265 s << " + " << parent_->name_ << "Builder::BitsOfHeader() ";
266 } else {
267 s << " + " << parent_->name_ << "::BitsOfHeader() ";
268 }
269 }
270
271 for (const auto& field : header_fields) {
272 s << " + " << field->GetBuilderSize();
273 }
274 s << ";";
275
276 s << "}\n\n";
277
278 s << "size_t BitsOfFooter() const {";
279 s << "return 0";
280 for (const auto& field : footer_fields) {
281 s << " + " << field->GetBuilderSize();
282 }
283
284 if (parent_ != nullptr) {
285 if (parent_->GetDefinitionType() == Type::PACKET) {
286 s << " + " << parent_->name_ << "Builder::BitsOfFooter() ";
287 } else {
288 s << " + " << parent_->name_ << "::BitsOfFooter() ";
289 }
290 }
291 s << ";";
292 s << "}\n\n";
293
294 if (fields_.HasPayload()) {
295 s << "size_t GetPayloadSize() const {";
296 s << "if (payload_ != nullptr) {return payload_->size();}";
297 s << "else { return size() - (BitsOfHeader() + BitsOfFooter()) / 8;}";
298 s << ";}\n\n";
299 }
300
301 Size padded_size;
302 for (const auto& field : header_fields) {
303 if (field->GetFieldType() == PaddingField::kFieldType) {
304 if (!padded_size.empty()) {
305 ERROR() << "Only one padding field is allowed. Second field: " << field->GetName();
306 }
307 padded_size = field->GetSize();
308 }
309 }
310
311 s << "public:";
312 s << "virtual size_t size() const override {";
313 if (!padded_size.empty()) {
314 s << "return " << padded_size.bytes() << ";}";
315 s << "size_t unpadded_size() const {";
316 }
317 s << "return (BitsOfHeader() / 8)";
318 if (fields_.HasPayload()) {
319 s << "+ payload_->size()";
320 }
321 if (fields_.HasBody()) {
322 for (const auto& field : header_fields) {
323 if (field->GetFieldType() == SizeField::kFieldType) {
324 const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
325 if (field_name == "body") {
326 s << "+ body_size_extracted_";
327 }
328 }
329 }
330 }
331 s << " + (BitsOfFooter() / 8);";
332 s << "}\n";
333 }
334
GenSerialize(std::ostream & s) const335 void ParentDef::GenSerialize(std::ostream& s) const {
336 auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
337 auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
338
339 s << "protected:";
340 s << "void SerializeHeader(BitInserter&";
341 if (parent_ != nullptr || header_fields.size() != 0) {
342 s << " i ";
343 }
344 s << ") const {";
345
346 if (parent_ != nullptr) {
347 if (parent_->GetDefinitionType() == Type::PACKET) {
348 s << parent_->name_ << "Builder::SerializeHeader(i);";
349 } else {
350 s << parent_->name_ << "::SerializeHeader(i);";
351 }
352 }
353
354 for (const auto& field : header_fields) {
355 if (field->GetFieldType() == SizeField::kFieldType) {
356 const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
357 const auto& sized_field = fields_.GetField(field_name);
358 if (sized_field == nullptr) {
359 ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
360 }
361 if (sized_field->GetFieldType() == PayloadField::kFieldType) {
362 s << "size_t payload_bytes = GetPayloadSize();";
363 std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
364 if (modifier != "") {
365 s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
366 s << "payload_bytes = payload_bytes + (" << modifier << ") / 8;";
367 }
368 s << "ASSERT(payload_bytes < (static_cast<size_t>(1) << " << field->GetSize().bits() << "));";
369 s << "insert(static_cast<" << field->GetDataType() << ">(payload_bytes), i," << field->GetSize().bits() << ");";
370 } else if (sized_field->GetFieldType() == BodyField::kFieldType) {
371 s << field->GetName() << "_extracted_ = 0;";
372 s << "size_t local_size = " << name_ << "::size();";
373
374 s << "ASSERT((size() - local_size) < (static_cast<size_t>(1) << " << field->GetSize().bits() << "));";
375 s << "insert(static_cast<" << field->GetDataType() << ">(size() - local_size), i," << field->GetSize().bits()
376 << ");";
377 } else {
378 if (sized_field->GetFieldType() != VectorField::kFieldType) {
379 ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
380 }
381 const auto& vector_name = field_name + "_";
382 const VectorField* vector = (VectorField*)sized_field;
383 s << "size_t " << vector_name + "bytes = 0;";
384 if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
385 s << "for (auto elem : " << vector_name << ") {";
386 s << vector_name + "bytes += elem.size(); }";
387 } else {
388 s << vector_name + "bytes = ";
389 s << vector_name << ".size() * ((" << vector->element_size_ << ") / 8);";
390 }
391 std::string modifier = vector->GetSizeModifier();
392 if (modifier != "") {
393 s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
394 s << vector_name << "bytes = ";
395 s << vector_name << "bytes + (" << modifier << ") / 8;";
396 }
397 s << "ASSERT(" << vector_name + "bytes < (1 << " << field->GetSize().bits() << "));";
398 s << "insert(" << vector_name << "bytes, i, ";
399 s << field->GetSize().bits() << ");";
400 }
401 } else if (field->GetFieldType() == ChecksumStartField::kFieldType) {
402 const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
403 const auto& started_field = fields_.GetField(field_name);
404 if (started_field == nullptr) {
405 ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
406 << ")";
407 }
408 s << "auto shared_checksum_ptr = std::make_shared<" << started_field->GetDataType() << ">();";
409 s << "shared_checksum_ptr->Initialize();";
410 s << "i.RegisterObserver(packet::ByteObserver(";
411 s << "[shared_checksum_ptr](uint8_t byte){ shared_checksum_ptr->AddByte(byte);},";
412 s << "[shared_checksum_ptr](){ return static_cast<uint64_t>(shared_checksum_ptr->GetChecksum());}));";
413 } else if (field->GetFieldType() == PaddingField::kFieldType) {
414 s << "ASSERT(unpadded_size() <= " << field->GetSize().bytes() << ");";
415 s << "size_t padding_bytes = ";
416 s << field->GetSize().bytes() << " - unpadded_size();";
417 s << "for (size_t padding = 0; padding < padding_bytes; padding++) {i.insert_byte(0);}";
418 } else if (field->GetFieldType() == CountField::kFieldType) {
419 const auto& vector_name = ((SizeField*)field)->GetSizedFieldName() + "_";
420 s << "insert(" << vector_name << ".size(), i, " << field->GetSize().bits() << ");";
421 } else {
422 field->GenInserter(s);
423 }
424 }
425 s << "}\n\n";
426
427 s << "void SerializeFooter(BitInserter&";
428 if (parent_ != nullptr || footer_fields.size() != 0) {
429 s << " i ";
430 }
431 s << ") const {";
432
433 for (const auto& field : footer_fields) {
434 field->GenInserter(s);
435 }
436 if (parent_ != nullptr) {
437 if (parent_->GetDefinitionType() == Type::PACKET) {
438 s << parent_->name_ << "Builder::SerializeFooter(i);";
439 } else {
440 s << parent_->name_ << "::SerializeFooter(i);";
441 }
442 }
443 s << "}\n\n";
444
445 s << "public:";
446 s << "virtual void Serialize(BitInserter& i) const override {";
447 s << "SerializeHeader(i);";
448 if (fields_.HasPayload()) {
449 s << "payload_->Serialize(i);";
450 }
451 s << "SerializeFooter(i);";
452
453 s << "}\n";
454 }
455
GenInstanceOf(std::ostream & s) const456 void ParentDef::GenInstanceOf(std::ostream& s) const {
457 if (parent_ != nullptr && parent_constraints_.size() > 0) {
458 s << "static bool IsInstance(const " << parent_->name_ << "& parent) {";
459 // Get the list of parent params.
460 FieldList parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
461 PayloadField::kFieldType,
462 BodyField::kFieldType,
463 });
464
465 // Check if constrained parent fields are set to their correct values.
466 for (int i = 0; i < parent_params.size(); i++) {
467 const auto& field = parent_params[i];
468 const auto& constraint = parent_constraints_.find(field->GetName());
469 if (constraint != parent_constraints_.end()) {
470 s << "if (parent." << field->GetName() << "_ != ";
471 if (field->GetFieldType() == ScalarField::kFieldType) {
472 s << std::get<int64_t>(constraint->second) << ")";
473 s << "{ return false;}";
474 } else if (field->GetFieldType() == EnumField::kFieldType) {
475 s << std::get<std::string>(constraint->second) << ")";
476 s << "{ return false;}";
477 } else {
478 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
479 }
480 }
481 }
482 s << "return true;}";
483 }
484 }
485