浏览代码

Improve GetTimestamp parsing

`GetTimestamp()` "assumed" values it could not parse
to be a valid unix timestamp, and would use invalid
values ("hello world") as-is (even testing that
it did so).

This patch validates unix timestamp to be a valid
numeric value, and makes other values invalid.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
Sebastiaan van Stijn 7 年之前
父节点
当前提交
48cfe3f087

+ 9 - 2
api/types/time/timestamp.go

@@ -82,11 +82,14 @@ func GetTimestamp(value string, reference time.Time) (string, error) {
 	}
 
 	if err != nil {
-		// if there is a `-` then it's an RFC3339 like timestamp otherwise assume unixtimestamp
+		// if there is a `-` then it's an RFC3339 like timestamp
 		if strings.Contains(value, "-") {
 			return "", err // was probably an RFC3339 like timestamp but the parser failed with an error
 		}
-		return value, nil // unixtimestamp in and out case (meaning: the value passed at the command line is already in the right format for passing to the server)
+		if _, _, err := parseTimestamp(value); err != nil {
+			return "", fmt.Errorf("failed to parse value as time or duration: %q", value)
+		}
+		return value, nil // unix timestamp in and out case (meaning: the value passed at the command line is already in the right format for passing to the server)
 	}
 
 	return fmt.Sprintf("%d.%09d", t.Unix(), int64(t.Nanosecond())), nil
@@ -104,6 +107,10 @@ func ParseTimestamps(value string, def int64) (int64, int64, error) {
 	if value == "" {
 		return def, 0, nil
 	}
+	return parseTimestamp(value)
+}
+
+func parseTimestamp(value string) (int64, int64, error) {
 	sa := strings.SplitN(value, ".", 2)
 	s, err := strconv.ParseInt(sa[0], 10, 64)
 	if err != nil {

+ 2 - 2
api/types/time/timestamp_test.go

@@ -49,8 +49,8 @@ func TestGetTimestamp(t *testing.T) {
 		{"1.5h", fmt.Sprintf("%d", now.Add(-90*time.Minute).Unix()), false},
 		{"1h30m", fmt.Sprintf("%d", now.Add(-90*time.Minute).Unix()), false},
 
-		// String fallback
-		{"invalid", "invalid", false},
+		{"invalid", "", true},
+		{"", "", true},
 	}
 
 	for _, c := range cases {

+ 3 - 2
client/container_logs.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/docker/docker/api/types"
 	timetypes "github.com/docker/docker/api/types/time"
+	"github.com/pkg/errors"
 )
 
 // ContainerLogs returns the logs generated by a container in an io.ReadCloser.
@@ -45,7 +46,7 @@ func (cli *Client) ContainerLogs(ctx context.Context, container string, options
 	if options.Since != "" {
 		ts, err := timetypes.GetTimestamp(options.Since, time.Now())
 		if err != nil {
-			return nil, err
+			return nil, errors.Wrap(err, `invalid value for "since"`)
 		}
 		query.Set("since", ts)
 	}
@@ -53,7 +54,7 @@ func (cli *Client) ContainerLogs(ctx context.Context, container string, options
 	if options.Until != "" {
 		ts, err := timetypes.GetTimestamp(options.Until, time.Now())
 		if err != nil {
-			return nil, err
+			return nil, errors.Wrap(err, `invalid value for "until"`)
 		}
 		query.Set("until", ts)
 	}

+ 34 - 25
client/container_logs_test.go

@@ -2,6 +2,7 @@ package client // import "github.com/docker/docker/client"
 
 import (
 	"bytes"
+	"context"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -12,10 +13,9 @@ import (
 	"testing"
 	"time"
 
-	"context"
-
 	"github.com/docker/docker/api/types"
-	"github.com/docker/docker/internal/testutil"
+	"github.com/gotestyourself/gotestyourself/assert"
+	is "github.com/gotestyourself/gotestyourself/assert/cmp"
 )
 
 func TestContainerLogsNotFoundError(t *testing.T) {
@@ -33,17 +33,15 @@ func TestContainerLogsError(t *testing.T) {
 		client: newMockClient(errorMock(http.StatusInternalServerError, "Server error")),
 	}
 	_, err := client.ContainerLogs(context.Background(), "container_id", types.ContainerLogsOptions{})
-	if err == nil || err.Error() != "Error response from daemon: Server error" {
-		t.Fatalf("expected a Server Error, got %v", err)
-	}
+	assert.Check(t, is.Error(err, "Error response from daemon: Server error"))
 	_, err = client.ContainerLogs(context.Background(), "container_id", types.ContainerLogsOptions{
 		Since: "2006-01-02TZ",
 	})
-	testutil.ErrorContains(t, err, `parsing time "2006-01-02TZ"`)
+	assert.Check(t, is.ErrorContains(err, `parsing time "2006-01-02TZ"`))
 	_, err = client.ContainerLogs(context.Background(), "container_id", types.ContainerLogsOptions{
 		Until: "2006-01-02TZ",
 	})
-	testutil.ErrorContains(t, err, `parsing time "2006-01-02TZ"`)
+	assert.Check(t, is.ErrorContains(err, `parsing time "2006-01-02TZ"`))
 }
 
 func TestContainerLogs(t *testing.T) {
@@ -51,6 +49,7 @@ func TestContainerLogs(t *testing.T) {
 	cases := []struct {
 		options             types.ContainerLogsOptions
 		expectedQueryParams map[string]string
+		expectedError       string
 	}{
 		{
 			expectedQueryParams: map[string]string{
@@ -84,32 +83,44 @@ func TestContainerLogs(t *testing.T) {
 		},
 		{
 			options: types.ContainerLogsOptions{
-				// An complete invalid date, timestamp or go duration will be
-				// passed as is
-				Since: "invalid but valid",
+				// timestamp will be passed as is
+				Since: "1136073600.000000001",
 			},
 			expectedQueryParams: map[string]string{
 				"tail":  "",
-				"since": "invalid but valid",
+				"since": "1136073600.000000001",
 			},
 		},
 		{
 			options: types.ContainerLogsOptions{
-				// An complete invalid date, timestamp or go duration will be
-				// passed as is
-				Until: "invalid but valid",
+				// timestamp will be passed as is
+				Until: "1136073600.000000001",
 			},
 			expectedQueryParams: map[string]string{
 				"tail":  "",
-				"until": "invalid but valid",
+				"until": "1136073600.000000001",
+			},
+		},
+		{
+			options: types.ContainerLogsOptions{
+				// An complete invalid date will not be passed
+				Since: "invalid value",
 			},
+			expectedError: `invalid value for "since": failed to parse value as time or duration: "invalid value"`,
+		},
+		{
+			options: types.ContainerLogsOptions{
+				// An complete invalid date will not be passed
+				Until: "invalid value",
+			},
+			expectedError: `invalid value for "until": failed to parse value as time or duration: "invalid value"`,
 		},
 	}
 	for _, logCase := range cases {
 		client := &Client{
 			client: newMockClient(func(r *http.Request) (*http.Response, error) {
 				if !strings.HasPrefix(r.URL.Path, expectedURL) {
-					return nil, fmt.Errorf("Expected URL '%s', got '%s'", expectedURL, r.URL)
+					return nil, fmt.Errorf("expected URL '%s', got '%s'", expectedURL, r.URL)
 				}
 				// Check query parameters
 				query := r.URL.Query()
@@ -126,17 +137,15 @@ func TestContainerLogs(t *testing.T) {
 			}),
 		}
 		body, err := client.ContainerLogs(context.Background(), "container_id", logCase.options)
-		if err != nil {
-			t.Fatal(err)
+		if logCase.expectedError != "" {
+			assert.Check(t, is.Error(err, logCase.expectedError))
+			continue
 		}
+		assert.NilError(t, err)
 		defer body.Close()
 		content, err := ioutil.ReadAll(body)
-		if err != nil {
-			t.Fatal(err)
-		}
-		if string(content) != "response" {
-			t.Fatalf("expected response to contain 'response', got %s", string(content))
-		}
+		assert.NilError(t, err)
+		assert.Check(t, is.Contains(string(content), "response"))
 	}
 }
 

+ 2 - 1
client/service_logs.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/docker/docker/api/types"
 	timetypes "github.com/docker/docker/api/types/time"
+	"github.com/pkg/errors"
 )
 
 // ServiceLogs returns the logs generated by a service in an io.ReadCloser.
@@ -25,7 +26,7 @@ func (cli *Client) ServiceLogs(ctx context.Context, serviceID string, options ty
 	if options.Since != "" {
 		ts, err := timetypes.GetTimestamp(options.Since, time.Now())
 		if err != nil {
-			return nil, err
+			return nil, errors.Wrap(err, `invalid value for "since"`)
 		}
 		query.Set("since", ts)
 	}

+ 23 - 21
client/service_logs_test.go

@@ -2,6 +2,7 @@ package client // import "github.com/docker/docker/client"
 
 import (
 	"bytes"
+	"context"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -12,9 +13,9 @@ import (
 	"testing"
 	"time"
 
-	"context"
-
 	"github.com/docker/docker/api/types"
+	"github.com/gotestyourself/gotestyourself/assert"
+	is "github.com/gotestyourself/gotestyourself/assert/cmp"
 )
 
 func TestServiceLogsError(t *testing.T) {
@@ -22,15 +23,11 @@ func TestServiceLogsError(t *testing.T) {
 		client: newMockClient(errorMock(http.StatusInternalServerError, "Server error")),
 	}
 	_, err := client.ServiceLogs(context.Background(), "service_id", types.ContainerLogsOptions{})
-	if err == nil || err.Error() != "Error response from daemon: Server error" {
-		t.Fatalf("expected a Server Error, got %v", err)
-	}
+	assert.Check(t, is.Error(err, "Error response from daemon: Server error"))
 	_, err = client.ServiceLogs(context.Background(), "service_id", types.ContainerLogsOptions{
 		Since: "2006-01-02TZ",
 	})
-	if err == nil || !strings.Contains(err.Error(), `parsing time "2006-01-02TZ"`) {
-		t.Fatalf("expected a 'parsing time' error, got %v", err)
-	}
+	assert.Check(t, is.ErrorContains(err, `parsing time "2006-01-02TZ"`))
 }
 
 func TestServiceLogs(t *testing.T) {
@@ -38,6 +35,7 @@ func TestServiceLogs(t *testing.T) {
 	cases := []struct {
 		options             types.ContainerLogsOptions
 		expectedQueryParams map[string]string
+		expectedError       string
 	}{
 		{
 			expectedQueryParams: map[string]string{
@@ -71,21 +69,27 @@ func TestServiceLogs(t *testing.T) {
 		},
 		{
 			options: types.ContainerLogsOptions{
-				// An complete invalid date, timestamp or go duration will be
-				// passed as is
-				Since: "invalid but valid",
+				// timestamp will be passed as is
+				Since: "1136073600.000000001",
 			},
 			expectedQueryParams: map[string]string{
 				"tail":  "",
-				"since": "invalid but valid",
+				"since": "1136073600.000000001",
 			},
 		},
+		{
+			options: types.ContainerLogsOptions{
+				// An complete invalid date will not be passed
+				Since: "invalid value",
+			},
+			expectedError: `invalid value for "since": failed to parse value as time or duration: "invalid value"`,
+		},
 	}
 	for _, logCase := range cases {
 		client := &Client{
 			client: newMockClient(func(r *http.Request) (*http.Response, error) {
 				if !strings.HasPrefix(r.URL.Path, expectedURL) {
-					return nil, fmt.Errorf("Expected URL '%s', got '%s'", expectedURL, r.URL)
+					return nil, fmt.Errorf("expected URL '%s', got '%s'", expectedURL, r.URL)
 				}
 				// Check query parameters
 				query := r.URL.Query()
@@ -102,17 +106,15 @@ func TestServiceLogs(t *testing.T) {
 			}),
 		}
 		body, err := client.ServiceLogs(context.Background(), "service_id", logCase.options)
-		if err != nil {
-			t.Fatal(err)
+		if logCase.expectedError != "" {
+			assert.Check(t, is.Error(err, logCase.expectedError))
+			continue
 		}
+		assert.NilError(t, err)
 		defer body.Close()
 		content, err := ioutil.ReadAll(body)
-		if err != nil {
-			t.Fatal(err)
-		}
-		if string(content) != "response" {
-			t.Fatalf("expected response to contain 'response', got %s", string(content))
-		}
+		assert.NilError(t, err)
+		assert.Check(t, is.Contains(string(content), "response"))
 	}
 }