Prechádzať zdrojové kódy

Support reading multiple bytes in escapeProxy

Currently, the escapeProxy works under the assumption that the
underlying reader will always return 1 byte at a time. Even though this
is usually true, it is not always the case, for example when using a pty
and writing multiple bytes to the master before flushing it.

In such cases the proxy reader doesn't work properly. For example with
an escape sequence being `ctrl-p,ctrl-q`, when the underlying reader
returns `ctrl-p,ctrl-q` at once, the escape sequence isn't detected.

This updates the reader to support this use-case and adds unit tests.

Signed-off-by: Bilal Amarni <bilal.amarni@gmail.com>
Bilal Amarni 5 rokov pred
rodič
commit
8dd1490473
2 zmenil súbory, kde vykonal 108 pridanie a 41 odobranie
  1. 43 33
      pkg/term/proxy.go
  2. 65 8
      pkg/term/proxy_test.go

+ 43 - 33
pkg/term/proxy.go

@@ -19,6 +19,7 @@ type escapeProxy struct {
 	escapeKeys   []byte
 	escapeKeyPos int
 	r            io.Reader
+	buf          []byte
 }
 
 // NewEscapeProxy returns a new TTY proxy reader which wraps the given reader
@@ -31,48 +32,57 @@ func NewEscapeProxy(r io.Reader, escapeKeys []byte) io.Reader {
 	}
 }
 
-func (r *escapeProxy) Read(buf []byte) (int, error) {
-	nr, err := r.r.Read(buf)
+func (r *escapeProxy) Read(buf []byte) (n int, err error) {
+	if len(r.escapeKeys) > 0 && r.escapeKeyPos == len(r.escapeKeys) {
+		return 0, EscapeError{}
+	}
 
-	if len(r.escapeKeys) == 0 {
-		return nr, err
+	if len(r.buf) > 0 {
+		n = copy(buf, r.buf)
+		r.buf = r.buf[n:]
 	}
 
-	preserve := func() {
-		// this preserves the original key presses in the passed in buffer
-		nr += r.escapeKeyPos
-		preserve := make([]byte, 0, r.escapeKeyPos+len(buf))
-		preserve = append(preserve, r.escapeKeys[:r.escapeKeyPos]...)
-		preserve = append(preserve, buf...)
-		r.escapeKeyPos = 0
-		copy(buf[0:nr], preserve)
+	nr, err := r.r.Read(buf[n:])
+	n += nr
+	if len(r.escapeKeys) == 0 {
+		return n, err
 	}
 
-	if nr != 1 || err != nil {
-		if r.escapeKeyPos > 0 {
-			preserve()
+	for i := 0; i < n; i++ {
+		if buf[i] == r.escapeKeys[r.escapeKeyPos] {
+			r.escapeKeyPos++
+
+			// Check if the full escape sequence is matched.
+			if r.escapeKeyPos == len(r.escapeKeys) {
+				n = i + 1 - r.escapeKeyPos
+				if n < 0 {
+					n = 0
+				}
+				return n, EscapeError{}
+			}
+			continue
 		}
-		return nr, err
-	}
 
-	if buf[0] != r.escapeKeys[r.escapeKeyPos] {
-		if r.escapeKeyPos > 0 {
-			preserve()
+		// If we need to prepend a partial escape sequence from the previous
+		// read, make sure the new buffer size doesn't exceed len(buf).
+		// Otherwise, preserve any extra data in a buffer for the next read.
+		if i < r.escapeKeyPos {
+			preserve := make([]byte, 0, r.escapeKeyPos+n)
+			preserve = append(preserve, r.escapeKeys[:r.escapeKeyPos]...)
+			preserve = append(preserve, buf[:n]...)
+			n = copy(buf, preserve)
+			i += r.escapeKeyPos
+			r.buf = append(r.buf, preserve[n:]...)
 		}
-		return nr, nil
+		r.escapeKeyPos = 0
 	}
 
-	if r.escapeKeyPos == len(r.escapeKeys)-1 {
-		return 0, EscapeError{}
+	// If we're in the middle of reading an escape sequence, make sure we don't
+	// let the caller read it. If later on we find that this is not the escape
+	// sequence, we'll prepend it back to buf.
+	n -= r.escapeKeyPos
+	if n < 0 {
+		n = 0
 	}
-
-	// Looks like we've got an escape key, but we need to match again on the next
-	// read.
-	// Store the current escape key we found so we can look for the next one on
-	// the next read.
-	// Since this is an escape key, make sure we don't let the caller read it
-	// If later on we find that this is not the escape sequence, we'll add the
-	// keys back
-	r.escapeKeyPos++
-	return nr - r.escapeKeyPos, nil
+	return n, err
 }

+ 65 - 8
pkg/term/proxy_test.go

@@ -9,7 +9,7 @@ import (
 )
 
 func TestEscapeProxyRead(t *testing.T) {
-	t.Run("no escape keys, keys a", func(t *testing.T) {
+	t.Run("no escape keys, keys [a]", func(t *testing.T) {
 		escapeKeys, _ := ToBytes("")
 		keys, _ := ToBytes("a")
 		reader := NewEscapeProxy(bytes.NewReader(keys), escapeKeys)
@@ -21,7 +21,7 @@ func TestEscapeProxyRead(t *testing.T) {
 		assert.DeepEqual(t, keys, buf)
 	})
 
-	t.Run("no escape keys, keys a,b,c", func(t *testing.T) {
+	t.Run("no escape keys, keys [a,b,c]", func(t *testing.T) {
 		escapeKeys, _ := ToBytes("")
 		keys, _ := ToBytes("a,b,c")
 		reader := NewEscapeProxy(bytes.NewReader(keys), escapeKeys)
@@ -46,7 +46,7 @@ func TestEscapeProxyRead(t *testing.T) {
 		assert.Check(t, is.Len(buf, 0))
 	})
 
-	t.Run("DEL escape key, keys a,b,c,+", func(t *testing.T) {
+	t.Run("DEL escape key, keys [a,b,c,+]", func(t *testing.T) {
 		escapeKeys, _ := ToBytes("DEL")
 		keys, _ := ToBytes("a,b,c,+")
 		reader := NewEscapeProxy(bytes.NewReader(keys), escapeKeys)
@@ -71,7 +71,7 @@ func TestEscapeProxyRead(t *testing.T) {
 		assert.Check(t, is.Len(buf, 0))
 	})
 
-	t.Run("ctrl-x,ctrl-@ escape key, keys DEL", func(t *testing.T) {
+	t.Run("ctrl-x,ctrl-@ escape key, keys [DEL]", func(t *testing.T) {
 		escapeKeys, _ := ToBytes("ctrl-x,ctrl-@")
 		keys, _ := ToBytes("DEL")
 		reader := NewEscapeProxy(bytes.NewReader(keys), escapeKeys)
@@ -83,7 +83,7 @@ func TestEscapeProxyRead(t *testing.T) {
 		assert.DeepEqual(t, keys, buf)
 	})
 
-	t.Run("ctrl-c escape key, keys ctrl-c", func(t *testing.T) {
+	t.Run("ctrl-c escape key, keys [ctrl-c]", func(t *testing.T) {
 		escapeKeys, _ := ToBytes("ctrl-c")
 		keys, _ := ToBytes("ctrl-c")
 		reader := NewEscapeProxy(bytes.NewReader(keys), escapeKeys)
@@ -95,7 +95,7 @@ func TestEscapeProxyRead(t *testing.T) {
 		assert.DeepEqual(t, keys, buf)
 	})
 
-	t.Run("ctrl-c,ctrl-z escape key, keys ctrl-c,ctrl-z", func(t *testing.T) {
+	t.Run("ctrl-c,ctrl-z escape key, keys [ctrl-c],[ctrl-z]", func(t *testing.T) {
 		escapeKeys, _ := ToBytes("ctrl-c,ctrl-z")
 		keys, _ := ToBytes("ctrl-c,ctrl-z")
 		reader := NewEscapeProxy(bytes.NewReader(keys), escapeKeys)
@@ -112,7 +112,19 @@ func TestEscapeProxyRead(t *testing.T) {
 		assert.DeepEqual(t, keys[1:], buf)
 	})
 
-	t.Run("ctrl-c,ctrl-z escape key, keys ctrl-c,DEL,+", func(t *testing.T) {
+	t.Run("ctrl-c,ctrl-z escape key, keys [ctrl-c,ctrl-z]", func(t *testing.T) {
+		escapeKeys, _ := ToBytes("ctrl-c,ctrl-z")
+		keys, _ := ToBytes("ctrl-c,ctrl-z")
+		reader := NewEscapeProxy(bytes.NewReader(keys), escapeKeys)
+
+		buf := make([]byte, 2)
+		nr, err := reader.Read(buf)
+		assert.Error(t, err, "read escape sequence")
+		assert.Equal(t, nr, 0, "nr should be equal to 0")
+		assert.DeepEqual(t, keys, buf)
+	})
+
+	t.Run("ctrl-c,ctrl-z escape key, keys [ctrl-c],[DEL,+]", func(t *testing.T) {
 		escapeKeys, _ := ToBytes("ctrl-c,ctrl-z")
 		keys, _ := ToBytes("ctrl-c,DEL,+")
 		reader := NewEscapeProxy(bytes.NewReader(keys), escapeKeys)
@@ -130,7 +142,7 @@ func TestEscapeProxyRead(t *testing.T) {
 		assert.DeepEqual(t, keys, buf)
 	})
 
-	t.Run("ctrl-c,ctrl-z escape key, keys ctrl-c,DEL", func(t *testing.T) {
+	t.Run("ctrl-c,ctrl-z escape key, keys [ctrl-c],[DEL]", func(t *testing.T) {
 		escapeKeys, _ := ToBytes("ctrl-c,ctrl-z")
 		keys, _ := ToBytes("ctrl-c,DEL")
 		reader := NewEscapeProxy(bytes.NewReader(keys), escapeKeys)
@@ -148,4 +160,49 @@ func TestEscapeProxyRead(t *testing.T) {
 		assert.DeepEqual(t, keys, buf)
 	})
 
+	t.Run("a,b,c,d escape key, keys [a,b],[c,d]", func(t *testing.T) {
+		escapeKeys, _ := ToBytes("a,b,c,d")
+		keys, _ := ToBytes("a,b,c,d")
+		reader := NewEscapeProxy(bytes.NewReader(keys), escapeKeys)
+
+		buf := make([]byte, 2)
+		nr, err := reader.Read(buf)
+		assert.NilError(t, err)
+		assert.Equal(t, 0, nr)
+		assert.DeepEqual(t, keys[0:2], buf)
+
+		buf = make([]byte, 2)
+		nr, err = reader.Read(buf)
+		assert.Error(t, err, "read escape sequence")
+		assert.Equal(t, 0, nr)
+		assert.DeepEqual(t, keys[2:4], buf)
+	})
+
+	t.Run("ctrl-p,ctrl-q escape key, keys [ctrl-p],[a],[ctrl-p,ctrl-q]", func(t *testing.T) {
+		escapeKeys, _ := ToBytes("ctrl-p,ctrl-q")
+		keys, _ := ToBytes("ctrl-p,a,ctrl-p,ctrl-q")
+		reader := NewEscapeProxy(bytes.NewReader(keys), escapeKeys)
+
+		buf := make([]byte, 1)
+		nr, err := reader.Read(buf)
+		assert.NilError(t, err)
+		assert.Equal(t, 0, nr)
+
+		buf = make([]byte, 1)
+		nr, err = reader.Read(buf)
+		assert.NilError(t, err)
+		assert.Equal(t, 1, nr)
+		assert.DeepEqual(t, keys[:1], buf)
+
+		buf = make([]byte, 2)
+		nr, err = reader.Read(buf)
+		assert.NilError(t, err)
+		assert.Equal(t, 1, nr)
+		assert.DeepEqual(t, keys[1:3], buf)
+
+		buf = make([]byte, 2)
+		nr, err = reader.Read(buf)
+		assert.Error(t, err, "read escape sequence")
+		assert.Equal(t, 0, nr)
+	})
 }