1 /*
2  * Copyright 2018 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package trebuchet.extractors
18 
19 import trebuchet.importers.ImportFeedback
20 import trebuchet.io.*
21 import trebuchet.util.indexOf
22 import java.util.zip.DataFormatException
23 import java.util.zip.Inflater
24 import kotlin.sequences.iterator
25 
26 private const val TRACE = "TRACE:"
27 
findStartnull28 private fun findStart(buffer: GenericByteBuffer): Long {
29     var start = buffer.indexOf(TRACE, 100)
30     if (start == -1L) {
31         start = 0L
32     } else {
33         start += TRACE.length
34     }
35     while (start < buffer.length &&
36             (buffer[start] == '\n'.toByte() || buffer[start] == '\r'.toByte())) {
37         start++
38     }
39     return start
40 }
41 
42 private class DeflateProducer(stream: StreamingReader, val feedback: ImportFeedback)
43         : BufferProducer {
44 
45     private val source = stream.source
46     private val inflater = Inflater()
47     private var closed = false
48 
<lambda>null49     private val sourceIterator = iterator {
50         stream.loadIndex(stream.startIndex + 1024)
51         val offset = findStart(stream)
52         val buffIter = stream.iter(offset)
53         var avgCompressFactor = 5.0
54         while (buffIter.hasNext()) {
55             val nextBuffer = buffIter.next()
56             inflater.setInput(nextBuffer.buffer, nextBuffer.startIndex, nextBuffer.length)
57             do {
58                 val remaining = inflater.remaining
59                 val estSize = (remaining * avgCompressFactor * 1.2).toInt()
60                 val array = ByteArray(estSize)
61                 val len = inflater.inflate(array)
62                 if (inflater.needsDictionary()) {
63                     feedback.reportImportException(IllegalStateException(
64                             "inflater needs dictionary, which isn't supported"))
65                     return@iterator
66                 }
67                 val compressFactor = len.toDouble() / (remaining - inflater.remaining)
68                 avgCompressFactor = (avgCompressFactor * 9 + compressFactor) / 10
69                 yield(array.asSlice(len))
70                 if (closed) return@iterator
71             } while (!inflater.needsInput())
72             inflater.end()
73         }
74     }
75 
nextnull76     override fun next(): DataSlice? {
77         return if (sourceIterator.hasNext()) sourceIterator.next() else null
78     }
79 
closenull80     override fun close() {
81         closed = true
82         source.close()
83         inflater.end()
84     }
85 }
86 
87 class ZlibExtractor(val feedback: ImportFeedback) : Extractor {
88 
extractnull89     override fun extract(stream: StreamingReader, processSubStream: (BufferProducer) -> Unit) {
90         processSubStream(DeflateProducer(stream, feedback))
91     }
92 
93     object Factory : ExtractorFactory {
94         private const val SIZE_TO_CHECK = 200
95 
extractorFornull96         override fun extractorFor(buffer: GenericByteBuffer, feedback: ImportFeedback): Extractor? {
97             val start = findStart(buffer)
98             val toRead = minOf((buffer.length - start).toInt(), SIZE_TO_CHECK)
99             // deflate must contain at least a 2 byte header + 4 byte checksum
100             // So if there's less than 6 bytes this either isn't deflate or
101             // there's not enough data to try an inflate anyway
102             if (toRead <= 6) {
103                 return null
104             }
105             val inflate = Inflater()
106             try {
107                 val tmpBuffer = ByteArray(toRead) { buffer[start + it] }
108                 inflate.setInput(tmpBuffer)
109                 val result = ByteArray(1024)
110                 val inflated = inflate.inflate(result)
111                 inflate.end()
112                 if (inflated > 0) {
113                     return ZlibExtractor(feedback)
114                 }
115             } catch (ex: DataFormatException) {
116                 // Must not be deflate format
117             } finally {
118                 inflate.end()
119             }
120             return null
121         }
122     }
123 }