Vendoring libnetwork for rc5
Signed-off-by: Madhu Venugopal <madhu@docker.com>
This commit is contained in:
parent
9c1be541ff
commit
56b78ab2f2
8 changed files with 85 additions and 283 deletions
|
@ -65,7 +65,7 @@ clone git github.com/RackSec/srslog 259aed10dfa74ea2961eddd1d9847619f6e98837
|
|||
clone git github.com/imdario/mergo 0.2.1
|
||||
|
||||
#get libnetwork packages
|
||||
clone git github.com/docker/libnetwork 83ab4deaa2da3deb32cb5e64ceec43801dc17370
|
||||
clone git github.com/docker/libnetwork 6a3feece4ede9473439f0c835a13e666dc2ab857
|
||||
clone git github.com/docker/go-events afb2b9f2c23f33ada1a22b03651775fdc65a5089
|
||||
clone git github.com/armon/go-radix e39d623f12e8e41c7b5529e9a9dd67a1e2261f80
|
||||
clone git github.com/armon/go-metrics eb0af217e5e9747e41dd5303755356b62d28e3ec
|
||||
|
|
|
@ -526,6 +526,8 @@ func (c *controller) Config() config.Config {
|
|||
}
|
||||
|
||||
func (c *controller) isManager() bool {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.cfg == nil || c.cfg.Daemon.ClusterProvider == nil {
|
||||
return false
|
||||
}
|
||||
|
@ -533,6 +535,8 @@ func (c *controller) isManager() bool {
|
|||
}
|
||||
|
||||
func (c *controller) isAgent() bool {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.cfg == nil || c.cfg.Daemon.ClusterProvider == nil {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -88,15 +88,15 @@ func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal
|
|||
return types.ForbiddenErrorf("encryption key is not present")
|
||||
}
|
||||
|
||||
lIP := types.GetMinimalIP(net.ParseIP(d.bindAddress))
|
||||
aIP := types.GetMinimalIP(net.ParseIP(d.advertiseAddress))
|
||||
lIP := net.ParseIP(d.bindAddress)
|
||||
aIP := net.ParseIP(d.advertiseAddress)
|
||||
nodes := map[string]net.IP{}
|
||||
|
||||
switch {
|
||||
case isLocal:
|
||||
if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
|
||||
if !lIP.Equal(pEntry.vtep) {
|
||||
nodes[pEntry.vtep.String()] = types.GetMinimalIP(pEntry.vtep)
|
||||
nodes[pEntry.vtep.String()] = pEntry.vtep
|
||||
}
|
||||
return false
|
||||
}); err != nil {
|
||||
|
@ -104,7 +104,7 @@ func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal
|
|||
}
|
||||
default:
|
||||
if len(d.network(nid).endpoints) > 0 {
|
||||
nodes[rIP.String()] = types.GetMinimalIP(rIP)
|
||||
nodes[rIP.String()] = rIP
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -337,7 +337,7 @@ func saExists(sa *netlink.XfrmState) (bool, error) {
|
|||
return false, nil
|
||||
default:
|
||||
err = fmt.Errorf("Error while checking for SA existence: %v", err)
|
||||
log.Debug(err)
|
||||
log.Warn(err)
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
@ -351,7 +351,7 @@ func spExists(sp *netlink.XfrmPolicy) (bool, error) {
|
|||
return false, nil
|
||||
default:
|
||||
err = fmt.Errorf("Error while checking for SP existence: %v", err)
|
||||
log.Debug(err)
|
||||
log.Warn(err)
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
@ -411,7 +411,7 @@ func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
|
|||
newIdx = -1
|
||||
priIdx = -1
|
||||
delIdx = -1
|
||||
lIP = types.GetMinimalIP(net.ParseIP(d.bindAddress))
|
||||
lIP = net.ParseIP(d.bindAddress)
|
||||
)
|
||||
|
||||
d.Lock()
|
||||
|
@ -440,7 +440,7 @@ func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
|
|||
}
|
||||
|
||||
d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) {
|
||||
rIP := types.GetMinimalIP(net.ParseIP(rIPs))
|
||||
rIP := net.ParseIP(rIPs)
|
||||
return updateNodeKey(lIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false
|
||||
})
|
||||
|
||||
|
@ -466,10 +466,9 @@ func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
|
|||
}
|
||||
|
||||
/********************************************************
|
||||
* Steady state: rSA0, rSA1, fSA0, fSP0
|
||||
* Rotation --> %rSA0, +rSA2, +fSA1, +fSP1/-fSP0, -fSA0,
|
||||
* Half state: rSA0, rSA1, rSA2, fSA1, fSP1
|
||||
* Steady state: rSA1, rSA2, fSA1, fSP1
|
||||
* Steady state: rSA0, rSA1, rSA2, fSA1, fSP1
|
||||
* Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1
|
||||
* Steady state: rSA1, rSA2, rSA3, fSA2, fSP2
|
||||
*********************************************************/
|
||||
|
||||
// Spis and keys are sorted in such away the one in position 0 is the primary
|
||||
|
@ -488,20 +487,8 @@ func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx,
|
|||
}
|
||||
|
||||
if delIdx != -1 {
|
||||
// %rSA0
|
||||
rSA0 := &netlink.XfrmState{
|
||||
Src: rIP,
|
||||
Dst: lIP,
|
||||
Proto: netlink.XFRM_PROTO_ESP,
|
||||
Spi: spis[delIdx].reverse,
|
||||
Mode: netlink.XFRM_MODE_TRANSPORT,
|
||||
Crypt: &netlink.XfrmStateAlgo{Name: "cbc(aes)", Key: curKeys[delIdx].value},
|
||||
Limits: netlink.XfrmStateLimits{TimeSoft: timeout},
|
||||
}
|
||||
log.Debugf("Updating rSA0{%s}", rSA0)
|
||||
if err := ns.NlHandle().XfrmStateUpdate(rSA0); err != nil {
|
||||
log.Warnf("Failed to update rSA0{%s}: %v", rSA0, err)
|
||||
}
|
||||
// -rSA0
|
||||
programSA(rIP, lIP, spis[delIdx], nil, reverse, false)
|
||||
}
|
||||
|
||||
if newIdx > -1 {
|
||||
|
@ -510,14 +497,14 @@ func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx,
|
|||
}
|
||||
|
||||
if priIdx > 0 {
|
||||
// +fSA1
|
||||
fSA1, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
|
||||
// +fSA2
|
||||
fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
|
||||
|
||||
// +fSP1, -fSP0
|
||||
fullMask := net.CIDRMask(8*len(fSA1.Src), 8*len(fSA1.Src))
|
||||
// +fSP2, -fSP1
|
||||
fullMask := net.CIDRMask(8*len(fSA2.Src), 8*len(fSA2.Src))
|
||||
fSP1 := &netlink.XfrmPolicy{
|
||||
Src: &net.IPNet{IP: fSA1.Src, Mask: fullMask},
|
||||
Dst: &net.IPNet{IP: fSA1.Dst, Mask: fullMask},
|
||||
Src: &net.IPNet{IP: fSA2.Src, Mask: fullMask},
|
||||
Dst: &net.IPNet{IP: fSA2.Dst, Mask: fullMask},
|
||||
Dir: netlink.XFRM_DIR_OUT,
|
||||
Proto: 17,
|
||||
DstPort: 4789,
|
||||
|
@ -526,11 +513,11 @@ func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx,
|
|||
},
|
||||
Tmpls: []netlink.XfrmPolicyTmpl{
|
||||
{
|
||||
Src: fSA1.Src,
|
||||
Dst: fSA1.Dst,
|
||||
Src: fSA2.Src,
|
||||
Dst: fSA2.Dst,
|
||||
Proto: netlink.XFRM_PROTO_ESP,
|
||||
Mode: netlink.XFRM_MODE_TRANSPORT,
|
||||
Spi: fSA1.Spi,
|
||||
Spi: fSA2.Spi,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -539,20 +526,8 @@ func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx,
|
|||
log.Warnf("Failed to update fSP{%s}: %v", fSP1, err)
|
||||
}
|
||||
|
||||
// -fSA0
|
||||
fSA0 := &netlink.XfrmState{
|
||||
Src: lIP,
|
||||
Dst: rIP,
|
||||
Proto: netlink.XFRM_PROTO_ESP,
|
||||
Spi: spis[0].forward,
|
||||
Mode: netlink.XFRM_MODE_TRANSPORT,
|
||||
Crypt: &netlink.XfrmStateAlgo{Name: "cbc(aes)", Key: curKeys[0].value},
|
||||
Limits: netlink.XfrmStateLimits{TimeHard: timeout},
|
||||
}
|
||||
log.Debugf("Removing fSA0{%s}", fSA0)
|
||||
if err := ns.NlHandle().XfrmStateUpdate(fSA0); err != nil {
|
||||
log.Warnf("Failed to remove fSA0{%s}: %v", fSA0, err)
|
||||
}
|
||||
// -fSA1
|
||||
programSA(lIP, rIP, spis[0], nil, forward, false)
|
||||
}
|
||||
|
||||
// swap
|
||||
|
@ -575,7 +550,11 @@ func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx,
|
|||
}
|
||||
|
||||
func (n *network) maxMTU() int {
|
||||
mtu := vxlanVethMTU
|
||||
mtu := 1500
|
||||
if n.mtu != 0 {
|
||||
mtu = n.mtu
|
||||
}
|
||||
mtu -= vxlanEncap
|
||||
if n.secure {
|
||||
// In case of encryption account for the
|
||||
// esp packet espansion and padding
|
||||
|
|
|
@ -63,6 +63,7 @@ type network struct {
|
|||
initErr error
|
||||
subnets []*subnet
|
||||
secure bool
|
||||
mtu int
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -111,9 +112,18 @@ func (d *driver) CreateNetwork(id string, option map[string]interface{}, nInfo d
|
|||
vnis = append(vnis, uint32(vni))
|
||||
}
|
||||
}
|
||||
if _, ok := optMap["secure"]; ok {
|
||||
if _, ok := optMap[secureOption]; ok {
|
||||
n.secure = true
|
||||
}
|
||||
if val, ok := optMap[netlabel.DriverMTU]; ok {
|
||||
var err error
|
||||
if n.mtu, err = strconv.Atoi(val); err != nil {
|
||||
return fmt.Errorf("failed to parse %v: %v", val, err)
|
||||
}
|
||||
if n.mtu < 0 {
|
||||
return fmt.Errorf("invalid MTU value: %v", n.mtu)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we are getting vnis from libnetwork, either we get for
|
||||
|
@ -140,6 +150,13 @@ func (d *driver) CreateNetwork(id string, option map[string]interface{}, nInfo d
|
|||
return fmt.Errorf("failed to update data store for network %v: %v", n.id, err)
|
||||
}
|
||||
|
||||
// Make sure no rule is on the way from any stale secure network
|
||||
if !n.secure {
|
||||
for _, vni := range vnis {
|
||||
programMangle(vni, false)
|
||||
}
|
||||
}
|
||||
|
||||
if nInfo != nil {
|
||||
if err := nInfo.TableEventRegister(ovPeerTable); err != nil {
|
||||
return err
|
||||
|
@ -315,7 +332,7 @@ func networkOnceInit() {
|
|||
return
|
||||
}
|
||||
|
||||
err := createVxlan("testvxlan", 1)
|
||||
err := createVxlan("testvxlan", 1, 0)
|
||||
if err != nil {
|
||||
logrus.Errorf("Failed to create testvxlan interface: %v", err)
|
||||
return
|
||||
|
@ -459,7 +476,7 @@ func (n *network) setupSubnetSandbox(s *subnet, brName, vxlanName string) error
|
|||
return fmt.Errorf("bridge creation in sandbox failed for subnet %q: %v", s.subnetIP.String(), err)
|
||||
}
|
||||
|
||||
err := createVxlan(vxlanName, n.vxlanID(s))
|
||||
err := createVxlan(vxlanName, n.vxlanID(s), n.maxMTU())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -732,6 +749,7 @@ func (n *network) Value() []byte {
|
|||
|
||||
m["secure"] = n.secure
|
||||
m["subnets"] = netJSON
|
||||
m["mtu"] = n.mtu
|
||||
b, err = json.Marshal(m)
|
||||
if err != nil {
|
||||
return []byte{}
|
||||
|
@ -781,6 +799,9 @@ func (n *network) SetValue(value []byte) error {
|
|||
if val, ok := m["secure"]; ok {
|
||||
n.secure = val.(bool)
|
||||
}
|
||||
if val, ok := m["mtu"]; ok {
|
||||
n.mtu = int(val.(float64))
|
||||
}
|
||||
bytes, err := json.Marshal(m["subnets"])
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -52,11 +52,11 @@ func createVethPair() (string, string, error) {
|
|||
return name1, name2, nil
|
||||
}
|
||||
|
||||
func createVxlan(name string, vni uint32) error {
|
||||
func createVxlan(name string, vni uint32, mtu int) error {
|
||||
defer osl.InitOSContext()()
|
||||
|
||||
vxlan := &netlink.Vxlan{
|
||||
LinkAttrs: netlink.LinkAttrs{Name: name},
|
||||
LinkAttrs: netlink.LinkAttrs{Name: name, MTU: mtu},
|
||||
VxlanId: int(vni),
|
||||
Learning: true,
|
||||
Port: vxlanPort,
|
||||
|
|
|
@ -25,7 +25,8 @@ const (
|
|||
vxlanIDStart = 256
|
||||
vxlanIDEnd = (1 << 24) - 1
|
||||
vxlanPort = 4789
|
||||
vxlanVethMTU = 1450
|
||||
vxlanEncap = 50
|
||||
secureOption = "encrypted"
|
||||
)
|
||||
|
||||
var initVxlanIdm = make(chan (bool), 1)
|
||||
|
|
|
@ -536,9 +536,6 @@ func (ep *endpoint) sbJoin(sb *sandbox, options ...EndpointOption) error {
|
|||
}
|
||||
}
|
||||
|
||||
if sb.resolver != nil {
|
||||
sb.resolver.FlushExtServers()
|
||||
}
|
||||
}
|
||||
|
||||
if !sb.needDefaultGW() {
|
||||
|
@ -619,10 +616,6 @@ func (ep *endpoint) Leave(sbox Sandbox, options ...EndpointOption) error {
|
|||
sb.joinLeaveStart()
|
||||
defer sb.joinLeaveEnd()
|
||||
|
||||
if sb.resolver != nil {
|
||||
sb.resolver.FlushExtServers()
|
||||
}
|
||||
|
||||
return ep.sbLeave(sb, false, options...)
|
||||
}
|
||||
|
||||
|
@ -776,9 +769,7 @@ func (ep *endpoint) Delete(force bool) error {
|
|||
}()
|
||||
|
||||
// unwatch for service records
|
||||
if !n.getController().isAgent() {
|
||||
n.getController().unWatchSvcRecord(ep)
|
||||
}
|
||||
n.getController().unWatchSvcRecord(ep)
|
||||
|
||||
if err = ep.deleteEndpoint(force); err != nil && !force {
|
||||
return err
|
||||
|
|
238
vendor/src/github.com/docker/libnetwork/resolver.go
vendored
238
vendor/src/github.com/docker/libnetwork/resolver.go
vendored
|
@ -30,9 +30,6 @@ type Resolver interface {
|
|||
// SetExtServers configures the external nameservers the resolver
|
||||
// should use to forward queries
|
||||
SetExtServers([]string)
|
||||
// FlushExtServers clears the cached UDP connections to external
|
||||
// nameservers
|
||||
FlushExtServers()
|
||||
// ResolverOptions returns resolv.conf options that should be set
|
||||
ResolverOptions() []string
|
||||
}
|
||||
|
@ -48,35 +45,12 @@ const (
|
|||
defaultRespSize = 512
|
||||
maxConcurrent = 100
|
||||
logInterval = 2 * time.Second
|
||||
maxDNSID = 65536
|
||||
)
|
||||
|
||||
type clientConn struct {
|
||||
dnsID uint16
|
||||
respWriter dns.ResponseWriter
|
||||
}
|
||||
|
||||
type extDNSEntry struct {
|
||||
ipStr string
|
||||
extConn net.Conn
|
||||
extOnce sync.Once
|
||||
ipStr string
|
||||
}
|
||||
|
||||
type sboxQuery struct {
|
||||
sboxID string
|
||||
dnsID uint16
|
||||
}
|
||||
|
||||
type clientConnGC struct {
|
||||
toDelete bool
|
||||
client clientConn
|
||||
}
|
||||
|
||||
var (
|
||||
queryGCMutex sync.Mutex
|
||||
queryGC map[sboxQuery]*clientConnGC
|
||||
)
|
||||
|
||||
// resolver implements the Resolver interface
|
||||
type resolver struct {
|
||||
sb *sandbox
|
||||
|
@ -89,34 +63,17 @@ type resolver struct {
|
|||
count int32
|
||||
tStamp time.Time
|
||||
queryLock sync.Mutex
|
||||
client map[uint16]clientConn
|
||||
}
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().Unix())
|
||||
queryGC = make(map[sboxQuery]*clientConnGC)
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
for range ticker.C {
|
||||
queryGCMutex.Lock()
|
||||
for query, conn := range queryGC {
|
||||
if !conn.toDelete {
|
||||
conn.toDelete = true
|
||||
continue
|
||||
}
|
||||
delete(queryGC, query)
|
||||
}
|
||||
queryGCMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// NewResolver creates a new instance of the Resolver
|
||||
func NewResolver(sb *sandbox) Resolver {
|
||||
return &resolver{
|
||||
sb: sb,
|
||||
err: fmt.Errorf("setup not done yet"),
|
||||
client: make(map[uint16]clientConn),
|
||||
sb: sb,
|
||||
err: fmt.Errorf("setup not done yet"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -173,20 +130,7 @@ func (r *resolver) Start() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *resolver) FlushExtServers() {
|
||||
for i := 0; i < maxExtDNS; i++ {
|
||||
if r.extDNSList[i].extConn != nil {
|
||||
r.extDNSList[i].extConn.Close()
|
||||
}
|
||||
|
||||
r.extDNSList[i].extConn = nil
|
||||
r.extDNSList[i].extOnce = sync.Once{}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) Stop() {
|
||||
r.FlushExtServers()
|
||||
|
||||
if r.server != nil {
|
||||
r.server.Shutdown()
|
||||
}
|
||||
|
@ -355,7 +299,6 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|||
extConn net.Conn
|
||||
resp *dns.Msg
|
||||
err error
|
||||
writer dns.ResponseWriter
|
||||
)
|
||||
|
||||
if query == nil || len(query.Question) == 0 {
|
||||
|
@ -397,10 +340,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|||
if resp.Len() > maxSize {
|
||||
truncateResp(resp, maxSize, proto == "tcp")
|
||||
}
|
||||
writer = w
|
||||
} else {
|
||||
queryID := query.Id
|
||||
extQueryLoop:
|
||||
for i := 0; i < maxExtDNS; i++ {
|
||||
extDNS := &r.extDNSList[i]
|
||||
if extDNS.ipStr == "" {
|
||||
|
@ -411,30 +351,9 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|||
extConn, err = net.DialTimeout(proto, addr, extIOTimeout)
|
||||
}
|
||||
|
||||
// For udp clients connection is persisted to reuse for further queries.
|
||||
// Accessing extDNS.extConn be a race here between go rouines. Hence the
|
||||
// connection setup is done in a Once block and fetch the extConn again
|
||||
extConn = extDNS.extConn
|
||||
if extConn == nil || proto == "tcp" {
|
||||
if proto == "udp" {
|
||||
extDNS.extOnce.Do(func() {
|
||||
r.sb.execFunc(extConnect)
|
||||
extDNS.extConn = extConn
|
||||
})
|
||||
extConn = extDNS.extConn
|
||||
} else {
|
||||
r.sb.execFunc(extConnect)
|
||||
}
|
||||
if err != nil {
|
||||
log.Debugf("Connect failed, %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// If two go routines are executing in parralel one will
|
||||
// block on the Once.Do and in case of error connecting
|
||||
// to the external server it will end up with a nil err
|
||||
// but extConn also being nil.
|
||||
if extConn == nil {
|
||||
r.sb.execFunc(extConnect)
|
||||
if err != nil {
|
||||
log.Debugf("Connect failed, %s", err)
|
||||
continue
|
||||
}
|
||||
log.Debugf("Query %s[%d] from %s, forwarding to %s:%s", name, query.Question[0].Qtype,
|
||||
|
@ -443,10 +362,10 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|||
// Timeout has to be set for every IO operation.
|
||||
extConn.SetDeadline(time.Now().Add(extIOTimeout))
|
||||
co := &dns.Conn{Conn: extConn}
|
||||
defer co.Close()
|
||||
|
||||
// forwardQueryStart stores required context to mux multiple client queries over
|
||||
// one connection; and limits the number of outstanding concurrent queries.
|
||||
if r.forwardQueryStart(w, query, queryID) == false {
|
||||
// limits the number of outstanding concurrent queries.
|
||||
if r.forwardQueryStart() == false {
|
||||
old := r.tStamp
|
||||
r.tStamp = time.Now()
|
||||
if r.tStamp.Sub(old) > logInterval {
|
||||
|
@ -455,69 +374,38 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|||
continue
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if proto == "tcp" {
|
||||
co.Close()
|
||||
}
|
||||
}()
|
||||
err = co.WriteMsg(query)
|
||||
if err != nil {
|
||||
r.forwardQueryEnd(w, query)
|
||||
r.forwardQueryEnd()
|
||||
log.Debugf("Send to DNS server failed, %s", err)
|
||||
continue
|
||||
}
|
||||
for {
|
||||
// If a reply comes after a read timeout it will remain in the socket buffer
|
||||
// and will be read after sending next query. To ignore such stale replies
|
||||
// save the query context in a GC queue when read timesout. On the next reply
|
||||
// if the context is present in the GC queue its a old reply. Ignore it and
|
||||
// read again
|
||||
resp, err = co.ReadMsg()
|
||||
if err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
r.addQueryToGC(w, query)
|
||||
}
|
||||
r.forwardQueryEnd(w, query)
|
||||
log.Debugf("Read from DNS server failed, %s", err)
|
||||
continue extQueryLoop
|
||||
}
|
||||
|
||||
if !r.checkRespInGC(w, resp) {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Retrieves the context for the forwarded query and returns the client connection
|
||||
// to send the reply to
|
||||
writer = r.forwardQueryEnd(w, resp)
|
||||
if writer == nil {
|
||||
resp, err = co.ReadMsg()
|
||||
// Truncated DNS replies should be sent to the client so that the
|
||||
// client can retry over TCP
|
||||
if err != nil && err != dns.ErrTruncated {
|
||||
r.forwardQueryEnd()
|
||||
log.Debugf("Read from DNS server failed, %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
r.forwardQueryEnd()
|
||||
|
||||
resp.Compress = true
|
||||
break
|
||||
}
|
||||
if resp == nil || writer == nil {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if writer == nil {
|
||||
return
|
||||
}
|
||||
if err = writer.WriteMsg(resp); err != nil {
|
||||
if err = w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("error writing resolver resp, %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg, queryID uint16) bool {
|
||||
proto := w.LocalAddr().Network()
|
||||
dnsID := uint16(rand.Intn(maxDNSID))
|
||||
|
||||
cc := clientConn{
|
||||
dnsID: queryID,
|
||||
respWriter: w,
|
||||
}
|
||||
|
||||
func (r *resolver) forwardQueryStart() bool {
|
||||
r.queryLock.Lock()
|
||||
defer r.queryLock.Unlock()
|
||||
|
||||
|
@ -526,74 +414,10 @@ func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg, queryID
|
|||
}
|
||||
r.count++
|
||||
|
||||
switch proto {
|
||||
case "tcp":
|
||||
break
|
||||
case "udp":
|
||||
for ok := true; ok == true; dnsID = uint16(rand.Intn(maxDNSID)) {
|
||||
_, ok = r.client[dnsID]
|
||||
}
|
||||
log.Debugf("client dns id %v, changed id %v", queryID, dnsID)
|
||||
r.client[dnsID] = cc
|
||||
msg.Id = dnsID
|
||||
default:
|
||||
log.Errorf("Invalid protocol..")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *resolver) addQueryToGC(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
if w.LocalAddr().Network() != "udp" {
|
||||
return
|
||||
}
|
||||
|
||||
r.queryLock.Lock()
|
||||
cc, ok := r.client[msg.Id]
|
||||
r.queryLock.Unlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
query := sboxQuery{
|
||||
sboxID: r.sb.ID(),
|
||||
dnsID: msg.Id,
|
||||
}
|
||||
clientGC := &clientConnGC{
|
||||
client: cc,
|
||||
}
|
||||
queryGCMutex.Lock()
|
||||
queryGC[query] = clientGC
|
||||
queryGCMutex.Unlock()
|
||||
}
|
||||
|
||||
func (r *resolver) checkRespInGC(w dns.ResponseWriter, msg *dns.Msg) bool {
|
||||
if w.LocalAddr().Network() != "udp" {
|
||||
return false
|
||||
}
|
||||
|
||||
query := sboxQuery{
|
||||
sboxID: r.sb.ID(),
|
||||
dnsID: msg.Id,
|
||||
}
|
||||
|
||||
queryGCMutex.Lock()
|
||||
defer queryGCMutex.Unlock()
|
||||
if _, ok := queryGC[query]; ok {
|
||||
delete(queryGC, query)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *resolver) forwardQueryEnd(w dns.ResponseWriter, msg *dns.Msg) dns.ResponseWriter {
|
||||
var (
|
||||
cc clientConn
|
||||
ok bool
|
||||
)
|
||||
proto := w.LocalAddr().Network()
|
||||
|
||||
func (r *resolver) forwardQueryEnd() {
|
||||
r.queryLock.Lock()
|
||||
defer r.queryLock.Unlock()
|
||||
|
||||
|
@ -602,22 +426,4 @@ func (r *resolver) forwardQueryEnd(w dns.ResponseWriter, msg *dns.Msg) dns.Respo
|
|||
} else {
|
||||
r.count--
|
||||
}
|
||||
|
||||
switch proto {
|
||||
case "tcp":
|
||||
break
|
||||
case "udp":
|
||||
if cc, ok = r.client[msg.Id]; ok == false {
|
||||
log.Debugf("Can't retrieve client context for dns id %v", msg.Id)
|
||||
return nil
|
||||
}
|
||||
log.Debugf("dns msg id %v, client id %v", msg.Id, cc.dnsID)
|
||||
delete(r.client, msg.Id)
|
||||
msg.Id = cc.dnsID
|
||||
w = cc.respWriter
|
||||
default:
|
||||
log.Errorf("Invalid protocol")
|
||||
return nil
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue