run status checks asynchronously so that they finish faster, since many checks are waiting on network replies and ought not to block the whole thing
This commit is contained in:
parent
8fd98d7db3
commit
7e05d7478f
1 changed files with 105 additions and 50 deletions
|
@ -6,7 +6,7 @@
|
|||
|
||||
__ALL__ = ['check_certificate']
|
||||
|
||||
import os, os.path, re, subprocess, datetime
|
||||
import os, os.path, re, subprocess, datetime, multiprocessing.pool
|
||||
|
||||
import dns.reversename, dns.resolver
|
||||
import dateutil.parser, dateutil.tz
|
||||
|
@ -35,15 +35,17 @@ def run_checks(env, output):
|
|||
|
||||
run_system_checks(env, output)
|
||||
|
||||
# perform other checks
|
||||
run_network_checks(env, output)
|
||||
run_domain_checks(env, output)
|
||||
# perform other checks asynchronously
|
||||
|
||||
pool = multiprocessing.pool.Pool(processes=1)
|
||||
r1 = pool.apply_async(run_network_checks, [env])
|
||||
r2 = run_domain_checks(env)
|
||||
r1.get().playback(output)
|
||||
r2.playback(output)
|
||||
|
||||
def run_services_checks(env, output):
|
||||
# Check that system services are running.
|
||||
|
||||
import socket
|
||||
|
||||
services = [
|
||||
{ "name": "Local DNS (bind9)", "port": 53, "public": False, },
|
||||
#{ "name": "NSD Control", "port": 8952, "public": False, },
|
||||
|
@ -66,33 +68,47 @@ def run_services_checks(env, output):
|
|||
{ "name": "HTTPS Web (nginx)", "port": 443, "public": True, },
|
||||
]
|
||||
|
||||
ok = True
|
||||
|
||||
for service in services:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.settimeout(.1)
|
||||
try:
|
||||
s.connect((
|
||||
"127.0.0.1" if not service["public"] else env['PUBLIC_IP'],
|
||||
service["port"]))
|
||||
except OSError as e:
|
||||
output.print_error("%s is not running (%s)." % (service['name'], str(e)))
|
||||
all_running = True
|
||||
fatal = False
|
||||
pool = multiprocessing.pool.Pool(processes=10)
|
||||
ret = pool.starmap(check_service, ((i, service, env) for i, service in enumerate(services)), chunksize=1)
|
||||
for i, running, fatal2, output2 in sorted(ret):
|
||||
all_running = all_running and running
|
||||
fatal = fatal or fatal2
|
||||
output2.playback(output)
|
||||
|
||||
# Why is nginx not running?
|
||||
if service["port"] in (80, 443):
|
||||
output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip())
|
||||
|
||||
# Flag if local DNS is not running.
|
||||
if service["port"] == 53 and service["public"] == False:
|
||||
ok = False
|
||||
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
if ok:
|
||||
if all_running:
|
||||
output.print_ok("All system services are running.")
|
||||
|
||||
return ok
|
||||
return not fatal
|
||||
|
||||
def check_service(i, service, env):
|
||||
import socket
|
||||
output = BufferedOutput()
|
||||
running = False
|
||||
fatal = False
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.settimeout(1)
|
||||
try:
|
||||
s.connect((
|
||||
"127.0.0.1" if not service["public"] else env['PUBLIC_IP'],
|
||||
service["port"]))
|
||||
running = True
|
||||
|
||||
except OSError as e:
|
||||
output.print_error("%s is not running (%s)." % (service['name'], str(e)))
|
||||
|
||||
# Why is nginx not running?
|
||||
if service["port"] in (80, 443):
|
||||
output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip())
|
||||
|
||||
# Flag if local DNS is not running.
|
||||
if service["port"] == 53 and service["public"] == False:
|
||||
fatal = True
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
return (i, running, fatal, output)
|
||||
|
||||
def run_system_checks(env, output):
|
||||
check_ssh_password(env, output)
|
||||
|
@ -146,9 +162,10 @@ def check_free_disk_space(env, output):
|
|||
else:
|
||||
output.print_error(disk_msg)
|
||||
|
||||
def run_network_checks(env, output):
|
||||
def run_network_checks(env):
|
||||
# Also see setup/network-checks.sh.
|
||||
|
||||
output = BufferedOutput()
|
||||
output.add_heading("Network")
|
||||
|
||||
# Stop if we cannot make an outbound connection on port 25. Many residential
|
||||
|
@ -176,7 +193,9 @@ def run_network_checks(env, output):
|
|||
which may prevent recipients from receiving your email. See http://www.spamhaus.org/query/ip/%s."""
|
||||
% (env['PUBLIC_IP'], zen, env['PUBLIC_IP']))
|
||||
|
||||
def run_domain_checks(env, output):
|
||||
return output
|
||||
|
||||
def run_domain_checks(env):
|
||||
# Get the list of domains we handle mail for.
|
||||
mail_domains = get_mail_domains(env)
|
||||
|
||||
|
@ -187,24 +206,44 @@ def run_domain_checks(env, output):
|
|||
# Get the list of domains we serve HTTPS for.
|
||||
web_domains = set(get_web_domains(env))
|
||||
|
||||
# Check the domains.
|
||||
for domain in sort_domains(mail_domains | dns_domains | web_domains, env):
|
||||
output.add_heading(domain)
|
||||
domains_to_check = mail_domains | dns_domains | web_domains
|
||||
|
||||
if domain == env["PRIMARY_HOSTNAME"]:
|
||||
check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles)
|
||||
# Serial version:
|
||||
#for domain in sort_domains(domains_to_check, env):
|
||||
# run_domain_checks_on_domain(domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains)
|
||||
|
||||
# Parallelize the checks across a worker pool.
|
||||
args = ((domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains)
|
||||
for domain in domains_to_check)
|
||||
pool = multiprocessing.pool.Pool(processes=10)
|
||||
ret = pool.starmap(run_domain_checks_on_domain, args, chunksize=1)
|
||||
ret = dict(ret) # (domain, output) => { domain: output }
|
||||
output = BufferedOutput()
|
||||
for domain in sort_domains(ret, env):
|
||||
ret[domain].playback(output)
|
||||
return output
|
||||
|
||||
def run_domain_checks_on_domain(domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains):
|
||||
output = BufferedOutput()
|
||||
|
||||
output.add_heading(domain)
|
||||
|
||||
if domain == env["PRIMARY_HOSTNAME"]:
|
||||
check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles)
|
||||
|
||||
if domain in dns_domains:
|
||||
check_dns_zone(domain, env, output, dns_zonefiles)
|
||||
if domain in dns_domains:
|
||||
check_dns_zone(domain, env, output, dns_zonefiles)
|
||||
|
||||
if domain in mail_domains:
|
||||
check_mail_domain(domain, env, output)
|
||||
if domain in mail_domains:
|
||||
check_mail_domain(domain, env, output)
|
||||
|
||||
if domain in web_domains:
|
||||
check_web_domain(domain, env, output)
|
||||
if domain in web_domains:
|
||||
check_web_domain(domain, env, output)
|
||||
|
||||
if domain in dns_domains:
|
||||
check_dns_zone_suggestions(domain, env, output, dns_zonefiles)
|
||||
if domain in dns_domains:
|
||||
check_dns_zone_suggestions(domain, env, output, dns_zonefiles)
|
||||
|
||||
return (domain, output)
|
||||
|
||||
def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles):
|
||||
# If a DS record is set on the zone containing this domain, check DNSSEC now.
|
||||
|
@ -655,11 +694,12 @@ def list_apt_updates(apt_update=True):
|
|||
return pkgs
|
||||
|
||||
|
||||
try:
|
||||
terminal_columns = int(shell('check_output', ['stty', 'size']).split()[1])
|
||||
except:
|
||||
terminal_columns = 76
|
||||
class ConsoleOutput:
|
||||
try:
|
||||
terminal_columns = int(shell('check_output', ['stty', 'size']).split()[1])
|
||||
except:
|
||||
terminal_columns = 76
|
||||
|
||||
def add_heading(self, heading):
|
||||
print()
|
||||
print(heading)
|
||||
|
@ -680,7 +720,7 @@ class ConsoleOutput:
|
|||
words = re.split("(\s+)", message)
|
||||
linelen = 0
|
||||
for w in words:
|
||||
if linelen + len(w) > terminal_columns-1-len(first_line):
|
||||
if linelen + len(w) > self.terminal_columns-1-len(first_line):
|
||||
print()
|
||||
print(" ", end="")
|
||||
linelen = 0
|
||||
|
@ -693,6 +733,21 @@ class ConsoleOutput:
|
|||
for line in message.split("\n"):
|
||||
self.print_block(line)
|
||||
|
||||
class BufferedOutput:
|
||||
# Record all of the instance method calls so we can play them back later.
|
||||
def __init__(self):
|
||||
self.buf = []
|
||||
def __getattr__(self, attr):
|
||||
if attr not in ("add_heading", "print_ok", "print_error", "print_warning", "print_block", "print_line"):
|
||||
raise AttributeError
|
||||
# Return a function that just records the call & arguments to our buffer.
|
||||
def w(*args, **kwargs):
|
||||
self.buf.append((attr, args, kwargs))
|
||||
return w
|
||||
def playback(self, output):
|
||||
for attr, args, kwargs in self.buf:
|
||||
getattr(output, attr)(*args, **kwargs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
from utils import load_environment
|
||||
|
|
Loading…
Reference in a new issue