123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- //go:build windows
- package socket
- import (
- "errors"
- "fmt"
- "net"
- "sync"
- "syscall"
- "unsafe"
- "github.com/Microsoft/go-winio/pkg/guid"
- "golang.org/x/sys/windows"
- )
- //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
- //sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
- //sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
- //sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
- const socketError = uintptr(^uint32(0))
- var (
- // todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
- ErrBufferSize = errors.New("buffer size")
- ErrAddrFamily = errors.New("address family")
- ErrInvalidPointer = errors.New("invalid pointer")
- ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
- )
- // todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
- // GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
- // If rsa is not large enough, the [windows.WSAEFAULT] is returned.
- func GetSockName(s windows.Handle, rsa RawSockaddr) error {
- ptr, l, err := rsa.Sockaddr()
- if err != nil {
- return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
- }
- // although getsockname returns WSAEFAULT if the buffer is too small, it does not set
- // &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
- return getsockname(s, ptr, &l)
- }
- // GetPeerName returns the remote address the socket is connected to.
- //
- // See [GetSockName] for more information.
- func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
- ptr, l, err := rsa.Sockaddr()
- if err != nil {
- return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
- }
- return getpeername(s, ptr, &l)
- }
- func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
- ptr, l, err := rsa.Sockaddr()
- if err != nil {
- return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
- }
- return bind(s, ptr, l)
- }
- // "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
- // their sockaddr interface, so they cannot be used with HvsockAddr
- // Replicate functionality here from
- // https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
- // The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
- // runtime via a WSAIoctl call:
- // https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
- type runtimeFunc struct {
- id guid.GUID
- once sync.Once
- addr uintptr
- err error
- }
- func (f *runtimeFunc) Load() error {
- f.once.Do(func() {
- var s windows.Handle
- s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
- if f.err != nil {
- return
- }
- defer windows.CloseHandle(s) //nolint:errcheck
- var n uint32
- f.err = windows.WSAIoctl(s,
- windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
- (*byte)(unsafe.Pointer(&f.id)),
- uint32(unsafe.Sizeof(f.id)),
- (*byte)(unsafe.Pointer(&f.addr)),
- uint32(unsafe.Sizeof(f.addr)),
- &n,
- nil, // overlapped
- 0, // completionRoutine
- )
- })
- return f.err
- }
- var (
- // todo: add `AcceptEx` and `GetAcceptExSockaddrs`
- WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
- Data1: 0x25a207b9,
- Data2: 0xddf3,
- Data3: 0x4660,
- Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
- }
- connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
- )
- func ConnectEx(
- fd windows.Handle,
- rsa RawSockaddr,
- sendBuf *byte,
- sendDataLen uint32,
- bytesSent *uint32,
- overlapped *windows.Overlapped,
- ) error {
- if err := connectExFunc.Load(); err != nil {
- return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
- }
- ptr, n, err := rsa.Sockaddr()
- if err != nil {
- return err
- }
- return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
- }
- // BOOL LpfnConnectex(
- // [in] SOCKET s,
- // [in] const sockaddr *name,
- // [in] int namelen,
- // [in, optional] PVOID lpSendBuffer,
- // [in] DWORD dwSendDataLength,
- // [out] LPDWORD lpdwBytesSent,
- // [in] LPOVERLAPPED lpOverlapped
- // )
- func connectEx(
- s windows.Handle,
- name unsafe.Pointer,
- namelen int32,
- sendBuf *byte,
- sendDataLen uint32,
- bytesSent *uint32,
- overlapped *windows.Overlapped,
- ) (err error) {
- // todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN
- r1, _, e1 := syscall.Syscall9(connectExFunc.addr,
- 7,
- uintptr(s),
- uintptr(name),
- uintptr(namelen),
- uintptr(unsafe.Pointer(sendBuf)),
- uintptr(sendDataLen),
- uintptr(unsafe.Pointer(bytesSent)),
- uintptr(unsafe.Pointer(overlapped)),
- 0,
- 0)
- if r1 == 0 {
- if e1 != 0 {
- err = error(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return err
- }
|