Forráskód Böngészése

Merge pull request #44327 from thaJeztah/ghsa-ambiguous-pull-by-digest_master

Validate digest in repo for pull by digest
Sebastiaan van Stijn 2 éve
szülő
commit
43b8dffb83

+ 101 - 6
distribution/manifest.go

@@ -5,6 +5,7 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"strings"
 
 
 	"github.com/containerd/containerd/content"
 	"github.com/containerd/containerd/content"
 	"github.com/containerd/containerd/errdefs"
 	"github.com/containerd/containerd/errdefs"
@@ -14,15 +15,22 @@ import (
 	"github.com/docker/distribution/manifest/manifestlist"
 	"github.com/docker/distribution/manifest/manifestlist"
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/manifest/schema2"
 	"github.com/docker/distribution/manifest/schema2"
+	"github.com/docker/distribution/reference"
+	"github.com/docker/docker/registry"
 	"github.com/opencontainers/go-digest"
 	"github.com/opencontainers/go-digest"
 	specs "github.com/opencontainers/image-spec/specs-go/v1"
 	specs "github.com/opencontainers/image-spec/specs-go/v1"
 	"github.com/pkg/errors"
 	"github.com/pkg/errors"
+	"github.com/sirupsen/logrus"
 )
 )
 
 
+// labelDistributionSource describes the source blob comes from.
+const labelDistributionSource = "containerd.io/distribution.source"
+
 // This is used by manifestStore to pare down the requirements to implement a
 // This is used by manifestStore to pare down the requirements to implement a
 // full distribution.ManifestService, since `Get` is all we use here.
 // full distribution.ManifestService, since `Get` is all we use here.
 type manifestGetter interface {
 type manifestGetter interface {
 	Get(ctx context.Context, dgst digest.Digest, options ...distribution.ManifestServiceOption) (distribution.Manifest, error)
 	Get(ctx context.Context, dgst digest.Digest, options ...distribution.ManifestServiceOption) (distribution.Manifest, error)
+	Exists(ctx context.Context, dgst digest.Digest) (bool, error)
 }
 }
 
 
 type manifestStore struct {
 type manifestStore struct {
@@ -39,15 +47,98 @@ type ContentStore interface {
 	content.Provider
 	content.Provider
 	Info(ctx context.Context, dgst digest.Digest) (content.Info, error)
 	Info(ctx context.Context, dgst digest.Digest) (content.Info, error)
 	Abort(ctx context.Context, ref string) error
 	Abort(ctx context.Context, ref string) error
+	Update(ctx context.Context, info content.Info, fieldpaths ...string) (content.Info, error)
+}
+
+func makeDistributionSourceLabel(ref reference.Named) (string, string) {
+	domain := reference.Domain(ref)
+	if domain == "" {
+		domain = registry.DefaultNamespace
+	}
+	repo := reference.Path(ref)
+
+	return fmt.Sprintf("%s.%s", labelDistributionSource, domain), repo
 }
 }
 
 
-func (m *manifestStore) getLocal(ctx context.Context, desc specs.Descriptor) (distribution.Manifest, error) {
+// Taken from https://github.com/containerd/containerd/blob/e079e4a155c86f07bbd602fe6753ecacc78198c2/remotes/docker/handler.go#L84-L108
+func appendDistributionSourceLabel(originLabel, repo string) string {
+	repos := []string{}
+	if originLabel != "" {
+		repos = strings.Split(originLabel, ",")
+	}
+	repos = append(repos, repo)
+
+	// use empty string to present duplicate items
+	for i := 1; i < len(repos); i++ {
+		tmp, j := repos[i], i-1
+		for ; j >= 0 && repos[j] >= tmp; j-- {
+			if repos[j] == tmp {
+				tmp = ""
+			}
+			repos[j+1] = repos[j]
+		}
+		repos[j+1] = tmp
+	}
+
+	i := 0
+	for ; i < len(repos) && repos[i] == ""; i++ {
+	}
+
+	return strings.Join(repos[i:], ",")
+}
+
+func hasDistributionSource(label, repo string) bool {
+	sources := strings.Split(label, ",")
+	for _, s := range sources {
+		if s == repo {
+			return true
+		}
+	}
+	return false
+}
+
+func (m *manifestStore) getLocal(ctx context.Context, desc specs.Descriptor, ref reference.Named) (distribution.Manifest, error) {
 	ra, err := m.local.ReaderAt(ctx, desc)
 	ra, err := m.local.ReaderAt(ctx, desc)
 	if err != nil {
 	if err != nil {
 		return nil, errors.Wrap(err, "error getting content store reader")
 		return nil, errors.Wrap(err, "error getting content store reader")
 	}
 	}
 	defer ra.Close()
 	defer ra.Close()
 
 
+	distKey, distRepo := makeDistributionSourceLabel(ref)
+	info, err := m.local.Info(ctx, desc.Digest)
+	if err != nil {
+		return nil, errors.Wrap(err, "error getting content info")
+	}
+
+	if _, ok := ref.(reference.Canonical); ok {
+		// Since this is specified by digest...
+		// We know we have the content locally, we need to check if we've seen this content at the specified repository before.
+		// If we have, we can just return the manifest from the local content store.
+		// If we haven't, we need to check the remote repository to see if it has the content, otherwise we can end up returning
+		// a manifest that has never even existed in the remote before.
+		if !hasDistributionSource(info.Labels[distKey], distRepo) {
+			logrus.WithField("ref", ref).Debug("found manifest but no mataching source repo is listed, checking with remote")
+			exists, err := m.remote.Exists(ctx, desc.Digest)
+			if err != nil {
+				return nil, errors.Wrap(err, "error checking if remote exists")
+			}
+
+			if !exists {
+				return nil, errors.Wrapf(errdefs.ErrNotFound, "manifest %v not found", desc.Digest)
+			}
+
+		}
+	}
+
+	// Update the distribution sources since we now know the content exists in the remote.
+	if info.Labels == nil {
+		info.Labels = map[string]string{}
+	}
+	info.Labels[distKey] = appendDistributionSourceLabel(info.Labels[distKey], distRepo)
+	if _, err := m.local.Update(ctx, info, "labels."+distKey); err != nil {
+		logrus.WithError(err).WithField("ref", ref).Warn("Could not update content distribution source")
+	}
+
 	r := io.NewSectionReader(ra, 0, ra.Size())
 	r := io.NewSectionReader(ra, 0, ra.Size())
 	data, err := io.ReadAll(r)
 	data, err := io.ReadAll(r)
 	if err != nil {
 	if err != nil {
@@ -58,6 +149,7 @@ func (m *manifestStore) getLocal(ctx context.Context, desc specs.Descriptor) (di
 	if err != nil {
 	if err != nil {
 		return nil, errors.Wrap(err, "error unmarshaling manifest from content store")
 		return nil, errors.Wrap(err, "error unmarshaling manifest from content store")
 	}
 	}
+
 	return manifest, nil
 	return manifest, nil
 }
 }
 
 
@@ -75,7 +167,7 @@ func (m *manifestStore) getMediaType(ctx context.Context, desc specs.Descriptor)
 	return mt, nil
 	return mt, nil
 }
 }
 
 
-func (m *manifestStore) Get(ctx context.Context, desc specs.Descriptor) (distribution.Manifest, error) {
+func (m *manifestStore) Get(ctx context.Context, desc specs.Descriptor, ref reference.Named) (distribution.Manifest, error) {
 	l := log.G(ctx)
 	l := log.G(ctx)
 
 
 	if desc.MediaType == "" {
 	if desc.MediaType == "" {
@@ -103,7 +195,7 @@ func (m *manifestStore) Get(ctx context.Context, desc specs.Descriptor) (distrib
 	if err != nil {
 	if err != nil {
 		if errdefs.IsAlreadyExists(err) {
 		if errdefs.IsAlreadyExists(err) {
 			var manifest distribution.Manifest
 			var manifest distribution.Manifest
-			if manifest, err = m.getLocal(ctx, desc); err == nil {
+			if manifest, err = m.getLocal(ctx, desc, ref); err == nil {
 				return manifest, nil
 				return manifest, nil
 			}
 			}
 		}
 		}
@@ -125,7 +217,7 @@ func (m *manifestStore) Get(ctx context.Context, desc specs.Descriptor) (distrib
 
 
 	if w != nil {
 	if w != nil {
 		// if `w` is nil here, something happened with the content store, so don't bother trying to persist.
 		// if `w` is nil here, something happened with the content store, so don't bother trying to persist.
-		if err := m.Put(ctx, manifest, desc, w); err != nil {
+		if err := m.Put(ctx, manifest, desc, w, ref); err != nil {
 			if err := m.local.Abort(ctx, key); err != nil {
 			if err := m.local.Abort(ctx, key); err != nil {
 				l.WithError(err).Warn("error aborting content ingest")
 				l.WithError(err).Warn("error aborting content ingest")
 			}
 			}
@@ -135,7 +227,7 @@ func (m *manifestStore) Get(ctx context.Context, desc specs.Descriptor) (distrib
 	return manifest, nil
 	return manifest, nil
 }
 }
 
 
-func (m *manifestStore) Put(ctx context.Context, manifest distribution.Manifest, desc specs.Descriptor, w content.Writer) error {
+func (m *manifestStore) Put(ctx context.Context, manifest distribution.Manifest, desc specs.Descriptor, w content.Writer, ref reference.Named) error {
 	mt, payload, err := manifest.Payload()
 	mt, payload, err := manifest.Payload()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -147,7 +239,10 @@ func (m *manifestStore) Put(ctx context.Context, manifest distribution.Manifest,
 		return errors.Wrap(err, "error writing manifest to content store")
 		return errors.Wrap(err, "error writing manifest to content store")
 	}
 	}
 
 
-	if err := w.Commit(ctx, desc.Size, desc.Digest); err != nil {
+	distKey, distSource := makeDistributionSourceLabel(ref)
+	if err := w.Commit(ctx, desc.Size, desc.Digest, content.WithLabels(map[string]string{
+		distKey: distSource,
+	})); err != nil {
 		return errors.Wrap(err, "error committing manifest to content store")
 		return errors.Wrap(err, "error committing manifest to content store")
 	}
 	}
 	return nil
 	return nil

+ 42 - 20
distribution/manifest_test.go

@@ -17,6 +17,7 @@ import (
 	"github.com/docker/distribution/manifest/ocischema"
 	"github.com/docker/distribution/manifest/ocischema"
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/manifest/schema2"
 	"github.com/docker/distribution/manifest/schema2"
+	"github.com/docker/distribution/reference"
 	"github.com/google/go-cmp/cmp/cmpopts"
 	"github.com/google/go-cmp/cmp/cmpopts"
 	"github.com/opencontainers/go-digest"
 	"github.com/opencontainers/go-digest"
 	specs "github.com/opencontainers/image-spec/specs-go/v1"
 	specs "github.com/opencontainers/image-spec/specs-go/v1"
@@ -39,6 +40,11 @@ func (m *mockManifestGetter) Get(ctx context.Context, dgst digest.Digest, option
 	return manifest, nil
 	return manifest, nil
 }
 }
 
 
+func (m *mockManifestGetter) Exists(ctx context.Context, dgst digest.Digest) (bool, error) {
+	_, ok := m.manifests[dgst]
+	return ok, nil
+}
+
 type memoryLabelStore struct {
 type memoryLabelStore struct {
 	mu     sync.Mutex
 	mu     sync.Mutex
 	labels map[digest.Digest]map[string]string
 	labels map[digest.Digest]map[string]string
@@ -76,7 +82,9 @@ func (s *memoryLabelStore) Update(dgst digest.Digest, update map[string]string)
 	for k, v := range update {
 	for k, v := range update {
 		labels[k] = v
 		labels[k] = v
 	}
 	}
-
+	if s.labels == nil {
+		s.labels = map[digest.Digest]map[string]string{}
+	}
 	s.labels[dgst] = labels
 	s.labels[dgst] = labels
 
 
 	return labels, nil
 	return labels, nil
@@ -125,7 +133,7 @@ func TestManifestStore(t *testing.T) {
 	assert.NilError(t, err)
 	assert.NilError(t, err)
 	dgst := digest.Canonical.FromBytes(serialized)
 	dgst := digest.Canonical.FromBytes(serialized)
 
 
-	setupTest := func(t *testing.T) (specs.Descriptor, *mockManifestGetter, *manifestStore, content.Store, func(*testing.T)) {
+	setupTest := func(t *testing.T) (reference.Named, specs.Descriptor, *mockManifestGetter, *manifestStore, content.Store, func(*testing.T)) {
 		root, err := os.MkdirTemp("", strings.ReplaceAll(t.Name(), "/", "_"))
 		root, err := os.MkdirTemp("", strings.ReplaceAll(t.Name(), "/", "_"))
 		assert.NilError(t, err)
 		assert.NilError(t, err)
 		defer func() {
 		defer func() {
@@ -141,7 +149,10 @@ func TestManifestStore(t *testing.T) {
 		store := &manifestStore{local: cs, remote: mg}
 		store := &manifestStore{local: cs, remote: mg}
 		desc := specs.Descriptor{Digest: dgst, MediaType: specs.MediaTypeImageManifest, Size: int64(len(serialized))}
 		desc := specs.Descriptor{Digest: dgst, MediaType: specs.MediaTypeImageManifest, Size: int64(len(serialized))}
 
 
-		return desc, mg, store, cs, func(t *testing.T) {
+		ref, err := reference.Parse("foo/bar")
+		assert.NilError(t, err)
+
+		return ref.(reference.Named), desc, mg, store, cs, func(t *testing.T) {
 			assert.Check(t, os.RemoveAll(root))
 			assert.Check(t, os.RemoveAll(root))
 		}
 		}
 	}
 	}
@@ -181,22 +192,22 @@ func TestManifestStore(t *testing.T) {
 	}
 	}
 
 
 	t.Run("no remote or local", func(t *testing.T) {
 	t.Run("no remote or local", func(t *testing.T) {
-		desc, _, store, cs, teardown := setupTest(t)
+		ref, desc, _, store, cs, teardown := setupTest(t)
 		defer teardown(t)
 		defer teardown(t)
 
 
-		_, err = store.Get(ctx, desc)
+		_, err = store.Get(ctx, desc, ref)
 		checkIngest(t, cs, desc)
 		checkIngest(t, cs, desc)
 		// This error is what our digest getter returns when it doesn't know about the manifest
 		// This error is what our digest getter returns when it doesn't know about the manifest
 		assert.Error(t, err, distribution.ErrManifestUnknown{Tag: dgst.String()}.Error())
 		assert.Error(t, err, distribution.ErrManifestUnknown{Tag: dgst.String()}.Error())
 	})
 	})
 
 
 	t.Run("no local cache", func(t *testing.T) {
 	t.Run("no local cache", func(t *testing.T) {
-		desc, mg, store, cs, teardown := setupTest(t)
+		ref, desc, mg, store, cs, teardown := setupTest(t)
 		defer teardown(t)
 		defer teardown(t)
 
 
 		mg.manifests[desc.Digest] = m
 		mg.manifests[desc.Digest] = m
 
 
-		m2, err := store.Get(ctx, desc)
+		m2, err := store.Get(ctx, desc, ref)
 		checkIngest(t, cs, desc)
 		checkIngest(t, cs, desc)
 		assert.NilError(t, err)
 		assert.NilError(t, err)
 		assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
 		assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
@@ -206,23 +217,34 @@ func TestManifestStore(t *testing.T) {
 		assert.NilError(t, err)
 		assert.NilError(t, err)
 		assert.Check(t, cmp.Equal(i.Digest, desc.Digest))
 		assert.Check(t, cmp.Equal(i.Digest, desc.Digest))
 
 
+		distKey, distSource := makeDistributionSourceLabel(ref)
+		assert.Check(t, hasDistributionSource(i.Labels[distKey], distSource))
+
 		// Now check again, this should not hit the remote
 		// Now check again, this should not hit the remote
-		m2, err = store.Get(ctx, desc)
+		m2, err = store.Get(ctx, desc, ref)
 		checkIngest(t, cs, desc)
 		checkIngest(t, cs, desc)
 		assert.NilError(t, err)
 		assert.NilError(t, err)
 		assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
 		assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
 		assert.Check(t, cmp.Equal(mg.gets, 1))
 		assert.Check(t, cmp.Equal(mg.gets, 1))
+
+		t.Run("digested", func(t *testing.T) {
+			ref, err := reference.WithDigest(ref, desc.Digest)
+			assert.NilError(t, err)
+
+			_, err = store.Get(ctx, desc, ref)
+			assert.NilError(t, err)
+		})
 	})
 	})
 
 
 	t.Run("with local cache", func(t *testing.T) {
 	t.Run("with local cache", func(t *testing.T) {
-		desc, mg, store, cs, teardown := setupTest(t)
+		ref, desc, mg, store, cs, teardown := setupTest(t)
 		defer teardown(t)
 		defer teardown(t)
 
 
 		// first add the manifest to the coontent store
 		// first add the manifest to the coontent store
 		writeManifest(t, cs, desc)
 		writeManifest(t, cs, desc)
 
 
 		// now do the get
 		// now do the get
-		m2, err := store.Get(ctx, desc)
+		m2, err := store.Get(ctx, desc, ref)
 		checkIngest(t, cs, desc)
 		checkIngest(t, cs, desc)
 		assert.NilError(t, err)
 		assert.NilError(t, err)
 		assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
 		assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
@@ -236,13 +258,13 @@ func TestManifestStore(t *testing.T) {
 	// This is for the case of pull by digest where we don't know the media type of the manifest until it's actually pulled.
 	// This is for the case of pull by digest where we don't know the media type of the manifest until it's actually pulled.
 	t.Run("unknown media type", func(t *testing.T) {
 	t.Run("unknown media type", func(t *testing.T) {
 		t.Run("no cache", func(t *testing.T) {
 		t.Run("no cache", func(t *testing.T) {
-			desc, mg, store, cs, teardown := setupTest(t)
+			ref, desc, mg, store, cs, teardown := setupTest(t)
 			defer teardown(t)
 			defer teardown(t)
 
 
 			mg.manifests[desc.Digest] = m
 			mg.manifests[desc.Digest] = m
 			desc.MediaType = ""
 			desc.MediaType = ""
 
 
-			m2, err := store.Get(ctx, desc)
+			m2, err := store.Get(ctx, desc, ref)
 			checkIngest(t, cs, desc)
 			checkIngest(t, cs, desc)
 			assert.NilError(t, err)
 			assert.NilError(t, err)
 			assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
 			assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
@@ -251,13 +273,13 @@ func TestManifestStore(t *testing.T) {
 
 
 		t.Run("with cache", func(t *testing.T) {
 		t.Run("with cache", func(t *testing.T) {
 			t.Run("cached manifest has media type", func(t *testing.T) {
 			t.Run("cached manifest has media type", func(t *testing.T) {
-				desc, mg, store, cs, teardown := setupTest(t)
+				ref, desc, mg, store, cs, teardown := setupTest(t)
 				defer teardown(t)
 				defer teardown(t)
 
 
 				writeManifest(t, cs, desc)
 				writeManifest(t, cs, desc)
 				desc.MediaType = ""
 				desc.MediaType = ""
 
 
-				m2, err := store.Get(ctx, desc)
+				m2, err := store.Get(ctx, desc, ref)
 				checkIngest(t, cs, desc)
 				checkIngest(t, cs, desc)
 				assert.NilError(t, err)
 				assert.NilError(t, err)
 				assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
 				assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
@@ -265,13 +287,13 @@ func TestManifestStore(t *testing.T) {
 			})
 			})
 
 
 			t.Run("cached manifest has no media type", func(t *testing.T) {
 			t.Run("cached manifest has no media type", func(t *testing.T) {
-				desc, mg, store, cs, teardown := setupTest(t)
+				ref, desc, mg, store, cs, teardown := setupTest(t)
 				defer teardown(t)
 				defer teardown(t)
 
 
 				desc.MediaType = ""
 				desc.MediaType = ""
 				writeManifest(t, cs, desc)
 				writeManifest(t, cs, desc)
 
 
-				m2, err := store.Get(ctx, desc)
+				m2, err := store.Get(ctx, desc, ref)
 				checkIngest(t, cs, desc)
 				checkIngest(t, cs, desc)
 				assert.NilError(t, err)
 				assert.NilError(t, err)
 				assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
 				assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
@@ -286,14 +308,14 @@ func TestManifestStore(t *testing.T) {
 	// Also makes sure the ingests are aborted.
 	// Also makes sure the ingests are aborted.
 	t.Run("error persisting manifest", func(t *testing.T) {
 	t.Run("error persisting manifest", func(t *testing.T) {
 		t.Run("error on writer", func(t *testing.T) {
 		t.Run("error on writer", func(t *testing.T) {
-			desc, mg, store, cs, teardown := setupTest(t)
+			ref, desc, mg, store, cs, teardown := setupTest(t)
 			defer teardown(t)
 			defer teardown(t)
 			mg.manifests[desc.Digest] = m
 			mg.manifests[desc.Digest] = m
 
 
 			csW := &testingContentStoreWrapper{ContentStore: store.local, errorOnWriter: errors.New("random error")}
 			csW := &testingContentStoreWrapper{ContentStore: store.local, errorOnWriter: errors.New("random error")}
 			store.local = csW
 			store.local = csW
 
 
-			m2, err := store.Get(ctx, desc)
+			m2, err := store.Get(ctx, desc, ref)
 			checkIngest(t, cs, desc)
 			checkIngest(t, cs, desc)
 			assert.NilError(t, err)
 			assert.NilError(t, err)
 			assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
 			assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
@@ -305,14 +327,14 @@ func TestManifestStore(t *testing.T) {
 		})
 		})
 
 
 		t.Run("error on commit", func(t *testing.T) {
 		t.Run("error on commit", func(t *testing.T) {
-			desc, mg, store, cs, teardown := setupTest(t)
+			ref, desc, mg, store, cs, teardown := setupTest(t)
 			defer teardown(t)
 			defer teardown(t)
 			mg.manifests[desc.Digest] = m
 			mg.manifests[desc.Digest] = m
 
 
 			csW := &testingContentStoreWrapper{ContentStore: store.local, errorOnCommit: errors.New("random error")}
 			csW := &testingContentStoreWrapper{ContentStore: store.local, errorOnCommit: errors.New("random error")}
 			store.local = csW
 			store.local = csW
 
 
-			m2, err := store.Get(ctx, desc)
+			m2, err := store.Get(ctx, desc, ref)
 			checkIngest(t, cs, desc)
 			checkIngest(t, cs, desc)
 			assert.NilError(t, err)
 			assert.NilError(t, err)
 			assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
 			assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))

+ 3 - 2
distribution/pull_v2.go

@@ -385,7 +385,8 @@ func (p *puller) pullTag(ctx context.Context, ref reference.Named, platform *spe
 		Digest:    dgst,
 		Digest:    dgst,
 		Size:      size,
 		Size:      size,
 	}
 	}
-	manifest, err := p.manifestStore.Get(ctx, desc)
+
+	manifest, err := p.manifestStore.Get(ctx, desc, ref)
 	if err != nil {
 	if err != nil {
 		if isTagged && isNotFound(errors.Cause(err)) {
 		if isTagged && isNotFound(errors.Cause(err)) {
 			logrus.WithField("ref", ref).WithError(err).Debug("Falling back to pull manifest by tag")
 			logrus.WithField("ref", ref).WithError(err).Debug("Falling back to pull manifest by tag")
@@ -843,7 +844,7 @@ func (p *puller) pullManifestList(ctx context.Context, ref reference.Named, mfst
 			Size:      match.Size,
 			Size:      match.Size,
 			MediaType: match.MediaType,
 			MediaType: match.MediaType,
 		}
 		}
-		manifest, err := p.manifestStore.Get(ctx, desc)
+		manifest, err := p.manifestStore.Get(ctx, desc, ref)
 		if err != nil {
 		if err != nil {
 			return "", "", err
 			return "", "", err
 		}
 		}

+ 138 - 0
integration/image/pull_test.go

@@ -2,11 +2,24 @@ package image
 
 
 import (
 import (
 	"context"
 	"context"
+	"encoding/json"
+	"io"
+	"os"
+	"path"
 	"testing"
 	"testing"
 
 
+	"github.com/containerd/containerd"
+	"github.com/containerd/containerd/content"
+	"github.com/containerd/containerd/content/local"
+	"github.com/containerd/containerd/images"
+	"github.com/containerd/containerd/platforms"
 	"github.com/docker/docker/api/types"
 	"github.com/docker/docker/api/types"
 	"github.com/docker/docker/api/types/versions"
 	"github.com/docker/docker/api/types/versions"
 	"github.com/docker/docker/errdefs"
 	"github.com/docker/docker/errdefs"
+	"github.com/docker/docker/testutil/registry"
+	"github.com/opencontainers/go-digest"
+	"github.com/opencontainers/image-spec/specs-go"
+	imagespec "github.com/opencontainers/image-spec/specs-go/v1"
 	"gotest.tools/v3/assert"
 	"gotest.tools/v3/assert"
 	"gotest.tools/v3/skip"
 	"gotest.tools/v3/skip"
 )
 )
@@ -22,3 +35,128 @@ func TestImagePullPlatformInvalid(t *testing.T) {
 	assert.ErrorContains(t, err, "unknown operating system or architecture")
 	assert.ErrorContains(t, err, "unknown operating system or architecture")
 	assert.Assert(t, errdefs.IsInvalidParameter(err))
 	assert.Assert(t, errdefs.IsInvalidParameter(err))
 }
 }
+
+func createTestImage(ctx context.Context, t testing.TB, store content.Store) imagespec.Descriptor {
+	w, err := store.Writer(ctx, content.WithRef("layer"))
+	assert.NilError(t, err)
+	defer w.Close()
+
+	// Empty layer with just a root dir
+	const layer = `./0000775000000000000000000000000014201045023007702 5ustar  rootroot`
+
+	_, err = w.Write([]byte(layer))
+	assert.NilError(t, err)
+
+	err = w.Commit(ctx, int64(len(layer)), digest.FromBytes([]byte(layer)))
+	assert.NilError(t, err)
+
+	layerDigest := w.Digest()
+	w.Close()
+
+	platform := platforms.DefaultSpec()
+
+	img := imagespec.Image{
+		Architecture: platform.Architecture,
+		OS:           platform.OS,
+		RootFS:       imagespec.RootFS{Type: "layers", DiffIDs: []digest.Digest{layerDigest}},
+		Config:       imagespec.ImageConfig{WorkingDir: "/"},
+	}
+	imgJSON, err := json.Marshal(img)
+	assert.NilError(t, err)
+
+	w, err = store.Writer(ctx, content.WithRef("config"))
+	assert.NilError(t, err)
+	defer w.Close()
+	_, err = w.Write(imgJSON)
+	assert.NilError(t, err)
+	assert.NilError(t, w.Commit(ctx, int64(len(imgJSON)), digest.FromBytes(imgJSON)))
+
+	configDigest := w.Digest()
+	w.Close()
+
+	info, err := store.Info(ctx, layerDigest)
+	assert.NilError(t, err)
+
+	manifest := imagespec.Manifest{
+		Versioned: specs.Versioned{
+			SchemaVersion: 2,
+		},
+		MediaType: images.MediaTypeDockerSchema2Manifest,
+		Config: imagespec.Descriptor{
+			MediaType: images.MediaTypeDockerSchema2Config,
+			Digest:    configDigest,
+			Size:      int64(len(imgJSON)),
+		},
+		Layers: []imagespec.Descriptor{{
+			MediaType: images.MediaTypeDockerSchema2Layer,
+			Digest:    layerDigest,
+			Size:      info.Size,
+		}},
+	}
+
+	manifestJSON, err := json.Marshal(manifest)
+	assert.NilError(t, err)
+
+	w, err = store.Writer(ctx, content.WithRef("manifest"))
+	assert.NilError(t, err)
+	defer w.Close()
+	_, err = w.Write(manifestJSON)
+	assert.NilError(t, err)
+	assert.NilError(t, w.Commit(ctx, int64(len(manifestJSON)), digest.FromBytes(manifestJSON)))
+
+	manifestDigest := w.Digest()
+	w.Close()
+
+	return imagespec.Descriptor{
+		MediaType: images.MediaTypeDockerSchema2Manifest,
+		Digest:    manifestDigest,
+		Size:      int64(len(manifestJSON)),
+	}
+}
+
+// Make sure that pulling by an already cached digest but for a different ref (that should not have that digest)
+// verifies with the remote that the digest exists in that repo.
+func TestImagePullStoredfDigestForOtherRepo(t *testing.T) {
+	skip.If(t, testEnv.IsRemoteDaemon, "cannot run daemon when remote daemon")
+	skip.If(t, testEnv.OSType == "windows", "We don't run a test registry on Windows")
+	skip.If(t, testEnv.IsRootless, "Rootless has a different view of localhost (needed for test registry access)")
+	defer setupTest(t)()
+
+	reg := registry.NewV2(t, registry.WithStdout(os.Stdout), registry.WithStderr(os.Stderr))
+	defer reg.Close()
+	reg.WaitReady(t)
+
+	ctx := context.Background()
+
+	// First create an image and upload it to our local registry
+	// Then we'll download it so that we can make sure the content is available in dockerd's manifest cache.
+	// Then we'll try to pull the same digest but with a different repo name.
+
+	dir := t.TempDir()
+	store, err := local.NewStore(dir)
+	assert.NilError(t, err)
+
+	desc := createTestImage(ctx, t, store)
+
+	remote := path.Join(registry.DefaultURL, "test:latest")
+
+	c8dClient, err := containerd.New("", containerd.WithServices(containerd.WithContentStore(store)))
+	assert.NilError(t, err)
+
+	c8dClient.Push(ctx, remote, desc)
+	assert.NilError(t, err)
+
+	client := testEnv.APIClient()
+	rdr, err := client.ImagePull(ctx, remote, types.ImagePullOptions{})
+	assert.NilError(t, err)
+	defer rdr.Close()
+	io.Copy(io.Discard, rdr)
+
+	// Now, pull a totally different repo with a the same digest
+	rdr, err = client.ImagePull(ctx, path.Join(registry.DefaultURL, "other:image@"+desc.Digest.String()), types.ImagePullOptions{})
+	if rdr != nil {
+		rdr.Close()
+	}
+	assert.Assert(t, err != nil, "Expected error, got none: %v", err)
+	assert.Assert(t, errdefs.IsNotFound(err), err)
+}

+ 16 - 0
testutil/registry/ops.go

@@ -1,5 +1,7 @@
 package registry
 package registry
 
 
+import "io"
+
 // Schema1 sets the registry to serve v1 api
 // Schema1 sets the registry to serve v1 api
 func Schema1(c *Config) {
 func Schema1(c *Config) {
 	c.schema1 = true
 	c.schema1 = true
@@ -24,3 +26,17 @@ func URL(registryURL string) func(*Config) {
 		c.registryURL = registryURL
 		c.registryURL = registryURL
 	}
 	}
 }
 }
+
+// WithStdout sets the stdout of the registry command to the passed in writer.
+func WithStdout(w io.Writer) func(c *Config) {
+	return func(c *Config) {
+		c.stdout = w
+	}
+}
+
+// WithStderr sets the stdout of the registry command to the passed in writer.
+func WithStderr(w io.Writer) func(c *Config) {
+	return func(c *Config) {
+		c.stderr = w
+	}
+}