Forráskód Böngészése

daemon/logger: read the length header correctly

Before this change, if Decode() couldn't read a log record fully,
the subsequent invocation of Decode() would read the record's non-header part
as a header and cause a huge heap allocation.

This change prevents such a case by having the intermediate buffer in
the decoder struct.

Fixes #42125.

Signed-off-by: Kazuyoshi Kato <katokazu@amazon.com>
Kazuyoshi Kato 3 éve
szülő
commit
48d387a757

+ 71 - 48
daemon/logger/local/read.go

@@ -1,12 +1,12 @@
 package local
 
 import (
+	"bytes"
 	"context"
 	"encoding/binary"
+	"fmt"
 	"io"
 
-	"bytes"
-
 	"github.com/docker/docker/api/types/plugins/logdriver"
 	"github.com/docker/docker/daemon/logger"
 	"github.com/docker/docker/daemon/logger/loggerutils"
@@ -14,6 +14,10 @@ import (
 	"github.com/pkg/errors"
 )
 
+// maxMsgLen is the maximum size of the logger.Message after serialization.
+// logger.defaultBufSize caps the size of Line field.
+const maxMsgLen int = 1e6 // 1MB.
+
 func (d *driver) ReadLogs(config logger.ReadConfig) *logger.LogWatcher {
 	logWatcher := logger.NewLogWatcher()
 
@@ -99,7 +103,35 @@ func getTailReader(ctx context.Context, r loggerutils.SizeReaderAt, req int) (io
 type decoder struct {
 	rdr   io.Reader
 	proto *logdriver.LogEntry
-	buf   []byte
+	// buf keeps bytes from rdr.
+	buf []byte
+	// offset is the position in buf.
+	// If offset > 0, buf[offset:] has bytes which are read but haven't used.
+	offset int
+	// nextMsgLen is the length of the next log message.
+	// If nextMsgLen = 0, a new value must be read from rdr.
+	nextMsgLen int
+}
+
+func (d *decoder) readRecord(size int) error {
+	var err error
+	for i := 0; i < maxDecodeRetry; i++ {
+		var n int
+		n, err = io.ReadFull(d.rdr, d.buf[d.offset:size])
+		d.offset += n
+		if err != nil {
+			if err != io.ErrUnexpectedEOF {
+				return err
+			}
+			continue
+		}
+		break
+	}
+	if err != nil {
+		return err
+	}
+	d.offset = 0
+	return nil
 }
 
 func (d *decoder) Decode() (*logger.Message, error) {
@@ -111,44 +143,35 @@ func (d *decoder) Decode() (*logger.Message, error) {
 	if d.buf == nil {
 		d.buf = make([]byte, initialBufSize)
 	}
-	var (
-		read int
-		err  error
-	)
 
-	for i := 0; i < maxDecodeRetry; i++ {
-		var n int
-		n, err = io.ReadFull(d.rdr, d.buf[read:encodeBinaryLen])
+	if d.nextMsgLen == 0 {
+		msgLen, err := d.decodeSizeHeader()
 		if err != nil {
-			if err != io.ErrUnexpectedEOF {
-				return nil, errors.Wrap(err, "error reading log message length")
-			}
-			read += n
-			continue
+			return nil, err
 		}
-		read += n
-		break
-	}
-	if err != nil {
-		return nil, errors.Wrapf(err, "could not read log message length: read: %d, expected: %d", read, encodeBinaryLen)
-	}
 
-	msgLen := int(binary.BigEndian.Uint32(d.buf[:read]))
+		if msgLen > maxMsgLen {
+			return nil, fmt.Errorf("log message is too large (%d > %d)", msgLen, maxMsgLen)
+		}
 
-	if len(d.buf) < msgLen+encodeBinaryLen {
-		d.buf = make([]byte, msgLen+encodeBinaryLen)
-	} else {
-		if msgLen <= initialBufSize {
+		if len(d.buf) < msgLen+encodeBinaryLen {
+			d.buf = make([]byte, msgLen+encodeBinaryLen)
+		} else if msgLen <= initialBufSize {
 			d.buf = d.buf[:initialBufSize]
 		} else {
 			d.buf = d.buf[:msgLen+encodeBinaryLen]
 		}
-	}
 
-	return decodeLogEntry(d.rdr, d.proto, d.buf, msgLen)
+		d.nextMsgLen = msgLen
+	}
+	return d.decodeLogEntry()
 }
 
 func (d *decoder) Reset(rdr io.Reader) {
+	if d.rdr == rdr {
+		return
+	}
+
 	d.rdr = rdr
 	if d.proto != nil {
 		resetProto(d.proto)
@@ -156,6 +179,8 @@ func (d *decoder) Reset(rdr io.Reader) {
 	if d.buf != nil {
 		d.buf = d.buf[:initialBufSize]
 	}
+	d.offset = 0
+	d.nextMsgLen = 0
 }
 
 func (d *decoder) Close() {
@@ -171,34 +196,32 @@ func decodeFunc(rdr io.Reader) loggerutils.Decoder {
 	return &decoder{rdr: rdr}
 }
 
-func decodeLogEntry(rdr io.Reader, proto *logdriver.LogEntry, buf []byte, msgLen int) (*logger.Message, error) {
-	var (
-		read int
-		err  error
-	)
-	for i := 0; i < maxDecodeRetry; i++ {
-		var n int
-		n, err = io.ReadFull(rdr, buf[read:msgLen+encodeBinaryLen])
-		if err != nil {
-			if err != io.ErrUnexpectedEOF {
-				return nil, errors.Wrap(err, "could not decode log entry")
-			}
-			read += n
-			continue
-		}
-		break
+func (d *decoder) decodeSizeHeader() (int, error) {
+	err := d.readRecord(encodeBinaryLen)
+	if err != nil {
+		return 0, errors.Wrap(err, "could not read a size header")
 	}
+
+	msgLen := int(binary.BigEndian.Uint32(d.buf[:encodeBinaryLen]))
+	return msgLen, nil
+}
+
+func (d *decoder) decodeLogEntry() (*logger.Message, error) {
+	msgLen := d.nextMsgLen
+	err := d.readRecord(msgLen + encodeBinaryLen)
 	if err != nil {
-		return nil, errors.Wrapf(err, "could not decode entry: read %d, expected: %d", read, msgLen)
+		return nil, errors.Wrapf(err, "could not read a log entry (size=%d+%d)", msgLen, encodeBinaryLen)
 	}
+	d.nextMsgLen = 0
 
-	if err := proto.Unmarshal(buf[:msgLen]); err != nil {
-		return nil, errors.Wrap(err, "error unmarshalling log entry")
+	if err := d.proto.Unmarshal(d.buf[:msgLen]); err != nil {
+		return nil, errors.Wrapf(err, "error unmarshalling log entry (size=%d)", msgLen)
 	}
 
-	msg := protoToMessage(proto)
+	msg := protoToMessage(d.proto)
 	if msg.PLogMetaData == nil {
 		msg.Line = append(msg.Line, '\n')
 	}
+
 	return msg, nil
 }

+ 52 - 0
daemon/logger/local/read_test.go

@@ -0,0 +1,52 @@
+package local
+
+import (
+	"io"
+	"io/ioutil"
+	"os"
+	"testing"
+
+	"github.com/docker/docker/daemon/logger"
+	"github.com/pkg/errors"
+	"gotest.tools/v3/assert"
+)
+
+func TestDecode(t *testing.T) {
+	marshal := makeMarshaller()
+
+	buf, err := marshal(&logger.Message{Line: []byte("hello")})
+	assert.NilError(t, err)
+
+	for i := 0; i < len(buf); i++ {
+		testDecode(t, buf, i)
+	}
+}
+
+func testDecode(t *testing.T, buf []byte, split int) {
+	fw, err := ioutil.TempFile("", t.Name())
+	assert.NilError(t, err)
+	defer os.Remove(fw.Name())
+
+	fr, err := os.Open(fw.Name())
+	assert.NilError(t, err)
+
+	d := &decoder{rdr: fr}
+
+	if split > 0 {
+		_, err = fw.Write(buf[0:split])
+		assert.NilError(t, err)
+
+		_, err = d.Decode()
+		assert.Assert(t, errors.Is(err, io.EOF))
+
+		_, err = fw.Write(buf[split:])
+		assert.NilError(t, err)
+	} else {
+		_, err = fw.Write(buf)
+		assert.NilError(t, err)
+	}
+
+	message, err := d.Decode()
+	assert.NilError(t, err)
+	assert.Equal(t, "hello\n", string(message.Line))
+}

+ 1 - 0
daemon/logger/loggerutils/logfile.go

@@ -715,6 +715,7 @@ func followLogs(f *os.File, logWatcher *logger.LogWatcher, notifyRotate, notifyE
 			defer func() { oldSize = size }()
 			if size < oldSize { // truncated
 				f.Seek(0, 0)
+				dec.Reset(f)
 				return nil
 			}
 		} else {