1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.util.proto;
18 
19 import java.io.IOException;
20 import java.io.InputStream;
21 import java.nio.charset.StandardCharsets;
22 import java.util.ArrayList;
23 
24 /**
25  * Class to read to a protobuf stream.
26  *
27  * Each read method takes an ID code from the protoc generated classes
28  * and return a value of the field. To read a nested object, call #start
29  * and then #end when you are done.
30  *
31  * The ID codes have type information embedded into them, so if you call
32  * the incorrect function you will get an IllegalArgumentException.
33  *
34  * nextField will return the field number of the next field, which can be
35  * matched to the protoc generated ID code and used to determine how to
36  * read the next field.
37  *
38  * It is STRONGLY RECOMMENDED to read from the ProtoInputStream with a switch
39  * statement wrapped in a while loop. Additionally, it is worth logging or
40  * storing unexpected fields or ones that do not match the expected wire type
41  *
42  * ex:
43  * void parseFromProto(ProtoInputStream stream) {
44  *     while(stream.nextField() != ProtoInputStream.NO_MORE_FIELDS) {
45  *         try {
46  *             switch (stream.getFieldNumber()) {
47  *                 case (int) DummyProto.NAME:
48  *                     mName = stream.readString(DummyProto.NAME);
49  *                     break;
50  *                 case (int) DummyProto.VALUE:
51  *                     mValue = stream.readInt(DummyProto.VALUE);
52  *                     break;
53  *                 default:
54  *                     LOG(TAG, "Unhandled field in proto!\n"
55  *                              + ProtoUtils.currentFieldToString(stream));
56  *             }
57  *         } catch (WireTypeMismatchException wtme) {
58  *             LOG(TAG, "Wire Type mismatch in proto!\n" + ProtoUtils.currentFieldToString(stream));
59  *         }
60  *     }
61  * }
62  *
63  * @hide
64  */
65 public final class ProtoInputStream extends ProtoStream {
66 
67     public static final int NO_MORE_FIELDS = -1;
68 
69     /**
70      * Our stream.  If there is one.
71      */
72     private InputStream mStream;
73 
74     /**
75      * The field number of the current field. Will be equal to NO_MORE_FIELDS if end of message is
76      * reached
77      */
78     private int mFieldNumber;
79 
80     /**
81      * The wire type of the current field
82      */
83     private int mWireType;
84 
85     private static final byte STATE_STARTED_FIELD_READ = 1 << 0;
86     private static final byte STATE_READING_PACKED = 1 << 1;
87     private static final byte STATE_FIELD_MISS = 2 << 1;
88 
89     /**
90      * Tracks some boolean states for the proto input stream
91      * bit 0: Started Field Read, true - tag has been read, ready to read field data.
92      * false - field data has been read, reading to start next field.
93      * bit 1: Reading Packed Field, true - currently reading values from a packed field
94      * false - not reading from packed field.
95      */
96     private byte mState = 0;
97 
98     /**
99      * Keeps track of the currently read nested Objects, for end object checking and debug
100      */
101     private ArrayList<Long> mExpectedObjectTokenStack = null;
102 
103     /**
104      * Current nesting depth of start calls.
105      */
106     private int mDepth = -1;
107 
108     /**
109      * Buffer for the to be read data. If mStream is not null, it will be constantly refilled from
110      * the stream.
111      */
112     private byte[] mBuffer;
113 
114     private static final int DEFAULT_BUFFER_SIZE = 8192;
115 
116     /**
117      * Size of the buffer if reading from a stream.
118      */
119     private final int mBufferSize;
120 
121     /**
122      * The number of bytes that have been skipped or dropped from the buffer.
123      */
124     private int mDiscardedBytes = 0;
125 
126     /**
127      * Current offset in the buffer
128      * mOffset + mDiscardedBytes = current offset in proto binary
129      */
130     private int mOffset = 0;
131 
132     /**
133      * Note the offset of the last byte in the buffer. Usually will equal the size of the buffer.
134      * mEnd + mDiscardedBytes = the last known byte offset + 1
135      */
136     private int mEnd = 0;
137 
138     /**
139      * Packed repeated fields are not read in one go. mPackedEnd keeps track of where the packed
140      * field ends in the proto binary if current field is packed.
141      */
142     private int mPackedEnd = 0;
143 
144     /**
145      * Construct a ProtoInputStream on top of an InputStream to read a proto. Also specify the
146      * number of bytes the ProtoInputStream will buffer from the input stream
147      *
148      * @param stream from which the proto is read
149      */
ProtoInputStream(InputStream stream, int bufferSize)150     public ProtoInputStream(InputStream stream, int bufferSize) {
151         mStream = stream;
152         if (bufferSize > 0) {
153             mBufferSize = bufferSize;
154         } else {
155             mBufferSize = DEFAULT_BUFFER_SIZE;
156         }
157         mBuffer = new byte[mBufferSize];
158     }
159 
160     /**
161      * Construct a ProtoInputStream on top of an InputStream to read a proto
162      *
163      * @param stream from which the proto is read
164      */
ProtoInputStream(InputStream stream)165     public ProtoInputStream(InputStream stream) {
166         this(stream, DEFAULT_BUFFER_SIZE);
167     }
168 
169     /**
170      * Construct a ProtoInputStream to read a proto directly from a byte array
171      *
172      * @param buffer - the byte array to be parsed
173      */
ProtoInputStream(byte[] buffer)174     public ProtoInputStream(byte[] buffer) {
175         mBufferSize = buffer.length;
176         mEnd = buffer.length;
177         mBuffer = buffer;
178         mStream = null;
179     }
180 
181     /**
182      * Get the field number of the current field.
183      */
getFieldNumber()184     public int getFieldNumber() {
185         return mFieldNumber;
186     }
187 
188     /**
189      * Get the wire type of the current field.
190      *
191      * @return an int that matches one of the ProtoStream WIRE_TYPE_ constants
192      */
getWireType()193     public int getWireType() {
194         if ((mState & STATE_READING_PACKED) == STATE_READING_PACKED) {
195             // mWireType got overwritten when STATE_READING_PACKED was set. Send length delimited
196             // constant instead
197             return WIRE_TYPE_LENGTH_DELIMITED;
198         }
199         return mWireType;
200     }
201 
202     /**
203      * Get the current offset in the proto binary.
204      */
getOffset()205     public int getOffset() {
206         return mOffset + mDiscardedBytes;
207     }
208 
209     /**
210      * Reads the tag of the next field from the stream. If previous field value was not read, its
211      * data will be skipped over.
212      *
213      * @return the field number of the next field
214      * @throws IOException if an I/O error occurs
215      */
nextField()216     public int nextField() throws IOException {
217 
218         if ((mState & STATE_FIELD_MISS) == STATE_FIELD_MISS) {
219             // Data from the last nextField was not used, reuse the info
220             mState &= ~STATE_FIELD_MISS;
221             return mFieldNumber;
222         }
223         if ((mState & STATE_STARTED_FIELD_READ) == STATE_STARTED_FIELD_READ) {
224             // Field data was not read, skip to the next field
225             skip();
226             mState &= ~STATE_STARTED_FIELD_READ;
227         }
228         if ((mState & STATE_READING_PACKED) == STATE_READING_PACKED) {
229             if (getOffset() < mPackedEnd) {
230                 // In the middle of a packed field, return the same tag until last packed value
231                 // has been read
232                 mState |= STATE_STARTED_FIELD_READ;
233                 return mFieldNumber;
234             } else if (getOffset() == mPackedEnd) {
235                 // Reached the end of the packed field
236                 mState &= ~STATE_READING_PACKED;
237             } else {
238                 throw new ProtoParseException(
239                         "Unexpectedly reached end of packed field at offset 0x"
240                                 + Integer.toHexString(mPackedEnd)
241                                 + dumpDebugData());
242             }
243         }
244 
245         if ((mDepth >= 0) && (getOffset() == getOffsetFromToken(
246                 mExpectedObjectTokenStack.get(mDepth)))) {
247             // reached end of a embedded message
248             mFieldNumber = NO_MORE_FIELDS;
249         } else {
250             readTag();
251         }
252         return mFieldNumber;
253     }
254 
255     /**
256      * Attempt to guess the next field. If there is a match, the field data will be ready to read.
257      * If there is no match, nextField will need to be called to get the field number
258      *
259      * @return true if fieldId matches the next field, false if not
260      */
isNextField(long fieldId)261     public boolean isNextField(long fieldId) throws IOException {
262         if (nextField() == (int) fieldId) {
263             return true;
264         }
265         // Note to reuse the info from the nextField call in the next call.
266         mState |= STATE_FIELD_MISS;
267         return false;
268     }
269 
270     /**
271      * Read a single double.
272      * Will throw if the current wire type is not fixed64
273      *
274      * @param fieldId - must match the current field number and field type
275      */
readDouble(long fieldId)276     public double readDouble(long fieldId) throws IOException {
277         assertFreshData();
278         assertFieldNumber(fieldId);
279         checkPacked(fieldId);
280 
281         double value;
282         switch ((int) ((fieldId & FIELD_TYPE_MASK)
283                 >>> FIELD_TYPE_SHIFT)) {
284             case (int) (FIELD_TYPE_DOUBLE >>> FIELD_TYPE_SHIFT):
285                 assertWireType(WIRE_TYPE_FIXED64);
286                 value = Double.longBitsToDouble(readFixed64());
287                 break;
288             default:
289                 throw new IllegalArgumentException(
290                         "Requested field id (" + getFieldIdString(fieldId)
291                                 + ") cannot be read as a double"
292                                 + dumpDebugData());
293         }
294         // Successfully read the field
295         mState &= ~STATE_STARTED_FIELD_READ;
296         return value;
297     }
298 
299     /**
300      * Read a single float.
301      * Will throw if the current wire type is not fixed32
302      *
303      * @param fieldId - must match the current field number and field type
304      */
readFloat(long fieldId)305     public float readFloat(long fieldId) throws IOException {
306         assertFreshData();
307         assertFieldNumber(fieldId);
308         checkPacked(fieldId);
309 
310         float value;
311         switch ((int) ((fieldId & FIELD_TYPE_MASK)
312                 >>> FIELD_TYPE_SHIFT)) {
313             case (int) (FIELD_TYPE_FLOAT >>> FIELD_TYPE_SHIFT):
314                 assertWireType(WIRE_TYPE_FIXED32);
315                 value = Float.intBitsToFloat(readFixed32());
316                 break;
317             default:
318                 throw new IllegalArgumentException(
319                         "Requested field id (" + getFieldIdString(fieldId) + ") is not a float"
320                                 + dumpDebugData());
321         }
322         // Successfully read the field
323         mState &= ~STATE_STARTED_FIELD_READ;
324         return value;
325     }
326 
327     /**
328      * Read a single 32bit or varint proto type field as an int.
329      * Will throw if the current wire type is not varint or fixed32
330      *
331      * @param fieldId - must match the current field number and field type
332      */
readInt(long fieldId)333     public int readInt(long fieldId) throws IOException {
334         assertFreshData();
335         assertFieldNumber(fieldId);
336         checkPacked(fieldId);
337 
338         int value;
339         switch ((int) ((fieldId & FIELD_TYPE_MASK)
340                 >>> FIELD_TYPE_SHIFT)) {
341             case (int) (FIELD_TYPE_FIXED32 >>> FIELD_TYPE_SHIFT):
342             case (int) (FIELD_TYPE_SFIXED32 >>> FIELD_TYPE_SHIFT):
343                 assertWireType(WIRE_TYPE_FIXED32);
344                 value = readFixed32();
345                 break;
346             case (int) (FIELD_TYPE_SINT32 >>> FIELD_TYPE_SHIFT):
347                 assertWireType(WIRE_TYPE_VARINT);
348                 value = decodeZigZag32((int) readVarint());
349                 break;
350             case (int) (FIELD_TYPE_INT32 >>> FIELD_TYPE_SHIFT):
351             case (int) (FIELD_TYPE_UINT32 >>> FIELD_TYPE_SHIFT):
352             case (int) (FIELD_TYPE_ENUM >>> FIELD_TYPE_SHIFT):
353                 assertWireType(WIRE_TYPE_VARINT);
354                 value = (int) readVarint();
355                 break;
356             default:
357                 throw new IllegalArgumentException(
358                         "Requested field id (" + getFieldIdString(fieldId) + ") is not an int"
359                                 + dumpDebugData());
360         }
361         // Successfully read the field
362         mState &= ~STATE_STARTED_FIELD_READ;
363         return value;
364     }
365 
366     /**
367      * Read a single 64bit or varint proto type field as an long.
368      *
369      * @param fieldId - must match the current field number
370      */
readLong(long fieldId)371     public long readLong(long fieldId) throws IOException {
372         assertFreshData();
373         assertFieldNumber(fieldId);
374         checkPacked(fieldId);
375 
376         long value;
377         switch ((int) ((fieldId & FIELD_TYPE_MASK)
378                 >>> FIELD_TYPE_SHIFT)) {
379             case (int) (FIELD_TYPE_FIXED64 >>> FIELD_TYPE_SHIFT):
380             case (int) (FIELD_TYPE_SFIXED64 >>> FIELD_TYPE_SHIFT):
381                 assertWireType(WIRE_TYPE_FIXED64);
382                 value = readFixed64();
383                 break;
384             case (int) (FIELD_TYPE_SINT64 >>> FIELD_TYPE_SHIFT):
385                 assertWireType(WIRE_TYPE_VARINT);
386                 value = decodeZigZag64(readVarint());
387                 break;
388             case (int) (FIELD_TYPE_INT64 >>> FIELD_TYPE_SHIFT):
389             case (int) (FIELD_TYPE_UINT64 >>> FIELD_TYPE_SHIFT):
390                 assertWireType(WIRE_TYPE_VARINT);
391                 value = readVarint();
392                 break;
393             default:
394                 throw new IllegalArgumentException(
395                         "Requested field id (" + getFieldIdString(fieldId) + ") is not an long"
396                                 + dumpDebugData());
397         }
398         // Successfully read the field
399         mState &= ~STATE_STARTED_FIELD_READ;
400         return value;
401     }
402 
403     /**
404      * Read a single 32bit or varint proto type field as an boolean.
405      *
406      * @param fieldId - must match the current field number
407      */
readBoolean(long fieldId)408     public boolean readBoolean(long fieldId) throws IOException {
409         assertFreshData();
410         assertFieldNumber(fieldId);
411         checkPacked(fieldId);
412 
413         boolean value;
414         switch ((int) ((fieldId & FIELD_TYPE_MASK)
415                 >>> FIELD_TYPE_SHIFT)) {
416             case (int) (FIELD_TYPE_BOOL >>> FIELD_TYPE_SHIFT):
417                 assertWireType(WIRE_TYPE_VARINT);
418                 value = readVarint() != 0;
419                 break;
420             default:
421                 throw new IllegalArgumentException(
422                         "Requested field id (" + getFieldIdString(fieldId) + ") is not an boolean"
423                                 + dumpDebugData());
424         }
425         // Successfully read the field
426         mState &= ~STATE_STARTED_FIELD_READ;
427         return value;
428     }
429 
430     /**
431      * Read a string field
432      *
433      * @param fieldId - must match the current field number
434      */
readString(long fieldId)435     public String readString(long fieldId) throws IOException {
436         assertFreshData();
437         assertFieldNumber(fieldId);
438 
439         String value;
440         switch ((int) ((fieldId & FIELD_TYPE_MASK) >>> FIELD_TYPE_SHIFT)) {
441             case (int) (FIELD_TYPE_STRING >>> FIELD_TYPE_SHIFT):
442                 assertWireType(WIRE_TYPE_LENGTH_DELIMITED);
443                 int len = (int) readVarint();
444                 value = readRawString(len);
445                 break;
446             default:
447                 throw new IllegalArgumentException(
448                         "Requested field id(" + getFieldIdString(fieldId)
449                                 + ") is not an string"
450                                 + dumpDebugData());
451         }
452         // Successfully read the field
453         mState &= ~STATE_STARTED_FIELD_READ;
454         return value;
455     }
456 
457     /**
458      * Read a bytes field
459      *
460      * @param fieldId - must match the current field number
461      */
readBytes(long fieldId)462     public byte[] readBytes(long fieldId) throws IOException {
463         assertFreshData();
464         assertFieldNumber(fieldId);
465 
466         byte[] value;
467         switch ((int) ((fieldId & FIELD_TYPE_MASK) >>> FIELD_TYPE_SHIFT)) {
468             case (int) (FIELD_TYPE_MESSAGE >>> FIELD_TYPE_SHIFT):
469             case (int) (FIELD_TYPE_BYTES >>> FIELD_TYPE_SHIFT):
470                 assertWireType(WIRE_TYPE_LENGTH_DELIMITED);
471                 int len = (int) readVarint();
472                 value = readRawBytes(len);
473                 break;
474             default:
475                 throw new IllegalArgumentException(
476                         "Requested field type (" + getFieldIdString(fieldId)
477                                 + ") cannot be read as raw bytes"
478                                 + dumpDebugData());
479         }
480         // Successfully read the field
481         mState &= ~STATE_STARTED_FIELD_READ;
482         return value;
483     }
484 
485     /**
486      * Start the read of an embedded Object
487      *
488      * @param fieldId - must match the current field number
489      * @return a token. The token must be handed back when finished reading embedded Object
490      */
start(long fieldId)491     public long start(long fieldId) throws IOException {
492         assertFreshData();
493         assertFieldNumber(fieldId);
494         assertWireType(WIRE_TYPE_LENGTH_DELIMITED);
495 
496         int messageSize = (int) readVarint();
497 
498         if (mExpectedObjectTokenStack == null) {
499             mExpectedObjectTokenStack = new ArrayList<>();
500         }
501         if (++mDepth == mExpectedObjectTokenStack.size()) {
502             // Create a token to keep track of nested Object and extend the object stack
503             mExpectedObjectTokenStack.add(makeToken(0,
504                     (fieldId & FIELD_COUNT_REPEATED) == FIELD_COUNT_REPEATED, mDepth,
505                     (int) fieldId, getOffset() + messageSize));
506 
507         } else {
508             // Create a token to keep track of nested Object
509             mExpectedObjectTokenStack.set(mDepth, makeToken(0,
510                     (fieldId & FIELD_COUNT_REPEATED) == FIELD_COUNT_REPEATED, mDepth,
511                     (int) fieldId, getOffset() + messageSize));
512         }
513 
514         // Validation check
515         if (mDepth > 0
516                 && getOffsetFromToken(mExpectedObjectTokenStack.get(mDepth))
517                 > getOffsetFromToken(mExpectedObjectTokenStack.get(mDepth - 1))) {
518             throw new ProtoParseException("Embedded Object ("
519                     + token2String(mExpectedObjectTokenStack.get(mDepth))
520                     + ") ends after of parent Objects's ("
521                     + token2String(mExpectedObjectTokenStack.get(mDepth - 1))
522                     + ") end"
523                     + dumpDebugData());
524         }
525         mState &= ~STATE_STARTED_FIELD_READ;
526         return mExpectedObjectTokenStack.get(mDepth);
527     }
528 
529     /**
530      * Note the end of a nested object. Must be called to continue streaming the rest of the proto.
531      * end can be called mid object parse. The offset will be moved to the next field outside the
532      * object.
533      *
534      * @param token - token
535      */
end(long token)536     public void end(long token) {
537         // Make sure user is keeping track of their embedded messages
538         if (mExpectedObjectTokenStack.get(mDepth) != token) {
539             throw new ProtoParseException(
540                     "end token " + token + " does not match current message token "
541                             + mExpectedObjectTokenStack.get(mDepth)
542                             + dumpDebugData());
543         }
544         if (getOffsetFromToken(mExpectedObjectTokenStack.get(mDepth)) > getOffset()) {
545             // Did not read all of the message, skip to the end
546             incOffset(getOffsetFromToken(mExpectedObjectTokenStack.get(mDepth)) - getOffset());
547         }
548         mDepth--;
549         mState &= ~STATE_STARTED_FIELD_READ;
550     }
551 
552     /**
553      * Read the tag at the start of the next field and collect field number and wire type.
554      * Will set mFieldNumber to NO_MORE_FIELDS if end of buffer/stream reached.
555      */
readTag()556     private void readTag() throws IOException {
557         fillBuffer();
558         if (mOffset >= mEnd) {
559             // reached end of the stream
560             mFieldNumber = NO_MORE_FIELDS;
561             return;
562         }
563         int tag = (int) readVarint();
564         mFieldNumber = tag >>> FIELD_ID_SHIFT;
565         mWireType = tag & WIRE_TYPE_MASK;
566         mState |= STATE_STARTED_FIELD_READ;
567     }
568 
569     /**
570      * Decode a 32 bit ZigZag encoded signed int.
571      *
572      * @param n - int to decode
573      * @return the decoded signed int
574      */
decodeZigZag32(final int n)575     public int decodeZigZag32(final int n) {
576         return (n >>> 1) ^ -(n & 1);
577     }
578 
579     /**
580      * Decode a 64 bit ZigZag encoded signed long.
581      *
582      * @param n - long to decode
583      * @return the decoded signed long
584      */
decodeZigZag64(final long n)585     public long decodeZigZag64(final long n) {
586         return (n >>> 1) ^ -(n & 1);
587     }
588 
589     /**
590      * Read a varint from the buffer
591      *
592      * @return the varint as a long
593      */
readVarint()594     private long readVarint() throws IOException {
595         long value = 0;
596         int shift = 0;
597         while (true) {
598             fillBuffer();
599             // Limit how much bookkeeping is done by checking how far away the end of the buffer is
600             // and directly accessing buffer up until the end.
601             final int fragment = mEnd - mOffset;
602             for (int i = 0; i < fragment; i++) {
603                 byte b = mBuffer[(mOffset + i)];
604                 value |= (b & 0x7FL) << shift;
605                 if ((b & 0x80) == 0) {
606                     incOffset(i + 1);
607                     return value;
608                 }
609                 shift += 7;
610                 if (shift > 63) {
611                     throw new ProtoParseException(
612                             "Varint is too large at offset 0x"
613                                     + Integer.toHexString(getOffset() + i)
614                                     + dumpDebugData());
615                 }
616             }
617             // Hit the end of the buffer, do some incrementing and checking, then continue
618             incOffset(fragment);
619         }
620     }
621 
622     /**
623      * Read a fixed 32 bit int from the buffer
624      *
625      * @return the fixed32 as a int
626      */
readFixed32()627     private int readFixed32() throws IOException {
628         // check for fast path, which is likely with a reasonable buffer size
629         if (mOffset + 4 <= mEnd) {
630             // don't bother filling buffer since we know the end is plenty far away
631             incOffset(4);
632             return (mBuffer[mOffset - 4] & 0xFF)
633                     | ((mBuffer[mOffset - 3] & 0xFF) << 8)
634                     | ((mBuffer[mOffset - 2] & 0xFF) << 16)
635                     | ((mBuffer[mOffset - 1] & 0xFF) << 24);
636         }
637 
638         // the Fixed32 crosses the edge of a chunk, read the Fixed32 in multiple fragments.
639         // There will be two fragment reads except when the chunk size is 2 or less.
640         int value = 0;
641         int shift = 0;
642         int bytesLeft = 4;
643         while (bytesLeft > 0) {
644             fillBuffer();
645             // Find the number of bytes available until the end of the chunk or Fixed32
646             int fragment = (mEnd - mOffset) < bytesLeft ? (mEnd - mOffset) : bytesLeft;
647             incOffset(fragment);
648             bytesLeft -= fragment;
649             while (fragment > 0) {
650                 value |= ((mBuffer[mOffset - fragment] & 0xFF) << shift);
651                 fragment--;
652                 shift += 8;
653             }
654         }
655         return value;
656     }
657 
658     /**
659      * Read a fixed 64 bit long from the buffer
660      *
661      * @return the fixed64 as a long
662      */
readFixed64()663     private long readFixed64() throws IOException {
664         // check for fast path, which is likely with a reasonable buffer size
665         if (mOffset + 8 <= mEnd) {
666             // don't bother filling buffer since we know the end is plenty far away
667             incOffset(8);
668             return (mBuffer[mOffset - 8] & 0xFFL)
669                     | ((mBuffer[mOffset - 7] & 0xFFL) << 8)
670                     | ((mBuffer[mOffset - 6] & 0xFFL) << 16)
671                     | ((mBuffer[mOffset - 5] & 0xFFL) << 24)
672                     | ((mBuffer[mOffset - 4] & 0xFFL) << 32)
673                     | ((mBuffer[mOffset - 3] & 0xFFL) << 40)
674                     | ((mBuffer[mOffset - 2] & 0xFFL) << 48)
675                     | ((mBuffer[mOffset - 1] & 0xFFL) << 56);
676         }
677 
678         // the Fixed64 crosses the edge of a chunk, read the Fixed64 in multiple fragments.
679         // There will be two fragment reads except when the chunk size is 6 or less.
680         long value = 0;
681         int shift = 0;
682         int bytesLeft = 8;
683         while (bytesLeft > 0) {
684             fillBuffer();
685             // Find the number of bytes available until the end of the chunk or Fixed64
686             int fragment = (mEnd - mOffset) < bytesLeft ? (mEnd - mOffset) : bytesLeft;
687             incOffset(fragment);
688             bytesLeft -= fragment;
689             while (fragment > 0) {
690                 value |= ((mBuffer[(mOffset - fragment)] & 0xFFL) << shift);
691                 fragment--;
692                 shift += 8;
693             }
694         }
695         return value;
696     }
697 
698     /**
699      * Read raw bytes from the buffer
700      *
701      * @param n - number of bytes to read
702      * @return a byte array with raw bytes
703      */
readRawBytes(int n)704     private byte[] readRawBytes(int n) throws IOException {
705         byte[] buffer = new byte[n];
706         int pos = 0;
707         while (mOffset + n - pos > mEnd) {
708             int fragment = mEnd - mOffset;
709             if (fragment > 0) {
710                 System.arraycopy(mBuffer, mOffset, buffer, pos, fragment);
711                 incOffset(fragment);
712                 pos += fragment;
713             }
714             fillBuffer();
715             if (mOffset >= mEnd) {
716                 throw new ProtoParseException(
717                         "Unexpectedly reached end of the InputStream at offset 0x"
718                                 + Integer.toHexString(mEnd)
719                                 + dumpDebugData());
720             }
721         }
722         System.arraycopy(mBuffer, mOffset, buffer, pos, n - pos);
723         incOffset(n - pos);
724         return buffer;
725     }
726 
727     /**
728      * Read raw string from the buffer
729      *
730      * @param n - number of bytes to read
731      * @return a string
732      */
readRawString(int n)733     private String readRawString(int n) throws IOException {
734         fillBuffer();
735         if (mOffset + n <= mEnd) {
736             // fast path read. String is well within the current buffer
737             String value = new String(mBuffer, mOffset, n, StandardCharsets.UTF_8);
738             incOffset(n);
739             return value;
740         } else if (n <= mBufferSize) {
741             // String extends past buffer, but can be encapsulated in a buffer. Copy the first chunk
742             // of the string to the start of the buffer and then fill the rest of the buffer from
743             // the stream.
744             final int stringHead = mEnd - mOffset;
745             System.arraycopy(mBuffer, mOffset, mBuffer, 0, stringHead);
746             mEnd = stringHead + mStream.read(mBuffer, stringHead, n - stringHead);
747 
748             mDiscardedBytes += mOffset;
749             mOffset = 0;
750 
751             String value = new String(mBuffer, mOffset, n, StandardCharsets.UTF_8);
752             incOffset(n);
753             return value;
754         }
755         // Otherwise, the string is too large to use the buffer. Create the string from a
756         // separate byte array.
757         return new String(readRawBytes(n), 0, n, StandardCharsets.UTF_8);
758     }
759 
760     /**
761      * Fill the buffer with a chunk from the stream if need be.
762      * Will skip chunks until mOffset is reached
763      */
fillBuffer()764     private void fillBuffer() throws IOException {
765         if (mOffset >= mEnd && mStream != null) {
766             mOffset -= mEnd;
767             mDiscardedBytes += mEnd;
768             if (mOffset >= mBufferSize) {
769                 int skipped = (int) mStream.skip((mOffset / mBufferSize) * mBufferSize);
770                 mDiscardedBytes += skipped;
771                 mOffset -= skipped;
772             }
773             mEnd = mStream.read(mBuffer);
774         }
775     }
776 
777     /**
778      * Skips the rest of current field and moves to the start of the next field. This should only be
779      * called while state is STATE_STARTED_FIELD_READ
780      */
skip()781     public void skip() throws IOException {
782         if ((mState & STATE_READING_PACKED) == STATE_READING_PACKED) {
783             incOffset(mPackedEnd - getOffset());
784         } else {
785             switch (mWireType) {
786                 case WIRE_TYPE_VARINT:
787                     byte b;
788                     do {
789                         fillBuffer();
790                         b = mBuffer[mOffset];
791                         incOffset(1);
792                     } while ((b & 0x80) != 0);
793                     break;
794                 case WIRE_TYPE_FIXED64:
795                     incOffset(8);
796                     break;
797                 case WIRE_TYPE_LENGTH_DELIMITED:
798                     fillBuffer();
799                     int length = (int) readVarint();
800                     incOffset(length);
801                     break;
802                 /*
803             case WIRE_TYPE_START_GROUP:
804                 // Not implemented
805                 break;
806             case WIRE_TYPE_END_GROUP:
807                 // Not implemented
808                 break;
809                 */
810                 case WIRE_TYPE_FIXED32:
811                     incOffset(4);
812                     break;
813                 default:
814                     throw new ProtoParseException(
815                             "Unexpected wire type: " + mWireType + " at offset 0x"
816                                     + Integer.toHexString(mOffset)
817                                     + dumpDebugData());
818             }
819         }
820         mState &= ~STATE_STARTED_FIELD_READ;
821     }
822 
823     /**
824      * Increment the offset and handle all the relevant bookkeeping
825      * Refilling the buffer when its end is reached will be handled elsewhere (ideally just before
826      * a read, to avoid unnecessary reads from stream)
827      *
828      * @param n - number of bytes to increment
829      */
incOffset(int n)830     private void incOffset(int n) {
831         mOffset += n;
832 
833         if (mDepth >= 0 && getOffset() > getOffsetFromToken(
834                 mExpectedObjectTokenStack.get(mDepth))) {
835             throw new ProtoParseException("Unexpectedly reached end of embedded object.  "
836                     + token2String(mExpectedObjectTokenStack.get(mDepth))
837                     + dumpDebugData());
838         }
839     }
840 
841     /**
842      * Check the current wire type to determine if current numeric field is packed. If it is packed,
843      * set up to deal with the field
844      * This should only be called for primitive numeric field types.
845      *
846      * @param fieldId - used to determine what the packed wire type is.
847      */
checkPacked(long fieldId)848     private void checkPacked(long fieldId) throws IOException {
849         if (mWireType == WIRE_TYPE_LENGTH_DELIMITED) {
850             // Primitive Field is length delimited, must be a packed field.
851             final int length = (int) readVarint();
852             mPackedEnd = getOffset() + length;
853             mState |= STATE_READING_PACKED;
854 
855             // Fake the wire type, based on the field type
856             switch ((int) ((fieldId & FIELD_TYPE_MASK)
857                     >>> FIELD_TYPE_SHIFT)) {
858                 case (int) (FIELD_TYPE_FLOAT >>> FIELD_TYPE_SHIFT):
859                 case (int) (FIELD_TYPE_FIXED32 >>> FIELD_TYPE_SHIFT):
860                 case (int) (FIELD_TYPE_SFIXED32 >>> FIELD_TYPE_SHIFT):
861                     if (length % 4 != 0) {
862                         throw new IllegalArgumentException(
863                                 "Requested field id (" + getFieldIdString(fieldId)
864                                         + ") packed length " + length
865                                         + " is not aligned for fixed32"
866                                         + dumpDebugData());
867                     }
868                     mWireType = WIRE_TYPE_FIXED32;
869                     break;
870                 case (int) (FIELD_TYPE_DOUBLE >>> FIELD_TYPE_SHIFT):
871                 case (int) (FIELD_TYPE_FIXED64 >>> FIELD_TYPE_SHIFT):
872                 case (int) (FIELD_TYPE_SFIXED64 >>> FIELD_TYPE_SHIFT):
873                     if (length % 8 != 0) {
874                         throw new IllegalArgumentException(
875                                 "Requested field id (" + getFieldIdString(fieldId)
876                                         + ") packed length " + length
877                                         + " is not aligned for fixed64"
878                                         + dumpDebugData());
879                     }
880                     mWireType = WIRE_TYPE_FIXED64;
881                     break;
882                 case (int) (FIELD_TYPE_SINT32 >>> FIELD_TYPE_SHIFT):
883                 case (int) (FIELD_TYPE_INT32 >>> FIELD_TYPE_SHIFT):
884                 case (int) (FIELD_TYPE_UINT32 >>> FIELD_TYPE_SHIFT):
885                 case (int) (FIELD_TYPE_SINT64 >>> FIELD_TYPE_SHIFT):
886                 case (int) (FIELD_TYPE_INT64 >>> FIELD_TYPE_SHIFT):
887                 case (int) (FIELD_TYPE_UINT64 >>> FIELD_TYPE_SHIFT):
888                 case (int) (FIELD_TYPE_ENUM >>> FIELD_TYPE_SHIFT):
889                 case (int) (FIELD_TYPE_BOOL >>> FIELD_TYPE_SHIFT):
890                     mWireType = WIRE_TYPE_VARINT;
891                     break;
892                 default:
893                     throw new IllegalArgumentException(
894                             "Requested field id (" + getFieldIdString(fieldId)
895                                     + ") is not a packable field"
896                                     + dumpDebugData());
897             }
898         }
899     }
900 
901 
902     /**
903      * Check a field id constant against current field number
904      *
905      * @param fieldId - throws if fieldId does not match mFieldNumber
906      */
assertFieldNumber(long fieldId)907     private void assertFieldNumber(long fieldId) {
908         if ((int) fieldId != mFieldNumber) {
909             throw new IllegalArgumentException("Requested field id (" + getFieldIdString(fieldId)
910                     + ") does not match current field number (0x" + Integer.toHexString(
911                     mFieldNumber)
912                     + ") at offset 0x" + Integer.toHexString(getOffset())
913                     + dumpDebugData());
914         }
915     }
916 
917 
918     /**
919      * Check a wire type against current wire type.
920      *
921      * @param wireType - throws if wireType does not match mWireType.
922      */
assertWireType(int wireType)923     private void assertWireType(int wireType) {
924         if (wireType != mWireType) {
925             throw new WireTypeMismatchException(
926                     "Current wire type " + getWireTypeString(mWireType)
927                             + " does not match expected wire type " + getWireTypeString(wireType)
928                             + " at offset 0x" + Integer.toHexString(getOffset())
929                             + dumpDebugData());
930         }
931     }
932 
933     /**
934      * Check if there is data ready to be read.
935      */
assertFreshData()936     private void assertFreshData() {
937         if ((mState & STATE_STARTED_FIELD_READ) != STATE_STARTED_FIELD_READ) {
938             throw new ProtoParseException(
939                     "Attempting to read already read field at offset 0x" + Integer.toHexString(
940                             getOffset()) + dumpDebugData());
941         }
942     }
943 
944     /**
945      * Dump debugging data about the buffer.
946      */
dumpDebugData()947     public String dumpDebugData() {
948         StringBuilder sb = new StringBuilder();
949 
950         sb.append("\nmFieldNumber : 0x" + Integer.toHexString(mFieldNumber));
951         sb.append("\nmWireType : 0x" + Integer.toHexString(mWireType));
952         sb.append("\nmState : 0x" + Integer.toHexString(mState));
953         sb.append("\nmDiscardedBytes : 0x" + Integer.toHexString(mDiscardedBytes));
954         sb.append("\nmOffset : 0x" + Integer.toHexString(mOffset));
955         sb.append("\nmExpectedObjectTokenStack : ");
956         if (mExpectedObjectTokenStack == null) {
957             sb.append("null");
958         } else {
959             sb.append(mExpectedObjectTokenStack);
960         }
961         sb.append("\nmDepth : 0x" + Integer.toHexString(mDepth));
962         sb.append("\nmBuffer : ");
963         if (mBuffer == null) {
964             sb.append("null");
965         } else {
966             sb.append(mBuffer);
967         }
968         sb.append("\nmBufferSize : 0x" + Integer.toHexString(mBufferSize));
969         sb.append("\nmEnd : 0x" + Integer.toHexString(mEnd));
970 
971         return sb.toString();
972     }
973 }
974