|
@@ -0,0 +1,177 @@
|
|
|
+// Copyright 2023 The Go Authors. All rights reserved.
|
|
|
+// Use of this source code is governed by a BSD-style
|
|
|
+// license that can be found in the LICENSE file.
|
|
|
+
|
|
|
+package dynamicpb
|
|
|
+
|
|
|
+import (
|
|
|
+ "fmt"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
+ "sync/atomic"
|
|
|
+
|
|
|
+ "google.golang.org/protobuf/internal/errors"
|
|
|
+ "google.golang.org/protobuf/reflect/protoreflect"
|
|
|
+ "google.golang.org/protobuf/reflect/protoregistry"
|
|
|
+)
|
|
|
+
|
|
|
+type extField struct {
|
|
|
+ name protoreflect.FullName
|
|
|
+ number protoreflect.FieldNumber
|
|
|
+}
|
|
|
+
|
|
|
+// A Types is a collection of dynamically constructed descriptors.
|
|
|
+// Its methods are safe for concurrent use.
|
|
|
+//
|
|
|
+// Types implements protoregistry.MessageTypeResolver and protoregistry.ExtensionTypeResolver.
|
|
|
+// A Types may be used as a proto.UnmarshalOptions.Resolver.
|
|
|
+type Types struct {
|
|
|
+ files *protoregistry.Files
|
|
|
+
|
|
|
+ extMu sync.Mutex
|
|
|
+ atomicExtFiles uint64
|
|
|
+ extensionsByMessage map[extField]protoreflect.ExtensionDescriptor
|
|
|
+}
|
|
|
+
|
|
|
+// NewTypes creates a new Types registry with the provided files.
|
|
|
+// The Files registry is retained, and changes to Files will be reflected in Types.
|
|
|
+// It is not safe to concurrently change the Files while calling Types methods.
|
|
|
+func NewTypes(f *protoregistry.Files) *Types {
|
|
|
+ return &Types{
|
|
|
+ files: f,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// FindEnumByName looks up an enum by its full name;
|
|
|
+// e.g., "google.protobuf.Field.Kind".
|
|
|
+//
|
|
|
+// This returns (nil, protoregistry.NotFound) if not found.
|
|
|
+func (t *Types) FindEnumByName(name protoreflect.FullName) (protoreflect.EnumType, error) {
|
|
|
+ d, err := t.files.FindDescriptorByName(name)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ ed, ok := d.(protoreflect.EnumDescriptor)
|
|
|
+ if !ok {
|
|
|
+ return nil, errors.New("found wrong type: got %v, want enum", descName(d))
|
|
|
+ }
|
|
|
+ return NewEnumType(ed), nil
|
|
|
+}
|
|
|
+
|
|
|
+// FindExtensionByName looks up an extension field by the field's full name.
|
|
|
+// Note that this is the full name of the field as determined by
|
|
|
+// where the extension is declared and is unrelated to the full name of the
|
|
|
+// message being extended.
|
|
|
+//
|
|
|
+// This returns (nil, protoregistry.NotFound) if not found.
|
|
|
+func (t *Types) FindExtensionByName(name protoreflect.FullName) (protoreflect.ExtensionType, error) {
|
|
|
+ d, err := t.files.FindDescriptorByName(name)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ xd, ok := d.(protoreflect.ExtensionDescriptor)
|
|
|
+ if !ok {
|
|
|
+ return nil, errors.New("found wrong type: got %v, want extension", descName(d))
|
|
|
+ }
|
|
|
+ return NewExtensionType(xd), nil
|
|
|
+}
|
|
|
+
|
|
|
+// FindExtensionByNumber looks up an extension field by the field number
|
|
|
+// within some parent message, identified by full name.
|
|
|
+//
|
|
|
+// This returns (nil, protoregistry.NotFound) if not found.
|
|
|
+func (t *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
|
|
|
+ // Construct the extension number map lazily, since not every user will need it.
|
|
|
+ // Update the map if new files are added to the registry.
|
|
|
+ if atomic.LoadUint64(&t.atomicExtFiles) != uint64(t.files.NumFiles()) {
|
|
|
+ t.updateExtensions()
|
|
|
+ }
|
|
|
+ xd := t.extensionsByMessage[extField{message, field}]
|
|
|
+ if xd == nil {
|
|
|
+ return nil, protoregistry.NotFound
|
|
|
+ }
|
|
|
+ return NewExtensionType(xd), nil
|
|
|
+}
|
|
|
+
|
|
|
+// FindMessageByName looks up a message by its full name;
|
|
|
+// e.g. "google.protobuf.Any".
|
|
|
+//
|
|
|
+// This returns (nil, protoregistry.NotFound) if not found.
|
|
|
+func (t *Types) FindMessageByName(name protoreflect.FullName) (protoreflect.MessageType, error) {
|
|
|
+ d, err := t.files.FindDescriptorByName(name)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ md, ok := d.(protoreflect.MessageDescriptor)
|
|
|
+ if !ok {
|
|
|
+ return nil, errors.New("found wrong type: got %v, want message", descName(d))
|
|
|
+ }
|
|
|
+ return NewMessageType(md), nil
|
|
|
+}
|
|
|
+
|
|
|
+// FindMessageByURL looks up a message by a URL identifier.
|
|
|
+// See documentation on google.protobuf.Any.type_url for the URL format.
|
|
|
+//
|
|
|
+// This returns (nil, protoregistry.NotFound) if not found.
|
|
|
+func (t *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) {
|
|
|
+ // This function is similar to FindMessageByName but
|
|
|
+ // truncates anything before and including '/' in the URL.
|
|
|
+ message := protoreflect.FullName(url)
|
|
|
+ if i := strings.LastIndexByte(url, '/'); i >= 0 {
|
|
|
+ message = message[i+len("/"):]
|
|
|
+ }
|
|
|
+ return t.FindMessageByName(message)
|
|
|
+}
|
|
|
+
|
|
|
+func (t *Types) updateExtensions() {
|
|
|
+ t.extMu.Lock()
|
|
|
+ defer t.extMu.Unlock()
|
|
|
+ if atomic.LoadUint64(&t.atomicExtFiles) == uint64(t.files.NumFiles()) {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer atomic.StoreUint64(&t.atomicExtFiles, uint64(t.files.NumFiles()))
|
|
|
+ t.files.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
|
|
|
+ t.registerExtensions(fd.Extensions())
|
|
|
+ t.registerExtensionsInMessages(fd.Messages())
|
|
|
+ return true
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func (t *Types) registerExtensionsInMessages(mds protoreflect.MessageDescriptors) {
|
|
|
+ count := mds.Len()
|
|
|
+ for i := 0; i < count; i++ {
|
|
|
+ md := mds.Get(i)
|
|
|
+ t.registerExtensions(md.Extensions())
|
|
|
+ t.registerExtensionsInMessages(md.Messages())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (t *Types) registerExtensions(xds protoreflect.ExtensionDescriptors) {
|
|
|
+ count := xds.Len()
|
|
|
+ for i := 0; i < count; i++ {
|
|
|
+ xd := xds.Get(i)
|
|
|
+ field := xd.Number()
|
|
|
+ message := xd.ContainingMessage().FullName()
|
|
|
+ if t.extensionsByMessage == nil {
|
|
|
+ t.extensionsByMessage = make(map[extField]protoreflect.ExtensionDescriptor)
|
|
|
+ }
|
|
|
+ t.extensionsByMessage[extField{message, field}] = xd
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func descName(d protoreflect.Descriptor) string {
|
|
|
+ switch d.(type) {
|
|
|
+ case protoreflect.EnumDescriptor:
|
|
|
+ return "enum"
|
|
|
+ case protoreflect.EnumValueDescriptor:
|
|
|
+ return "enum value"
|
|
|
+ case protoreflect.MessageDescriptor:
|
|
|
+ return "message"
|
|
|
+ case protoreflect.ExtensionDescriptor:
|
|
|
+ return "extension"
|
|
|
+ case protoreflect.ServiceDescriptor:
|
|
|
+ return "service"
|
|
|
+ default:
|
|
|
+ return fmt.Sprintf("%T", d)
|
|
|
+ }
|
|
|
+}
|