Просмотр исходного кода

create get_spam_score() as a sync function, use a simpler version for running MailHandler. Remove async/await

Son NK 4 лет назад
Родитель
Сommit
abc42df0fb
1 измененных файлов с 75 добавлено и 40 удалено
  1. 75 40
      email_handler.py

+ 75 - 40
email_handler.py

@@ -49,6 +49,7 @@ import aiosmtpd
 import aiospamc
 import arrow
 import spf
+from aiosmtpd.controller import Controller
 from aiosmtpd.smtp import Envelope
 from sqlalchemy.exc import IntegrityError
 
@@ -109,6 +110,7 @@ from app.models import (
     Mailbox,
 )
 from app.pgp_utils import PGPException
+from app.spamassassin_utils import SpamAssassin
 from app.utils import random_string
 from init_app import load_pgp_public_keys
 from server import create_app, create_light_app
@@ -436,9 +438,7 @@ def handle_email_sent_to_ourself(alias, mailbox, msg: Message, user):
     )
 
 
-async def handle_forward(
-    envelope, msg: Message, rcpt_to: str
-) -> List[Tuple[bool, str]]:
+def handle_forward(envelope, msg: Message, rcpt_to: str) -> List[Tuple[bool, str]]:
     """return an array of SMTP status (is_success, smtp_status)
     is_success indicates whether an email has been delivered and
     smtp_status is the SMTP Status ("250 Message accepted", "550 Non-existent email address", etc)
@@ -490,7 +490,7 @@ async def handle_forward(
             return [(False, "550 SL E18 unverified mailbox")]
         else:
             ret.append(
-                await forward_email_to_mailbox(
+                forward_email_to_mailbox(
                     alias, msg, email_log, contact, envelope, mailbox, user
                 )
             )
@@ -502,7 +502,7 @@ async def handle_forward(
                 ret.append((False, "550 SL E19 unverified mailbox"))
             else:
                 ret.append(
-                    await forward_email_to_mailbox(
+                    forward_email_to_mailbox(
                         alias,
                         copy(msg),
                         email_log,
@@ -516,7 +516,7 @@ async def handle_forward(
     return ret
 
 
-async def forward_email_to_mailbox(
+def forward_email_to_mailbox(
     alias,
     msg: Message,
     email_log: EmailLog,
@@ -566,7 +566,7 @@ async def forward_email_to_mailbox(
 
     if SPAMASSASSIN_HOST:
         start = time.time()
-        spam_score = await get_spam_score(msg)
+        spam_score = get_spam_score(msg)
         LOG.d(
             "%s -> %s - spam score %s in %s seconds",
             contact,
@@ -684,7 +684,7 @@ async def forward_email_to_mailbox(
         return True, "250 Message accepted for delivery"
 
 
-async def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
+def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
     """
     return whether an email has been delivered and
     the smtp status ("250 Message accepted", "550 Non-existent email address", etc)
@@ -762,7 +762,7 @@ async def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
     # do not use user.max_spam_score here
     if SPAMASSASSIN_HOST:
         start = time.time()
-        spam_score = await get_spam_score(msg)
+        spam_score = get_spam_score(msg)
         LOG.d(
             "%s -> %s - spam score %s in %s seconds",
             alias,
@@ -1418,7 +1418,7 @@ def handle_sender_email(envelope: Envelope):
     return "250 email to sender accepted"
 
 
-async def handle(envelope: Envelope) -> str:
+def handle(envelope: Envelope) -> str:
     """Return SMTP status"""
 
     # sanitize mail_from, rcpt_tos
@@ -1455,7 +1455,7 @@ async def handle(envelope: Envelope) -> str:
         # recipient starts with "reply+" or "ra+" (ra=reverse-alias) prefix
         if rcpt_to.startswith("reply+") or rcpt_to.startswith("ra+"):
             LOG.debug(">>> Reply phase %s(%s) -> %s", mail_from, msg["From"], rcpt_to)
-            is_delivered, smtp_status = await handle_reply(envelope, msg, rcpt_to)
+            is_delivered, smtp_status = handle_reply(envelope, msg, rcpt_to)
             res.append((is_delivered, smtp_status))
         else:  # Forward case
             LOG.debug(
@@ -1464,9 +1464,7 @@ async def handle(envelope: Envelope) -> str:
                 msg["From"],
                 rcpt_to,
             )
-            for is_delivered, smtp_status in await handle_forward(
-                envelope, msg, rcpt_to
-            ):
+            for is_delivered, smtp_status in handle_forward(envelope, msg, rcpt_to):
                 res.append((is_delivered, smtp_status))
 
     for (is_success, smtp_status) in res:
@@ -1478,7 +1476,7 @@ async def handle(envelope: Envelope) -> str:
     return res[0][1]
 
 
-async def get_spam_score(message: Message) -> float:
+async def get_spam_score_async(message: Message) -> float:
     LOG.debug("get spam score for %s", message[_MESSAGE_ID])
     sa_input = to_bytes(message)
 
@@ -1502,6 +1500,24 @@ async def get_spam_score(message: Message) -> float:
         return -999
 
 
+def get_spam_score(message: Message) -> float:
+    LOG.debug("get spam score for %s", message[_MESSAGE_ID])
+    sa_input = to_bytes(message)
+
+    # Spamassassin requires to have an ending linebreak
+    if not sa_input.endswith(b"\n"):
+        LOG.d("add linebreak to spamassassin input")
+        sa_input += b"\n"
+
+    try:
+        sa = SpamAssassin(sa_input, host=SPAMASSASSIN_HOST)
+        return sa.get_score()
+    except Exception:
+        LOG.exception("SpamAssassin exception")
+        # return a negative score so the message is always considered as ham
+        return -999
+
+
 def sl_sendmail(from_addr, to_addr, msg: Message, mail_options, rcpt_options):
     """replace smtp.sendmail"""
     if POSTFIX_SUBMISSION_TLS:
@@ -1522,12 +1538,9 @@ def sl_sendmail(from_addr, to_addr, msg: Message, mail_options, rcpt_options):
 
 
 class MailHandler:
-    def __init__(self, lock):
-        self.lock = lock
-
     async def handle_DATA(self, server, session, envelope: Envelope):
         try:
-            ret = await self._handle(envelope)
+            ret = self._handle(envelope)
             return ret
         except Exception:
             LOG.exception(
@@ -1537,31 +1550,42 @@ class MailHandler:
             )
             return "421 SL Retry later"
 
-    async def _handle(self, envelope: Envelope):
-        async with self.lock:
-            start = time.time()
-            LOG.info(
-                "===>> New message, mail from %s, rctp tos %s ",
-                envelope.mail_from,
-                envelope.rcpt_tos,
-            )
+    def _handle(self, envelope: Envelope):
+        start = time.time()
+        LOG.info(
+            "===>> New message, mail from %s, rctp tos %s ",
+            envelope.mail_from,
+            envelope.rcpt_tos,
+        )
 
-            app = new_app()
-            with app.app_context():
-                ret = await handle(envelope)
-                LOG.info("takes %s seconds <<===", time.time() - start)
-                return ret
+        app = new_app()
+        with app.app_context():
+            ret = handle(envelope)
+            LOG.info("takes %s seconds <<===", time.time() - start)
+            return ret
 
 
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "-p", "--port", help="SMTP port to listen for", type=int, default=20381
-    )
-    args = parser.parse_args()
+def main(port: int):
+    """Use aiosmtpd Controller"""
+    controller = Controller(MailHandler(), hostname="0.0.0.0", port=port)
 
-    LOG.info("Listen for port %s", args.port)
+    controller.start()
+    LOG.d("Start mail controller %s %s", controller.hostname, controller.port)
 
+    if LOAD_PGP_EMAIL_HANDLER:
+        LOG.warning("LOAD PGP keys")
+        app = create_app()
+        with app.app_context():
+            load_pgp_public_keys()
+
+    while True:
+        time.sleep(2)
+
+
+def asyncio_main(port: int):
+    """
+    Main entrypoint using asyncio directly without passing by aiosmtpd Controller
+    """
     if LOAD_PGP_EMAIL_HANDLER:
         LOG.warning("LOAD PGP keys")
         app = create_app()
@@ -1577,7 +1601,7 @@ if __name__ == "__main__":
         return aiosmtpd.smtp.SMTP(handler, enable_SMTPUTF8=True)
 
     server = loop.run_until_complete(
-        loop.create_server(factory, host="0.0.0.0", port=args.port)
+        loop.create_server(factory, host="0.0.0.0", port=port)
     )
 
     try:
@@ -1590,3 +1614,14 @@ if __name__ == "__main__":
     server.close()
     loop.run_until_complete(server.wait_closed())
     loop.close()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "-p", "--port", help="SMTP port to listen for", type=int, default=20381
+    )
+    args = parser.parse_args()
+
+    LOG.info("Listen for port %s", args.port)
+    main(port=args.port)