diff --git a/network.go b/network.go index 85c6083316..54b8dbfe31 100644 --- a/network.go +++ b/network.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "fmt" + "io" "log" "net" "os/exec" @@ -221,10 +222,55 @@ func (mapper *PortMapper) Map(port int, dest net.TCPAddr) error { if err := mapper.iptablesForward("-A", port, dest); err != nil { return err } + mapper.mapping[port] = dest + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + mapper.Unmap(port) + return err + } + // FIXME: store the listener so we can close it at Unmap + go proxy(listener, "tcp", dest.String()) return nil } +// proxy listens for socket connections on `listener`, and forwards them unmodified +// to `proto:address` +func proxy(listener net.Listener, proto, address string) error { + Debugf("proxying to %s:%s", proto, address) + defer Debugf("Done proxying to %s:%s", proto, address) + for { + Debugf("Listening on %s", listener) + src, err := listener.Accept() + if err != nil { + return err + } + Debugf("Connecting to %s:%s", proto, address) + dst, err := net.Dial(proto, address) + if err != nil { + log.Printf("Error connecting to %s:%s: %s", proto, address, err) + src.Close() + continue + } + Debugf("Connected to backend, splicing") + splice(src, dst) + } + return nil +} + +func halfSplice(dst, src net.Conn) error { + _, err := io.Copy(dst, src) + // FIXME: on EOF from a tcp connection, pass WriteClose() + dst.Close() + src.Close() + return err +} + +func splice(a, b net.Conn) { + go halfSplice(a, b) + go halfSplice(b, a) +} + func (mapper *PortMapper) Unmap(port int) error { dest, ok := mapper.mapping[port] if !ok {