power-mailinabox/management/utils.py
2022-10-16 18:26:39 +01:00

234 lines
5.9 KiB
Python

import os.path
# DO NOT import non-standard modules. This module is imported by
# migrate.py which runs on fresh machines before anything is installed
# besides Python.
# THE ENVIRONMENT FILE AT /etc/mailinabox.conf
def load_environment():
# Load settings from /etc/mailinabox.conf.
return load_env_vars_from_file("/etc/mailinabox.conf")
def load_env_vars_from_file(fn):
# Load settings from a KEY=VALUE file.
import collections
env = collections.OrderedDict()
for line in open(fn):
env.setdefault(*line.strip().split("=", 1))
return env
def save_environment(env):
with open("/etc/mailinabox.conf", "w") as f:
for k, v in env.items():
f.write("%s=%s\n" % (k, v))
# THE SETTINGS FILE AT STORAGE_ROOT/settings.yaml.
def write_settings(config, env):
import rtyaml
fn = os.path.join(env['STORAGE_ROOT'], 'settings.yaml')
with open(fn, "w") as f:
f.write(rtyaml.dump(config))
def load_settings(env):
import rtyaml
fn = os.path.join(env['STORAGE_ROOT'], 'settings.yaml')
try:
config = rtyaml.load(open(fn, "r"))
if not isinstance(config, dict):
raise ValueError() # caught below
return config
except:
return {}
# UTILITIES
def safe_domain_name(name):
# Sanitize a domain name so it is safe to use as a file name on disk.
import urllib.parse
return urllib.parse.quote(name, safe='')
def sort_domains(domain_names, env):
# Put domain names in a nice sorted order.
# The nice order will group domain names by DNS zone, i.e. the top-most
# domain name that we serve that ecompasses a set of subdomains. Map
# each of the domain names to the zone that contains them. Walk the domains
# from shortest to longest since zones are always shorter than their
# subdomains.
zones = {}
for domain in sorted(domain_names, key=lambda d: len(d)):
for z in zones.values():
if domain.endswith("." + z):
# We found a parent domain already in the list.
zones[domain] = z
break
else:
# 'break' did not occur: there is no parent domain, so it is its
# own zone.
zones[domain] = domain
# Sort the zones.
zone_domains = sorted(
zones.values(),
key=lambda d: (
# PRIMARY_HOSTNAME or the zone that contains it is always first.
not (d == env['PRIMARY_HOSTNAME'] or env['PRIMARY_HOSTNAME'].
endswith("." + d)),
# Then just dumb lexicographically.
d,
))
# Now sort the domain names that fall within each zone.
domain_names = sorted(
domain_names,
key=lambda d: (
# First by zone.
zone_domains.index(zones[d]),
# PRIMARY_HOSTNAME is always first within the zone that contains it.
d != env['PRIMARY_HOSTNAME'],
# Followed by any of its subdomains.
not d.endswith("." + env['PRIMARY_HOSTNAME']),
# Then in right-to-left lexicographic order of the .-separated parts of the name.
list(reversed(d.split("."))),
))
return domain_names
def sort_email_addresses(email_addresses, env):
email_addresses = set(email_addresses)
domains = set(
email.split("@", 1)[1] for email in email_addresses if "@" in email)
ret = []
for domain in sort_domains(domains, env):
domain_emails = set(email for email in email_addresses
if email.endswith("@" + domain))
ret.extend(sorted(domain_emails))
email_addresses -= domain_emails
ret.extend(sorted(email_addresses)) # whatever is left
return ret
def shell(method,
cmd_args,
env={},
capture_stderr=False,
return_bytes=False,
trap=False,
input=None):
# A safe way to execute processes.
# Some processes like apt-get require being given a sane PATH.
import subprocess
env.update({"PATH": "/sbin:/bin:/usr/sbin:/usr/bin"})
kwargs = {
'env': env,
'stderr': None if not capture_stderr else subprocess.STDOUT,
}
if method == "check_output" and input is not None:
kwargs['input'] = input
if not trap:
ret = getattr(subprocess, method)(cmd_args, **kwargs)
else:
try:
ret = getattr(subprocess, method)(cmd_args, **kwargs)
code = 0
except subprocess.CalledProcessError as e:
ret = e.output
code = e.returncode
if not return_bytes and isinstance(ret, bytes):
ret = ret.decode("utf8")
if not trap:
return ret
else:
return code, ret
def create_syslog_handler():
import logging.handlers
handler = logging.handlers.SysLogHandler(address='/dev/log')
handler.setLevel(logging.WARNING)
return handler
def du(path):
# Computes the size of all files in the path, like the `du` command.
# Based on http://stackoverflow.com/a/17936789. Takes into account
# soft and hard links.
total_size = 0
seen = set()
for dirpath, dirnames, filenames in os.walk(path):
for f in filenames:
fp = os.path.join(dirpath, f)
try:
stat = os.lstat(fp)
except OSError:
continue
if stat.st_ino in seen:
continue
seen.add(stat.st_ino)
total_size += stat.st_size
return total_size
def wait_for_service(port, public, env, timeout):
# Block until a service on a given port (bound privately or publicly)
# is taking connections, with a maximum timeout.
import socket
import time
start = time.perf_counter()
while True:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(timeout / 3)
try:
s.connect(("127.0.0.1" if not public else env['PUBLIC_IP'], port))
return True
except OSError:
if time.perf_counter() > start + timeout:
return False
time.sleep(min(timeout / 4, 1))
def get_php_version():
# Gets the version of PHP installed in the system.
return shell("check_output", ["/usr/bin/php", "-v"])[4:7]
def get_os_code():
# Massive mess incoming
dist = shell("check_output", ["/usr/bin/lsb_release", "-is"]).strip()
version = shell("check_output", ["/usr/bin/lsb_release", "-rs"]).strip()
if dist == "Debian":
if version == "10":
return "Debian10"
elif version == "11":
return "Debian11"
elif dist == "Ubuntu":
if version == "20.04":
return "Ubuntu2004"
elif version == "22.04":
return "Ubuntu2204"
return None
if __name__ == "__main__":
from web_update import get_web_domains
env = load_environment()
domains = get_web_domains(env)
for domain in domains:
print(domain)