framedec.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "bytes"
  7. "encoding/hex"
  8. "errors"
  9. "hash"
  10. "io"
  11. "sync"
  12. "github.com/klauspost/compress/zstd/internal/xxhash"
  13. )
  14. type frameDec struct {
  15. o decoderOptions
  16. crc hash.Hash64
  17. offset int64
  18. WindowSize uint64
  19. // maxWindowSize is the maximum windows size to support.
  20. // should never be bigger than max-int.
  21. maxWindowSize uint64
  22. // In order queue of blocks being decoded.
  23. decoding chan *blockDec
  24. // Frame history passed between blocks
  25. history history
  26. rawInput byteBuffer
  27. // Byte buffer that can be reused for small input blocks.
  28. bBuf byteBuf
  29. FrameContentSize uint64
  30. frameDone sync.WaitGroup
  31. DictionaryID *uint32
  32. HasCheckSum bool
  33. SingleSegment bool
  34. // asyncRunning indicates whether the async routine processes input on 'decoding'.
  35. asyncRunningMu sync.Mutex
  36. asyncRunning bool
  37. }
  38. const (
  39. // The minimum Window_Size is 1 KB.
  40. MinWindowSize = 1 << 10
  41. MaxWindowSize = 1 << 29
  42. )
  43. var (
  44. frameMagic = []byte{0x28, 0xb5, 0x2f, 0xfd}
  45. skippableFrameMagic = []byte{0x2a, 0x4d, 0x18}
  46. )
  47. func newFrameDec(o decoderOptions) *frameDec {
  48. d := frameDec{
  49. o: o,
  50. maxWindowSize: MaxWindowSize,
  51. }
  52. if d.maxWindowSize > o.maxDecodedSize {
  53. d.maxWindowSize = o.maxDecodedSize
  54. }
  55. return &d
  56. }
  57. // reset will read the frame header and prepare for block decoding.
  58. // If nothing can be read from the input, io.EOF will be returned.
  59. // Any other error indicated that the stream contained data, but
  60. // there was a problem.
  61. func (d *frameDec) reset(br byteBuffer) error {
  62. d.HasCheckSum = false
  63. d.WindowSize = 0
  64. var b []byte
  65. for {
  66. var err error
  67. b, err = br.readSmall(4)
  68. switch err {
  69. case io.EOF, io.ErrUnexpectedEOF:
  70. return io.EOF
  71. default:
  72. return err
  73. case nil:
  74. }
  75. if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 {
  76. if debug {
  77. println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic))
  78. }
  79. // Break if not skippable frame.
  80. break
  81. }
  82. // Read size to skip
  83. b, err = br.readSmall(4)
  84. if err != nil {
  85. println("Reading Frame Size", err)
  86. return err
  87. }
  88. n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  89. println("Skipping frame with", n, "bytes.")
  90. err = br.skipN(int(n))
  91. if err != nil {
  92. if debug {
  93. println("Reading discarded frame", err)
  94. }
  95. return err
  96. }
  97. }
  98. if !bytes.Equal(b, frameMagic) {
  99. println("Got magic numbers: ", b, "want:", frameMagic)
  100. return ErrMagicMismatch
  101. }
  102. // Read Frame_Header_Descriptor
  103. fhd, err := br.readByte()
  104. if err != nil {
  105. println("Reading Frame_Header_Descriptor", err)
  106. return err
  107. }
  108. d.SingleSegment = fhd&(1<<5) != 0
  109. if fhd&(1<<3) != 0 {
  110. return errors.New("reserved bit set on frame header")
  111. }
  112. // Read Window_Descriptor
  113. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
  114. d.WindowSize = 0
  115. if !d.SingleSegment {
  116. wd, err := br.readByte()
  117. if err != nil {
  118. println("Reading Window_Descriptor", err)
  119. return err
  120. }
  121. printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
  122. windowLog := 10 + (wd >> 3)
  123. windowBase := uint64(1) << windowLog
  124. windowAdd := (windowBase / 8) * uint64(wd&0x7)
  125. d.WindowSize = windowBase + windowAdd
  126. }
  127. // Read Dictionary_ID
  128. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id
  129. d.DictionaryID = nil
  130. if size := fhd & 3; size != 0 {
  131. if size == 3 {
  132. size = 4
  133. }
  134. b, err = br.readSmall(int(size))
  135. if err != nil {
  136. println("Reading Dictionary_ID", err)
  137. return err
  138. }
  139. var id uint32
  140. switch size {
  141. case 1:
  142. id = uint32(b[0])
  143. case 2:
  144. id = uint32(b[0]) | (uint32(b[1]) << 8)
  145. case 4:
  146. id = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  147. }
  148. if debug {
  149. println("Dict size", size, "ID:", id)
  150. }
  151. if id > 0 {
  152. // ID 0 means "sorry, no dictionary anyway".
  153. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format
  154. d.DictionaryID = &id
  155. }
  156. }
  157. // Read Frame_Content_Size
  158. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size
  159. var fcsSize int
  160. v := fhd >> 6
  161. switch v {
  162. case 0:
  163. if d.SingleSegment {
  164. fcsSize = 1
  165. }
  166. default:
  167. fcsSize = 1 << v
  168. }
  169. d.FrameContentSize = 0
  170. if fcsSize > 0 {
  171. b, err = br.readSmall(fcsSize)
  172. if err != nil {
  173. println("Reading Frame content", err)
  174. return err
  175. }
  176. switch fcsSize {
  177. case 1:
  178. d.FrameContentSize = uint64(b[0])
  179. case 2:
  180. // When FCS_Field_Size is 2, the offset of 256 is added.
  181. d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256
  182. case 4:
  183. d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24)
  184. case 8:
  185. d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  186. d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24)
  187. d.FrameContentSize = uint64(d1) | (uint64(d2) << 32)
  188. }
  189. if debug {
  190. println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize)
  191. }
  192. }
  193. // Move this to shared.
  194. d.HasCheckSum = fhd&(1<<2) != 0
  195. if d.HasCheckSum {
  196. if d.crc == nil {
  197. d.crc = xxhash.New()
  198. }
  199. d.crc.Reset()
  200. }
  201. if d.WindowSize == 0 && d.SingleSegment {
  202. // We may not need window in this case.
  203. d.WindowSize = d.FrameContentSize
  204. if d.WindowSize < MinWindowSize {
  205. d.WindowSize = MinWindowSize
  206. }
  207. }
  208. if d.WindowSize > d.maxWindowSize {
  209. printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize)
  210. return ErrWindowSizeExceeded
  211. }
  212. // The minimum Window_Size is 1 KB.
  213. if d.WindowSize < MinWindowSize {
  214. println("got window size: ", d.WindowSize)
  215. return ErrWindowSizeTooSmall
  216. }
  217. d.history.windowSize = int(d.WindowSize)
  218. if d.o.lowMem && d.history.windowSize < maxBlockSize {
  219. d.history.maxSize = d.history.windowSize * 2
  220. } else {
  221. d.history.maxSize = d.history.windowSize + maxBlockSize
  222. }
  223. // history contains input - maybe we do something
  224. d.rawInput = br
  225. return nil
  226. }
  227. // next will start decoding the next block from stream.
  228. func (d *frameDec) next(block *blockDec) error {
  229. if debug {
  230. printf("decoding new block %p:%p", block, block.data)
  231. }
  232. err := block.reset(d.rawInput, d.WindowSize)
  233. if err != nil {
  234. println("block error:", err)
  235. // Signal the frame decoder we have a problem.
  236. d.sendErr(block, err)
  237. return err
  238. }
  239. block.input <- struct{}{}
  240. if debug {
  241. println("next block:", block)
  242. }
  243. d.asyncRunningMu.Lock()
  244. defer d.asyncRunningMu.Unlock()
  245. if !d.asyncRunning {
  246. return nil
  247. }
  248. if block.Last {
  249. // We indicate the frame is done by sending io.EOF
  250. d.decoding <- block
  251. return io.EOF
  252. }
  253. d.decoding <- block
  254. return nil
  255. }
  256. // sendEOF will queue an error block on the frame.
  257. // This will cause the frame decoder to return when it encounters the block.
  258. // Returns true if the decoder was added.
  259. func (d *frameDec) sendErr(block *blockDec, err error) bool {
  260. d.asyncRunningMu.Lock()
  261. defer d.asyncRunningMu.Unlock()
  262. if !d.asyncRunning {
  263. return false
  264. }
  265. println("sending error", err.Error())
  266. block.sendErr(err)
  267. d.decoding <- block
  268. return true
  269. }
  270. // checkCRC will check the checksum if the frame has one.
  271. // Will return ErrCRCMismatch if crc check failed, otherwise nil.
  272. func (d *frameDec) checkCRC() error {
  273. if !d.HasCheckSum {
  274. return nil
  275. }
  276. var tmp [4]byte
  277. got := d.crc.Sum64()
  278. // Flip to match file order.
  279. tmp[0] = byte(got >> 0)
  280. tmp[1] = byte(got >> 8)
  281. tmp[2] = byte(got >> 16)
  282. tmp[3] = byte(got >> 24)
  283. // We can overwrite upper tmp now
  284. want, err := d.rawInput.readSmall(4)
  285. if err != nil {
  286. println("CRC missing?", err)
  287. return err
  288. }
  289. if !bytes.Equal(tmp[:], want) {
  290. if debug {
  291. println("CRC Check Failed:", tmp[:], "!=", want)
  292. }
  293. return ErrCRCMismatch
  294. }
  295. if debug {
  296. println("CRC ok", tmp[:])
  297. }
  298. return nil
  299. }
  300. func (d *frameDec) initAsync() {
  301. if !d.o.lowMem && !d.SingleSegment {
  302. // set max extra size history to 10MB.
  303. d.history.maxSize = d.history.windowSize + maxBlockSize*5
  304. }
  305. // re-alloc if more than one extra block size.
  306. if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize {
  307. d.history.b = make([]byte, 0, d.history.maxSize)
  308. }
  309. if cap(d.history.b) < d.history.maxSize {
  310. d.history.b = make([]byte, 0, d.history.maxSize)
  311. }
  312. if cap(d.decoding) < d.o.concurrent {
  313. d.decoding = make(chan *blockDec, d.o.concurrent)
  314. }
  315. if debug {
  316. h := d.history
  317. printf("history init. len: %d, cap: %d", len(h.b), cap(h.b))
  318. }
  319. d.asyncRunningMu.Lock()
  320. d.asyncRunning = true
  321. d.asyncRunningMu.Unlock()
  322. }
  323. // startDecoder will start decoding blocks and write them to the writer.
  324. // The decoder will stop as soon as an error occurs or at end of frame.
  325. // When the frame has finished decoding the *bufio.Reader
  326. // containing the remaining input will be sent on frameDec.frameDone.
  327. func (d *frameDec) startDecoder(output chan decodeOutput) {
  328. written := int64(0)
  329. defer func() {
  330. d.asyncRunningMu.Lock()
  331. d.asyncRunning = false
  332. d.asyncRunningMu.Unlock()
  333. // Drain the currently decoding.
  334. d.history.error = true
  335. flushdone:
  336. for {
  337. select {
  338. case b := <-d.decoding:
  339. b.history <- &d.history
  340. output <- <-b.result
  341. default:
  342. break flushdone
  343. }
  344. }
  345. println("frame decoder done, signalling done")
  346. d.frameDone.Done()
  347. }()
  348. // Get decoder for first block.
  349. block := <-d.decoding
  350. block.history <- &d.history
  351. for {
  352. var next *blockDec
  353. // Get result
  354. r := <-block.result
  355. if r.err != nil {
  356. println("Result contained error", r.err)
  357. output <- r
  358. return
  359. }
  360. if debug {
  361. println("got result, from ", d.offset, "to", d.offset+int64(len(r.b)))
  362. d.offset += int64(len(r.b))
  363. }
  364. if !block.Last {
  365. // Send history to next block
  366. select {
  367. case next = <-d.decoding:
  368. if debug {
  369. println("Sending ", len(d.history.b), "bytes as history")
  370. }
  371. next.history <- &d.history
  372. default:
  373. // Wait until we have sent the block, so
  374. // other decoders can potentially get the decoder.
  375. next = nil
  376. }
  377. }
  378. // Add checksum, async to decoding.
  379. if d.HasCheckSum {
  380. n, err := d.crc.Write(r.b)
  381. if err != nil {
  382. r.err = err
  383. if n != len(r.b) {
  384. r.err = io.ErrShortWrite
  385. }
  386. output <- r
  387. return
  388. }
  389. }
  390. written += int64(len(r.b))
  391. if d.SingleSegment && uint64(written) > d.FrameContentSize {
  392. println("runDecoder: single segment and", uint64(written), ">", d.FrameContentSize)
  393. r.err = ErrFrameSizeExceeded
  394. output <- r
  395. return
  396. }
  397. if block.Last {
  398. r.err = d.checkCRC()
  399. output <- r
  400. return
  401. }
  402. output <- r
  403. if next == nil {
  404. // There was no decoder available, we wait for one now that we have sent to the writer.
  405. if debug {
  406. println("Sending ", len(d.history.b), " bytes as history")
  407. }
  408. next = <-d.decoding
  409. next.history <- &d.history
  410. }
  411. block = next
  412. }
  413. }
  414. // runDecoder will create a sync decoder that will decode a block of data.
  415. func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
  416. saved := d.history.b
  417. // We use the history for output to avoid copying it.
  418. d.history.b = dst
  419. // Store input length, so we only check new data.
  420. crcStart := len(dst)
  421. var err error
  422. for {
  423. err = dec.reset(d.rawInput, d.WindowSize)
  424. if err != nil {
  425. break
  426. }
  427. if debug {
  428. println("next block:", dec)
  429. }
  430. err = dec.decodeBuf(&d.history)
  431. if err != nil || dec.Last {
  432. break
  433. }
  434. if uint64(len(d.history.b)) > d.o.maxDecodedSize {
  435. err = ErrDecoderSizeExceeded
  436. break
  437. }
  438. if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize {
  439. println("runDecoder: single segment and", uint64(len(d.history.b)), ">", d.o.maxDecodedSize)
  440. err = ErrFrameSizeExceeded
  441. break
  442. }
  443. }
  444. dst = d.history.b
  445. if err == nil {
  446. if d.HasCheckSum {
  447. var n int
  448. n, err = d.crc.Write(dst[crcStart:])
  449. if err == nil {
  450. if n != len(dst)-crcStart {
  451. err = io.ErrShortWrite
  452. } else {
  453. err = d.checkCRC()
  454. }
  455. }
  456. }
  457. }
  458. d.history.b = saved
  459. return dst, err
  460. }