1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.apksig.internal.asn1.ber;
18 
19 import java.io.ByteArrayOutputStream;
20 import java.io.IOException;
21 import java.io.InputStream;
22 import java.nio.ByteBuffer;
23 
24 /**
25  * {@link BerDataValueReader} which reads from an {@link InputStream} returning BER-encoded data
26  * values. See {@code X.690} for the encoding.
27  */
28 public class InputStreamBerDataValueReader implements BerDataValueReader {
29     private final InputStream mIn;
30 
InputStreamBerDataValueReader(InputStream in)31     public InputStreamBerDataValueReader(InputStream in) {
32         if (in == null) {
33             throw new NullPointerException("in == null");
34         }
35         mIn = in;
36     }
37 
38     @Override
readDataValue()39     public BerDataValue readDataValue() throws BerDataValueFormatException {
40         return readDataValue(mIn);
41     }
42 
43     /**
44      * Returns the next data value or {@code null} if end of input has been reached.
45      *
46      * @throws BerDataValueFormatException if the value being read is malformed.
47      */
48     @SuppressWarnings("resource")
readDataValue(InputStream input)49     private static BerDataValue readDataValue(InputStream input)
50             throws BerDataValueFormatException {
51         RecordingInputStream in = new RecordingInputStream(input);
52 
53         try {
54             int firstIdentifierByte = in.read();
55             if (firstIdentifierByte == -1) {
56                 // End of input
57                 return null;
58             }
59             int tagNumber = readTagNumber(in, firstIdentifierByte);
60 
61             int firstLengthByte = in.read();
62             if (firstLengthByte == -1) {
63                 throw new BerDataValueFormatException("Missing length");
64             }
65 
66             boolean constructed = BerEncoding.isConstructed((byte) firstIdentifierByte);
67             int contentsLength;
68             int contentsOffsetInDataValue;
69             if ((firstLengthByte & 0x80) == 0) {
70                 // short form length
71                 contentsLength = readShortFormLength(firstLengthByte);
72                 contentsOffsetInDataValue = in.getReadByteCount();
73                 skipDefiniteLengthContents(in, contentsLength);
74             } else if ((firstLengthByte & 0xff) != 0x80) {
75                 // long form length
76                 contentsLength = readLongFormLength(in, firstLengthByte);
77                 contentsOffsetInDataValue = in.getReadByteCount();
78                 skipDefiniteLengthContents(in, contentsLength);
79             } else {
80                 // indefinite length
81                 contentsOffsetInDataValue = in.getReadByteCount();
82                 contentsLength =
83                         constructed
84                                 ? skipConstructedIndefiniteLengthContents(in)
85                                 : skipPrimitiveIndefiniteLengthContents(in);
86             }
87 
88             byte[] encoded = in.getReadBytes();
89             ByteBuffer encodedContents =
90                     ByteBuffer.wrap(encoded, contentsOffsetInDataValue, contentsLength);
91             return new BerDataValue(
92                     ByteBuffer.wrap(encoded),
93                     encodedContents,
94                     BerEncoding.getTagClass((byte) firstIdentifierByte),
95                     constructed,
96                     tagNumber);
97         } catch (IOException e) {
98             throw new BerDataValueFormatException("Failed to read data value", e);
99         }
100     }
101 
readTagNumber(InputStream in, int firstIdentifierByte)102     private static int readTagNumber(InputStream in, int firstIdentifierByte)
103             throws IOException, BerDataValueFormatException {
104         int tagNumber = BerEncoding.getTagNumber((byte) firstIdentifierByte);
105         if (tagNumber == 0x1f) {
106             // high-tag-number form
107             return readHighTagNumber(in);
108         } else {
109             // low-tag-number form
110             return tagNumber;
111         }
112     }
113 
readHighTagNumber(InputStream in)114     private static int readHighTagNumber(InputStream in)
115             throws IOException, BerDataValueFormatException {
116         // Base-128 big-endian form, where each byte has the highest bit set, except for the last
117         // byte where the highest bit is not set
118         int b;
119         int result = 0;
120         do {
121             b = in.read();
122             if (b == -1) {
123                 throw new BerDataValueFormatException("Truncated tag number");
124             }
125             if (result > Integer.MAX_VALUE >>> 7) {
126                 throw new BerDataValueFormatException("Tag number too large");
127             }
128             result <<= 7;
129             result |= b & 0x7f;
130         } while ((b & 0x80) != 0);
131         return result;
132     }
133 
readShortFormLength(int firstLengthByte)134     private static int readShortFormLength(int firstLengthByte) {
135         return firstLengthByte & 0x7f;
136     }
137 
readLongFormLength(InputStream in, int firstLengthByte)138     private static int readLongFormLength(InputStream in, int firstLengthByte)
139             throws IOException, BerDataValueFormatException {
140         // The low 7 bits of the first byte represent the number of bytes (following the first
141         // byte) in which the length is in big-endian base-256 form
142         int byteCount = firstLengthByte & 0x7f;
143         if (byteCount > 4) {
144             throw new BerDataValueFormatException("Length too large: " + byteCount + " bytes");
145         }
146         int result = 0;
147         for (int i = 0; i < byteCount; i++) {
148             int b = in.read();
149             if (b == -1) {
150                 throw new BerDataValueFormatException("Truncated length");
151             }
152             if (result > Integer.MAX_VALUE >>> 8) {
153                 throw new BerDataValueFormatException("Length too large");
154             }
155             result <<= 8;
156             result |= b & 0xff;
157         }
158         return result;
159     }
160 
skipDefiniteLengthContents(InputStream in, int len)161     private static void skipDefiniteLengthContents(InputStream in, int len)
162             throws IOException, BerDataValueFormatException {
163         long bytesRead = 0;
164         while (len > 0) {
165             int skipped = (int) in.skip(len);
166             if (skipped <= 0) {
167                 throw new BerDataValueFormatException(
168                         "Truncated definite-length contents: " + bytesRead + " bytes read"
169                                 + ", " + len + " missing");
170             }
171             len -= skipped;
172             bytesRead += skipped;
173         }
174     }
175 
skipPrimitiveIndefiniteLengthContents(InputStream in)176     private static int skipPrimitiveIndefiniteLengthContents(InputStream in)
177             throws IOException, BerDataValueFormatException {
178         // Contents are terminated by 0x00 0x00
179         boolean prevZeroByte = false;
180         int bytesRead = 0;
181         while (true) {
182             int b = in.read();
183             if (b == -1) {
184                 throw new BerDataValueFormatException(
185                         "Truncated indefinite-length contents: " + bytesRead + " bytes read");
186             }
187             bytesRead++;
188             if (bytesRead < 0) {
189                 throw new BerDataValueFormatException("Indefinite-length contents too long");
190             }
191             if (b == 0) {
192                 if (prevZeroByte) {
193                     // End of contents reached -- we've read the value and its terminator 0x00 0x00
194                     return bytesRead - 2;
195                 }
196                 prevZeroByte = true;
197                 continue;
198             } else {
199                 prevZeroByte = false;
200             }
201         }
202     }
203 
skipConstructedIndefiniteLengthContents(RecordingInputStream in)204     private static int skipConstructedIndefiniteLengthContents(RecordingInputStream in)
205             throws BerDataValueFormatException {
206         // Contents are terminated by 0x00 0x00. However, this data value is constructed, meaning it
207         // can contain data values which are indefinite length encoded as well. As a result, we
208         // must parse the direct children of this data value to correctly skip over the contents of
209         // this data value.
210         int readByteCountBefore = in.getReadByteCount();
211         while (true) {
212             // We can't easily peek for the 0x00 0x00 terminator using the provided InputStream.
213             // Thus, we use the fact that 0x00 0x00 parses as a data value whose encoded form we
214             // then check below to see whether it's 0x00 0x00.
215             BerDataValue dataValue = readDataValue(in);
216             if (dataValue == null) {
217                 throw new BerDataValueFormatException(
218                         "Truncated indefinite-length contents: "
219                                 + (in.getReadByteCount() - readByteCountBefore) + " bytes read");
220             }
221             if (in.getReadByteCount() <= 0) {
222                 throw new BerDataValueFormatException("Indefinite-length contents too long");
223             }
224             ByteBuffer encoded = dataValue.getEncoded();
225             if ((encoded.remaining() == 2) && (encoded.get(0) == 0) && (encoded.get(1) == 0)) {
226                 // 0x00 0x00 encountered
227                 return in.getReadByteCount() - readByteCountBefore - 2;
228             }
229         }
230     }
231 
232     private static class RecordingInputStream extends InputStream {
233         private final InputStream mIn;
234         private final ByteArrayOutputStream mBuf;
235 
RecordingInputStream(InputStream in)236         private RecordingInputStream(InputStream in) {
237             mIn = in;
238             mBuf = new ByteArrayOutputStream();
239         }
240 
getReadBytes()241         public byte[] getReadBytes() {
242             return mBuf.toByteArray();
243         }
244 
getReadByteCount()245         public int getReadByteCount() {
246             return mBuf.size();
247         }
248 
249         @Override
read()250         public int read() throws IOException {
251             int b = mIn.read();
252             if (b != -1) {
253                 mBuf.write(b);
254             }
255             return b;
256         }
257 
258         @Override
read(byte[] b)259         public int read(byte[] b) throws IOException {
260             int len = mIn.read(b);
261             if (len > 0) {
262                 mBuf.write(b, 0, len);
263             }
264             return len;
265         }
266 
267         @Override
read(byte[] b, int off, int len)268         public int read(byte[] b, int off, int len) throws IOException {
269             len = mIn.read(b, off, len);
270             if (len > 0) {
271                 mBuf.write(b, off, len);
272             }
273             return len;
274         }
275 
276         @Override
skip(long n)277         public long skip(long n) throws IOException {
278             if (n <= 0) {
279                 return mIn.skip(n);
280             }
281 
282             byte[] buf = new byte[4096];
283             int len = mIn.read(buf, 0, (int) Math.min(buf.length, n));
284             if (len > 0) {
285                 mBuf.write(buf, 0, len);
286             }
287             return (len < 0) ? 0 : len;
288         }
289 
290         @Override
available()291         public int available() throws IOException {
292             return super.available();
293         }
294 
295         @Override
close()296         public void close() throws IOException {
297             super.close();
298         }
299 
300         @Override
mark(int readlimit)301         public synchronized void mark(int readlimit) {}
302 
303         @Override
reset()304         public synchronized void reset() throws IOException {
305             throw new IOException("mark/reset not supported");
306         }
307 
308         @Override
markSupported()309         public boolean markSupported() {
310             return false;
311         }
312     }
313 }
314