Vendor distribution v2 api

Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)
This commit is contained in:
Derek McGowan 2015-03-17 11:03:56 -07:00
parent ded5c73deb
commit f011d722ce
16 changed files with 2970 additions and 92 deletions

View file

@ -43,7 +43,7 @@ clone git github.com/kr/pty 05017fcccf
clone git github.com/gorilla/context 14f550f51a
clone git github.com/gorilla/mux 136d54f81f
clone git github.com/gorilla/mux e444e69cbd
clone git github.com/tchap/go-patricia v1.0.1
@ -68,12 +68,15 @@ if [ "$1" = '--go' ]; then
mv tmp-tar src/code.google.com/p/go/src/pkg/archive/tar
fi
# get digest package from distribution
# get distribution packages
clone git github.com/docker/distribution d957768537c5af40e4f4cd96871f7b2bde9e2923
mv src/github.com/docker/distribution/digest tmp-digest
mv src/github.com/docker/distribution/registry/api tmp-api
rm -rf src/github.com/docker/distribution
mkdir -p src/github.com/docker/distribution
mv tmp-digest src/github.com/docker/distribution/digest
mkdir -p src/github.com/docker/distribution/registry
mv tmp-api src/github.com/docker/distribution/registry/api
clone git github.com/docker/libcontainer c8512754166539461fd860451ff1a0af7491c197
# see src/github.com/docker/libcontainer/update-vendor.sh which is the "source of truth" for libcontainer deps (just like this file)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,9 @@
// Package v2 describes routes, urls and the error codes used in the Docker
// Registry JSON HTTP API V2. In addition to declarations, descriptors are
// provided for routes and error codes that can be used for implementation and
// automatically generating documentation.
//
// Definitions here are considered to be locked down for the V2 registry api.
// Any changes must be considered carefully and should not proceed without a
// change proposal in docker core.
package v2

View file

@ -0,0 +1,194 @@
package v2
import (
"fmt"
"strings"
)
// ErrorCode represents the error type. The errors are serialized via strings
// and the integer format may change and should *never* be exported.
type ErrorCode int
const (
// ErrorCodeUnknown is a catch-all for errors not defined below.
ErrorCodeUnknown ErrorCode = iota
// ErrorCodeUnsupported is returned when an operation is not supported.
ErrorCodeUnsupported
// ErrorCodeUnauthorized is returned if a request is not authorized.
ErrorCodeUnauthorized
// ErrorCodeDigestInvalid is returned when uploading a blob if the
// provided digest does not match the blob contents.
ErrorCodeDigestInvalid
// ErrorCodeSizeInvalid is returned when uploading a blob if the provided
// size does not match the content length.
ErrorCodeSizeInvalid
// ErrorCodeNameInvalid is returned when the name in the manifest does not
// match the provided name.
ErrorCodeNameInvalid
// ErrorCodeTagInvalid is returned when the tag in the manifest does not
// match the provided tag.
ErrorCodeTagInvalid
// ErrorCodeNameUnknown when the repository name is not known.
ErrorCodeNameUnknown
// ErrorCodeManifestUnknown returned when image manifest is unknown.
ErrorCodeManifestUnknown
// ErrorCodeManifestInvalid returned when an image manifest is invalid,
// typically during a PUT operation. This error encompasses all errors
// encountered during manifest validation that aren't signature errors.
ErrorCodeManifestInvalid
// ErrorCodeManifestUnverified is returned when the manifest fails
// signature verfication.
ErrorCodeManifestUnverified
// ErrorCodeBlobUnknown is returned when a blob is unknown to the
// registry. This can happen when the manifest references a nonexistent
// layer or the result is not found by a blob fetch.
ErrorCodeBlobUnknown
// ErrorCodeBlobUploadUnknown is returned when an upload is unknown.
ErrorCodeBlobUploadUnknown
// ErrorCodeBlobUploadInvalid is returned when an upload is invalid.
ErrorCodeBlobUploadInvalid
)
// ParseErrorCode attempts to parse the error code string, returning
// ErrorCodeUnknown if the error is not known.
func ParseErrorCode(s string) ErrorCode {
desc, ok := idToDescriptors[s]
if !ok {
return ErrorCodeUnknown
}
return desc.Code
}
// Descriptor returns the descriptor for the error code.
func (ec ErrorCode) Descriptor() ErrorDescriptor {
d, ok := errorCodeToDescriptors[ec]
if !ok {
return ErrorCodeUnknown.Descriptor()
}
return d
}
// String returns the canonical identifier for this error code.
func (ec ErrorCode) String() string {
return ec.Descriptor().Value
}
// Message returned the human-readable error message for this error code.
func (ec ErrorCode) Message() string {
return ec.Descriptor().Message
}
// MarshalText encodes the receiver into UTF-8-encoded text and returns the
// result.
func (ec ErrorCode) MarshalText() (text []byte, err error) {
return []byte(ec.String()), nil
}
// UnmarshalText decodes the form generated by MarshalText.
func (ec *ErrorCode) UnmarshalText(text []byte) error {
desc, ok := idToDescriptors[string(text)]
if !ok {
desc = ErrorCodeUnknown.Descriptor()
}
*ec = desc.Code
return nil
}
// Error provides a wrapper around ErrorCode with extra Details provided.
type Error struct {
Code ErrorCode `json:"code"`
Message string `json:"message,omitempty"`
Detail interface{} `json:"detail,omitempty"`
}
// Error returns a human readable representation of the error.
func (e Error) Error() string {
return fmt.Sprintf("%s: %s",
strings.ToLower(strings.Replace(e.Code.String(), "_", " ", -1)),
e.Message)
}
// Errors provides the envelope for multiple errors and a few sugar methods
// for use within the application.
type Errors struct {
Errors []Error `json:"errors,omitempty"`
}
// Push pushes an error on to the error stack, with the optional detail
// argument. It is a programming error (ie panic) to push more than one
// detail at a time.
func (errs *Errors) Push(code ErrorCode, details ...interface{}) {
if len(details) > 1 {
panic("please specify zero or one detail items for this error")
}
var detail interface{}
if len(details) > 0 {
detail = details[0]
}
if err, ok := detail.(error); ok {
detail = err.Error()
}
errs.PushErr(Error{
Code: code,
Message: code.Message(),
Detail: detail,
})
}
// PushErr pushes an error interface onto the error stack.
func (errs *Errors) PushErr(err error) {
switch err.(type) {
case Error:
errs.Errors = append(errs.Errors, err.(Error))
default:
errs.Errors = append(errs.Errors, Error{Message: err.Error()})
}
}
func (errs *Errors) Error() string {
switch errs.Len() {
case 0:
return "<nil>"
case 1:
return errs.Errors[0].Error()
default:
msg := "errors:\n"
for _, err := range errs.Errors {
msg += err.Error() + "\n"
}
return msg
}
}
// Clear clears the errors.
func (errs *Errors) Clear() {
errs.Errors = errs.Errors[:0]
}
// Len returns the current number of errors.
func (errs *Errors) Len() int {
return len(errs.Errors)
}

View file

@ -0,0 +1,165 @@
package v2
import (
"encoding/json"
"reflect"
"testing"
"github.com/docker/distribution/digest"
)
// TestErrorCodes ensures that error code format, mappings and
// marshaling/unmarshaling. round trips are stable.
func TestErrorCodes(t *testing.T) {
for _, desc := range errorDescriptors {
if desc.Code.String() != desc.Value {
t.Fatalf("error code string incorrect: %q != %q", desc.Code.String(), desc.Value)
}
if desc.Code.Message() != desc.Message {
t.Fatalf("incorrect message for error code %v: %q != %q", desc.Code, desc.Code.Message(), desc.Message)
}
// Serialize the error code using the json library to ensure that we
// get a string and it works round trip.
p, err := json.Marshal(desc.Code)
if err != nil {
t.Fatalf("error marshaling error code %v: %v", desc.Code, err)
}
if len(p) <= 0 {
t.Fatalf("expected content in marshaled before for error code %v", desc.Code)
}
// First, unmarshal to interface and ensure we have a string.
var ecUnspecified interface{}
if err := json.Unmarshal(p, &ecUnspecified); err != nil {
t.Fatalf("error unmarshaling error code %v: %v", desc.Code, err)
}
if _, ok := ecUnspecified.(string); !ok {
t.Fatalf("expected a string for error code %v on unmarshal got a %T", desc.Code, ecUnspecified)
}
// Now, unmarshal with the error code type and ensure they are equal
var ecUnmarshaled ErrorCode
if err := json.Unmarshal(p, &ecUnmarshaled); err != nil {
t.Fatalf("error unmarshaling error code %v: %v", desc.Code, err)
}
if ecUnmarshaled != desc.Code {
t.Fatalf("unexpected error code during error code marshal/unmarshal: %v != %v", ecUnmarshaled, desc.Code)
}
}
}
// TestErrorsManagement does a quick check of the Errors type to ensure that
// members are properly pushed and marshaled.
func TestErrorsManagement(t *testing.T) {
var errs Errors
errs.Push(ErrorCodeDigestInvalid)
errs.Push(ErrorCodeBlobUnknown,
map[string]digest.Digest{"digest": "sometestblobsumdoesntmatter"})
p, err := json.Marshal(errs)
if err != nil {
t.Fatalf("error marashaling errors: %v", err)
}
expectedJSON := "{\"errors\":[{\"code\":\"DIGEST_INVALID\",\"message\":\"provided digest did not match uploaded content\"},{\"code\":\"BLOB_UNKNOWN\",\"message\":\"blob unknown to registry\",\"detail\":{\"digest\":\"sometestblobsumdoesntmatter\"}}]}"
if string(p) != expectedJSON {
t.Fatalf("unexpected json: %q != %q", string(p), expectedJSON)
}
errs.Clear()
errs.Push(ErrorCodeUnknown)
expectedJSON = "{\"errors\":[{\"code\":\"UNKNOWN\",\"message\":\"unknown error\"}]}"
p, err = json.Marshal(errs)
if err != nil {
t.Fatalf("error marashaling errors: %v", err)
}
if string(p) != expectedJSON {
t.Fatalf("unexpected json: %q != %q", string(p), expectedJSON)
}
}
// TestMarshalUnmarshal ensures that api errors can round trip through json
// without losing information.
func TestMarshalUnmarshal(t *testing.T) {
var errors Errors
for _, testcase := range []struct {
description string
err Error
}{
{
description: "unknown error",
err: Error{
Code: ErrorCodeUnknown,
Message: ErrorCodeUnknown.Descriptor().Message,
},
},
{
description: "unknown manifest",
err: Error{
Code: ErrorCodeManifestUnknown,
Message: ErrorCodeManifestUnknown.Descriptor().Message,
},
},
{
description: "unknown manifest",
err: Error{
Code: ErrorCodeBlobUnknown,
Message: ErrorCodeBlobUnknown.Descriptor().Message,
Detail: map[string]interface{}{"digest": "asdfqwerqwerqwerqwer"},
},
},
} {
fatalf := func(format string, args ...interface{}) {
t.Fatalf(testcase.description+": "+format, args...)
}
unexpectedErr := func(err error) {
fatalf("unexpected error: %v", err)
}
p, err := json.Marshal(testcase.err)
if err != nil {
unexpectedErr(err)
}
var unmarshaled Error
if err := json.Unmarshal(p, &unmarshaled); err != nil {
unexpectedErr(err)
}
if !reflect.DeepEqual(unmarshaled, testcase.err) {
fatalf("errors not equal after round trip: %#v != %#v", unmarshaled, testcase.err)
}
// Roll everything up into an error response envelope.
errors.PushErr(testcase.err)
}
p, err := json.Marshal(errors)
if err != nil {
t.Fatalf("unexpected error marshaling error envelope: %v", err)
}
var unmarshaled Errors
if err := json.Unmarshal(p, &unmarshaled); err != nil {
t.Fatalf("unexpected error unmarshaling error envelope: %v", err)
}
if !reflect.DeepEqual(unmarshaled, errors) {
t.Fatalf("errors not equal after round trip: %#v != %#v", unmarshaled, errors)
}
}

View file

@ -0,0 +1,100 @@
package v2
import (
"fmt"
"regexp"
"strings"
)
// TODO(stevvooe): Move these definitions back to an exported package. While
// they are used with v2 definitions, their relevance expands beyond.
// "distribution/names" is a candidate package.
const (
// RepositoryNameComponentMinLength is the minimum number of characters in a
// single repository name slash-delimited component
RepositoryNameComponentMinLength = 2
// RepositoryNameMinComponents is the minimum number of slash-delimited
// components that a repository name must have
RepositoryNameMinComponents = 1
// RepositoryNameTotalLengthMax is the maximum total number of characters in
// a repository name
RepositoryNameTotalLengthMax = 255
)
// RepositoryNameComponentRegexp restricts registry path component names to
// start with at least one letter or number, with following parts able to
// be separated by one period, dash or underscore.
var RepositoryNameComponentRegexp = regexp.MustCompile(`[a-z0-9]+(?:[._-][a-z0-9]+)*`)
// RepositoryNameComponentAnchoredRegexp is the version of
// RepositoryNameComponentRegexp which must completely match the content
var RepositoryNameComponentAnchoredRegexp = regexp.MustCompile(`^` + RepositoryNameComponentRegexp.String() + `$`)
// RepositoryNameRegexp builds on RepositoryNameComponentRegexp to allow
// multiple path components, separated by a forward slash.
var RepositoryNameRegexp = regexp.MustCompile(`(?:` + RepositoryNameComponentRegexp.String() + `/)*` + RepositoryNameComponentRegexp.String())
// TagNameRegexp matches valid tag names. From docker/docker:graph/tags.go.
var TagNameRegexp = regexp.MustCompile(`[\w][\w.-]{0,127}`)
// TODO(stevvooe): Contribute these exports back to core, so they are shared.
var (
// ErrRepositoryNameComponentShort is returned when a repository name
// contains a component which is shorter than
// RepositoryNameComponentMinLength
ErrRepositoryNameComponentShort = fmt.Errorf("respository name component must be %v or more characters", RepositoryNameComponentMinLength)
// ErrRepositoryNameMissingComponents is returned when a repository name
// contains fewer than RepositoryNameMinComponents components
ErrRepositoryNameMissingComponents = fmt.Errorf("repository name must have at least %v components", RepositoryNameMinComponents)
// ErrRepositoryNameLong is returned when a repository name is longer than
// RepositoryNameTotalLengthMax
ErrRepositoryNameLong = fmt.Errorf("repository name must not be more than %v characters", RepositoryNameTotalLengthMax)
// ErrRepositoryNameComponentInvalid is returned when a repository name does
// not match RepositoryNameComponentRegexp
ErrRepositoryNameComponentInvalid = fmt.Errorf("repository name component must match %q", RepositoryNameComponentRegexp.String())
)
// ValidateRespositoryName ensures the repository name is valid for use in the
// registry. This function accepts a superset of what might be accepted by
// docker core or docker hub. If the name does not pass validation, an error,
// describing the conditions, is returned.
//
// Effectively, the name should comply with the following grammar:
//
// alpha-numeric := /[a-z0-9]+/
// separator := /[._-]/
// component := alpha-numeric [separator alpha-numeric]*
// namespace := component ['/' component]*
//
// The result of the production, known as the "namespace", should be limited
// to 255 characters.
func ValidateRespositoryName(name string) error {
if len(name) > RepositoryNameTotalLengthMax {
return ErrRepositoryNameLong
}
components := strings.Split(name, "/")
if len(components) < RepositoryNameMinComponents {
return ErrRepositoryNameMissingComponents
}
for _, component := range components {
if len(component) < RepositoryNameComponentMinLength {
return ErrRepositoryNameComponentShort
}
if !RepositoryNameComponentAnchoredRegexp.MatchString(component) {
return ErrRepositoryNameComponentInvalid
}
}
return nil
}

View file

@ -0,0 +1,100 @@
package v2
import (
"strings"
"testing"
)
func TestRepositoryNameRegexp(t *testing.T) {
for _, testcase := range []struct {
input string
err error
}{
{
input: "short",
},
{
input: "simple/name",
},
{
input: "library/ubuntu",
},
{
input: "docker/stevvooe/app",
},
{
input: "aa/aa/aa/aa/aa/aa/aa/aa/aa/bb/bb/bb/bb/bb/bb",
},
{
input: "aa/aa/bb/bb/bb",
},
{
input: "a/a/a/b/b",
err: ErrRepositoryNameComponentShort,
},
{
input: "a/a/a/a/",
err: ErrRepositoryNameComponentShort,
},
{
input: "foo.com/bar/baz",
},
{
input: "blog.foo.com/bar/baz",
},
{
input: "asdf",
},
{
input: "asdf$$^/aa",
err: ErrRepositoryNameComponentInvalid,
},
{
input: "aa-a/aa",
},
{
input: "aa/aa",
},
{
input: "a-a/a-a",
},
{
input: "a",
err: ErrRepositoryNameComponentShort,
},
{
input: "a-/a/a/a",
err: ErrRepositoryNameComponentInvalid,
},
{
input: strings.Repeat("a", 255),
},
{
input: strings.Repeat("a", 256),
err: ErrRepositoryNameLong,
},
} {
failf := func(format string, v ...interface{}) {
t.Logf(testcase.input+": "+format, v...)
t.Fail()
}
if err := ValidateRespositoryName(testcase.input); err != testcase.err {
if testcase.err != nil {
if err != nil {
failf("unexpected error for invalid repository: got %v, expected %v", err, testcase.err)
} else {
failf("expected invalid repository: %v", testcase.err)
}
} else {
if err != nil {
// Wrong error returned.
failf("unexpected error validating repository name: %v, expected %v", err, testcase.err)
} else {
failf("unexpected error validating repository name: %v", err)
}
}
}
}
}

View file

@ -0,0 +1,47 @@
package v2
import "github.com/gorilla/mux"
// The following are definitions of the name under which all V2 routes are
// registered. These symbols can be used to look up a route based on the name.
const (
RouteNameBase = "base"
RouteNameManifest = "manifest"
RouteNameTags = "tags"
RouteNameBlob = "blob"
RouteNameBlobUpload = "blob-upload"
RouteNameBlobUploadChunk = "blob-upload-chunk"
)
var allEndpoints = []string{
RouteNameManifest,
RouteNameTags,
RouteNameBlob,
RouteNameBlobUpload,
RouteNameBlobUploadChunk,
}
// Router builds a gorilla router with named routes for the various API
// methods. This can be used directly by both server implementations and
// clients.
func Router() *mux.Router {
return RouterWithPrefix("")
}
// RouterWithPrefix builds a gorilla router with a configured prefix
// on all routes.
func RouterWithPrefix(prefix string) *mux.Router {
rootRouter := mux.NewRouter()
router := rootRouter
if prefix != "" {
router = router.PathPrefix(prefix).Subrouter()
}
router.StrictSlash(true)
for _, descriptor := range routeDescriptors {
router.Path(descriptor.Path).Name(descriptor.Name)
}
return rootRouter
}

View file

@ -0,0 +1,315 @@
package v2
import (
"encoding/json"
"fmt"
"math/rand"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/gorilla/mux"
)
type routeTestCase struct {
RequestURI string
ExpectedURI string
Vars map[string]string
RouteName string
StatusCode int
}
// TestRouter registers a test handler with all the routes and ensures that
// each route returns the expected path variables. Not method verification is
// present. This not meant to be exhaustive but as check to ensure that the
// expected variables are extracted.
//
// This may go away as the application structure comes together.
func TestRouter(t *testing.T) {
testCases := []routeTestCase{
{
RouteName: RouteNameBase,
RequestURI: "/v2/",
Vars: map[string]string{},
},
{
RouteName: RouteNameManifest,
RequestURI: "/v2/foo/manifests/bar",
Vars: map[string]string{
"name": "foo",
"reference": "bar",
},
},
{
RouteName: RouteNameManifest,
RequestURI: "/v2/foo/bar/manifests/tag",
Vars: map[string]string{
"name": "foo/bar",
"reference": "tag",
},
},
{
RouteName: RouteNameManifest,
RequestURI: "/v2/foo/bar/manifests/sha256:abcdef01234567890",
Vars: map[string]string{
"name": "foo/bar",
"reference": "sha256:abcdef01234567890",
},
},
{
RouteName: RouteNameTags,
RequestURI: "/v2/foo/bar/tags/list",
Vars: map[string]string{
"name": "foo/bar",
},
},
{
RouteName: RouteNameBlob,
RequestURI: "/v2/foo/bar/blobs/tarsum.dev+foo:abcdef0919234",
Vars: map[string]string{
"name": "foo/bar",
"digest": "tarsum.dev+foo:abcdef0919234",
},
},
{
RouteName: RouteNameBlob,
RequestURI: "/v2/foo/bar/blobs/sha256:abcdef0919234",
Vars: map[string]string{
"name": "foo/bar",
"digest": "sha256:abcdef0919234",
},
},
{
RouteName: RouteNameBlobUpload,
RequestURI: "/v2/foo/bar/blobs/uploads/",
Vars: map[string]string{
"name": "foo/bar",
},
},
{
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/bar/blobs/uploads/uuid",
Vars: map[string]string{
"name": "foo/bar",
"uuid": "uuid",
},
},
{
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/bar/blobs/uploads/D95306FA-FAD3-4E36-8D41-CF1C93EF8286",
Vars: map[string]string{
"name": "foo/bar",
"uuid": "D95306FA-FAD3-4E36-8D41-CF1C93EF8286",
},
},
{
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/bar/blobs/uploads/RDk1MzA2RkEtRkFEMy00RTM2LThENDEtQ0YxQzkzRUY4Mjg2IA==",
Vars: map[string]string{
"name": "foo/bar",
"uuid": "RDk1MzA2RkEtRkFEMy00RTM2LThENDEtQ0YxQzkzRUY4Mjg2IA==",
},
},
{
// Check ambiguity: ensure we can distinguish between tags for
// "foo/bar/image/image" and image for "foo/bar/image" with tag
// "tags"
RouteName: RouteNameManifest,
RequestURI: "/v2/foo/bar/manifests/manifests/tags",
Vars: map[string]string{
"name": "foo/bar/manifests",
"reference": "tags",
},
},
{
// This case presents an ambiguity between foo/bar with tag="tags"
// and list tags for "foo/bar/manifest"
RouteName: RouteNameTags,
RequestURI: "/v2/foo/bar/manifests/tags/list",
Vars: map[string]string{
"name": "foo/bar/manifests",
},
},
}
checkTestRouter(t, testCases, "", true)
checkTestRouter(t, testCases, "/prefix/", true)
}
func TestRouterWithPathTraversals(t *testing.T) {
testCases := []routeTestCase{
{
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/../../blob/uploads/D95306FA-FAD3-4E36-8D41-CF1C93EF8286",
ExpectedURI: "/blob/uploads/D95306FA-FAD3-4E36-8D41-CF1C93EF8286",
StatusCode: http.StatusNotFound,
},
{
// Testing for path traversal attack handling
RouteName: RouteNameTags,
RequestURI: "/v2/foo/../bar/baz/tags/list",
ExpectedURI: "/v2/bar/baz/tags/list",
Vars: map[string]string{
"name": "bar/baz",
},
},
}
checkTestRouter(t, testCases, "", false)
}
func TestRouterWithBadCharacters(t *testing.T) {
if testing.Short() {
testCases := []routeTestCase{
{
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/blob/uploads/不95306FA-FAD3-4E36-8D41-CF1C93EF8286",
StatusCode: http.StatusNotFound,
},
{
// Testing for path traversal attack handling
RouteName: RouteNameTags,
RequestURI: "/v2/foo/不bar/tags/list",
StatusCode: http.StatusNotFound,
},
}
checkTestRouter(t, testCases, "", true)
} else {
// in the long version we're going to fuzz the router
// with random UTF8 characters not in the 128 bit ASCII range.
// These are not valid characters for the router and we expect
// 404s on every test.
rand.Seed(time.Now().UTC().UnixNano())
testCases := make([]routeTestCase, 1000)
for idx := range testCases {
testCases[idx] = routeTestCase{
RouteName: RouteNameTags,
RequestURI: fmt.Sprintf("/v2/%v/%v/tags/list", randomString(10), randomString(10)),
StatusCode: http.StatusNotFound,
}
}
checkTestRouter(t, testCases, "", true)
}
}
func checkTestRouter(t *testing.T, testCases []routeTestCase, prefix string, deeplyEqual bool) {
router := RouterWithPrefix(prefix)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
testCase := routeTestCase{
RequestURI: r.RequestURI,
Vars: mux.Vars(r),
RouteName: mux.CurrentRoute(r).GetName(),
}
enc := json.NewEncoder(w)
if err := enc.Encode(testCase); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})
// Startup test server
server := httptest.NewServer(router)
for _, testcase := range testCases {
testcase.RequestURI = strings.TrimSuffix(prefix, "/") + testcase.RequestURI
// Register the endpoint
route := router.GetRoute(testcase.RouteName)
if route == nil {
t.Fatalf("route for name %q not found", testcase.RouteName)
}
route.Handler(testHandler)
u := server.URL + testcase.RequestURI
resp, err := http.Get(u)
if err != nil {
t.Fatalf("error issuing get request: %v", err)
}
if testcase.StatusCode == 0 {
// Override default, zero-value
testcase.StatusCode = http.StatusOK
}
if testcase.ExpectedURI == "" {
// Override default, zero-value
testcase.ExpectedURI = testcase.RequestURI
}
if resp.StatusCode != testcase.StatusCode {
t.Fatalf("unexpected status for %s: %v %v", u, resp.Status, resp.StatusCode)
}
if testcase.StatusCode != http.StatusOK {
// We don't care about json response.
continue
}
dec := json.NewDecoder(resp.Body)
var actualRouteInfo routeTestCase
if err := dec.Decode(&actualRouteInfo); err != nil {
t.Fatalf("error reading json response: %v", err)
}
// Needs to be set out of band
actualRouteInfo.StatusCode = resp.StatusCode
if actualRouteInfo.RequestURI != testcase.ExpectedURI {
t.Fatalf("URI %v incorrectly parsed, expected %v", actualRouteInfo.RequestURI, testcase.ExpectedURI)
}
if actualRouteInfo.RouteName != testcase.RouteName {
t.Fatalf("incorrect route %q matched, expected %q", actualRouteInfo.RouteName, testcase.RouteName)
}
// when testing deep equality, the actualRouteInfo has an empty ExpectedURI, we don't want
// that to make the comparison fail. We're otherwise done with the testcase so empty the
// testcase.ExpectedURI
testcase.ExpectedURI = ""
if deeplyEqual && !reflect.DeepEqual(actualRouteInfo, testcase) {
t.Fatalf("actual does not equal expected: %#v != %#v", actualRouteInfo, testcase)
}
}
}
// -------------- START LICENSED CODE --------------
// The following code is derivative of https://github.com/google/gofuzz
// gofuzz is licensed under the Apache License, Version 2.0, January 2004,
// a copy of which can be found in the LICENSE file at the root of this
// repository.
// These functions allow us to generate strings containing only multibyte
// characters that are invalid in our URLs. They are used above for fuzzing
// to ensure we always get 404s on these invalid strings
type charRange struct {
first, last rune
}
// choose returns a random unicode character from the given range, using the
// given randomness source.
func (r *charRange) choose() rune {
count := int64(r.last - r.first)
return r.first + rune(rand.Int63n(count))
}
var unicodeRanges = []charRange{
{'\u00a0', '\u02af'}, // Multi-byte encoded characters
{'\u4e00', '\u9fff'}, // Common CJK (even longer encodings)
}
func randomString(length int) string {
runes := make([]rune, length)
for i := range runes {
runes[i] = unicodeRanges[rand.Intn(len(unicodeRanges))].choose()
}
return string(runes)
}
// -------------- END LICENSED CODE --------------

View file

@ -0,0 +1,217 @@
package v2
import (
"net/http"
"net/url"
"strings"
"github.com/docker/distribution/digest"
"github.com/gorilla/mux"
)
// URLBuilder creates registry API urls from a single base endpoint. It can be
// used to create urls for use in a registry client or server.
//
// All urls will be created from the given base, including the api version.
// For example, if a root of "/foo/" is provided, urls generated will be fall
// under "/foo/v2/...". Most application will only provide a schema, host and
// port, such as "https://localhost:5000/".
type URLBuilder struct {
root *url.URL // url root (ie http://localhost/)
router *mux.Router
}
// NewURLBuilder creates a URLBuilder with provided root url object.
func NewURLBuilder(root *url.URL) *URLBuilder {
return &URLBuilder{
root: root,
router: Router(),
}
}
// NewURLBuilderFromString workes identically to NewURLBuilder except it takes
// a string argument for the root, returning an error if it is not a valid
// url.
func NewURLBuilderFromString(root string) (*URLBuilder, error) {
u, err := url.Parse(root)
if err != nil {
return nil, err
}
return NewURLBuilder(u), nil
}
// NewURLBuilderFromRequest uses information from an *http.Request to
// construct the root url.
func NewURLBuilderFromRequest(r *http.Request) *URLBuilder {
var scheme string
forwardedProto := r.Header.Get("X-Forwarded-Proto")
switch {
case len(forwardedProto) > 0:
scheme = forwardedProto
case r.TLS != nil:
scheme = "https"
case len(r.URL.Scheme) > 0:
scheme = r.URL.Scheme
default:
scheme = "http"
}
host := r.Host
forwardedHost := r.Header.Get("X-Forwarded-Host")
if len(forwardedHost) > 0 {
host = forwardedHost
}
basePath := routeDescriptorsMap[RouteNameBase].Path
requestPath := r.URL.Path
index := strings.Index(requestPath, basePath)
u := &url.URL{
Scheme: scheme,
Host: host,
}
if index > 0 {
// N.B. index+1 is important because we want to include the trailing /
u.Path = requestPath[0 : index+1]
}
return NewURLBuilder(u)
}
// BuildBaseURL constructs a base url for the API, typically just "/v2/".
func (ub *URLBuilder) BuildBaseURL() (string, error) {
route := ub.cloneRoute(RouteNameBase)
baseURL, err := route.URL()
if err != nil {
return "", err
}
return baseURL.String(), nil
}
// BuildTagsURL constructs a url to list the tags in the named repository.
func (ub *URLBuilder) BuildTagsURL(name string) (string, error) {
route := ub.cloneRoute(RouteNameTags)
tagsURL, err := route.URL("name", name)
if err != nil {
return "", err
}
return tagsURL.String(), nil
}
// BuildManifestURL constructs a url for the manifest identified by name and
// reference. The argument reference may be either a tag or digest.
func (ub *URLBuilder) BuildManifestURL(name, reference string) (string, error) {
route := ub.cloneRoute(RouteNameManifest)
manifestURL, err := route.URL("name", name, "reference", reference)
if err != nil {
return "", err
}
return manifestURL.String(), nil
}
// BuildBlobURL constructs the url for the blob identified by name and dgst.
func (ub *URLBuilder) BuildBlobURL(name string, dgst digest.Digest) (string, error) {
route := ub.cloneRoute(RouteNameBlob)
layerURL, err := route.URL("name", name, "digest", dgst.String())
if err != nil {
return "", err
}
return layerURL.String(), nil
}
// BuildBlobUploadURL constructs a url to begin a blob upload in the
// repository identified by name.
func (ub *URLBuilder) BuildBlobUploadURL(name string, values ...url.Values) (string, error) {
route := ub.cloneRoute(RouteNameBlobUpload)
uploadURL, err := route.URL("name", name)
if err != nil {
return "", err
}
return appendValuesURL(uploadURL, values...).String(), nil
}
// BuildBlobUploadChunkURL constructs a url for the upload identified by uuid,
// including any url values. This should generally not be used by clients, as
// this url is provided by server implementations during the blob upload
// process.
func (ub *URLBuilder) BuildBlobUploadChunkURL(name, uuid string, values ...url.Values) (string, error) {
route := ub.cloneRoute(RouteNameBlobUploadChunk)
uploadURL, err := route.URL("name", name, "uuid", uuid)
if err != nil {
return "", err
}
return appendValuesURL(uploadURL, values...).String(), nil
}
// clondedRoute returns a clone of the named route from the router. Routes
// must be cloned to avoid modifying them during url generation.
func (ub *URLBuilder) cloneRoute(name string) clonedRoute {
route := new(mux.Route)
root := new(url.URL)
*route = *ub.router.GetRoute(name) // clone the route
*root = *ub.root
return clonedRoute{Route: route, root: root}
}
type clonedRoute struct {
*mux.Route
root *url.URL
}
func (cr clonedRoute) URL(pairs ...string) (*url.URL, error) {
routeURL, err := cr.Route.URL(pairs...)
if err != nil {
return nil, err
}
if routeURL.Scheme == "" && routeURL.User == nil && routeURL.Host == "" {
routeURL.Path = routeURL.Path[1:]
}
return cr.root.ResolveReference(routeURL), nil
}
// appendValuesURL appends the parameters to the url.
func appendValuesURL(u *url.URL, values ...url.Values) *url.URL {
merged := u.Query()
for _, v := range values {
for k, vv := range v {
merged[k] = append(merged[k], vv...)
}
}
u.RawQuery = merged.Encode()
return u
}
// appendValues appends the parameters to the url. Panics if the string is not
// a url.
func appendValues(u string, values ...url.Values) string {
up, err := url.Parse(u)
if err != nil {
panic(err) // should never happen
}
return appendValuesURL(up, values...).String()
}

View file

@ -0,0 +1,225 @@
package v2
import (
"net/http"
"net/url"
"testing"
)
type urlBuilderTestCase struct {
description string
expectedPath string
build func() (string, error)
}
func makeURLBuilderTestCases(urlBuilder *URLBuilder) []urlBuilderTestCase {
return []urlBuilderTestCase{
{
description: "test base url",
expectedPath: "/v2/",
build: urlBuilder.BuildBaseURL,
},
{
description: "test tags url",
expectedPath: "/v2/foo/bar/tags/list",
build: func() (string, error) {
return urlBuilder.BuildTagsURL("foo/bar")
},
},
{
description: "test manifest url",
expectedPath: "/v2/foo/bar/manifests/tag",
build: func() (string, error) {
return urlBuilder.BuildManifestURL("foo/bar", "tag")
},
},
{
description: "build blob url",
expectedPath: "/v2/foo/bar/blobs/tarsum.v1+sha256:abcdef0123456789",
build: func() (string, error) {
return urlBuilder.BuildBlobURL("foo/bar", "tarsum.v1+sha256:abcdef0123456789")
},
},
{
description: "build blob upload url",
expectedPath: "/v2/foo/bar/blobs/uploads/",
build: func() (string, error) {
return urlBuilder.BuildBlobUploadURL("foo/bar")
},
},
{
description: "build blob upload url with digest and size",
expectedPath: "/v2/foo/bar/blobs/uploads/?digest=tarsum.v1%2Bsha256%3Aabcdef0123456789&size=10000",
build: func() (string, error) {
return urlBuilder.BuildBlobUploadURL("foo/bar", url.Values{
"size": []string{"10000"},
"digest": []string{"tarsum.v1+sha256:abcdef0123456789"},
})
},
},
{
description: "build blob upload chunk url",
expectedPath: "/v2/foo/bar/blobs/uploads/uuid-part",
build: func() (string, error) {
return urlBuilder.BuildBlobUploadChunkURL("foo/bar", "uuid-part")
},
},
{
description: "build blob upload chunk url with digest and size",
expectedPath: "/v2/foo/bar/blobs/uploads/uuid-part?digest=tarsum.v1%2Bsha256%3Aabcdef0123456789&size=10000",
build: func() (string, error) {
return urlBuilder.BuildBlobUploadChunkURL("foo/bar", "uuid-part", url.Values{
"size": []string{"10000"},
"digest": []string{"tarsum.v1+sha256:abcdef0123456789"},
})
},
},
}
}
// TestURLBuilder tests the various url building functions, ensuring they are
// returning the expected values.
func TestURLBuilder(t *testing.T) {
roots := []string{
"http://example.com",
"https://example.com",
"http://localhost:5000",
"https://localhost:5443",
}
for _, root := range roots {
urlBuilder, err := NewURLBuilderFromString(root)
if err != nil {
t.Fatalf("unexpected error creating urlbuilder: %v", err)
}
for _, testCase := range makeURLBuilderTestCases(urlBuilder) {
url, err := testCase.build()
if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err)
}
expectedURL := root + testCase.expectedPath
if url != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL)
}
}
}
}
func TestURLBuilderWithPrefix(t *testing.T) {
roots := []string{
"http://example.com/prefix/",
"https://example.com/prefix/",
"http://localhost:5000/prefix/",
"https://localhost:5443/prefix/",
}
for _, root := range roots {
urlBuilder, err := NewURLBuilderFromString(root)
if err != nil {
t.Fatalf("unexpected error creating urlbuilder: %v", err)
}
for _, testCase := range makeURLBuilderTestCases(urlBuilder) {
url, err := testCase.build()
if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err)
}
expectedURL := root[0:len(root)-1] + testCase.expectedPath
if url != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL)
}
}
}
}
type builderFromRequestTestCase struct {
request *http.Request
base string
}
func TestBuilderFromRequest(t *testing.T) {
u, err := url.Parse("http://example.com")
if err != nil {
t.Fatal(err)
}
forwardedProtoHeader := make(http.Header, 1)
forwardedProtoHeader.Set("X-Forwarded-Proto", "https")
testRequests := []struct {
request *http.Request
base string
}{
{
request: &http.Request{URL: u, Host: u.Host},
base: "http://example.com",
},
{
request: &http.Request{URL: u, Host: u.Host, Header: forwardedProtoHeader},
base: "https://example.com",
},
}
for _, tr := range testRequests {
builder := NewURLBuilderFromRequest(tr.request)
for _, testCase := range makeURLBuilderTestCases(builder) {
url, err := testCase.build()
if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err)
}
expectedURL := tr.base + testCase.expectedPath
if url != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL)
}
}
}
}
func TestBuilderFromRequestWithPrefix(t *testing.T) {
u, err := url.Parse("http://example.com/prefix/v2/")
if err != nil {
t.Fatal(err)
}
forwardedProtoHeader := make(http.Header, 1)
forwardedProtoHeader.Set("X-Forwarded-Proto", "https")
testRequests := []struct {
request *http.Request
base string
}{
{
request: &http.Request{URL: u, Host: u.Host},
base: "http://example.com/prefix/",
},
{
request: &http.Request{URL: u, Host: u.Host, Header: forwardedProtoHeader},
base: "https://example.com/prefix/",
},
}
for _, tr := range testRequests {
builder := NewURLBuilderFromRequest(tr.request)
for _, testCase := range makeURLBuilderTestCases(builder) {
url, err := testCase.build()
if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err)
}
expectedURL := tr.base[0:len(tr.base)-1] + testCase.expectedPath
if url != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL)
}
}
}
}

View file

@ -87,10 +87,10 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
setCurrentRoute(req, match.Route)
}
if handler == nil {
if r.NotFoundHandler == nil {
r.NotFoundHandler = http.NotFoundHandler()
}
handler = r.NotFoundHandler
if handler == nil {
handler = http.NotFoundHandler()
}
}
if !r.KeepContext {
defer context.Clear(req)

View file

@ -462,6 +462,15 @@ func TestQueries(t *testing.T) {
path: "",
shouldMatch: true,
},
{
title: "Queries route, match with a query string out of order",
route: new(Route).Host("www.example.com").Path("/api").Queries("foo", "bar", "baz", "ding"),
request: newRequest("GET", "http://www.example.com/api?baz=ding&foo=bar"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Queries route, bad query",
route: new(Route).Queries("foo", "bar", "baz", "ding"),
@ -471,6 +480,42 @@ func TestQueries(t *testing.T) {
path: "",
shouldMatch: false,
},
{
title: "Queries route with pattern, match",
route: new(Route).Queries("foo", "{v1}"),
request: newRequest("GET", "http://localhost?foo=bar"),
vars: map[string]string{"v1": "bar"},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Queries route with multiple patterns, match",
route: new(Route).Queries("foo", "{v1}", "baz", "{v2}"),
request: newRequest("GET", "http://localhost?foo=bar&baz=ding"),
vars: map[string]string{"v1": "bar", "v2": "ding"},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Queries route with regexp pattern, match",
route: new(Route).Queries("foo", "{v1:[0-9]+}"),
request: newRequest("GET", "http://localhost?foo=10"),
vars: map[string]string{"v1": "10"},
host: "",
path: "",
shouldMatch: true,
},
{
title: "Queries route with regexp pattern, regexp does not match",
route: new(Route).Queries("foo", "{v1:[0-9]+}"),
request: newRequest("GET", "http://localhost?foo=a"),
vars: map[string]string{},
host: "",
path: "",
shouldMatch: false,
},
}
for _, test := range tests {

View file

@ -329,35 +329,6 @@ var pathMatcherTests = []pathMatcherTest{
},
}
type queryMatcherTest struct {
matcher queryMatcher
url string
result bool
}
var queryMatcherTests = []queryMatcherTest{
{
matcher: queryMatcher(map[string]string{"foo": "bar", "baz": "ding"}),
url: "http://localhost:8080/?foo=bar&baz=ding",
result: true,
},
{
matcher: queryMatcher(map[string]string{"foo": "", "baz": ""}),
url: "http://localhost:8080/?foo=anything&baz=anything",
result: true,
},
{
matcher: queryMatcher(map[string]string{"foo": "ding", "baz": "bar"}),
url: "http://localhost:8080/?foo=bar&baz=ding",
result: false,
},
{
matcher: queryMatcher(map[string]string{"bar": "foo", "ding": "baz"}),
url: "http://localhost:8080/?foo=bar&baz=ding",
result: false,
},
}
type schemeMatcherTest struct {
matcher schemeMatcher
url string
@ -519,23 +490,8 @@ func TestPathMatcher(t *testing.T) {
}
}
func TestQueryMatcher(t *testing.T) {
for _, v := range queryMatcherTests {
request, _ := http.NewRequest("GET", v.url, nil)
var routeMatch RouteMatch
result := v.matcher.Match(request, &routeMatch)
if result != v.result {
if v.result {
t.Errorf("%#v: should match %v.", v.matcher, v.url)
} else {
t.Errorf("%#v: should not match %v.", v.matcher, v.url)
}
}
}
}
func TestSchemeMatcher(t *testing.T) {
for _, v := range queryMatcherTests {
for _, v := range schemeMatcherTests {
request, _ := http.NewRequest("GET", v.url, nil)
var routeMatch RouteMatch
result := v.matcher.Match(request, &routeMatch)
@ -735,7 +691,7 @@ func TestNewRegexp(t *testing.T) {
}
for pattern, paths := range tests {
p, _ = newRouteRegexp(pattern, false, false, false)
p, _ = newRouteRegexp(pattern, false, false, false, false)
for path, result := range paths {
matches = p.regexp.FindStringSubmatch(path)
if result == nil {

View file

@ -14,7 +14,7 @@ import (
)
// newRouteRegexp parses a route template and returns a routeRegexp,
// used to match a host or path.
// used to match a host, a path or a query string.
//
// It will extract named variables, assemble a regexp to be matched, create
// a "reverse" template to build URLs and compile regexps to validate variable
@ -23,7 +23,7 @@ import (
// Previously we accepted only Python-like identifiers for variable
// names ([a-zA-Z_][a-zA-Z0-9_]*), but currently the only restriction is that
// name and pattern can't be empty, and names can't contain a colon.
func newRouteRegexp(tpl string, matchHost, matchPrefix, strictSlash bool) (*routeRegexp, error) {
func newRouteRegexp(tpl string, matchHost, matchPrefix, matchQuery, strictSlash bool) (*routeRegexp, error) {
// Check if it is well-formed.
idxs, errBraces := braceIndices(tpl)
if errBraces != nil {
@ -33,11 +33,15 @@ func newRouteRegexp(tpl string, matchHost, matchPrefix, strictSlash bool) (*rout
template := tpl
// Now let's parse it.
defaultPattern := "[^/]+"
if matchHost {
if matchQuery {
defaultPattern = "[^?&]+"
matchPrefix = true
} else if matchHost {
defaultPattern = "[^.]+"
matchPrefix, strictSlash = false, false
matchPrefix = false
}
if matchPrefix {
// Only match strict slash if not matching
if matchPrefix || matchHost || matchQuery {
strictSlash = false
}
// Set a flag for strictSlash.
@ -48,7 +52,10 @@ func newRouteRegexp(tpl string, matchHost, matchPrefix, strictSlash bool) (*rout
}
varsN := make([]string, len(idxs)/2)
varsR := make([]*regexp.Regexp, len(idxs)/2)
pattern := bytes.NewBufferString("^")
pattern := bytes.NewBufferString("")
if !matchQuery {
pattern.WriteByte('^')
}
reverse := bytes.NewBufferString("")
var end int
var err error
@ -100,6 +107,7 @@ func newRouteRegexp(tpl string, matchHost, matchPrefix, strictSlash bool) (*rout
return &routeRegexp{
template: template,
matchHost: matchHost,
matchQuery: matchQuery,
strictSlash: strictSlash,
regexp: reg,
reverse: reverse.String(),
@ -113,8 +121,10 @@ func newRouteRegexp(tpl string, matchHost, matchPrefix, strictSlash bool) (*rout
type routeRegexp struct {
// The unmodified template.
template string
// True for host match, false for path match.
// True for host match, false for path or query string match.
matchHost bool
// True for query string match, false for path and host match.
matchQuery bool
// The strictSlash value defined on the route, but disabled if PathPrefix was used.
strictSlash bool
// Expanded regexp.
@ -130,7 +140,11 @@ type routeRegexp struct {
// Match matches the regexp against the URL host or path.
func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
if !r.matchHost {
return r.regexp.MatchString(req.URL.Path)
if r.matchQuery {
return r.regexp.MatchString(req.URL.RawQuery)
} else {
return r.regexp.MatchString(req.URL.Path)
}
}
return r.regexp.MatchString(getHost(req))
}
@ -196,8 +210,9 @@ func braceIndices(s string) ([]int, error) {
// routeRegexpGroup groups the route matchers that carry variables.
type routeRegexpGroup struct {
host *routeRegexp
path *routeRegexp
host *routeRegexp
path *routeRegexp
queries []*routeRegexp
}
// setMatch extracts the variables from the URL once a route matches.
@ -234,17 +249,28 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route)
}
}
}
// Store query string variables.
rawQuery := req.URL.RawQuery
for _, q := range v.queries {
queryVars := q.regexp.FindStringSubmatch(rawQuery)
if queryVars != nil {
for k, v := range q.varsN {
m.Vars[v] = queryVars[k+1]
}
}
}
}
// getHost tries its best to return the request host.
func getHost(r *http.Request) string {
if !r.URL.IsAbs() {
host := r.Host
// Slice off any port information.
if i := strings.Index(host, ":"); i != -1 {
host = host[:i]
}
return host
if r.URL.IsAbs() {
return r.URL.Host
}
return r.URL.Host
host := r.Host
// Slice off any port information.
if i := strings.Index(host, ":"); i != -1 {
host = host[:i]
}
return host
}

View file

@ -135,12 +135,12 @@ func (r *Route) addMatcher(m matcher) *Route {
}
// addRegexpMatcher adds a host or path matcher and builder to a route.
func (r *Route) addRegexpMatcher(tpl string, matchHost, matchPrefix bool) error {
func (r *Route) addRegexpMatcher(tpl string, matchHost, matchPrefix, matchQuery bool) error {
if r.err != nil {
return r.err
}
r.regexp = r.getRegexpGroup()
if !matchHost {
if !matchHost && !matchQuery {
if len(tpl) == 0 || tpl[0] != '/' {
return fmt.Errorf("mux: path must start with a slash, got %q", tpl)
}
@ -148,10 +148,15 @@ func (r *Route) addRegexpMatcher(tpl string, matchHost, matchPrefix bool) error
tpl = strings.TrimRight(r.regexp.path.template, "/") + tpl
}
}
rr, err := newRouteRegexp(tpl, matchHost, matchPrefix, r.strictSlash)
rr, err := newRouteRegexp(tpl, matchHost, matchPrefix, matchQuery, r.strictSlash)
if err != nil {
return err
}
for _, q := range r.regexp.queries {
if err = uniqueVars(rr.varsN, q.varsN); err != nil {
return err
}
}
if matchHost {
if r.regexp.path != nil {
if err = uniqueVars(rr.varsN, r.regexp.path.varsN); err != nil {
@ -165,7 +170,11 @@ func (r *Route) addRegexpMatcher(tpl string, matchHost, matchPrefix bool) error
return err
}
}
r.regexp.path = rr
if matchQuery {
r.regexp.queries = append(r.regexp.queries, rr)
} else {
r.regexp.path = rr
}
}
r.addMatcher(rr)
return nil
@ -219,7 +228,7 @@ func (r *Route) Headers(pairs ...string) *Route {
// Variable names must be unique in a given route. They can be retrieved
// calling mux.Vars(request).
func (r *Route) Host(tpl string) *Route {
r.err = r.addRegexpMatcher(tpl, true, false)
r.err = r.addRegexpMatcher(tpl, true, false, false)
return r
}
@ -278,7 +287,7 @@ func (r *Route) Methods(methods ...string) *Route {
// Variable names must be unique in a given route. They can be retrieved
// calling mux.Vars(request).
func (r *Route) Path(tpl string) *Route {
r.err = r.addRegexpMatcher(tpl, false, false)
r.err = r.addRegexpMatcher(tpl, false, false, false)
return r
}
@ -294,35 +303,42 @@ func (r *Route) Path(tpl string) *Route {
// Also note that the setting of Router.StrictSlash() has no effect on routes
// with a PathPrefix matcher.
func (r *Route) PathPrefix(tpl string) *Route {
r.err = r.addRegexpMatcher(tpl, false, true)
r.err = r.addRegexpMatcher(tpl, false, true, false)
return r
}
// Query ----------------------------------------------------------------------
// queryMatcher matches the request against URL queries.
type queryMatcher map[string]string
func (m queryMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchMap(m, r.URL.Query(), false)
}
// Queries adds a matcher for URL query values.
// It accepts a sequence of key/value pairs. For example:
// It accepts a sequence of key/value pairs. Values may define variables.
// For example:
//
// r := mux.NewRouter()
// r.Queries("foo", "bar", "baz", "ding")
// r.Queries("foo", "bar", "id", "{id:[0-9]+}")
//
// The above route will only match if the URL contains the defined queries
// values, e.g.: ?foo=bar&baz=ding.
// values, e.g.: ?foo=bar&id=42.
//
// It the value is an empty string, it will match any value if the key is set.
//
// Variables can define an optional regexp pattern to me matched:
//
// - {name} matches anything until the next slash.
//
// - {name:pattern} matches the given regexp pattern.
func (r *Route) Queries(pairs ...string) *Route {
if r.err == nil {
var queries map[string]string
queries, r.err = mapFromPairs(pairs...)
return r.addMatcher(queryMatcher(queries))
length := len(pairs)
if length%2 != 0 {
r.err = fmt.Errorf(
"mux: number of parameters must be multiple of 2, got %v", pairs)
return nil
}
for i := 0; i < length; i += 2 {
if r.err = r.addRegexpMatcher(pairs[i]+"="+pairs[i+1], false, true, true); r.err != nil {
return r
}
}
return r
}
@ -498,8 +514,9 @@ func (r *Route) getRegexpGroup() *routeRegexpGroup {
} else {
// Copy.
r.regexp = &routeRegexpGroup{
host: regexp.host,
path: regexp.path,
host: regexp.host,
path: regexp.path,
queries: regexp.queries,
}
}
}