encoder.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "crypto/rand"
  7. "fmt"
  8. "io"
  9. rdebug "runtime/debug"
  10. "sync"
  11. "github.com/klauspost/compress/zstd/internal/xxhash"
  12. )
  13. // Encoder provides encoding to Zstandard.
  14. // An Encoder can be used for either compressing a stream via the
  15. // io.WriteCloser interface supported by the Encoder or as multiple independent
  16. // tasks via the EncodeAll function.
  17. // Smaller encodes are encouraged to use the EncodeAll function.
  18. // Use NewWriter to create a new instance.
  19. type Encoder struct {
  20. o encoderOptions
  21. encoders chan encoder
  22. state encoderState
  23. init sync.Once
  24. }
  25. type encoder interface {
  26. Encode(blk *blockEnc, src []byte)
  27. EncodeNoHist(blk *blockEnc, src []byte)
  28. Block() *blockEnc
  29. CRC() *xxhash.Digest
  30. AppendCRC([]byte) []byte
  31. WindowSize(size int) int32
  32. UseBlock(*blockEnc)
  33. Reset(d *dict, singleBlock bool)
  34. }
  35. type encoderState struct {
  36. w io.Writer
  37. filling []byte
  38. current []byte
  39. previous []byte
  40. encoder encoder
  41. writing *blockEnc
  42. err error
  43. writeErr error
  44. nWritten int64
  45. headerWritten bool
  46. eofWritten bool
  47. fullFrameWritten bool
  48. // This waitgroup indicates an encode is running.
  49. wg sync.WaitGroup
  50. // This waitgroup indicates we have a block encoding/writing.
  51. wWg sync.WaitGroup
  52. }
  53. // NewWriter will create a new Zstandard encoder.
  54. // If the encoder will be used for encoding blocks a nil writer can be used.
  55. func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
  56. initPredefined()
  57. var e Encoder
  58. e.o.setDefault()
  59. for _, o := range opts {
  60. err := o(&e.o)
  61. if err != nil {
  62. return nil, err
  63. }
  64. }
  65. if w != nil {
  66. e.Reset(w)
  67. }
  68. return &e, nil
  69. }
  70. func (e *Encoder) initialize() {
  71. if e.o.concurrent == 0 {
  72. e.o.setDefault()
  73. }
  74. e.encoders = make(chan encoder, e.o.concurrent)
  75. for i := 0; i < e.o.concurrent; i++ {
  76. enc := e.o.encoder()
  77. e.encoders <- enc
  78. }
  79. }
  80. // Reset will re-initialize the writer and new writes will encode to the supplied writer
  81. // as a new, independent stream.
  82. func (e *Encoder) Reset(w io.Writer) {
  83. s := &e.state
  84. s.wg.Wait()
  85. s.wWg.Wait()
  86. if cap(s.filling) == 0 {
  87. s.filling = make([]byte, 0, e.o.blockSize)
  88. }
  89. if cap(s.current) == 0 {
  90. s.current = make([]byte, 0, e.o.blockSize)
  91. }
  92. if cap(s.previous) == 0 {
  93. s.previous = make([]byte, 0, e.o.blockSize)
  94. }
  95. if s.encoder == nil {
  96. s.encoder = e.o.encoder()
  97. }
  98. if s.writing == nil {
  99. s.writing = &blockEnc{lowMem: e.o.lowMem}
  100. s.writing.init()
  101. }
  102. s.writing.initNewEncode()
  103. s.filling = s.filling[:0]
  104. s.current = s.current[:0]
  105. s.previous = s.previous[:0]
  106. s.encoder.Reset(e.o.dict, false)
  107. s.headerWritten = false
  108. s.eofWritten = false
  109. s.fullFrameWritten = false
  110. s.w = w
  111. s.err = nil
  112. s.nWritten = 0
  113. s.writeErr = nil
  114. }
  115. // Write data to the encoder.
  116. // Input data will be buffered and as the buffer fills up
  117. // content will be compressed and written to the output.
  118. // When done writing, use Close to flush the remaining output
  119. // and write CRC if requested.
  120. func (e *Encoder) Write(p []byte) (n int, err error) {
  121. s := &e.state
  122. for len(p) > 0 {
  123. if len(p)+len(s.filling) < e.o.blockSize {
  124. if e.o.crc {
  125. _, _ = s.encoder.CRC().Write(p)
  126. }
  127. s.filling = append(s.filling, p...)
  128. return n + len(p), nil
  129. }
  130. add := p
  131. if len(p)+len(s.filling) > e.o.blockSize {
  132. add = add[:e.o.blockSize-len(s.filling)]
  133. }
  134. if e.o.crc {
  135. _, _ = s.encoder.CRC().Write(add)
  136. }
  137. s.filling = append(s.filling, add...)
  138. p = p[len(add):]
  139. n += len(add)
  140. if len(s.filling) < e.o.blockSize {
  141. return n, nil
  142. }
  143. err := e.nextBlock(false)
  144. if err != nil {
  145. return n, err
  146. }
  147. if debugAsserts && len(s.filling) > 0 {
  148. panic(len(s.filling))
  149. }
  150. }
  151. return n, nil
  152. }
  153. // nextBlock will synchronize and start compressing input in e.state.filling.
  154. // If an error has occurred during encoding it will be returned.
  155. func (e *Encoder) nextBlock(final bool) error {
  156. s := &e.state
  157. // Wait for current block.
  158. s.wg.Wait()
  159. if s.err != nil {
  160. return s.err
  161. }
  162. if len(s.filling) > e.o.blockSize {
  163. return fmt.Errorf("block > maxStoreBlockSize")
  164. }
  165. if !s.headerWritten {
  166. // If we have a single block encode, do a sync compression.
  167. if final && len(s.filling) == 0 && !e.o.fullZero {
  168. s.headerWritten = true
  169. s.fullFrameWritten = true
  170. s.eofWritten = true
  171. return nil
  172. }
  173. if final && len(s.filling) > 0 {
  174. s.current = e.EncodeAll(s.filling, s.current[:0])
  175. var n2 int
  176. n2, s.err = s.w.Write(s.current)
  177. if s.err != nil {
  178. return s.err
  179. }
  180. s.nWritten += int64(n2)
  181. s.current = s.current[:0]
  182. s.filling = s.filling[:0]
  183. s.headerWritten = true
  184. s.fullFrameWritten = true
  185. s.eofWritten = true
  186. return nil
  187. }
  188. var tmp [maxHeaderSize]byte
  189. fh := frameHeader{
  190. ContentSize: 0,
  191. WindowSize: uint32(s.encoder.WindowSize(0)),
  192. SingleSegment: false,
  193. Checksum: e.o.crc,
  194. DictID: e.o.dict.ID(),
  195. }
  196. dst, err := fh.appendTo(tmp[:0])
  197. if err != nil {
  198. return err
  199. }
  200. s.headerWritten = true
  201. s.wWg.Wait()
  202. var n2 int
  203. n2, s.err = s.w.Write(dst)
  204. if s.err != nil {
  205. return s.err
  206. }
  207. s.nWritten += int64(n2)
  208. }
  209. if s.eofWritten {
  210. // Ensure we only write it once.
  211. final = false
  212. }
  213. if len(s.filling) == 0 {
  214. // Final block, but no data.
  215. if final {
  216. enc := s.encoder
  217. blk := enc.Block()
  218. blk.reset(nil)
  219. blk.last = true
  220. blk.encodeRaw(nil)
  221. s.wWg.Wait()
  222. _, s.err = s.w.Write(blk.output)
  223. s.nWritten += int64(len(blk.output))
  224. s.eofWritten = true
  225. }
  226. return s.err
  227. }
  228. // Move blocks forward.
  229. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
  230. s.wg.Add(1)
  231. go func(src []byte) {
  232. if debug {
  233. println("Adding block,", len(src), "bytes, final:", final)
  234. }
  235. defer func() {
  236. if r := recover(); r != nil {
  237. s.err = fmt.Errorf("panic while encoding: %v", r)
  238. rdebug.PrintStack()
  239. }
  240. s.wg.Done()
  241. }()
  242. enc := s.encoder
  243. blk := enc.Block()
  244. enc.Encode(blk, src)
  245. blk.last = final
  246. if final {
  247. s.eofWritten = true
  248. }
  249. // Wait for pending writes.
  250. s.wWg.Wait()
  251. if s.writeErr != nil {
  252. s.err = s.writeErr
  253. return
  254. }
  255. // Transfer encoders from previous write block.
  256. blk.swapEncoders(s.writing)
  257. // Transfer recent offsets to next.
  258. enc.UseBlock(s.writing)
  259. s.writing = blk
  260. s.wWg.Add(1)
  261. go func() {
  262. defer func() {
  263. if r := recover(); r != nil {
  264. s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
  265. rdebug.PrintStack()
  266. }
  267. s.wWg.Done()
  268. }()
  269. err := errIncompressible
  270. // If we got the exact same number of literals as input,
  271. // assume the literals cannot be compressed.
  272. if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
  273. err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  274. }
  275. switch err {
  276. case errIncompressible:
  277. if debug {
  278. println("Storing incompressible block as raw")
  279. }
  280. blk.encodeRaw(src)
  281. // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
  282. case nil:
  283. default:
  284. s.writeErr = err
  285. return
  286. }
  287. _, s.writeErr = s.w.Write(blk.output)
  288. s.nWritten += int64(len(blk.output))
  289. }()
  290. }(s.current)
  291. return nil
  292. }
  293. // ReadFrom reads data from r until EOF or error.
  294. // The return value n is the number of bytes read.
  295. // Any error except io.EOF encountered during the read is also returned.
  296. //
  297. // The Copy function uses ReaderFrom if available.
  298. func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
  299. if debug {
  300. println("Using ReadFrom")
  301. }
  302. // Flush any current writes.
  303. if len(e.state.filling) > 0 {
  304. if err := e.nextBlock(false); err != nil {
  305. return 0, err
  306. }
  307. }
  308. e.state.filling = e.state.filling[:e.o.blockSize]
  309. src := e.state.filling
  310. for {
  311. n2, err := r.Read(src)
  312. if e.o.crc {
  313. _, _ = e.state.encoder.CRC().Write(src[:n2])
  314. }
  315. // src is now the unfilled part...
  316. src = src[n2:]
  317. n += int64(n2)
  318. switch err {
  319. case io.EOF:
  320. e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
  321. if debug {
  322. println("ReadFrom: got EOF final block:", len(e.state.filling))
  323. }
  324. return n, nil
  325. case nil:
  326. default:
  327. if debug {
  328. println("ReadFrom: got error:", err)
  329. }
  330. e.state.err = err
  331. return n, err
  332. }
  333. if len(src) > 0 {
  334. if debug {
  335. println("ReadFrom: got space left in source:", len(src))
  336. }
  337. continue
  338. }
  339. err = e.nextBlock(false)
  340. if err != nil {
  341. return n, err
  342. }
  343. e.state.filling = e.state.filling[:e.o.blockSize]
  344. src = e.state.filling
  345. }
  346. }
  347. // Flush will send the currently written data to output
  348. // and block until everything has been written.
  349. // This should only be used on rare occasions where pushing the currently queued data is critical.
  350. func (e *Encoder) Flush() error {
  351. s := &e.state
  352. if len(s.filling) > 0 {
  353. err := e.nextBlock(false)
  354. if err != nil {
  355. return err
  356. }
  357. }
  358. s.wg.Wait()
  359. s.wWg.Wait()
  360. if s.err != nil {
  361. return s.err
  362. }
  363. return s.writeErr
  364. }
  365. // Close will flush the final output and close the stream.
  366. // The function will block until everything has been written.
  367. // The Encoder can still be re-used after calling this.
  368. func (e *Encoder) Close() error {
  369. s := &e.state
  370. if s.encoder == nil {
  371. return nil
  372. }
  373. err := e.nextBlock(true)
  374. if err != nil {
  375. return err
  376. }
  377. if e.state.fullFrameWritten {
  378. return s.err
  379. }
  380. s.wg.Wait()
  381. s.wWg.Wait()
  382. if s.err != nil {
  383. return s.err
  384. }
  385. if s.writeErr != nil {
  386. return s.writeErr
  387. }
  388. // Write CRC
  389. if e.o.crc && s.err == nil {
  390. // heap alloc.
  391. var tmp [4]byte
  392. _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
  393. s.nWritten += 4
  394. }
  395. // Add padding with content from crypto/rand.Reader
  396. if s.err == nil && e.o.pad > 0 {
  397. add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
  398. frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
  399. if err != nil {
  400. return err
  401. }
  402. _, s.err = s.w.Write(frame)
  403. }
  404. return s.err
  405. }
  406. // EncodeAll will encode all input in src and append it to dst.
  407. // This function can be called concurrently, but each call will only run on a single goroutine.
  408. // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
  409. // Encoded blocks can be concatenated and the result will be the combined input stream.
  410. // Data compressed with EncodeAll can be decoded with the Decoder,
  411. // using either a stream or DecodeAll.
  412. func (e *Encoder) EncodeAll(src, dst []byte) []byte {
  413. if len(src) == 0 {
  414. if e.o.fullZero {
  415. // Add frame header.
  416. fh := frameHeader{
  417. ContentSize: 0,
  418. WindowSize: MinWindowSize,
  419. SingleSegment: true,
  420. // Adding a checksum would be a waste of space.
  421. Checksum: false,
  422. DictID: 0,
  423. }
  424. dst, _ = fh.appendTo(dst)
  425. // Write raw block as last one only.
  426. var blk blockHeader
  427. blk.setSize(0)
  428. blk.setType(blockTypeRaw)
  429. blk.setLast(true)
  430. dst = blk.appendTo(dst)
  431. }
  432. return dst
  433. }
  434. e.init.Do(e.initialize)
  435. enc := <-e.encoders
  436. defer func() {
  437. // Release encoder reference to last block.
  438. // If a non-single block is needed the encoder will reset again.
  439. e.encoders <- enc
  440. }()
  441. // Use single segments when above minimum window and below 1MB.
  442. single := len(src) < 1<<20 && len(src) > MinWindowSize
  443. if e.o.single != nil {
  444. single = *e.o.single
  445. }
  446. fh := frameHeader{
  447. ContentSize: uint64(len(src)),
  448. WindowSize: uint32(enc.WindowSize(len(src))),
  449. SingleSegment: single,
  450. Checksum: e.o.crc,
  451. DictID: e.o.dict.ID(),
  452. }
  453. // If less than 1MB, allocate a buffer up front.
  454. if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
  455. dst = make([]byte, 0, len(src))
  456. }
  457. dst, err := fh.appendTo(dst)
  458. if err != nil {
  459. panic(err)
  460. }
  461. // If we can do everything in one block, prefer that.
  462. if len(src) <= maxCompressedBlockSize {
  463. enc.Reset(e.o.dict, true)
  464. // Slightly faster with no history and everything in one block.
  465. if e.o.crc {
  466. _, _ = enc.CRC().Write(src)
  467. }
  468. blk := enc.Block()
  469. blk.last = true
  470. if e.o.dict == nil {
  471. enc.EncodeNoHist(blk, src)
  472. } else {
  473. enc.Encode(blk, src)
  474. }
  475. // If we got the exact same number of literals as input,
  476. // assume the literals cannot be compressed.
  477. err := errIncompressible
  478. oldout := blk.output
  479. if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
  480. // Output directly to dst
  481. blk.output = dst
  482. err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  483. }
  484. switch err {
  485. case errIncompressible:
  486. if debug {
  487. println("Storing incompressible block as raw")
  488. }
  489. dst = blk.encodeRawTo(dst, src)
  490. case nil:
  491. dst = blk.output
  492. default:
  493. panic(err)
  494. }
  495. blk.output = oldout
  496. } else {
  497. enc.Reset(e.o.dict, false)
  498. blk := enc.Block()
  499. for len(src) > 0 {
  500. todo := src
  501. if len(todo) > e.o.blockSize {
  502. todo = todo[:e.o.blockSize]
  503. }
  504. src = src[len(todo):]
  505. if e.o.crc {
  506. _, _ = enc.CRC().Write(todo)
  507. }
  508. blk.pushOffsets()
  509. enc.Encode(blk, todo)
  510. if len(src) == 0 {
  511. blk.last = true
  512. }
  513. err := errIncompressible
  514. // If we got the exact same number of literals as input,
  515. // assume the literals cannot be compressed.
  516. if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
  517. err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
  518. }
  519. switch err {
  520. case errIncompressible:
  521. if debug {
  522. println("Storing incompressible block as raw")
  523. }
  524. dst = blk.encodeRawTo(dst, todo)
  525. blk.popOffsets()
  526. case nil:
  527. dst = append(dst, blk.output...)
  528. default:
  529. panic(err)
  530. }
  531. blk.reset(nil)
  532. }
  533. }
  534. if e.o.crc {
  535. dst = enc.AppendCRC(dst)
  536. }
  537. // Add padding with content from crypto/rand.Reader
  538. if e.o.pad > 0 {
  539. add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
  540. dst, err = skippableFrame(dst, add, rand.Reader)
  541. if err != nil {
  542. panic(err)
  543. }
  544. }
  545. return dst
  546. }