seqdec.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "errors"
  7. "fmt"
  8. "io"
  9. )
  10. type seq struct {
  11. litLen uint32
  12. matchLen uint32
  13. offset uint32
  14. // Codes are stored here for the encoder
  15. // so they only have to be looked up once.
  16. llCode, mlCode, ofCode uint8
  17. }
  18. type seqVals struct {
  19. ll, ml, mo int
  20. }
  21. func (s seq) String() string {
  22. if s.offset <= 3 {
  23. if s.offset == 0 {
  24. return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset: INVALID (0)")
  25. }
  26. return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset, " (repeat)")
  27. }
  28. return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset-3, " (new)")
  29. }
  30. type seqCompMode uint8
  31. const (
  32. compModePredefined seqCompMode = iota
  33. compModeRLE
  34. compModeFSE
  35. compModeRepeat
  36. )
  37. type sequenceDec struct {
  38. // decoder keeps track of the current state and updates it from the bitstream.
  39. fse *fseDecoder
  40. state fseState
  41. repeat bool
  42. }
  43. // init the state of the decoder with input from stream.
  44. func (s *sequenceDec) init(br *bitReader) error {
  45. if s.fse == nil {
  46. return errors.New("sequence decoder not defined")
  47. }
  48. s.state.init(br, s.fse.actualTableLog, s.fse.dt[:1<<s.fse.actualTableLog])
  49. return nil
  50. }
  51. // sequenceDecs contains all 3 sequence decoders and their state.
  52. type sequenceDecs struct {
  53. litLengths sequenceDec
  54. offsets sequenceDec
  55. matchLengths sequenceDec
  56. prevOffset [3]int
  57. dict []byte
  58. literals []byte
  59. out []byte
  60. nSeqs int
  61. br *bitReader
  62. seqSize int
  63. windowSize int
  64. maxBits uint8
  65. maxSyncLen uint64
  66. }
  67. // initialize all 3 decoders from the stream input.
  68. func (s *sequenceDecs) initialize(br *bitReader, hist *history, out []byte) error {
  69. if err := s.litLengths.init(br); err != nil {
  70. return errors.New("litLengths:" + err.Error())
  71. }
  72. if err := s.offsets.init(br); err != nil {
  73. return errors.New("offsets:" + err.Error())
  74. }
  75. if err := s.matchLengths.init(br); err != nil {
  76. return errors.New("matchLengths:" + err.Error())
  77. }
  78. s.br = br
  79. s.prevOffset = hist.recentOffsets
  80. s.maxBits = s.litLengths.fse.maxBits + s.offsets.fse.maxBits + s.matchLengths.fse.maxBits
  81. s.windowSize = hist.windowSize
  82. s.out = out
  83. s.dict = nil
  84. if hist.dict != nil {
  85. s.dict = hist.dict.content
  86. }
  87. return nil
  88. }
  89. func (s *sequenceDecs) freeDecoders() {
  90. if f := s.litLengths.fse; f != nil && !f.preDefined {
  91. fseDecoderPool.Put(f)
  92. s.litLengths.fse = nil
  93. }
  94. if f := s.offsets.fse; f != nil && !f.preDefined {
  95. fseDecoderPool.Put(f)
  96. s.offsets.fse = nil
  97. }
  98. if f := s.matchLengths.fse; f != nil && !f.preDefined {
  99. fseDecoderPool.Put(f)
  100. s.matchLengths.fse = nil
  101. }
  102. }
  103. // execute will execute the decoded sequence with the provided history.
  104. // The sequence must be evaluated before being sent.
  105. func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {
  106. if len(s.dict) == 0 {
  107. return s.executeSimple(seqs, hist)
  108. }
  109. // Ensure we have enough output size...
  110. if len(s.out)+s.seqSize > cap(s.out) {
  111. addBytes := s.seqSize + len(s.out)
  112. s.out = append(s.out, make([]byte, addBytes)...)
  113. s.out = s.out[:len(s.out)-addBytes]
  114. }
  115. if debugDecoder {
  116. printf("Execute %d seqs with hist %d, dict %d, literals: %d into %d bytes\n", len(seqs), len(hist), len(s.dict), len(s.literals), s.seqSize)
  117. }
  118. var t = len(s.out)
  119. out := s.out[:t+s.seqSize]
  120. for _, seq := range seqs {
  121. // Add literals
  122. copy(out[t:], s.literals[:seq.ll])
  123. t += seq.ll
  124. s.literals = s.literals[seq.ll:]
  125. // Copy from dictionary...
  126. if seq.mo > t+len(hist) || seq.mo > s.windowSize {
  127. if len(s.dict) == 0 {
  128. return fmt.Errorf("match offset (%d) bigger than current history (%d)", seq.mo, t+len(hist))
  129. }
  130. // we may be in dictionary.
  131. dictO := len(s.dict) - (seq.mo - (t + len(hist)))
  132. if dictO < 0 || dictO >= len(s.dict) {
  133. return fmt.Errorf("match offset (%d) bigger than current history+dict (%d)", seq.mo, t+len(hist)+len(s.dict))
  134. }
  135. end := dictO + seq.ml
  136. if end > len(s.dict) {
  137. n := len(s.dict) - dictO
  138. copy(out[t:], s.dict[dictO:])
  139. t += n
  140. seq.ml -= n
  141. } else {
  142. copy(out[t:], s.dict[dictO:end])
  143. t += end - dictO
  144. continue
  145. }
  146. }
  147. // Copy from history.
  148. if v := seq.mo - t; v > 0 {
  149. // v is the start position in history from end.
  150. start := len(hist) - v
  151. if seq.ml > v {
  152. // Some goes into current block.
  153. // Copy remainder of history
  154. copy(out[t:], hist[start:])
  155. t += v
  156. seq.ml -= v
  157. } else {
  158. copy(out[t:], hist[start:start+seq.ml])
  159. t += seq.ml
  160. continue
  161. }
  162. }
  163. // We must be in current buffer now
  164. if seq.ml > 0 {
  165. start := t - seq.mo
  166. if seq.ml <= t-start {
  167. // No overlap
  168. copy(out[t:], out[start:start+seq.ml])
  169. t += seq.ml
  170. continue
  171. } else {
  172. // Overlapping copy
  173. // Extend destination slice and copy one byte at the time.
  174. src := out[start : start+seq.ml]
  175. dst := out[t:]
  176. dst = dst[:len(src)]
  177. t += len(src)
  178. // Destination is the space we just added.
  179. for i := range src {
  180. dst[i] = src[i]
  181. }
  182. }
  183. }
  184. }
  185. // Add final literals
  186. copy(out[t:], s.literals)
  187. if debugDecoder {
  188. t += len(s.literals)
  189. if t != len(out) {
  190. panic(fmt.Errorf("length mismatch, want %d, got %d, ss: %d", len(out), t, s.seqSize))
  191. }
  192. }
  193. s.out = out
  194. return nil
  195. }
  196. // decode sequences from the stream with the provided history.
  197. func (s *sequenceDecs) decodeSync(hist []byte) error {
  198. supported, err := s.decodeSyncSimple(hist)
  199. if supported {
  200. return err
  201. }
  202. br := s.br
  203. seqs := s.nSeqs
  204. startSize := len(s.out)
  205. // Grab full sizes tables, to avoid bounds checks.
  206. llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize]
  207. llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
  208. out := s.out
  209. maxBlockSize := maxCompressedBlockSize
  210. if s.windowSize < maxBlockSize {
  211. maxBlockSize = s.windowSize
  212. }
  213. for i := seqs - 1; i >= 0; i-- {
  214. if br.overread() {
  215. printf("reading sequence %d, exceeded available data\n", seqs-i)
  216. return io.ErrUnexpectedEOF
  217. }
  218. var ll, mo, ml int
  219. if br.off > 4+((maxOffsetBits+16+16)>>3) {
  220. // inlined function:
  221. // ll, mo, ml = s.nextFast(br, llState, mlState, ofState)
  222. // Final will not read from stream.
  223. var llB, mlB, moB uint8
  224. ll, llB = llState.final()
  225. ml, mlB = mlState.final()
  226. mo, moB = ofState.final()
  227. // extra bits are stored in reverse order.
  228. br.fillFast()
  229. mo += br.getBits(moB)
  230. if s.maxBits > 32 {
  231. br.fillFast()
  232. }
  233. ml += br.getBits(mlB)
  234. ll += br.getBits(llB)
  235. if moB > 1 {
  236. s.prevOffset[2] = s.prevOffset[1]
  237. s.prevOffset[1] = s.prevOffset[0]
  238. s.prevOffset[0] = mo
  239. } else {
  240. // mo = s.adjustOffset(mo, ll, moB)
  241. // Inlined for rather big speedup
  242. if ll == 0 {
  243. // There is an exception though, when current sequence's literals_length = 0.
  244. // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
  245. // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
  246. mo++
  247. }
  248. if mo == 0 {
  249. mo = s.prevOffset[0]
  250. } else {
  251. var temp int
  252. if mo == 3 {
  253. temp = s.prevOffset[0] - 1
  254. } else {
  255. temp = s.prevOffset[mo]
  256. }
  257. if temp == 0 {
  258. // 0 is not valid; input is corrupted; force offset to 1
  259. println("WARNING: temp was 0")
  260. temp = 1
  261. }
  262. if mo != 1 {
  263. s.prevOffset[2] = s.prevOffset[1]
  264. }
  265. s.prevOffset[1] = s.prevOffset[0]
  266. s.prevOffset[0] = temp
  267. mo = temp
  268. }
  269. }
  270. br.fillFast()
  271. } else {
  272. ll, mo, ml = s.next(br, llState, mlState, ofState)
  273. br.fill()
  274. }
  275. if debugSequences {
  276. println("Seq", seqs-i-1, "Litlen:", ll, "mo:", mo, "(abs) ml:", ml)
  277. }
  278. if ll > len(s.literals) {
  279. return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, len(s.literals))
  280. }
  281. size := ll + ml + len(out)
  282. if size-startSize > maxBlockSize {
  283. return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
  284. }
  285. if size > cap(out) {
  286. // Not enough size, which can happen under high volume block streaming conditions
  287. // but could be if destination slice is too small for sync operations.
  288. // over-allocating here can create a large amount of GC pressure so we try to keep
  289. // it as contained as possible
  290. used := len(out) - startSize
  291. addBytes := 256 + ll + ml + used>>2
  292. // Clamp to max block size.
  293. if used+addBytes > maxBlockSize {
  294. addBytes = maxBlockSize - used
  295. }
  296. out = append(out, make([]byte, addBytes)...)
  297. out = out[:len(out)-addBytes]
  298. }
  299. if ml > maxMatchLen {
  300. return fmt.Errorf("match len (%d) bigger than max allowed length", ml)
  301. }
  302. // Add literals
  303. out = append(out, s.literals[:ll]...)
  304. s.literals = s.literals[ll:]
  305. if mo == 0 && ml > 0 {
  306. return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml)
  307. }
  308. if mo > len(out)+len(hist) || mo > s.windowSize {
  309. if len(s.dict) == 0 {
  310. return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(out)+len(hist)-startSize)
  311. }
  312. // we may be in dictionary.
  313. dictO := len(s.dict) - (mo - (len(out) + len(hist)))
  314. if dictO < 0 || dictO >= len(s.dict) {
  315. return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(out)+len(hist)-startSize)
  316. }
  317. end := dictO + ml
  318. if end > len(s.dict) {
  319. out = append(out, s.dict[dictO:]...)
  320. ml -= len(s.dict) - dictO
  321. } else {
  322. out = append(out, s.dict[dictO:end]...)
  323. mo = 0
  324. ml = 0
  325. }
  326. }
  327. // Copy from history.
  328. // TODO: Blocks without history could be made to ignore this completely.
  329. if v := mo - len(out); v > 0 {
  330. // v is the start position in history from end.
  331. start := len(hist) - v
  332. if ml > v {
  333. // Some goes into current block.
  334. // Copy remainder of history
  335. out = append(out, hist[start:]...)
  336. ml -= v
  337. } else {
  338. out = append(out, hist[start:start+ml]...)
  339. ml = 0
  340. }
  341. }
  342. // We must be in current buffer now
  343. if ml > 0 {
  344. start := len(out) - mo
  345. if ml <= len(out)-start {
  346. // No overlap
  347. out = append(out, out[start:start+ml]...)
  348. } else {
  349. // Overlapping copy
  350. // Extend destination slice and copy one byte at the time.
  351. out = out[:len(out)+ml]
  352. src := out[start : start+ml]
  353. // Destination is the space we just added.
  354. dst := out[len(out)-ml:]
  355. dst = dst[:len(src)]
  356. for i := range src {
  357. dst[i] = src[i]
  358. }
  359. }
  360. }
  361. if i == 0 {
  362. // This is the last sequence, so we shouldn't update state.
  363. break
  364. }
  365. // Manually inlined, ~ 5-20% faster
  366. // Update all 3 states at once. Approx 20% faster.
  367. nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits()
  368. if nBits == 0 {
  369. llState = llTable[llState.newState()&maxTableMask]
  370. mlState = mlTable[mlState.newState()&maxTableMask]
  371. ofState = ofTable[ofState.newState()&maxTableMask]
  372. } else {
  373. bits := br.get32BitsFast(nBits)
  374. lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31))
  375. llState = llTable[(llState.newState()+lowBits)&maxTableMask]
  376. lowBits = uint16(bits >> (ofState.nbBits() & 31))
  377. lowBits &= bitMask[mlState.nbBits()&15]
  378. mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask]
  379. lowBits = uint16(bits) & bitMask[ofState.nbBits()&15]
  380. ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask]
  381. }
  382. }
  383. if size := len(s.literals) + len(out) - startSize; size > maxBlockSize {
  384. return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
  385. }
  386. // Add final literals
  387. s.out = append(out, s.literals...)
  388. return br.close()
  389. }
  390. var bitMask [16]uint16
  391. func init() {
  392. for i := range bitMask[:] {
  393. bitMask[i] = uint16((1 << uint(i)) - 1)
  394. }
  395. }
  396. func (s *sequenceDecs) next(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
  397. // Final will not read from stream.
  398. ll, llB := llState.final()
  399. ml, mlB := mlState.final()
  400. mo, moB := ofState.final()
  401. // extra bits are stored in reverse order.
  402. br.fill()
  403. if s.maxBits <= 32 {
  404. mo += br.getBits(moB)
  405. ml += br.getBits(mlB)
  406. ll += br.getBits(llB)
  407. } else {
  408. mo += br.getBits(moB)
  409. br.fill()
  410. // matchlength+literal length, max 32 bits
  411. ml += br.getBits(mlB)
  412. ll += br.getBits(llB)
  413. }
  414. mo = s.adjustOffset(mo, ll, moB)
  415. return
  416. }
  417. func (s *sequenceDecs) adjustOffset(offset, litLen int, offsetB uint8) int {
  418. if offsetB > 1 {
  419. s.prevOffset[2] = s.prevOffset[1]
  420. s.prevOffset[1] = s.prevOffset[0]
  421. s.prevOffset[0] = offset
  422. return offset
  423. }
  424. if litLen == 0 {
  425. // There is an exception though, when current sequence's literals_length = 0.
  426. // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
  427. // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
  428. offset++
  429. }
  430. if offset == 0 {
  431. return s.prevOffset[0]
  432. }
  433. var temp int
  434. if offset == 3 {
  435. temp = s.prevOffset[0] - 1
  436. } else {
  437. temp = s.prevOffset[offset]
  438. }
  439. if temp == 0 {
  440. // 0 is not valid; input is corrupted; force offset to 1
  441. println("temp was 0")
  442. temp = 1
  443. }
  444. if offset != 1 {
  445. s.prevOffset[2] = s.prevOffset[1]
  446. }
  447. s.prevOffset[1] = s.prevOffset[0]
  448. s.prevOffset[0] = temp
  449. return temp
  450. }