Compare commits

..

2 commits

Author SHA1 Message Date
Son
57a7565c12 show admin menu if user is admin 2022-08-03 16:00:00 +02:00
Son
badfe2f752 admin can stop a paddle sub 2022-08-03 15:59:28 +02:00
275 changed files with 329131 additions and 18956 deletions

View file

@ -1,5 +1,5 @@
{ {
"template": "${{CHANGELOG}}\n\n<details>\n<summary>Uncategorized</summary>\n\n${{UNCATEGORIZED}}\n</details>", "template": "${{CHANGELOG}}",
"pr_template": "- ${{TITLE}} #${{NUMBER}}", "pr_template": "- ${{TITLE}} #${{NUMBER}}",
"empty_template": "- no changes", "empty_template": "- no changes",
"categories": [ "categories": [
@ -20,4 +20,4 @@
"tag_resolver": { "tag_resolver": {
"method": "semver" "method": "semver"
} }
} }

View file

@ -1,44 +1,15 @@
name: Test and lint name: Run tests & Publish to Docker Registry
on: on:
push: push:
jobs: jobs:
lint:
runs-on: ubuntu-latest
steps:
- name: Check out repo
uses: actions/checkout@v3
- name: Install poetry
run: pipx install poetry
- uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'poetry'
- name: Install OS dependencies
if: ${{ matrix.python-version }} == '3.10'
run: |
sudo apt update
sudo apt install -y libre2-dev libpq-dev
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction
- name: Check formatting & linting
run: |
poetry run pre-commit run --all-files
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
max-parallel: 4 max-parallel: 4
matrix: matrix:
python-version: ["3.10"] python-version: ["3.9", "3.10"]
# service containers to run with `postgres-job` # service containers to run with `postgres-job`
services: services:
@ -67,16 +38,27 @@ jobs:
--health-retries 5 --health-retries 5
steps: steps:
- name: Check out repo - name: Check out repository
uses: actions/checkout@v3 uses: actions/checkout@v2
- name: Install poetry - name: Set up Python ${{ matrix.python-version }}
run: pipx install poetry uses: actions/setup-python@v2
- uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
cache: 'poetry'
- name: Install poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Run caching
id: cached-poetry-dependencies
uses: actions/cache@v2
with:
path: .venv
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install OS dependencies - name: Install OS dependencies
if: ${{ matrix.python-version }} == '3.10' if: ${{ matrix.python-version }} == '3.10'
@ -86,13 +68,14 @@ jobs:
- name: Install dependencies - name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root
- name: Install library
run: poetry install --no-interaction run: poetry install --no-interaction
- name: Check formatting & linting
- name: Start Redis v6 run: |
uses: superchargejs/redis-github-action@1.1.0 poetry run pre-commit run --all-files
with:
redis-version: 6
- name: Run db migration - name: Run db migration
run: | run: |
@ -117,7 +100,7 @@ jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: ['test', 'lint'] needs: ['test']
if: github.event_name == 'push' && (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/tags/v')) if: github.event_name == 'push' && (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/tags/v'))
steps: steps:
@ -135,19 +118,7 @@ jobs:
# We need to checkout the repository in order for the "Create Sentry release" to work # We need to checkout the repository in order for the "Create Sentry release" to work
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Create Sentry release
uses: getsentry/action-release@v1
env:
SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }}
SENTRY_ORG: ${{ secrets.SENTRY_ORG }}
SENTRY_PROJECT: ${{ secrets.SENTRY_PROJECT }}
with:
ignore_missing: true
ignore_empty: true
- name: Prepare version file - name: Prepare version file
run: | run: |
@ -161,6 +132,12 @@ jobs:
push: true push: true
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
- name: Create Sentry release
uses: getsentry/action-release@v1
env:
SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }}
SENTRY_ORG: ${{ secrets.SENTRY_ORG }}
SENTRY_PROJECT: ${{ secrets.SENTRY_PROJECT }}
#- name: Send Telegram message #- name: Send Telegram message
# uses: appleboy/telegram-action@master # uses: appleboy/telegram-action@master

1
.gitignore vendored
View file

@ -15,4 +15,3 @@ venv/
.coverage .coverage
htmlcov htmlcov
adhoc adhoc
.env.*

View file

@ -7,19 +7,21 @@ repos:
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
- repo: https://github.com/Riverside-Healthcare/djLint - repo: https://github.com/Riverside-Healthcare/djLint
rev: v1.3.0 rev: v1.3.0
hooks: hooks:
- id: djlint-jinja - id: djlint-jinja
files: '.*\.html' files: '.*\.html'
entry: djlint --reformat entry: djlint --reformat
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/PyCQA/pylint
# Ruff version. rev: v2.14.4
rev: v0.1.5
hooks: hooks:
# Run the linter. - id: pylint
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format

View file

@ -34,7 +34,7 @@ poetry install
On Mac, sometimes you might need to install some other packages via `brew`: On Mac, sometimes you might need to install some other packages via `brew`:
```bash ```bash
brew install pkg-config libffi openssl postgresql@13 brew install pkg-config libffi openssl postgresql
``` ```
You also need to install `gpg` tool, on Mac it can be done with: You also need to install `gpg` tool, on Mac it can be done with:
@ -62,8 +62,6 @@ To install it in your development environment.
## Run tests ## Run tests
For most tests, you will need to have ``redis`` installed and started on your machine (listening on port 6379).
```bash ```bash
sh scripts/run-test.sh sh scripts/run-test.sh
``` ```
@ -82,16 +80,10 @@ To run the code locally, please create a local setting file based on `example.en
cp example.env .env cp example.env .env
``` ```
You need to edit your .env to reflect the postgres exposed port, edit the `DB_URI` to:
```
DB_URI=postgresql://myuser:mypassword@localhost:35432/simplelogin
```
Run the postgres database: Run the postgres database:
```bash ```bash
docker run -e POSTGRES_PASSWORD=mypassword -e POSTGRES_USER=myuser -e POSTGRES_DB=simplelogin -p 15432:5432 postgres:13 docker run -e POSTGRES_PASSWORD=mypassword -e POSTGRES_USER=myuser -e POSTGRES_DB=simplelogin -p 35432:5432 postgres:13
``` ```
To run the server: To run the server:
@ -169,12 +161,6 @@ For HTML templates, we use `djlint`. Before creating a pull request, please run
poetry run djlint --check templates poetry run djlint --check templates
``` ```
If some files aren't properly formatted, you can format all files with
```bash
poetry run djlint --reformat .
```
## Test sending email ## Test sending email
[swaks](http://www.jetmore.org/john/code/swaks/) is used for sending test emails to the `email_handler`. [swaks](http://www.jetmore.org/john/code/swaks/) is used for sending test emails to the `email_handler`.
@ -212,11 +198,4 @@ python email_handler.py
swaks --to e1@sl.local --from hey@google.com --server 127.0.0.1:20381 swaks --to e1@sl.local --from hey@google.com --server 127.0.0.1:20381
``` ```
Now open http://localhost:1080/ (or http://localhost:1080/ for MailHog), you should see the forwarded email. Now open http://localhost:1080/ (or http://localhost:1080/ for MailHog), you should see the forwarded email.
## Job runner
Some features require a job handler (such as GDPR data export). To test such feature you need to run the job_runner
```bash
python job_runner.py
```

View file

@ -2,7 +2,7 @@
FROM node:10.17.0-alpine AS npm FROM node:10.17.0-alpine AS npm
WORKDIR /code WORKDIR /code
COPY ./static/package*.json /code/static/ COPY ./static/package*.json /code/static/
RUN cd /code/static && npm ci RUN cd /code/static && npm install
# Main image # Main image
FROM python:3.10 FROM python:3.10
@ -13,7 +13,7 @@ ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1 ENV PYTHONUNBUFFERED 1
# Add poetry to PATH # Add poetry to PATH
ENV PATH="${PATH}:/root/.local/bin" ENV PATH="${PATH}:/root/.poetry/bin"
WORKDIR /code WORKDIR /code
@ -23,15 +23,15 @@ COPY poetry.lock pyproject.toml ./
# Install and setup poetry # Install and setup poetry
RUN pip install -U pip \ RUN pip install -U pip \
&& apt-get update \ && apt-get update \
&& apt install -y curl netcat-traditional gcc python3-dev gnupg git libre2-dev cmake ninja-build\ && apt install -y curl netcat gcc python3-dev gnupg git libre2-dev \
&& curl -sSL https://install.python-poetry.org | python3 - \ && curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python - \
# Remove curl and netcat from the image # Remove curl and netcat from the image
&& apt-get purge -y curl netcat-traditional \ && apt-get purge -y curl netcat \
# Run poetry # Run poetry
&& poetry config virtualenvs.create false \ && poetry config virtualenvs.create false \
&& poetry install --no-interaction --no-ansi --no-root \ && poetry install --no-interaction --no-ansi --no-root \
# Clear apt cache \ # Clear apt cache \
&& apt-get purge -y libre2-dev cmake ninja-build\ && apt-get purge -y libre2-dev \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*

View file

@ -15,8 +15,8 @@
<img src="https://img.shields.io/github/license/simple-login/app"> <img src="https://img.shields.io/github/license/simple-login/app">
</a> </a>
<a href="https://twitter.com/simplelogin"> <a href="https://twitter.com/simple_login">
<img src="https://img.shields.io/twitter/follow/simplelogin?style=social"> <img src="https://img.shields.io/twitter/follow/simple_login?style=social">
</a> </a>
</p> </p>
@ -334,12 +334,6 @@ smtpd_recipient_restrictions =
permit permit
``` ```
Check that the ssl certificates `/etc/ssl/certs/ssl-cert-snakeoil.pem` and `/etc/ssl/private/ssl-cert-snakeoil.key` exist. Depending on the linux distribution you are using they may or may not be present. If they are not, you will need to generate them with this command:
```bash
openssl req -x509 -nodes -days 3650 -newkey rsa:2048 -keyout /etc/ssl/private/ssl-cert-snakeoil.key -out /etc/ssl/certs/ssl-cert-snakeoil.pem
```
Create the `/etc/postfix/pgsql-relay-domains.cf` file with the following content. Create the `/etc/postfix/pgsql-relay-domains.cf` file with the following content.
Make sure that the database config is correctly set, replace `mydomain.com` with your domain, update 'myuser' and 'mypassword' with your postgres credentials. Make sure that the database config is correctly set, replace `mydomain.com` with your domain, update 'myuser' and 'mypassword' with your postgres credentials.

View file

@ -5,23 +5,17 @@ from typing import Optional
from arrow import Arrow from arrow import Arrow
from newrelic import agent from newrelic import agent
from sqlalchemy import or_
from app.db import Session from app.db import Session
from app.email_utils import send_welcome_email from app.email_utils import send_welcome_email
from app.utils import sanitize_email, canonicalize_email from app.utils import sanitize_email
from app.errors import ( from app.errors import AccountAlreadyLinkedToAnotherPartnerException
AccountAlreadyLinkedToAnotherPartnerException,
AccountIsUsingAliasAsEmail,
AccountAlreadyLinkedToAnotherUserException,
)
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
PartnerSubscription, PartnerSubscription,
Partner, Partner,
PartnerUser, PartnerUser,
User, User,
Alias,
) )
from app.utils import random_string from app.utils import random_string
@ -132,9 +126,8 @@ class ClientMergeStrategy(ABC):
class NewUserStrategy(ClientMergeStrategy): class NewUserStrategy(ClientMergeStrategy):
def process(self) -> LinkResult: def process(self) -> LinkResult:
# Will create a new SL User with a random password # Will create a new SL User with a random password
canonical_email = canonicalize_email(self.link_request.email)
new_user = User.create( new_user = User.create(
email=canonical_email, email=self.link_request.email,
name=self.link_request.name, name=self.link_request.name,
password=random_string(20), password=random_string(20),
activated=True, activated=True,
@ -168,6 +161,7 @@ class NewUserStrategy(ClientMergeStrategy):
class ExistingUnlinkedUserStrategy(ClientMergeStrategy): class ExistingUnlinkedUserStrategy(ClientMergeStrategy):
def process(self) -> LinkResult: def process(self) -> LinkResult:
partner_user = ensure_partner_user_exists_for_user( partner_user = ensure_partner_user_exists_for_user(
self.link_request, self.user, self.partner self.link_request, self.user, self.partner
) )
@ -181,7 +175,7 @@ class ExistingUnlinkedUserStrategy(ClientMergeStrategy):
class LinkedWithAnotherPartnerUserStrategy(ClientMergeStrategy): class LinkedWithAnotherPartnerUserStrategy(ClientMergeStrategy):
def process(self) -> LinkResult: def process(self) -> LinkResult:
raise AccountAlreadyLinkedToAnotherUserException() raise AccountAlreadyLinkedToAnotherPartnerException()
def get_login_strategy( def get_login_strategy(
@ -198,12 +192,6 @@ def get_login_strategy(
return ExistingUnlinkedUserStrategy(link_request, user, partner) return ExistingUnlinkedUserStrategy(link_request, user, partner)
def check_alias(email: str) -> bool:
alias = Alias.get_by(email=email)
if alias is not None:
raise AccountIsUsingAliasAsEmail()
def process_login_case( def process_login_case(
link_request: PartnerLinkRequest, partner: Partner link_request: PartnerLinkRequest, partner: Partner
) -> LinkResult: ) -> LinkResult:
@ -214,21 +202,9 @@ def process_login_case(
partner_id=partner.id, external_user_id=link_request.external_user_id partner_id=partner.id, external_user_id=link_request.external_user_id
) )
if partner_user is None: if partner_user is None:
canonical_email = canonicalize_email(link_request.email)
# We didn't find any SimpleLogin user registered with that partner user id # We didn't find any SimpleLogin user registered with that partner user id
# Make sure they aren't using an alias as their link email
check_alias(link_request.email)
check_alias(canonical_email)
# Try to find it using the partner's e-mail address # Try to find it using the partner's e-mail address
users = User.filter( user = User.get_by(email=link_request.email)
or_(User.email == link_request.email, User.email == canonical_email)
).all()
if len(users) > 1:
user = [user for user in users if user.email == canonical_email][0]
elif len(users) == 1:
user = users[0]
else:
user = None
return get_login_strategy(link_request, user, partner).process() return get_login_strategy(link_request, user, partner).process()
else: else:
# We found the SL user registered with that partner user id # We found the SL user registered with that partner user id
@ -300,7 +276,7 @@ def process_link_case(
return link_user(link_request, current_user, partner) return link_user(link_request, current_user, partner)
# There is a SL user registered with the partner. Check if is the current one # There is a SL user registered with the partner. Check if is the current one
if partner_user.user_id == current_user.id: if partner_user.id == current_user.id:
# Update plan # Update plan
set_plan_for_partner_user(partner_user, link_request.plan) set_plan_for_partner_user(partner_user, link_request.plan)
# It's the same user. No need to do anything # It's the same user. No need to do anything
@ -309,4 +285,5 @@ def process_link_case(
strategy="Link", strategy="Link",
) )
else: else:
return switch_already_linked_user(link_request, partner_user, current_user) return switch_already_linked_user(link_request, partner_user, current_user)

View file

@ -34,7 +34,6 @@ from app.newsletter_utils import send_newsletter_to_user, send_newsletter_to_add
class SLModelView(sqla.ModelView): class SLModelView(sqla.ModelView):
column_default_sort = ("id", True) column_default_sort = ("id", True)
column_display_pk = True column_display_pk = True
page_size = 100
can_edit = False can_edit = False
can_create = False can_create = False
@ -94,10 +93,6 @@ class SLAdminIndexView(AdminIndexView):
return redirect("/admin/user") return redirect("/admin/user")
def _user_upgrade_channel_formatter(view, context, model, name):
return Markup(model.upgrade_channel)
class UserAdmin(SLModelView): class UserAdmin(SLModelView):
column_searchable_list = ["email", "id"] column_searchable_list = ["email", "id"]
column_exclude_list = [ column_exclude_list = [
@ -115,38 +110,6 @@ class UserAdmin(SLModelView):
ret.insert(0, "upgrade_channel") ret.insert(0, "upgrade_channel")
return ret return ret
column_formatters = {
"upgrade_channel": _user_upgrade_channel_formatter,
}
@action(
"disable_user",
"Disable user",
"Are you sure you want to disable the selected users?",
)
def action_disable_user(self, ids):
for user in User.filter(User.id.in_(ids)):
user.disabled = True
flash(f"Disabled user {user.id}")
AdminAuditLog.disable_user(current_user.id, user.id)
Session.commit()
@action(
"enable_user",
"Enable user",
"Are you sure you want to enable the selected users?",
)
def action_enable_user(self, ids):
for user in User.filter(User.id.in_(ids)):
user.disabled = False
flash(f"Enabled user {user.id}")
AdminAuditLog.enable_user(current_user.id, user.id)
Session.commit()
@action( @action(
"education_upgrade", "education_upgrade",
"Education upgrade", "Education upgrade",
@ -256,17 +219,6 @@ class UserAdmin(SLModelView):
Session.commit() Session.commit()
@action(
"clear_delete_on",
"Remove scheduled deletion of user",
"This will remove the scheduled deletion for this users",
)
def clean_delete_on(self, ids):
for user in User.filter(User.id.in_(ids)):
user.delete_on = None
Session.commit()
# @action( # @action(
# "login_as", # "login_as",
# "Login as this user", # "Login as this user",
@ -611,26 +563,6 @@ class NewsletterAdmin(SLModelView):
else: else:
flash(error_msg, "error") flash(error_msg, "error")
@action(
"clone_newsletter",
"Clone this newsletter",
)
def clone_newsletter(self, newsletter_ids):
if len(newsletter_ids) != 1:
flash("you can only select 1 newsletter", "error")
return
newsletter_id = newsletter_ids[0]
newsletter: Newsletter = Newsletter.get(newsletter_id)
new_newsletter = Newsletter.create(
subject=newsletter.subject,
html=newsletter.html,
plain_text=newsletter.plain_text,
commit=True,
)
flash(f"Newsletter {new_newsletter.subject} has been cloned", "success")
class NewsletterUserAdmin(SLModelView): class NewsletterUserAdmin(SLModelView):
column_searchable_list = ["id"] column_searchable_list = ["id"]
@ -639,20 +571,3 @@ class NewsletterUserAdmin(SLModelView):
can_edit = False can_edit = False
can_create = False can_create = False
class DailyMetricAdmin(SLModelView):
column_exclude_list = ["created_at", "updated_at", "id"]
can_export = True
class MetricAdmin(SLModelView):
column_exclude_list = ["created_at", "updated_at", "id"]
can_export = True
class InvalidMailboxDomainAdmin(SLModelView):
can_create = True
can_delete = True

View file

@ -6,7 +6,8 @@ from typing import Optional
import itsdangerous import itsdangerous
from app import config from app import config
from app.log import LOG from app.log import LOG
from app.models import User, AliasOptions, SLDomain from app.models import User
signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET) signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET)
@ -42,9 +43,7 @@ def check_suffix_signature(signed_suffix: str) -> Optional[str]:
return None return None
def verify_prefix_suffix( def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
user: User, alias_prefix, alias_suffix, alias_options: Optional[AliasOptions] = None
) -> bool:
"""verify if user could create an alias with the given prefix and suffix""" """verify if user could create an alias with the given prefix and suffix"""
if not alias_prefix or not alias_suffix: # should be caught on frontend if not alias_prefix or not alias_suffix: # should be caught on frontend
return False return False
@ -57,7 +56,7 @@ def verify_prefix_suffix(
alias_domain_prefix, alias_domain = alias_suffix.split("@", 1) alias_domain_prefix, alias_domain = alias_suffix.split("@", 1)
# alias_domain must be either one of user custom domains or built-in domains # alias_domain must be either one of user custom domains or built-in domains
if alias_domain not in user.available_alias_domains(alias_options=alias_options): if alias_domain not in user.available_alias_domains():
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False return False
@ -65,11 +64,12 @@ def verify_prefix_suffix(
# 1) alias_suffix must start with "." and # 1) alias_suffix must start with "." and
# 2) alias_domain_prefix must come from the word list # 2) alias_domain_prefix must come from the word list
if ( if (
alias_domain in user.available_sl_domains(alias_options=alias_options) alias_domain in user.available_sl_domains()
and alias_domain not in user_custom_domains and alias_domain not in user_custom_domains
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty # when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
and not config.DISABLE_ALIAS_SUFFIX and not config.DISABLE_ALIAS_SUFFIX
): ):
if not alias_domain_prefix.startswith("."): if not alias_domain_prefix.startswith("."):
LOG.e("User %s submits a wrong alias suffix %s", user, alias_suffix) LOG.e("User %s submits a wrong alias suffix %s", user, alias_suffix)
return False return False
@ -80,18 +80,14 @@ def verify_prefix_suffix(
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False return False
if alias_domain not in user.available_sl_domains( if alias_domain not in user.available_sl_domains():
alias_options=alias_options
):
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False return False
return True return True
def get_alias_suffixes( def get_alias_suffixes(user: User) -> [AliasSuffix]:
user: User, alias_options: Optional[AliasOptions] = None
) -> [AliasSuffix]:
""" """
Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up. Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up.
""" """
@ -103,9 +99,7 @@ def get_alias_suffixes(
# for each user domain, generate both the domain and a random suffix version # for each user domain, generate both the domain and a random suffix version
for custom_domain in user_custom_domains: for custom_domain in user_custom_domains:
if custom_domain.random_prefix_generation: if custom_domain.random_prefix_generation:
suffix = ( suffix = "." + user.get_random_alias_suffix() + "@" + custom_domain.domain
f".{user.get_random_alias_suffix(custom_domain)}@{custom_domain.domain}"
)
alias_suffix = AliasSuffix( alias_suffix = AliasSuffix(
is_custom=True, is_custom=True,
suffix=suffix, suffix=suffix,
@ -119,7 +113,7 @@ def get_alias_suffixes(
else: else:
alias_suffixes.append(alias_suffix) alias_suffixes.append(alias_suffix)
suffix = f"@{custom_domain.domain}" suffix = "@" + custom_domain.domain
alias_suffix = AliasSuffix( alias_suffix = AliasSuffix(
is_custom=True, is_custom=True,
suffix=suffix, suffix=suffix,
@ -140,13 +134,16 @@ def get_alias_suffixes(
alias_suffixes.append(alias_suffix) alias_suffixes.append(alias_suffix)
# then SimpleLogin domain # then SimpleLogin domain
sl_domains = user.get_sl_domains(alias_options=alias_options) for sl_domain in user.get_sl_domains():
default_domain_found = False suffix = (
for sl_domain in sl_domains: (
prefix = ( ""
"" if config.DISABLE_ALIAS_SUFFIX else f".{user.get_random_alias_suffix()}" if config.DISABLE_ALIAS_SUFFIX
else "." + user.get_random_alias_suffix()
)
+ "@"
+ sl_domain.domain
) )
suffix = f"{prefix}@{sl_domain.domain}"
alias_suffix = AliasSuffix( alias_suffix = AliasSuffix(
is_custom=False, is_custom=False,
suffix=suffix, suffix=suffix,
@ -155,36 +152,11 @@ def get_alias_suffixes(
domain=sl_domain.domain, domain=sl_domain.domain,
mx_verified=True, mx_verified=True,
) )
# No default or this is not the default
if (
user.default_alias_public_domain_id is None
or user.default_alias_public_domain_id != sl_domain.id
):
alias_suffixes.append(alias_suffix)
else:
default_domain_found = True
alias_suffixes.insert(0, alias_suffix)
if not default_domain_found: # put the default domain to top
domain_conditions = {"id": user.default_alias_public_domain_id, "hidden": False} if user.default_alias_public_domain_id == sl_domain.id:
if not user.is_premium():
domain_conditions["premium_only"] = False
sl_domain = SLDomain.get_by(**domain_conditions)
if sl_domain:
prefix = (
""
if config.DISABLE_ALIAS_SUFFIX
else f".{user.get_random_alias_suffix()}"
)
suffix = f"{prefix}@{sl_domain.domain}"
alias_suffix = AliasSuffix(
is_custom=False,
suffix=suffix,
signed_suffix=signer.sign(suffix).decode(),
is_premium=sl_domain.premium_only,
domain=sl_domain.domain,
mx_verified=True,
)
alias_suffixes.insert(0, alias_suffix) alias_suffixes.insert(0, alias_suffix)
else:
alias_suffixes.append(alias_suffix)
return alias_suffixes return alias_suffixes

View file

@ -1,11 +1,8 @@
import csv
from io import StringIO
import re import re
from typing import Optional, Tuple from typing import Optional, Tuple
from email_validator import validate_email, EmailNotValidError from email_validator import validate_email, EmailNotValidError
from sqlalchemy.exc import IntegrityError, DataError from sqlalchemy.exc import IntegrityError, DataError
from flask import make_response
from app.config import ( from app.config import (
BOUNCE_PREFIX_FOR_REPLY_PHASE, BOUNCE_PREFIX_FOR_REPLY_PHASE,
@ -21,8 +18,6 @@ from app.email_utils import (
send_cannot_create_directory_alias_disabled, send_cannot_create_directory_alias_disabled,
get_email_local_part, get_email_local_part,
send_cannot_create_domain_alias, send_cannot_create_domain_alias,
send_email,
render,
) )
from app.errors import AliasInTrashError from app.errors import AliasInTrashError
from app.log import LOG from app.log import LOG
@ -38,8 +33,6 @@ from app.models import (
EmailLog, EmailLog,
Contact, Contact,
AutoCreateRule, AutoCreateRule,
AliasUsedOn,
ClientUser,
) )
from app.regex_utils import regex_match from app.regex_utils import regex_match
@ -61,8 +54,6 @@ def get_user_if_alias_would_auto_create(
domain_and_rule = check_if_alias_can_be_auto_created_for_custom_domain( domain_and_rule = check_if_alias_can_be_auto_created_for_custom_domain(
address, notify_user=notify_user address, notify_user=notify_user
) )
if DomainDeletedAlias.get_by(email=address):
return None
if domain_and_rule: if domain_and_rule:
return domain_and_rule[0].user return domain_and_rule[0].user
directory = check_if_alias_can_be_auto_created_for_a_directory( directory = check_if_alias_can_be_auto_created_for_a_directory(
@ -94,7 +85,6 @@ def check_if_alias_can_be_auto_created_for_custom_domain(
return None return None
if not user.can_create_new_alias(): if not user.can_create_new_alias():
LOG.d(f"{user} can't create new custom-domain alias {address}")
if notify_user: if notify_user:
send_cannot_create_domain_alias(custom_domain.user, address, alias_domain) send_cannot_create_domain_alias(custom_domain.user, address, alias_domain)
return None return None
@ -156,7 +146,6 @@ def check_if_alias_can_be_auto_created_for_a_directory(
return None return None
if not user.can_create_new_alias(): if not user.can_create_new_alias():
LOG.d(f"{user} can't create new directory alias {address}")
if notify_user: if notify_user:
send_cannot_create_directory_alias(user, address, directory_name) send_cannot_create_directory_alias(user, address, directory_name)
return None return None
@ -373,88 +362,3 @@ def check_alias_prefix(alias_prefix) -> bool:
return False return False
return True return True
def alias_export_csv(user, csv_direct_export=False):
"""
Get user aliases as importable CSV file
Output:
Importable CSV file
"""
data = [["alias", "note", "enabled", "mailboxes"]]
for alias in Alias.filter_by(user_id=user.id).all(): # type: Alias
# Always put the main mailbox first
# It is seen a primary while importing
alias_mailboxes = alias.mailboxes
alias_mailboxes.insert(
0, alias_mailboxes.pop(alias_mailboxes.index(alias.mailbox))
)
mailboxes = " ".join([mailbox.email for mailbox in alias_mailboxes])
data.append([alias.email, alias.note, alias.enabled, mailboxes])
si = StringIO()
cw = csv.writer(si)
cw.writerows(data)
if csv_direct_export:
return si.getvalue()
output = make_response(si.getvalue())
output.headers["Content-Disposition"] = "attachment; filename=aliases.csv"
output.headers["Content-type"] = "text/csv"
return output
def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]):
# cannot transfer alias which is used for receiving newsletter
if User.get_by(newsletter_alias_id=alias.id):
raise Exception("Cannot transfer alias that's used to receive newsletter")
# update user_id
Session.query(Contact).filter(Contact.alias_id == alias.id).update(
{"user_id": new_user.id}
)
Session.query(AliasUsedOn).filter(AliasUsedOn.alias_id == alias.id).update(
{"user_id": new_user.id}
)
Session.query(ClientUser).filter(ClientUser.alias_id == alias.id).update(
{"user_id": new_user.id}
)
# remove existing mailboxes from the alias
Session.query(AliasMailbox).filter(AliasMailbox.alias_id == alias.id).delete()
# set mailboxes
alias.mailbox_id = new_mailboxes.pop().id
for mb in new_mailboxes:
AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id)
# alias has never been transferred before
if not alias.original_owner_id:
alias.original_owner_id = alias.user_id
# inform previous owner
old_user = alias.user
send_email(
old_user.email,
f"Alias {alias.email} has been received",
render(
"transactional/alias-transferred.txt",
alias=alias,
),
render(
"transactional/alias-transferred.html",
alias=alias,
),
)
# now the alias belongs to the new user
alias.user_id = new_user.id
# set some fields back to default
alias.disable_pgp = False
alias.pinned = False
Session.commit()

View file

@ -16,22 +16,3 @@ from .views import (
sudo, sudo,
user, user,
) )
__all__ = [
"alias_options",
"new_custom_alias",
"custom_domain",
"new_random_alias",
"user_info",
"auth",
"auth_mfa",
"alias",
"apple",
"mailbox",
"notification",
"setting",
"export",
"phone",
"sudo",
"user",
]

View file

@ -24,7 +24,6 @@ from app.errors import (
ErrContactAlreadyExists, ErrContactAlreadyExists,
ErrAddressInvalid, ErrAddressInvalid,
) )
from app.extensions import limiter
from app.models import Alias, Contact, Mailbox, AliasMailbox from app.models import Alias, Contact, Mailbox, AliasMailbox
@ -72,9 +71,6 @@ def get_aliases():
@api_bp.route("/v2/aliases", methods=["GET", "POST"]) @api_bp.route("/v2/aliases", methods=["GET", "POST"])
@limiter.limit(
"5/minute",
)
@require_api_auth @require_api_auth
def get_aliases_v2(): def get_aliases_v2():
""" """

View file

@ -9,7 +9,6 @@ from requests import RequestException
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.config import APPLE_API_SECRET, MACAPP_APPLE_API_SECRET from app.config import APPLE_API_SECRET, MACAPP_APPLE_API_SECRET
from app.subscription_webhook import execute_subscription_webhook
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import PlanEnum, AppleSubscription from app.models import PlanEnum, AppleSubscription
@ -41,17 +40,15 @@ def apple_process_payment():
LOG.d("request for /apple/process_payment from %s", user) LOG.d("request for /apple/process_payment from %s", user)
data = request.get_json() data = request.get_json()
receipt_data = data.get("receipt_data") receipt_data = data.get("receipt_data")
is_macapp = "is_macapp" in data and data["is_macapp"] is True is_macapp = "is_macapp" in data
if is_macapp: if is_macapp:
LOG.d("Use Macapp secret")
password = MACAPP_APPLE_API_SECRET password = MACAPP_APPLE_API_SECRET
else: else:
password = APPLE_API_SECRET password = APPLE_API_SECRET
apple_sub = verify_receipt(receipt_data, user, password) apple_sub = verify_receipt(receipt_data, user, password)
if apple_sub: if apple_sub:
execute_subscription_webhook(user)
return jsonify(ok=True), 200 return jsonify(ok=True), 200
return jsonify(error="Processing failed"), 400 return jsonify(error="Processing failed"), 400
@ -284,7 +281,6 @@ def apple_update_notification():
apple_sub.plan = plan apple_sub.plan = plan
apple_sub.product_id = transaction["product_id"] apple_sub.product_id = transaction["product_id"]
Session.commit() Session.commit()
execute_subscription_webhook(user)
return jsonify(ok=True), 200 return jsonify(ok=True), 200
else: else:
LOG.w( LOG.w(
@ -478,7 +474,7 @@ def verify_receipt(receipt_data, user, password) -> Optional[AppleSubscription]:
# } # }
if data["status"] != 0: if data["status"] != 0:
LOG.e( LOG.w(
"verifyReceipt status !=0, probably invalid receipt. User %s, data %s", "verifyReceipt status !=0, probably invalid receipt. User %s, data %s",
user, user,
data, data,
@ -525,10 +521,9 @@ def verify_receipt(receipt_data, user, password) -> Optional[AppleSubscription]:
if apple_sub: if apple_sub:
LOG.d( LOG.d(
"Update AppleSubscription for user %s, expired at %s (%s), plan %s", "Update AppleSubscription for user %s, expired at %s, plan %s",
user, user,
expires_date, expires_date,
expires_date.humanize(),
plan, plan,
) )
apple_sub.receipt_data = receipt_data apple_sub.receipt_data = receipt_data
@ -557,7 +552,6 @@ def verify_receipt(receipt_data, user, password) -> Optional[AppleSubscription]:
product_id=latest_transaction["product_id"], product_id=latest_transaction["product_id"],
) )
execute_subscription_webhook(user)
Session.commit() Session.commit()
return apple_sub return apple_sub

View file

@ -23,7 +23,7 @@ from app.events.auth_event import LoginEvent, RegisterEvent
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import User, ApiKey, SocialAuth, AccountActivation from app.models import User, ApiKey, SocialAuth, AccountActivation
from app.utils import sanitize_email, canonicalize_email from app.utils import sanitize_email
@api_bp.route("/auth/login", methods=["POST"]) @api_bp.route("/auth/login", methods=["POST"])
@ -49,13 +49,11 @@ def auth_login():
if not data: if not data:
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400
email = sanitize_email(data.get("email"))
password = data.get("password") password = data.get("password")
device = data.get("device") device = data.get("device")
email = sanitize_email(data.get("email")) user = User.filter_by(email=email).first()
canonical_email = canonicalize_email(data.get("email"))
user = User.get_by(email=email) or User.get_by(email=canonical_email)
if not user or not user.check_password(password): if not user or not user.check_password(password):
LoginEvent(LoginEvent.ActionType.failed, LoginEvent.Source.api).send() LoginEvent(LoginEvent.ActionType.failed, LoginEvent.Source.api).send()
@ -63,11 +61,6 @@ def auth_login():
elif user.disabled: elif user.disabled:
LoginEvent(LoginEvent.ActionType.disabled_login, LoginEvent.Source.api).send() LoginEvent(LoginEvent.ActionType.disabled_login, LoginEvent.Source.api).send()
return jsonify(error="Account disabled"), 400 return jsonify(error="Account disabled"), 400
elif user.delete_on is not None:
LoginEvent(
LoginEvent.ActionType.scheduled_to_be_deleted, LoginEvent.Source.api
).send()
return jsonify(error="Account scheduled for deletion"), 400
elif not user.activated: elif not user.activated:
LoginEvent(LoginEvent.ActionType.not_activated, LoginEvent.Source.api).send() LoginEvent(LoginEvent.ActionType.not_activated, LoginEvent.Source.api).send()
return jsonify(error="Account not activated"), 422 return jsonify(error="Account not activated"), 422
@ -96,8 +89,7 @@ def auth_register():
if not data: if not data:
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400
dirty_email = data.get("email") email = sanitize_email(data.get("email"))
email = canonicalize_email(dirty_email)
password = data.get("password") password = data.get("password")
if DISABLE_REGISTRATION: if DISABLE_REGISTRATION:
@ -118,7 +110,7 @@ def auth_register():
return jsonify(error="password too long"), 400 return jsonify(error="password too long"), 400
LOG.d("create user %s", email) LOG.d("create user %s", email)
user = User.create(email=email, name=dirty_email, password=password) user = User.create(email=email, name="", password=password)
Session.flush() Session.flush()
# create activation code # create activation code
@ -156,10 +148,9 @@ def auth_activate():
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400
email = sanitize_email(data.get("email")) email = sanitize_email(data.get("email"))
canonical_email = canonicalize_email(data.get("email"))
code = data.get("code") code = data.get("code")
user = User.get_by(email=email) or User.get_by(email=canonical_email) user = User.get_by(email=email)
# do not use a different message to avoid exposing existing email # do not use a different message to avoid exposing existing email
if not user or user.activated: if not user or user.activated:
@ -205,9 +196,7 @@ def auth_reactivate():
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400
email = sanitize_email(data.get("email")) email = sanitize_email(data.get("email"))
canonical_email = canonicalize_email(data.get("email")) user = User.get_by(email=email)
user = User.get_by(email=email) or User.get_by(email=canonical_email)
# do not use a different message to avoid exposing existing email # do not use a different message to avoid exposing existing email
if not user or user.activated: if not user or user.activated:
@ -362,7 +351,7 @@ def auth_payload(user, device) -> dict:
@api_bp.route("/auth/forgot_password", methods=["POST"]) @api_bp.route("/auth/forgot_password", methods=["POST"])
@limiter.limit("2/minute") @limiter.limit("10/minute")
def forgot_password(): def forgot_password():
""" """
User forgot password User forgot password
@ -378,9 +367,8 @@ def forgot_password():
return jsonify(error="request body must contain email"), 400 return jsonify(error="request body must contain email"), 400
email = sanitize_email(data.get("email")) email = sanitize_email(data.get("email"))
canonical_email = canonicalize_email(data.get("email"))
user = User.get_by(email=email) or User.get_by(email=canonical_email) user = User.get_by(email=email)
if user: if user:
send_reset_password_email(user) send_reset_password_email(user)

View file

@ -55,7 +55,7 @@ def auth_mfa():
) )
totp = pyotp.TOTP(user.otp_secret) totp = pyotp.TOTP(user.otp_secret)
if not totp.verify(mfa_token, valid_window=2): if not totp.verify(mfa_token):
send_invalid_totp_login_email(user, "TOTP") send_invalid_totp_login_email(user, "TOTP")
return jsonify(error="Wrong TOTP Token"), 400 return jsonify(error="Wrong TOTP Token"), 400

View file

@ -1,9 +1,12 @@
import csv
from io import StringIO
from flask import g from flask import g
from flask import jsonify from flask import jsonify
from flask import make_response
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.models import Alias, Client, CustomDomain from app.models import Alias, Client, CustomDomain
from app.alias_utils import alias_export_csv
@api_bp.route("/export/data", methods=["GET"]) @api_bp.route("/export/data", methods=["GET"])
@ -46,4 +49,24 @@ def export_aliases():
Importable CSV file Importable CSV file
""" """
return alias_export_csv(g.user) user = g.user
data = [["alias", "note", "enabled", "mailboxes"]]
for alias in Alias.filter_by(user_id=user.id).all(): # type: Alias
# Always put the main mailbox first
# It is seen a primary while importing
alias_mailboxes = alias.mailboxes
alias_mailboxes.insert(
0, alias_mailboxes.pop(alias_mailboxes.index(alias.mailbox))
)
mailboxes = " ".join([mailbox.email for mailbox in alias_mailboxes])
data.append([alias.email, alias.note, alias.enabled, mailboxes])
si = StringIO()
cw = csv.writer(si)
cw.writerows(data)
output = make_response(si.getvalue())
output.headers["Content-Disposition"] = "attachment; filename=aliases.csv"
output.headers["Content-type"] = "text/csv"
return output

View file

@ -13,8 +13,8 @@ from app.db import Session
from app.email_utils import ( from app.email_utils import (
mailbox_already_used, mailbox_already_used,
email_can_be_used_as_mailbox, email_can_be_used_as_mailbox,
is_valid_email,
) )
from app.email_validation import is_valid_email
from app.log import LOG from app.log import LOG
from app.models import Mailbox, Job from app.models import Mailbox, Job
from app.utils import sanitize_email from app.utils import sanitize_email
@ -45,7 +45,7 @@ def create_mailbox():
mailbox_email = sanitize_email(request.get_json().get("email")) mailbox_email = sanitize_email(request.get_json().get("email"))
if not user.is_premium(): if not user.is_premium():
return jsonify(error="Only premium plan can add additional mailbox"), 400 return jsonify(error=f"Only premium plan can add additional mailbox"), 400
if not is_valid_email(mailbox_email): if not is_valid_email(mailbox_email):
return jsonify(error=f"{mailbox_email} invalid"), 400 return jsonify(error=f"{mailbox_email} invalid"), 400
@ -78,9 +78,6 @@ def delete_mailbox(mailbox_id):
Delete mailbox Delete mailbox
Input: Input:
mailbox_id: in url mailbox_id: in url
(optional) transfer_aliases_to: in body. Id of the new mailbox for the aliases.
If omitted or the value is set to -1,
the aliases of the mailbox will be deleted too.
Output: Output:
200 if deleted successfully 200 if deleted successfully
@ -94,36 +91,11 @@ def delete_mailbox(mailbox_id):
if mailbox.id == user.default_mailbox_id: if mailbox.id == user.default_mailbox_id:
return jsonify(error="You cannot delete the default mailbox"), 400 return jsonify(error="You cannot delete the default mailbox"), 400
data = request.get_json() or {}
transfer_mailbox_id = data.get("transfer_aliases_to")
if transfer_mailbox_id and int(transfer_mailbox_id) >= 0:
transfer_mailbox = Mailbox.get(transfer_mailbox_id)
if not transfer_mailbox or transfer_mailbox.user_id != user.id:
return (
jsonify(error="You must transfer the aliases to a mailbox you own."),
403,
)
if transfer_mailbox_id == mailbox_id:
return (
jsonify(
error="You can not transfer the aliases to the mailbox you want to delete."
),
400,
)
if not transfer_mailbox.verified:
return jsonify(error="Your new mailbox is not verified"), 400
# Schedule delete account job # Schedule delete account job
LOG.w("schedule delete mailbox job for %s", mailbox) LOG.w("schedule delete mailbox job for %s", mailbox)
Job.create( Job.create(
name=JOB_DELETE_MAILBOX, name=JOB_DELETE_MAILBOX,
payload={ payload={"mailbox_id": mailbox.id},
"mailbox_id": mailbox.id,
"transfer_mailbox_id": transfer_mailbox_id,
},
run_at=arrow.now(), run_at=arrow.now(),
commit=True, commit=True,
) )

View file

@ -1,7 +1,6 @@
from flask import g from flask import g
from flask import jsonify, request from flask import jsonify, request
from app import parallel_limiter
from app.alias_suffix import check_suffix_signature, verify_prefix_suffix from app.alias_suffix import check_suffix_signature, verify_prefix_suffix
from app.alias_utils import check_alias_prefix from app.alias_utils import check_alias_prefix
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
@ -28,7 +27,6 @@ from app.utils import convert_to_id
@api_bp.route("/v2/alias/custom/new", methods=["POST"]) @api_bp.route("/v2/alias/custom/new", methods=["POST"])
@limiter.limit(ALIAS_LIMIT) @limiter.limit(ALIAS_LIMIT)
@require_api_auth @require_api_auth
@parallel_limiter.lock(name="alias_creation")
def new_custom_alias_v2(): def new_custom_alias_v2():
""" """
Create a new custom alias Create a new custom alias
@ -115,7 +113,6 @@ def new_custom_alias_v2():
@api_bp.route("/v3/alias/custom/new", methods=["POST"]) @api_bp.route("/v3/alias/custom/new", methods=["POST"])
@limiter.limit(ALIAS_LIMIT) @limiter.limit(ALIAS_LIMIT)
@require_api_auth @require_api_auth
@parallel_limiter.lock(name="alias_creation")
def new_custom_alias_v3(): def new_custom_alias_v3():
""" """
Create a new custom alias Create a new custom alias
@ -150,7 +147,7 @@ def new_custom_alias_v3():
if not data: if not data:
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400
if not isinstance(data, dict): if type(data) is not dict:
return jsonify(error="request body does not follow the required format"), 400 return jsonify(error="request body does not follow the required format"), 400
alias_prefix = data.get("alias_prefix", "").strip().lower().replace(" ", "") alias_prefix = data.get("alias_prefix", "").strip().lower().replace(" ", "")
@ -168,7 +165,7 @@ def new_custom_alias_v3():
return jsonify(error="alias prefix invalid format or too long"), 400 return jsonify(error="alias prefix invalid format or too long"), 400
# check if mailbox is not tempered with # check if mailbox is not tempered with
if not isinstance(mailbox_ids, list): if type(mailbox_ids) is not list:
return jsonify(error="mailbox_ids must be an array of id"), 400 return jsonify(error="mailbox_ids must be an array of id"), 400
mailboxes = [] mailboxes = []
for mailbox_id in mailbox_ids: for mailbox_id in mailbox_ids:

View file

@ -2,7 +2,6 @@ import tldextract
from flask import g from flask import g
from flask import jsonify, request from flask import jsonify, request
from app import parallel_limiter
from app.alias_suffix import get_alias_suffixes from app.alias_suffix import get_alias_suffixes
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.api.serializer import ( from app.api.serializer import (
@ -21,7 +20,6 @@ from app.utils import convert_to_id
@api_bp.route("/alias/random/new", methods=["POST"]) @api_bp.route("/alias/random/new", methods=["POST"])
@limiter.limit(ALIAS_LIMIT) @limiter.limit(ALIAS_LIMIT)
@require_api_auth @require_api_auth
@parallel_limiter.lock(name="alias_creation")
def new_random_alias(): def new_random_alias():
""" """
Create a new random alias Create a new random alias

View file

@ -12,7 +12,6 @@ from app.models import (
SenderFormatEnum, SenderFormatEnum,
AliasSuffixEnum, AliasSuffixEnum,
) )
from app.proton.utils import perform_proton_account_unlink
def setting_to_dict(user: User): def setting_to_dict(user: User):
@ -138,11 +137,3 @@ def get_available_domains_for_random_alias_v2():
] ]
return jsonify(ret) return jsonify(ret)
@api_bp.route("/setting/unlink_proton_account", methods=["DELETE"])
@require_api_auth
def unlink_proton_account():
user = g.user
perform_proton_account_unlink(user)
return jsonify({"ok": True})

View file

@ -1,11 +1,10 @@
from flask import jsonify, g from flask import jsonify, g
from sqlalchemy_utils.types.arrow import arrow from sqlalchemy_utils.types.arrow import arrow
from app.api.base import api_bp, require_api_sudo, require_api_auth from app.api.base import api_bp, require_api_sudo
from app import config from app import config
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import Job, ApiToCookieToken from app.models import Job
@api_bp.route("/user", methods=["DELETE"]) @api_bp.route("/user", methods=["DELETE"])
@ -24,23 +23,3 @@ def delete_user():
commit=True, commit=True,
) )
return jsonify(ok=True) return jsonify(ok=True)
@api_bp.route("/user/cookie_token", methods=["GET"])
@require_api_auth
@limiter.limit("5/minute")
def get_api_session_token():
"""
Get a temporary token to exchange it for a cookie based session
Output:
200 and a temporary random token
{
token: "asdli3ldq39h9hd3",
}
"""
token = ApiToCookieToken.create(
user=g.user,
api_key_id=g.api_key.id,
commit=True,
)
return jsonify({"token": token.code})

View file

@ -1,29 +1,17 @@
import base64 import base64
import dataclasses
from io import BytesIO from io import BytesIO
from typing import Optional
from flask import jsonify, g, request, make_response from flask import jsonify, g, request, make_response
from flask_login import logout_user
from app import s3, config from app import s3
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.config import SESSION_COOKIE_NAME from app.config import SESSION_COOKIE_NAME
from app.dashboard.views.index import get_stats
from app.db import Session from app.db import Session
from app.models import ApiKey, File, PartnerUser, User from app.models import ApiKey, File, User
from app.proton.utils import get_proton_partner
from app.session import logout_session
from app.utils import random_string from app.utils import random_string
def get_connected_proton_address(user: User) -> Optional[str]:
proton_partner = get_proton_partner()
partner_user = PartnerUser.get_by(user_id=user.id, partner_id=proton_partner.id)
if partner_user is None:
return None
return partner_user.partner_email
def user_to_dict(user: User) -> dict: def user_to_dict(user: User) -> dict:
ret = { ret = {
"name": user.name or "", "name": user.name or "",
@ -31,12 +19,8 @@ def user_to_dict(user: User) -> dict:
"email": user.email, "email": user.email,
"in_trial": user.in_trial(), "in_trial": user.in_trial(),
"max_alias_free_plan": user.max_alias_for_free_account(), "max_alias_free_plan": user.max_alias_for_free_account(),
"connected_proton_address": None,
} }
if config.CONNECT_WITH_PROTON:
ret["connected_proton_address"] = get_connected_proton_address(user)
if user.profile_picture_id: if user.profile_picture_id:
ret["profile_picture_url"] = user.profile_picture.get_url() ret["profile_picture_url"] = user.profile_picture.get_url()
else: else:
@ -57,7 +41,6 @@ def user_info():
- email - email
- in_trial - in_trial
- max_alias_free - max_alias_free
- is_connected_with_proton
""" """
user = g.user user = g.user
@ -133,27 +116,8 @@ def logout():
Output: Output:
- 200 - 200
""" """
logout_session() logout_user()
response = make_response(jsonify(msg="User is logged out"), 200) response = make_response(jsonify(msg="User is logged out"), 200)
response.delete_cookie(SESSION_COOKIE_NAME) response.delete_cookie(SESSION_COOKIE_NAME)
return response return response
@api_bp.route("/stats")
@require_api_auth
def user_stats():
"""
Return stats
Output as json
- nb_alias
- nb_forward
- nb_reply
- nb_block
"""
user = g.user
stats = get_stats(user)
return jsonify(dataclasses.asdict(stats))

View file

@ -15,25 +15,4 @@ from .views import (
fido, fido,
social, social,
recovery, recovery,
api_to_cookie,
) )
__all__ = [
"login",
"logout",
"register",
"activate",
"resend_activation",
"reset_password",
"forgot_password",
"github",
"google",
"facebook",
"proton",
"change_email",
"mfa",
"fido",
"social",
"recovery",
"api_to_cookie",
]

View file

@ -1,30 +0,0 @@
import arrow
from flask import redirect, url_for, request, flash
from flask_login import login_user
from app.auth.base import auth_bp
from app.models import ApiToCookieToken
from app.utils import sanitize_next_url
@auth_bp.route("/api_to_cookie", methods=["GET"])
def api_to_cookie():
code = request.args.get("token")
if not code:
flash("Missing token", "error")
return redirect(url_for("auth.login"))
token = ApiToCookieToken.get_by(code=code)
if not token or token.created_at < arrow.now().shift(minutes=-5):
flash("Missing token", "error")
return redirect(url_for("auth.login"))
user = token.user
ApiToCookieToken.delete(token.id, commit=True)
login_user(user)
next_url = sanitize_next_url(request.args.get("next"))
if next_url:
return redirect(next_url)
else:
return redirect(url_for("dashboard.index"))

View file

@ -62,7 +62,7 @@ def fido():
browser = MfaBrowser.get_by(token=request.cookies.get("mfa")) browser = MfaBrowser.get_by(token=request.cookies.get("mfa"))
if browser and not browser.is_expired() and browser.user_id == user.id: if browser and not browser.is_expired() and browser.user_id == user.id:
login_user(user) login_user(user)
flash("Welcome back!", "success") flash(f"Welcome back!", "success")
# Redirect user to correct page # Redirect user to correct page
return redirect(next_url or url_for("dashboard.index")) return redirect(next_url or url_for("dashboard.index"))
else: else:
@ -110,7 +110,7 @@ def fido():
session["sudo_time"] = int(time()) session["sudo_time"] = int(time())
login_user(user) login_user(user)
flash("Welcome back!", "success") flash(f"Welcome back!", "success")
# Redirect user to correct page # Redirect user to correct page
response = make_response(redirect(next_url or url_for("dashboard.index"))) response = make_response(redirect(next_url or url_for("dashboard.index")))

View file

@ -1,4 +1,4 @@
from flask import request, render_template, flash, g from flask import request, render_template, redirect, url_for, flash, g
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
@ -7,7 +7,7 @@ from app.dashboard.views.setting import send_reset_password_email
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import User from app.models import User
from app.utils import sanitize_email, canonicalize_email from app.utils import sanitize_email
class ForgotPasswordForm(FlaskForm): class ForgotPasswordForm(FlaskForm):
@ -16,7 +16,7 @@ class ForgotPasswordForm(FlaskForm):
@auth_bp.route("/forgot_password", methods=["GET", "POST"]) @auth_bp.route("/forgot_password", methods=["GET", "POST"])
@limiter.limit( @limiter.limit(
"10/hour", deduct_when=lambda r: hasattr(g, "deduct_limit") and g.deduct_limit "10/minute", deduct_when=lambda r: hasattr(g, "deduct_limit") and g.deduct_limit
) )
def forgot_password(): def forgot_password():
form = ForgotPasswordForm(request.form) form = ForgotPasswordForm(request.form)
@ -25,17 +25,16 @@ def forgot_password():
# Trigger rate limiter # Trigger rate limiter
g.deduct_limit = True g.deduct_limit = True
email = sanitize_email(form.email.data)
flash( flash(
"If your email is correct, you are going to receive an email to reset your password", "If your email is correct, you are going to receive an email to reset your password",
"success", "success",
) )
user = User.get_by(email=email)
email = sanitize_email(form.email.data)
canonical_email = canonicalize_email(email)
user = User.get_by(email=email) or User.get_by(email=canonical_email)
if user: if user:
LOG.d("Send forgot password email to %s", user) LOG.d("Send forgot password email to %s", user)
send_reset_password_email(user) send_reset_password_email(user)
return redirect(url_for("auth.forgot_password"))
return render_template("auth/forgot_password.html", form=form) return render_template("auth/forgot_password.html", form=form)

View file

@ -10,7 +10,7 @@ from app.events.auth_event import LoginEvent
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import User from app.models import User
from app.utils import sanitize_email, sanitize_next_url, canonicalize_email from app.utils import sanitize_email, sanitize_next_url
class LoginForm(FlaskForm): class LoginForm(FlaskForm):
@ -38,9 +38,7 @@ def login():
show_resend_activation = False show_resend_activation = False
if form.validate_on_submit(): if form.validate_on_submit():
email = sanitize_email(form.email.data) user = User.filter_by(email=sanitize_email(form.email.data)).first()
canonical_email = canonicalize_email(email)
user = User.get_by(email=email) or User.get_by(email=canonical_email)
if not user or not user.check_password(form.password.data): if not user or not user.check_password(form.password.data):
# Trigger rate limiter # Trigger rate limiter
@ -54,12 +52,6 @@ def login():
"error", "error",
) )
LoginEvent(LoginEvent.ActionType.disabled_login).send() LoginEvent(LoginEvent.ActionType.disabled_login).send()
elif user.delete_on is not None:
flash(
f"Your account is scheduled to be deleted on {user.delete_on}",
"error",
)
LoginEvent(LoginEvent.ActionType.scheduled_to_be_deleted).send()
elif not user.activated: elif not user.activated:
show_resend_activation = True show_resend_activation = True
flash( flash(

View file

@ -1,13 +1,13 @@
from flask import redirect, url_for, flash, make_response from flask import redirect, url_for, flash, make_response
from flask_login import logout_user
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.config import SESSION_COOKIE_NAME from app.config import SESSION_COOKIE_NAME
from app.session import logout_session
@auth_bp.route("/logout") @auth_bp.route("/logout")
def logout(): def logout():
logout_session() logout_user()
flash("You are logged out", "success") flash("You are logged out", "success")
response = make_response(redirect(url_for("auth.login"))) response = make_response(redirect(url_for("auth.login")))
response.delete_cookie(SESSION_COOKIE_NAME) response.delete_cookie(SESSION_COOKIE_NAME)

View file

@ -55,7 +55,7 @@ def mfa():
browser = MfaBrowser.get_by(token=request.cookies.get("mfa")) browser = MfaBrowser.get_by(token=request.cookies.get("mfa"))
if browser and not browser.is_expired() and browser.user_id == user.id: if browser and not browser.is_expired() and browser.user_id == user.id:
login_user(user) login_user(user)
flash("Welcome back!", "success") flash(f"Welcome back!", "success")
# Redirect user to correct page # Redirect user to correct page
return redirect(next_url or url_for("dashboard.index")) return redirect(next_url or url_for("dashboard.index"))
else: else:
@ -67,13 +67,13 @@ def mfa():
token = otp_token_form.token.data.replace(" ", "") token = otp_token_form.token.data.replace(" ", "")
if totp.verify(token, valid_window=2) and user.last_otp != token: if totp.verify(token) and user.last_otp != token:
del session[MFA_USER_ID] del session[MFA_USER_ID]
user.last_otp = token user.last_otp = token
Session.commit() Session.commit()
login_user(user) login_user(user)
flash("Welcome back!", "success") flash(f"Welcome back!", "success")
# Redirect user to correct page # Redirect user to correct page
response = make_response(redirect(next_url or url_for("dashboard.index"))) response = make_response(redirect(next_url or url_for("dashboard.index")))

View file

@ -3,7 +3,6 @@ from flask import request, session, redirect, flash, url_for
from flask_limiter.util import get_remote_address from flask_limiter.util import get_remote_address
from flask_login import current_user from flask_login import current_user
from requests_oauthlib import OAuth2Session from requests_oauthlib import OAuth2Session
from typing import Optional
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.auth.views.login_utils import after_login from app.auth.views.login_utils import after_login
@ -24,7 +23,7 @@ from app.proton.proton_callback_handler import (
Action, Action,
) )
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import sanitize_next_url, sanitize_scheme from app.utils import sanitize_next_url
_authorization_base_url = PROTON_BASE_URL + "/oauth/authorize" _authorization_base_url = PROTON_BASE_URL + "/oauth/authorize"
_token_url = PROTON_BASE_URL + "/oauth/token" _token_url = PROTON_BASE_URL + "/oauth/token"
@ -35,7 +34,6 @@ _redirect_uri = URL + "/auth/proton/callback"
SESSION_ACTION_KEY = "oauth_action" SESSION_ACTION_KEY = "oauth_action"
SESSION_STATE_KEY = "oauth_state" SESSION_STATE_KEY = "oauth_state"
DEFAULT_SCHEME = "auth.simplelogin"
def get_api_key_for_user(user: User) -> str: def get_api_key_for_user(user: User) -> str:
@ -47,16 +45,13 @@ def get_api_key_for_user(user: User) -> str:
return ak.code return ak.code
def extract_action() -> Optional[Action]: def extract_action() -> Action:
action = request.args.get("action") action = request.args.get("action")
if action is not None: if action is not None:
if action == "link": if action == "link":
return Action.Link return Action.Link
elif action == "login":
return Action.Login
else: else:
LOG.w(f"Unknown action received: {action}") raise Exception(f"Unknown action: {action}")
return None
return Action.Login return Action.Login
@ -74,24 +69,12 @@ def proton_login():
if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None: if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None:
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
action = extract_action()
if action is None:
return redirect(url_for("auth.login"))
if action == Action.Link and not current_user.is_authenticated:
return redirect(url_for("auth.login"))
next_url = sanitize_next_url(request.args.get("next")) next_url = sanitize_next_url(request.args.get("next"))
if next_url: if next_url:
session["oauth_next"] = next_url session["oauth_next"] = next_url
elif "oauth_next" in session: elif "oauth_next" in session:
del session["oauth_next"] del session["oauth_next"]
scheme = sanitize_scheme(request.args.get("scheme"))
if scheme:
session["oauth_scheme"] = scheme
elif "oauth_scheme" in session:
del session["oauth_scheme"]
mode = request.args.get("mode", "session") mode = request.args.get("mode", "session")
if mode == "apikey": if mode == "apikey":
session["oauth_mode"] = "apikey" session["oauth_mode"] = "apikey"
@ -103,7 +86,7 @@ def proton_login():
# State is used to prevent CSRF, keep this for later. # State is used to prevent CSRF, keep this for later.
session[SESSION_STATE_KEY] = state session[SESSION_STATE_KEY] = state
session[SESSION_ACTION_KEY] = action.value session[SESSION_ACTION_KEY] = extract_action().value
return redirect(authorization_url) return redirect(authorization_url)
@ -163,7 +146,6 @@ def proton_callback():
handler = ProtonCallbackHandler(proton_client) handler = ProtonCallbackHandler(proton_client)
proton_partner = get_proton_partner() proton_partner = get_proton_partner()
next_url = session.get("oauth_next")
if action == Action.Login: if action == Action.Login:
res = handler.handle_login(proton_partner) res = handler.handle_login(proton_partner)
elif action == Action.Link: elif action == Action.Link:
@ -174,17 +156,15 @@ def proton_callback():
if res.flash_message is not None: if res.flash_message is not None:
flash(res.flash_message, res.flash_category) flash(res.flash_message, res.flash_category)
oauth_scheme = session.get("oauth_scheme")
if session.get("oauth_mode", "session") == "apikey": if session.get("oauth_mode", "session") == "apikey":
apikey = get_api_key_for_user(res.user) apikey = get_api_key_for_user(res.user)
scheme = oauth_scheme or DEFAULT_SCHEME return redirect(f"auth.simplelogin://callback?apikey={apikey}")
return redirect(f"{scheme}:///login?apikey={apikey}")
if res.redirect_to_login: if res.redirect_to_login:
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
if next_url and next_url[0] == "/" and oauth_scheme: if res.redirect:
next_url = f"{oauth_scheme}://{next_url}" return after_login(res.user, res.redirect, login_from_proton=True)
redirect_url = next_url or res.redirect next_url = session.get("oauth_next")
return after_login(res.user, redirect_url, login_from_proton=True) return after_login(res.user, next_url, login_from_proton=True)

View file

@ -42,7 +42,7 @@ def recovery_route():
if recovery_form.validate_on_submit(): if recovery_form.validate_on_submit():
code = recovery_form.code.data code = recovery_form.code.data
recovery_code = RecoveryCode.find_by_user_code(user, code) recovery_code = RecoveryCode.get_by(user_id=user.id, code=code)
if recovery_code: if recovery_code:
if recovery_code.used: if recovery_code.used:
@ -53,7 +53,7 @@ def recovery_route():
del session[MFA_USER_ID] del session[MFA_USER_ID]
login_user(user) login_user(user)
flash("Welcome back!", "success") flash(f"Welcome back!", "success")
recovery_code.used = True recovery_code.used = True
recovery_code.used_at = arrow.now() recovery_code.used_at = arrow.now()

View file

@ -16,8 +16,8 @@ from app.email_utils import (
) )
from app.events.auth_event import RegisterEvent from app.events.auth_event import RegisterEvent
from app.log import LOG from app.log import LOG
from app.models import User, ActivationCode, DailyMetric from app.models import User, ActivationCode
from app.utils import random_string, encode_url, sanitize_email, canonicalize_email from app.utils import random_string, encode_url, sanitize_email
class RegisterForm(FlaskForm): class RegisterForm(FlaskForm):
@ -70,22 +70,19 @@ def register():
HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY, HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY,
) )
email = canonicalize_email(form.email.data) email = sanitize_email(form.email.data)
if not email_can_be_used_as_mailbox(email): if not email_can_be_used_as_mailbox(email):
flash("You cannot use this email address as your personal inbox.", "error") flash("You cannot use this email address as your personal inbox.", "error")
RegisterEvent(RegisterEvent.ActionType.email_in_use).send() RegisterEvent(RegisterEvent.ActionType.email_in_use).send()
else: else:
sanitized_email = sanitize_email(form.email.data) if personal_email_already_used(email):
if personal_email_already_used(email) or personal_email_already_used(
sanitized_email
):
flash(f"Email {email} already used", "error") flash(f"Email {email} already used", "error")
RegisterEvent(RegisterEvent.ActionType.email_in_use).send() RegisterEvent(RegisterEvent.ActionType.email_in_use).send()
else: else:
LOG.d("create user %s", email) LOG.d("create user %s", email)
user = User.create( user = User.create(
email=email, email=email,
name=form.email.data, name="",
password=form.password.data, password=form.password.data,
referral=get_referral(), referral=get_referral(),
) )
@ -94,8 +91,6 @@ def register():
try: try:
send_activation_email(user, next_url) send_activation_email(user, next_url)
RegisterEvent(RegisterEvent.ActionType.success).send() RegisterEvent(RegisterEvent.ActionType.success).send()
DailyMetric.get_or_create_today_metric().nb_new_web_non_proton_user += 1
Session.commit()
except Exception: except Exception:
flash("Invalid email, are you sure the email is correct?", "error") flash("Invalid email, are you sure the email is correct?", "error")
RegisterEvent(RegisterEvent.ActionType.invalid_email).send() RegisterEvent(RegisterEvent.ActionType.invalid_email).send()

View file

@ -7,7 +7,7 @@ from app.auth.views.register import send_activation_email
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import User from app.models import User
from app.utils import sanitize_email, canonicalize_email from app.utils import sanitize_email
class ResendActivationForm(FlaskForm): class ResendActivationForm(FlaskForm):
@ -20,9 +20,7 @@ def resend_activation():
form = ResendActivationForm(request.form) form = ResendActivationForm(request.form)
if form.validate_on_submit(): if form.validate_on_submit():
email = sanitize_email(form.email.data) user = User.filter_by(email=sanitize_email(form.email.data)).first()
canonical_email = canonicalize_email(email)
user = User.get_by(email=email) or User.get_by(email=canonical_email)
if not user: if not user:
flash("There is no such email", "warning") flash("There is no such email", "warning")

View file

@ -60,8 +60,8 @@ def reset_password():
# this can be served to activate user too # this can be served to activate user too
user.activated = True user.activated = True
# remove all reset password codes # remove the reset password code
ResetPasswordCode.filter_by(user_id=user.id).delete() ResetPasswordCode.delete(reset_password_code.id)
# change the alternative_id to log user out on other browsers # change the alternative_id to log user out on other browsers
user.alternative_id = str(uuid.uuid4()) user.alternative_id = str(uuid.uuid4())

View file

@ -8,6 +8,7 @@ from urllib.parse import urlparse
from dotenv import load_dotenv from dotenv import load_dotenv
ROOT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) ROOT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
@ -111,16 +112,13 @@ POSTFIX_SERVER = os.environ.get("POSTFIX_SERVER", "240.0.0.1")
DISABLE_REGISTRATION = "DISABLE_REGISTRATION" in os.environ DISABLE_REGISTRATION = "DISABLE_REGISTRATION" in os.environ
# allow using a different postfix port, useful when developing locally # allow using a different postfix port, useful when developing locally
POSTFIX_PORT = 25
if "POSTFIX_PORT" in os.environ:
POSTFIX_PORT = int(os.environ["POSTFIX_PORT"])
# Use port 587 instead of 25 when sending emails through Postfix # Use port 587 instead of 25 when sending emails through Postfix
# Useful when calling Postfix from an external network # Useful when calling Postfix from an external network
POSTFIX_SUBMISSION_TLS = "POSTFIX_SUBMISSION_TLS" in os.environ POSTFIX_SUBMISSION_TLS = "POSTFIX_SUBMISSION_TLS" in os.environ
if POSTFIX_SUBMISSION_TLS:
default_postfix_port = 587
else:
default_postfix_port = 25
POSTFIX_PORT = int(os.environ.get("POSTFIX_PORT", default_postfix_port))
POSTFIX_TIMEOUT = os.environ.get("POSTFIX_TIMEOUT", 3)
# ["domain1.com", "domain2.com"] # ["domain1.com", "domain2.com"]
OTHER_ALIAS_DOMAINS = sl_getenv("OTHER_ALIAS_DOMAINS", list) OTHER_ALIAS_DOMAINS = sl_getenv("OTHER_ALIAS_DOMAINS", list)
@ -163,7 +161,6 @@ if "DKIM_PRIVATE_KEY_PATH" in os.environ:
# Database # Database
DB_URI = os.environ["DB_URI"] DB_URI = os.environ["DB_URI"]
DB_CONN_NAME = os.environ.get("DB_CONN_NAME", "webapp")
# Flask secret # Flask secret
FLASK_SECRET = os.environ["FLASK_SECRET"] FLASK_SECRET = os.environ["FLASK_SECRET"]
@ -357,7 +354,6 @@ ALERT_COMPLAINT_TRANSACTIONAL_PHASE = "alert_complaint_transactional_phase"
ALERT_QUARANTINE_DMARC = "alert_quarantine_dmarc" ALERT_QUARANTINE_DMARC = "alert_quarantine_dmarc"
ALERT_DUAL_SUBSCRIPTION_WITH_PARTNER = "alert_dual_sub_with_partner" ALERT_DUAL_SUBSCRIPTION_WITH_PARTNER = "alert_dual_sub_with_partner"
ALERT_WARN_MULTIPLE_SUBSCRIPTIONS = "alert_multiple_subscription"
# <<<<< END ALERT EMAIL >>>> # <<<<< END ALERT EMAIL >>>>
@ -498,44 +494,3 @@ JOB_TAKEN_RETRY_WAIT_MINS = 30
# MEM_STORE # MEM_STORE
MEM_STORE_URI = os.environ.get("MEM_STORE_URI", None) MEM_STORE_URI = os.environ.get("MEM_STORE_URI", None)
# Recovery codes hash salt
RECOVERY_CODE_HMAC_SECRET = os.environ.get("RECOVERY_CODE_HMAC_SECRET") or (
FLASK_SECRET + "generatearandomtoken"
)
if not RECOVERY_CODE_HMAC_SECRET or len(RECOVERY_CODE_HMAC_SECRET) < 16:
raise RuntimeError(
"Please define RECOVERY_CODE_HMAC_SECRET in your configuration with a random string at least 16 chars long"
)
# the minimum rspamd spam score above which emails that fail DMARC should be quarantined
if "MIN_RSPAMD_SCORE_FOR_FAILED_DMARC" in os.environ:
MIN_RSPAMD_SCORE_FOR_FAILED_DMARC = float(
os.environ["MIN_RSPAMD_SCORE_FOR_FAILED_DMARC"]
)
else:
MIN_RSPAMD_SCORE_FOR_FAILED_DMARC = None
# run over all reverse alias for an alias and replace them with sender address
ENABLE_ALL_REVERSE_ALIAS_REPLACEMENT = (
"ENABLE_ALL_REVERSE_ALIAS_REPLACEMENT" in os.environ
)
if ENABLE_ALL_REVERSE_ALIAS_REPLACEMENT:
# max number of reverse alias that can be replaced
MAX_NB_REVERSE_ALIAS_REPLACEMENT = int(
os.environ["MAX_NB_REVERSE_ALIAS_REPLACEMENT"]
)
# Only used for tests
SKIP_MX_LOOKUP_ON_CHECK = False
DISABLE_RATE_LIMIT = "DISABLE_RATE_LIMIT" in os.environ
SUBSCRIPTION_CHANGE_WEBHOOK = os.environ.get("SUBSCRIPTION_CHANGE_WEBHOOK", None)
MAX_API_KEYS = int(os.environ.get("MAX_API_KEYS", 30))
UPCLOUD_USERNAME = os.environ.get("UPCLOUD_USERNAME", None)
UPCLOUD_PASSWORD = os.environ.get("UPCLOUD_PASSWORD", None)
UPCLOUD_DB_ID = os.environ.get("UPCLOUD_DB_ID", None)

View file

@ -1,37 +0,0 @@
from app.db import Session
from app.dns_utils import get_cname_record
from app.models import CustomDomain
class CustomDomainValidation:
def __init__(self, dkim_domain: str):
self.dkim_domain = dkim_domain
self._dkim_records = {
(f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}")
for key in ("dkim", "dkim02", "dkim03")
}
def get_dkim_records(self) -> {str: str}:
"""
Get a list of dkim records to set up. It will be
"""
return self._dkim_records
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
"""
Check if dkim records are properly set for this custom domain.
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
"""
invalid_records = {}
for prefix, expected_record in self.get_dkim_records():
custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = get_cname_record(custom_record)
if dkim_record != expected_record:
invalid_records[custom_record] = dkim_record or "empty"
# HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES
if custom_domain.dkim_verified:
return invalid_records
custom_domain.dkim_verified = len(invalid_records) == 0
Session.commit()
return invalid_records

View file

@ -6,7 +6,6 @@ from .views import (
subdomain, subdomain,
billing, billing,
alias_log, alias_log,
alias_export,
unsubscribe, unsubscribe,
api_key, api_key,
custom_domain, custom_domain,
@ -24,6 +23,7 @@ from .views import (
mailbox_detail, mailbox_detail,
refused_email, refused_email,
referral, referral,
recovery_code,
contact_detail, contact_detail,
setup_done, setup_done,
batch_import, batch_import,
@ -33,39 +33,3 @@ from .views import (
notification, notification,
support, support,
) )
__all__ = [
"index",
"pricing",
"setting",
"custom_alias",
"subdomain",
"billing",
"alias_log",
"alias_export",
"unsubscribe",
"api_key",
"custom_domain",
"alias_contact_manager",
"enter_sudo",
"mfa_setup",
"mfa_cancel",
"fido_setup",
"coupon",
"fido_manage",
"domain_detail",
"lifetime_licence",
"directory",
"mailbox",
"mailbox_detail",
"refused_email",
"referral",
"contact_detail",
"setup_done",
"batch_import",
"alias_transfer",
"app",
"delete_account",
"notification",
"support",
]

View file

@ -9,14 +9,14 @@ from sqlalchemy import and_, func, case
from wtforms import StringField, validators, ValidationError from wtforms import StringField, validators, ValidationError
# Need to import directly from config to allow modification from the tests # Need to import directly from config to allow modification from the tests
from app import config, parallel_limiter from app import config
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.email_utils import ( from app.email_utils import (
is_valid_email,
generate_reply_email, generate_reply_email,
parse_full_address, parse_full_address,
) )
from app.email_validation import is_valid_email
from app.errors import ( from app.errors import (
CannotCreateContactForReverseAlias, CannotCreateContactForReverseAlias,
ErrContactErrorUpgradeNeeded, ErrContactErrorUpgradeNeeded,
@ -25,7 +25,7 @@ from app.errors import (
) )
from app.log import LOG from app.log import LOG
from app.models import Alias, Contact, EmailLog, User from app.models import Alias, Contact, EmailLog, User
from app.utils import sanitize_email, CSRFValidationForm from app.utils import sanitize_email
def email_validator(): def email_validator():
@ -90,7 +90,7 @@ def create_contact(user: User, alias: Alias, contact_address: str) -> Contact:
alias_id=alias.id, alias_id=alias.id,
website_email=contact_email, website_email=contact_email,
name=contact_name, name=contact_name,
reply_email=generate_reply_email(contact_email, alias), reply_email=generate_reply_email(contact_email, user),
) )
LOG.d( LOG.d(
@ -231,7 +231,6 @@ def delete_contact(alias: Alias, contact_id: int):
@dashboard_bp.route("/alias_contact_manager/<int:alias_id>/", methods=["GET", "POST"]) @dashboard_bp.route("/alias_contact_manager/<int:alias_id>/", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock(name="contact_creation")
def alias_contact_manager(alias_id): def alias_contact_manager(alias_id):
highlight_contact_id = None highlight_contact_id = None
if request.args.get("highlight_contact_id"): if request.args.get("highlight_contact_id"):
@ -259,12 +258,8 @@ def alias_contact_manager(alias_id):
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
new_contact_form = NewContactForm() new_contact_form = NewContactForm()
csrf_form = CSRFValidationForm()
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "create": if request.form.get("form-name") == "create":
if new_contact_form.validate(): if new_contact_form.validate():
contact_address = new_contact_form.email.data.strip() contact_address = new_contact_form.email.data.strip()
@ -328,5 +323,4 @@ def alias_contact_manager(alias_id):
query=query, query=query,
nb_contact=nb_contact, nb_contact=nb_contact,
can_create_contacts=user_can_create_contacts(current_user), can_create_contacts=user_can_create_contacts(current_user),
csrf_form=csrf_form,
) )

View file

@ -1,9 +0,0 @@
from app.dashboard.base import dashboard_bp
from flask_login import login_required, current_user
from app.alias_utils import alias_export_csv
@dashboard_bp.route("/alias_export", methods=["GET"])
@login_required
def alias_export_route():
return alias_export_csv(current_user)

View file

@ -87,6 +87,6 @@ def get_alias_log(alias: Alias, page_id=0) -> [AliasLog]:
contact=contact, contact=contact,
) )
logs.append(al) logs.append(al)
logs = sorted(logs, key=lambda log: log.when, reverse=True) logs = sorted(logs, key=lambda l: l.when, reverse=True)
return logs return logs

View file

@ -7,17 +7,76 @@ from flask import render_template, redirect, url_for, flash, request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app import config from app import config
from app.alias_utils import transfer_alias
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.email_utils import send_email, render
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Alias, Alias,
Contact,
AliasUsedOn,
AliasMailbox,
User,
ClientUser,
) )
from app.models import Mailbox from app.models import Mailbox
from app.utils import CSRFValidationForm
def transfer(alias, new_user, new_mailboxes: [Mailbox]):
# cannot transfer alias which is used for receiving newsletter
if User.get_by(newsletter_alias_id=alias.id):
raise Exception("Cannot transfer alias that's used to receive newsletter")
# update user_id
Session.query(Contact).filter(Contact.alias_id == alias.id).update(
{"user_id": new_user.id}
)
Session.query(AliasUsedOn).filter(AliasUsedOn.alias_id == alias.id).update(
{"user_id": new_user.id}
)
Session.query(ClientUser).filter(ClientUser.alias_id == alias.id).update(
{"user_id": new_user.id}
)
# remove existing mailboxes from the alias
Session.query(AliasMailbox).filter(AliasMailbox.alias_id == alias.id).delete()
# set mailboxes
alias.mailbox_id = new_mailboxes.pop().id
for mb in new_mailboxes:
AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id)
# alias has never been transferred before
if not alias.original_owner_id:
alias.original_owner_id = alias.user_id
# inform previous owner
old_user = alias.user
send_email(
old_user.email,
f"Alias {alias.email} has been received",
render(
"transactional/alias-transferred.txt",
alias=alias,
),
render(
"transactional/alias-transferred.html",
alias=alias,
),
)
# now the alias belongs to the new user
alias.user_id = new_user.id
# set some fields back to default
alias.disable_pgp = False
alias.pinned = False
Session.commit()
def hmac_alias_transfer_token(transfer_token: str) -> str: def hmac_alias_transfer_token(transfer_token: str) -> str:
@ -46,12 +105,8 @@ def alias_transfer_send_route(alias_id):
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
alias_transfer_url = None alias_transfer_url = None
csrf_form = CSRFValidationForm()
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
# generate a new transfer_token # generate a new transfer_token
if request.form.get("form-name") == "create": if request.form.get("form-name") == "create":
transfer_token = f"{alias.id}.{secrets.token_urlsafe(32)}" transfer_token = f"{alias.id}.{secrets.token_urlsafe(32)}"
@ -78,7 +133,6 @@ def alias_transfer_send_route(alias_id):
alias_transfer_url=alias_transfer_url, alias_transfer_url=alias_transfer_url,
link_active=alias.transfer_token_expiration is not None link_active=alias.transfer_token_expiration is not None
and alias.transfer_token_expiration > arrow.utcnow(), and alias.transfer_token_expiration > arrow.utcnow(),
csrf_form=csrf_form,
) )
@ -154,13 +208,7 @@ def alias_transfer_receive_route():
mailboxes, mailboxes,
token, token,
) )
transfer_alias(alias, current_user, mailboxes) transfer(alias, current_user, mailboxes)
# reset transfer token
alias.transfer_token = None
alias.transfer_token_expiration = None
Session.commit()
flash(f"You are now owner of {alias.email}", "success") flash(f"You are now owner of {alias.email}", "success")
return redirect(url_for("dashboard.index", highlight_alias_id=alias.id)) return redirect(url_for("dashboard.index", highlight_alias_id=alias.id))

View file

@ -3,47 +3,19 @@ from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
from app import config
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.extensions import limiter
from app.models import ApiKey from app.models import ApiKey
from app.utils import CSRFValidationForm
class NewApiKeyForm(FlaskForm): class NewApiKeyForm(FlaskForm):
name = StringField("Name", validators=[validators.DataRequired()]) name = StringField("Name", validators=[validators.DataRequired()])
def clean_up_unused_or_old_api_keys(user_id: int):
total_keys = ApiKey.filter_by(user_id=user_id).count()
if total_keys <= config.MAX_API_KEYS:
return
# Remove oldest unused
for api_key in (
ApiKey.filter_by(user_id=user_id, last_used=None)
.order_by(ApiKey.created_at.asc())
.all()
):
Session.delete(api_key)
total_keys -= 1
if total_keys <= config.MAX_API_KEYS:
return
# Clean up oldest used
for api_key in (
ApiKey.filter_by(user_id=user_id).order_by(ApiKey.last_used.asc()).all()
):
Session.delete(api_key)
total_keys -= 1
if total_keys <= config.MAX_API_KEYS:
return
@dashboard_bp.route("/api_key", methods=["GET", "POST"]) @dashboard_bp.route("/api_key", methods=["GET", "POST"])
@login_required @login_required
@sudo_required @sudo_required
@limiter.limit("10/hour")
def api_key(): def api_key():
api_keys = ( api_keys = (
ApiKey.filter(ApiKey.user_id == current_user.id) ApiKey.filter(ApiKey.user_id == current_user.id)
@ -51,13 +23,9 @@ def api_key():
.all() .all()
) )
csrf_form = CSRFValidationForm()
new_api_key_form = NewApiKeyForm() new_api_key_form = NewApiKeyForm()
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "delete": if request.form.get("form-name") == "delete":
api_key_id = request.form.get("api-key-id") api_key_id = request.form.get("api-key-id")
@ -77,7 +45,6 @@ def api_key():
elif request.form.get("form-name") == "create": elif request.form.get("form-name") == "create":
if new_api_key_form.validate(): if new_api_key_form.validate():
clean_up_unused_or_old_api_keys(current_user.id)
new_api_key = ApiKey.create( new_api_key = ApiKey.create(
name=new_api_key_form.name.data, user_id=current_user.id name=new_api_key_form.name.data, user_id=current_user.id
) )
@ -95,8 +62,5 @@ def api_key():
return redirect(url_for("dashboard.api_key")) return redirect(url_for("dashboard.api_key"))
return render_template( return render_template(
"dashboard/api_key.html", "dashboard/api_key.html", api_keys=api_keys, new_api_key_form=new_api_key_form
api_keys=api_keys,
new_api_key_form=new_api_key_form,
csrf_form=csrf_form,
) )

View file

@ -1,9 +1,14 @@
from app.db import Session
"""
List of apps that user has used via the "Sign in with SimpleLogin"
"""
from flask import render_template, request, flash, redirect from flask import render_template, request, flash, redirect
from flask_login import login_required, current_user from flask_login import login_required, current_user
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.models import ( from app.models import (
ClientUser, ClientUser,
) )
@ -12,10 +17,6 @@ from app.models import (
@dashboard_bp.route("/app", methods=["GET", "POST"]) @dashboard_bp.route("/app", methods=["GET", "POST"])
@login_required @login_required
def app_route(): def app_route():
"""
List of apps that user has used via the "Sign in with SimpleLogin"
"""
client_users = ( client_users = (
ClientUser.filter_by(user_id=current_user.id) ClientUser.filter_by(user_id=current_user.id)
.options(joinedload(ClientUser.client)) .options(joinedload(ClientUser.client))

View file

@ -8,7 +8,7 @@ from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import File, BatchImport, Job from app.models import File, BatchImport, Job
from app.utils import random_string, CSRFValidationForm from app.utils import random_string
@dashboard_bp.route("/batch_import", methods=["GET", "POST"]) @dashboard_bp.route("/batch_import", methods=["GET", "POST"])
@ -25,25 +25,16 @@ def batch_import_route():
) )
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
batch_imports = BatchImport.filter_by( batch_imports = BatchImport.filter_by(user_id=current_user.id).all()
user_id=current_user.id, processed=False
).all()
csrf_form = CSRFValidationForm()
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if len(batch_imports) > 10: if len(batch_imports) > 10:
flash( flash(
"You have too many imports already. Wait until some get cleaned up", "You have too many imports already. Wait until some get cleaned up",
"error", "error",
) )
return render_template( return render_template(
"dashboard/batch_import.html", "dashboard/batch_import.html", batch_imports=batch_imports
batch_imports=batch_imports,
csrf_form=csrf_form,
) )
alias_file = request.files["alias-file"] alias_file = request.files["alias-file"]
@ -73,6 +64,4 @@ def batch_import_route():
return redirect(url_for("dashboard.batch_import_route")) return redirect(url_for("dashboard.batch_import_route"))
return render_template( return render_template("dashboard/batch_import.html", batch_imports=batch_imports)
"dashboard/batch_import.html", batch_imports=batch_imports, csrf_form=csrf_form
)

View file

@ -1,7 +1,5 @@
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm
from wtforms import StringField, validators
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
@ -9,14 +7,6 @@ from app.models import Contact
from app.pgp_utils import PGPException, load_public_key_and_check from app.pgp_utils import PGPException, load_public_key_and_check
class PGPContactForm(FlaskForm):
action = StringField(
"action",
validators=[validators.DataRequired(), validators.AnyOf(("save", "remove"))],
)
pgp = StringField("pgp", validators=[validators.Optional()])
@dashboard_bp.route("/contact/<int:contact_id>/", methods=["GET", "POST"]) @dashboard_bp.route("/contact/<int:contact_id>/", methods=["GET", "POST"])
@login_required @login_required
def contact_detail_route(contact_id): def contact_detail_route(contact_id):
@ -26,41 +16,33 @@ def contact_detail_route(contact_id):
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
alias = contact.alias alias = contact.alias
pgp_form = PGPContactForm()
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "pgp": if request.form.get("form-name") == "pgp":
if not pgp_form.validate(): if request.form.get("action") == "save":
flash("Invalid request", "warning")
return redirect(request.url)
if pgp_form.action.data == "save":
if not current_user.is_premium(): if not current_user.is_premium():
flash("Only premium plan can add PGP Key", "warning") flash("Only premium plan can add PGP Key", "warning")
return redirect( return redirect(
url_for("dashboard.contact_detail_route", contact_id=contact_id) url_for("dashboard.contact_detail_route", contact_id=contact_id)
) )
if not pgp_form.pgp.data:
flash("Invalid pgp key") contact.pgp_public_key = request.form.get("pgp")
try:
contact.pgp_finger_print = load_public_key_and_check(
contact.pgp_public_key
)
except PGPException:
flash("Cannot add the public key, please verify it", "error")
else: else:
contact.pgp_public_key = pgp_form.pgp.data Session.commit()
try: flash(
contact.pgp_finger_print = load_public_key_and_check( f"PGP public key for {contact.email} is saved successfully",
contact.pgp_public_key "success",
) )
except PGPException: return redirect(
flash("Cannot add the public key, please verify it", "error") url_for("dashboard.contact_detail_route", contact_id=contact_id)
else: )
Session.commit() elif request.form.get("action") == "remove":
flash(
f"PGP public key for {contact.email} is saved successfully",
"success",
)
return redirect(
url_for(
"dashboard.contact_detail_route", contact_id=contact_id
)
)
elif pgp_form.action.data == "remove":
# Free user can decide to remove contact PGP key # Free user can decide to remove contact PGP key
contact.pgp_public_key = None contact.pgp_public_key = None
contact.pgp_finger_print = None contact.pgp_finger_print = None
@ -71,5 +53,5 @@ def contact_detail_route(contact_id):
) )
return render_template( return render_template(
"dashboard/contact_detail.html", contact=contact, alias=alias, pgp_form=pgp_form "dashboard/contact_detail.html", contact=contact, alias=alias
) )

View file

@ -4,7 +4,6 @@ from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
from app import parallel_limiter
from app.config import PADDLE_VENDOR_ID, PADDLE_COUPON_ID from app.config import PADDLE_VENDOR_ID, PADDLE_COUPON_ID
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
@ -25,7 +24,6 @@ class CouponForm(FlaskForm):
@dashboard_bp.route("/coupon", methods=["GET", "POST"]) @dashboard_bp.route("/coupon", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock()
def coupon_route(): def coupon_route():
coupon_form = CouponForm() coupon_form = CouponForm()
@ -68,14 +66,9 @@ def coupon_route():
) )
return redirect(request.url) return redirect(request.url)
updated = ( coupon.used_by_user_id = current_user.id
Session.query(Coupon) coupon.used = True
.filter_by(code=code, used=False) Session.commit()
.update({"used_by_user_id": current_user.id, "used": True})
)
if updated != 1:
flash("Coupon is not valid", "error")
return redirect(request.url)
manual_sub: ManualSubscription = ManualSubscription.get_by( manual_sub: ManualSubscription = ManualSubscription.get_by(
user_id=current_user.id user_id=current_user.id
@ -100,7 +93,7 @@ def coupon_route():
commit=True, commit=True,
) )
flash( flash(
"Your account has been upgraded to Premium, thanks for your support!", f"Your account has been upgraded to Premium, thanks for your support!",
"success", "success",
) )

View file

@ -3,7 +3,6 @@ from flask import render_template, redirect, url_for, flash, request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from app import parallel_limiter
from app.alias_suffix import ( from app.alias_suffix import (
get_alias_suffixes, get_alias_suffixes,
check_suffix_signature, check_suffix_signature,
@ -29,7 +28,6 @@ from app.models import (
@dashboard_bp.route("/custom_alias", methods=["GET", "POST"]) @dashboard_bp.route("/custom_alias", methods=["GET", "POST"])
@limiter.limit(ALIAS_LIMIT, methods=["POST"]) @limiter.limit(ALIAS_LIMIT, methods=["POST"])
@login_required @login_required
@parallel_limiter.lock(name="alias_creation")
def custom_alias(): def custom_alias():
# check if user has not exceeded the alias quota # check if user has not exceeded the alias quota
if not current_user.can_create_new_alias(): if not current_user.can_create_new_alias():
@ -120,11 +118,18 @@ def custom_alias():
email=full_alias email=full_alias
) )
custom_domain = domain_deleted_alias.domain custom_domain = domain_deleted_alias.domain
flash( if domain_deleted_alias.user_id == current_user.id:
f"You have deleted this alias before. You can restore it on " flash(
f"{custom_domain.domain} 'Deleted Alias' page", f"You have deleted this alias before. You can restore it on "
"error", f"{custom_domain.domain} 'Deleted Alias' page",
) "error",
)
else:
# should never happen as user can only choose their domains
LOG.e(
"Deleted Alias %s does not belong to user %s",
domain_deleted_alias,
)
elif DeletedAlias.get_by(email=full_alias): elif DeletedAlias.get_by(email=full_alias):
flash(general_error_msg, "error") flash(general_error_msg, "error")

View file

@ -3,7 +3,6 @@ from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
from app import parallel_limiter
from app.config import EMAIL_SERVERS_WITH_PRIORITY from app.config import EMAIL_SERVERS_WITH_PRIORITY
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
@ -20,7 +19,6 @@ class NewCustomDomainForm(FlaskForm):
@dashboard_bp.route("/custom_domain", methods=["GET", "POST"]) @dashboard_bp.route("/custom_domain", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
def custom_domain(): def custom_domain():
custom_domains = CustomDomain.filter_by( custom_domains = CustomDomain.filter_by(
user_id=current_user.id, is_sl_subdomain=False user_id=current_user.id, is_sl_subdomain=False

View file

@ -1,7 +1,6 @@
import arrow import arrow
from flask import flash, redirect, url_for, request, render_template from flask import flash, redirect, url_for, request, render_template
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm
from app.config import JOB_DELETE_ACCOUNT from app.config import JOB_DELETE_ACCOUNT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
@ -10,21 +9,11 @@ from app.log import LOG
from app.models import Subscription, Job from app.models import Subscription, Job
class DeleteDirForm(FlaskForm):
pass
@dashboard_bp.route("/delete_account", methods=["GET", "POST"]) @dashboard_bp.route("/delete_account", methods=["GET", "POST"])
@login_required @login_required
@sudo_required @sudo_required
def delete_account(): def delete_account():
delete_form = DeleteDirForm()
if request.method == "POST" and request.form.get("form-name") == "delete-account": if request.method == "POST" and request.form.get("form-name") == "delete-account":
if not delete_form.validate():
flash("Invalid request", "warning")
return render_template(
"dashboard/delete_account.html", delete_form=delete_form
)
sub: Subscription = current_user.get_paddle_subscription() sub: Subscription = current_user.get_paddle_subscription()
# user who has canceled can also re-subscribe # user who has canceled can also re-subscribe
if sub and not sub.cancelled: if sub and not sub.cancelled:
@ -47,4 +36,6 @@ def delete_account():
) )
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
return render_template("dashboard/delete_account.html", delete_form=delete_form) return render_template(
"dashboard/delete_account.html",
)

View file

@ -1,15 +1,8 @@
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import ( from wtforms import StringField, validators
StringField,
validators,
SelectMultipleField,
BooleanField,
IntegerField,
)
from app import parallel_limiter
from app.config import ( from app.config import (
EMAIL_DOMAIN, EMAIL_DOMAIN,
ALIAS_DOMAINS, ALIAS_DOMAINS,
@ -28,25 +21,8 @@ class NewDirForm(FlaskForm):
) )
class ToggleDirForm(FlaskForm):
directory_id = IntegerField(validators=[validators.DataRequired()])
directory_enabled = BooleanField(validators=[])
class UpdateDirForm(FlaskForm):
directory_id = IntegerField(validators=[validators.DataRequired()])
mailbox_ids = SelectMultipleField(
validators=[validators.DataRequired()], validate_choice=False, choices=[]
)
class DeleteDirForm(FlaskForm):
directory_id = IntegerField(validators=[validators.DataRequired()])
@dashboard_bp.route("/directory", methods=["GET", "POST"]) @dashboard_bp.route("/directory", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
def directory(): def directory():
dirs = ( dirs = (
Directory.filter_by(user_id=current_user.id) Directory.filter_by(user_id=current_user.id)
@ -57,68 +33,54 @@ def directory():
mailboxes = current_user.mailboxes() mailboxes = current_user.mailboxes()
new_dir_form = NewDirForm() new_dir_form = NewDirForm()
toggle_dir_form = ToggleDirForm()
update_dir_form = UpdateDirForm()
update_dir_form.mailbox_ids.choices = [
(str(mailbox.id), str(mailbox.id)) for mailbox in mailboxes
]
delete_dir_form = DeleteDirForm()
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "delete": if request.form.get("form-name") == "delete":
if not delete_dir_form.validate(): dir_id = request.form.get("dir-id")
flash("Invalid request", "warning") dir = Directory.get(dir_id)
return redirect(url_for("dashboard.directory"))
dir_obj = Directory.get(delete_dir_form.directory_id.data)
if not dir_obj: if not dir:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
elif dir_obj.user_id != current_user.id: elif dir.user_id != current_user.id:
flash("You cannot delete this directory", "warning") flash("You cannot delete this directory", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
name = dir_obj.name name = dir.name
Directory.delete(dir_obj.id) Directory.delete(dir_id)
Session.commit() Session.commit()
flash(f"Directory {name} has been deleted", "success") flash(f"Directory {name} has been deleted", "success")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
if request.form.get("form-name") == "toggle-directory": if request.form.get("form-name") == "toggle-directory":
if not toggle_dir_form.validate(): dir_id = request.form.get("dir-id")
flash("Invalid request", "warning") dir = Directory.get(dir_id)
return redirect(url_for("dashboard.directory"))
dir_id = toggle_dir_form.directory_id.data
dir_obj = Directory.get(dir_id)
if not dir_obj or dir_obj.user_id != current_user.id: if not dir or dir.user_id != current_user.id:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
if toggle_dir_form.directory_enabled.data: if request.form.get("dir-status") == "on":
dir_obj.disabled = False dir.disabled = False
flash(f"On-the-fly is enabled for {dir_obj.name}", "success") flash(f"On-the-fly is enabled for {dir.name}", "success")
else: else:
dir_obj.disabled = True dir.disabled = True
flash(f"On-the-fly is disabled for {dir_obj.name}", "warning") flash(f"On-the-fly is disabled for {dir.name}", "warning")
Session.commit() Session.commit()
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
elif request.form.get("form-name") == "update": elif request.form.get("form-name") == "update":
if not update_dir_form.validate(): dir_id = request.form.get("dir-id")
flash("Invalid request", "warning") dir = Directory.get(dir_id)
return redirect(url_for("dashboard.directory"))
dir_id = update_dir_form.directory_id.data
dir_obj = Directory.get(dir_id)
if not dir_obj or dir_obj.user_id != current_user.id: if not dir or dir.user_id != current_user.id:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
mailbox_ids = update_dir_form.mailbox_ids.data mailbox_ids = request.form.getlist("mailbox_ids")
# check if mailbox is not tempered with # check if mailbox is not tempered with
mailboxes = [] mailboxes = []
for mailbox_id in mailbox_ids: for mailbox_id in mailbox_ids:
@ -137,14 +99,14 @@ def directory():
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
# first remove all existing directory-mailboxes links # first remove all existing directory-mailboxes links
DirectoryMailbox.filter_by(directory_id=dir_obj.id).delete() DirectoryMailbox.filter_by(directory_id=dir.id).delete()
Session.flush() Session.flush()
for mailbox in mailboxes: for mailbox in mailboxes:
DirectoryMailbox.create(directory_id=dir_obj.id, mailbox_id=mailbox.id) DirectoryMailbox.create(directory_id=dir.id, mailbox_id=mailbox.id)
Session.commit() Session.commit()
flash(f"Directory {dir_obj.name} has been updated", "success") flash(f"Directory {dir.name} has been updated", "success")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
elif request.form.get("form-name") == "create": elif request.form.get("form-name") == "create":
@ -219,9 +181,6 @@ def directory():
return render_template( return render_template(
"dashboard/directory.html", "dashboard/directory.html",
dirs=dirs, dirs=dirs,
toggle_dir_form=toggle_dir_form,
update_dir_form=update_dir_form,
delete_dir_form=delete_dir_form,
new_dir_form=new_dir_form, new_dir_form=new_dir_form,
mailboxes=mailboxes, mailboxes=mailboxes,
EMAIL_DOMAIN=EMAIL_DOMAIN, EMAIL_DOMAIN=EMAIL_DOMAIN,

View file

@ -7,13 +7,13 @@ from flask_wtf import FlaskForm
from wtforms import StringField, validators, IntegerField from wtforms import StringField, validators, IntegerField
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN
from app.custom_domain_validation import CustomDomainValidation
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.dns_utils import ( from app.dns_utils import (
get_mx_domains, get_mx_domains,
get_spf_domain, get_spf_domain,
get_txt_record, get_txt_record,
get_cname_record,
is_mx_equivalent, is_mx_equivalent,
) )
from app.log import LOG from app.log import LOG
@ -28,7 +28,7 @@ from app.models import (
Job, Job,
) )
from app.regex_utils import regex_match from app.regex_utils import regex_match
from app.utils import random_string, CSRFValidationForm from app.utils import random_string
@dashboard_bp.route("/domains/<int:custom_domain_id>/dns", methods=["GET", "POST"]) @dashboard_bp.route("/domains/<int:custom_domain_id>/dns", methods=["GET", "POST"])
@ -46,8 +46,8 @@ def domain_detail_dns(custom_domain_id):
spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all" spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all"
domain_validator = CustomDomainValidation(EMAIL_DOMAIN) # hardcode the DKIM selector here
csrf_form = CSRFValidationForm() dkim_cname = f"dkim._domainkey.{EMAIL_DOMAIN}"
dmarc_record = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s" dmarc_record = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"
@ -55,9 +55,6 @@ def domain_detail_dns(custom_domain_id):
mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = [] mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = []
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "check-ownership": if request.form.get("form-name") == "check-ownership":
txt_records = get_txt_record(custom_domain.domain) txt_records = get_txt_record(custom_domain.domain)
@ -125,17 +122,23 @@ def domain_detail_dns(custom_domain_id):
spf_errors = get_txt_record(custom_domain.domain) spf_errors = get_txt_record(custom_domain.domain)
elif request.form.get("form-name") == "check-dkim": elif request.form.get("form-name") == "check-dkim":
dkim_errors = domain_validator.validate_dkim_records(custom_domain) dkim_record = get_cname_record("dkim._domainkey." + custom_domain.domain)
if len(dkim_errors) == 0: if dkim_record == dkim_cname:
flash("DKIM is setup correctly.", "success") flash("DKIM is setup correctly.", "success")
custom_domain.dkim_verified = True
Session.commit()
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id "dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
) )
) )
else: else:
dkim_ok = False custom_domain.dkim_verified = False
Session.commit()
flash("DKIM: the CNAME record is not correctly set", "warning") flash("DKIM: the CNAME record is not correctly set", "warning")
dkim_ok = False
dkim_errors = [dkim_record or "[Empty]"]
elif request.form.get("form-name") == "check-dmarc": elif request.form.get("form-name") == "check-dmarc":
txt_records = get_txt_record("_dmarc." + custom_domain.domain) txt_records = get_txt_record("_dmarc." + custom_domain.domain)
@ -161,7 +164,6 @@ def domain_detail_dns(custom_domain_id):
return render_template( return render_template(
"dashboard/domain_detail/dns.html", "dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
dkim_records=domain_validator.get_dkim_records(),
**locals(), **locals(),
) )
@ -169,7 +171,6 @@ def domain_detail_dns(custom_domain_id):
@dashboard_bp.route("/domains/<int:custom_domain_id>/info", methods=["GET", "POST"]) @dashboard_bp.route("/domains/<int:custom_domain_id>/info", methods=["GET", "POST"])
@login_required @login_required
def domain_detail(custom_domain_id): def domain_detail(custom_domain_id):
csrf_form = CSRFValidationForm()
custom_domain: CustomDomain = CustomDomain.get(custom_domain_id) custom_domain: CustomDomain = CustomDomain.get(custom_domain_id)
mailboxes = current_user.mailboxes() mailboxes = current_user.mailboxes()
@ -178,9 +179,6 @@ def domain_detail(custom_domain_id):
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "switch-catch-all": if request.form.get("form-name") == "switch-catch-all":
custom_domain.catch_all = not custom_domain.catch_all custom_domain.catch_all = not custom_domain.catch_all
Session.commit() Session.commit()
@ -309,16 +307,12 @@ def domain_detail(custom_domain_id):
@dashboard_bp.route("/domains/<int:custom_domain_id>/trash", methods=["GET", "POST"]) @dashboard_bp.route("/domains/<int:custom_domain_id>/trash", methods=["GET", "POST"])
@login_required @login_required
def domain_detail_trash(custom_domain_id): def domain_detail_trash(custom_domain_id):
csrf_form = CSRFValidationForm()
custom_domain = CustomDomain.get(custom_domain_id) custom_domain = CustomDomain.get(custom_domain_id)
if not custom_domain or custom_domain.user_id != current_user.id: if not custom_domain or custom_domain.user_id != current_user.id:
flash("You cannot see this page", "warning") flash("You cannot see this page", "warning")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "empty-all": if request.form.get("form-name") == "empty-all":
DomainDeletedAlias.filter_by(domain_id=custom_domain.id).delete() DomainDeletedAlias.filter_by(domain_id=custom_domain.id).delete()
Session.commit() Session.commit()
@ -362,7 +356,6 @@ def domain_detail_trash(custom_domain_id):
"dashboard/domain_detail/trash.html", "dashboard/domain_detail/trash.html",
domain_deleted_aliases=domain_deleted_aliases, domain_deleted_aliases=domain_deleted_aliases,
custom_domain=custom_domain, custom_domain=custom_domain,
csrf_form=csrf_form,
) )

View file

@ -8,7 +8,6 @@ from wtforms import PasswordField, validators
from app.config import CONNECT_WITH_PROTON from app.config import CONNECT_WITH_PROTON
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import PartnerUser from app.models import PartnerUser
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
@ -22,7 +21,6 @@ class LoginForm(FlaskForm):
@dashboard_bp.route("/enter_sudo", methods=["GET", "POST"]) @dashboard_bp.route("/enter_sudo", methods=["GET", "POST"])
@limiter.limit("3/minute")
@login_required @login_required
def enter_sudo(): def enter_sudo():
password_check_form = LoginForm() password_check_form = LoginForm()

View file

@ -78,10 +78,10 @@ def fido_setup():
) )
flash("Security key has been activated", "success") flash("Security key has been activated", "success")
recovery_codes = RecoveryCode.generate(current_user) if not RecoveryCode.filter_by(user_id=current_user.id).all():
return render_template( return redirect(url_for("dashboard.recovery_code_route"))
"dashboard/recovery_code.html", recovery_codes=recovery_codes else:
) return redirect(url_for("dashboard.fido_manage"))
# Prepare information for key registration process # Prepare information for key registration process
fido_uuid = ( fido_uuid = (

View file

@ -3,7 +3,7 @@ from dataclasses import dataclass
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app import alias_utils, parallel_limiter from app import alias_utils
from app.api.serializer import get_alias_infos_with_pagination_v3, get_alias_info_v3 from app.api.serializer import get_alias_infos_with_pagination_v3, get_alias_info_v3
from app.config import ALIAS_LIMIT, PAGE_LIMIT from app.config import ALIAS_LIMIT, PAGE_LIMIT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
@ -17,7 +17,6 @@ from app.models import (
EmailLog, EmailLog,
Contact, Contact,
) )
from app.utils import CSRFValidationForm
@dataclass @dataclass
@ -57,15 +56,7 @@ def get_stats(user: User) -> Stats:
methods=["POST"], methods=["POST"],
exempt_when=lambda: request.form.get("form-name") != "create-random-email", exempt_when=lambda: request.form.get("form-name") != "create-random-email",
) )
@limiter.limit(
"5/minute",
methods=["GET"],
)
@login_required @login_required
@parallel_limiter.lock(
name="alias_creation",
only_when=lambda: request.form.get("form-name") == "create-random-email",
)
def index(): def index():
query = request.args.get("query") or "" query = request.args.get("query") or ""
sort = request.args.get("sort") or "" sort = request.args.get("sort") or ""
@ -84,12 +75,8 @@ def index():
"highlight_alias_id must be a number, received %s", "highlight_alias_id must be a number, received %s",
request.args.get("highlight_alias_id"), request.args.get("highlight_alias_id"),
) )
csrf_form = CSRFValidationForm()
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "create-custom-email": if request.form.get("form-name") == "create-custom-email":
if current_user.can_create_new_alias(): if current_user.can_create_new_alias():
return redirect(url_for("dashboard.custom_alias")) return redirect(url_for("dashboard.custom_alias"))
@ -154,13 +141,7 @@ def index():
flash(f"Alias {alias.email} has been disabled", "success") flash(f"Alias {alias.email} has been disabled", "success")
return redirect( return redirect(
url_for( url_for("dashboard.index", query=query, sort=sort, filter=alias_filter)
"dashboard.index",
query=query,
sort=sort,
filter=alias_filter,
page=page,
)
) )
mailboxes = current_user.mailboxes() mailboxes = current_user.mailboxes()
@ -223,7 +204,6 @@ def index():
sort=sort, sort=sort,
filter=alias_filter, filter=alias_filter,
stats=stats, stats=stats,
csrf_form=csrf_form,
) )

View file

@ -1,16 +1,11 @@
import base64
import binascii
import json
import arrow import arrow
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from itsdangerous import TimestampSigner from itsdangerous import Signer
from wtforms import validators, IntegerField from wtforms import validators
from wtforms.fields.html5 import EmailField from wtforms.fields.html5 import EmailField
from app import parallel_limiter
from app.config import MAILBOX_SECRET, URL, JOB_DELETE_MAILBOX from app.config import MAILBOX_SECRET, URL, JOB_DELETE_MAILBOX
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
@ -19,11 +14,10 @@ from app.email_utils import (
mailbox_already_used, mailbox_already_used,
render, render,
send_email, send_email,
is_valid_email,
) )
from app.email_validation import is_valid_email
from app.log import LOG from app.log import LOG
from app.models import Mailbox, Job from app.models import Mailbox, Job
from app.utils import CSRFValidationForm
class NewMailboxForm(FlaskForm): class NewMailboxForm(FlaskForm):
@ -32,16 +26,8 @@ class NewMailboxForm(FlaskForm):
) )
class DeleteMailboxForm(FlaskForm):
mailbox_id = IntegerField(
validators=[validators.DataRequired()],
)
transfer_mailbox_id = IntegerField()
@dashboard_bp.route("/mailbox", methods=["GET", "POST"]) @dashboard_bp.route("/mailbox", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
def mailbox_route(): def mailbox_route():
mailboxes = ( mailboxes = (
Mailbox.filter_by(user_id=current_user.id) Mailbox.filter_by(user_id=current_user.id)
@ -50,57 +36,25 @@ def mailbox_route():
) )
new_mailbox_form = NewMailboxForm() new_mailbox_form = NewMailboxForm()
csrf_form = CSRFValidationForm()
delete_mailbox_form = DeleteMailboxForm()
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "delete": if request.form.get("form-name") == "delete":
if not delete_mailbox_form.validate(): mailbox_id = request.form.get("mailbox-id")
flash("Invalid request", "warning") mailbox = Mailbox.get(mailbox_id)
return redirect(request.url)
mailbox = Mailbox.get(delete_mailbox_form.mailbox_id.data)
if not mailbox or mailbox.user_id != current_user.id: if not mailbox or mailbox.user_id != current_user.id:
flash("Invalid mailbox. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
if mailbox.id == current_user.default_mailbox_id: if mailbox.id == current_user.default_mailbox_id:
flash("You cannot delete default mailbox", "error") flash("You cannot delete default mailbox", "error")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
transfer_mailbox_id = delete_mailbox_form.transfer_mailbox_id.data
if transfer_mailbox_id and transfer_mailbox_id > 0:
transfer_mailbox = Mailbox.get(transfer_mailbox_id)
if not transfer_mailbox or transfer_mailbox.user_id != current_user.id:
flash(
"You must transfer the aliases to a mailbox you own.", "error"
)
return redirect(url_for("dashboard.mailbox_route"))
if transfer_mailbox.id == mailbox.id:
flash(
"You can not transfer the aliases to the mailbox you want to delete.",
"error",
)
return redirect(url_for("dashboard.mailbox_route"))
if not transfer_mailbox.verified:
flash("Your new mailbox is not verified", "error")
return redirect(url_for("dashboard.mailbox_route"))
# Schedule delete account job # Schedule delete account job
LOG.w( LOG.w("schedule delete mailbox job for %s", mailbox)
f"schedule delete mailbox job for {mailbox.id} with transfer to mailbox {transfer_mailbox_id}"
)
Job.create( Job.create(
name=JOB_DELETE_MAILBOX, name=JOB_DELETE_MAILBOX,
payload={ payload={"mailbox_id": mailbox.id},
"mailbox_id": mailbox.id,
"transfer_mailbox_id": transfer_mailbox_id
if transfer_mailbox_id > 0
else None,
},
run_at=arrow.now(), run_at=arrow.now(),
commit=True, commit=True,
) )
@ -113,10 +67,7 @@ def mailbox_route():
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
if request.form.get("form-name") == "set-default": if request.form.get("form-name") == "set-default":
if not csrf_form.validate(): mailbox_id = request.form.get("mailbox-id")
flash("Invalid request", "warning")
return redirect(request.url)
mailbox_id = request.form.get("mailbox_id")
mailbox = Mailbox.get(mailbox_id) mailbox = Mailbox.get(mailbox_id)
if not mailbox or mailbox.user_id != current_user.id: if not mailbox or mailbox.user_id != current_user.id:
@ -168,8 +119,7 @@ def mailbox_route():
return redirect( return redirect(
url_for( url_for(
"dashboard.mailbox_detail_route", "dashboard.mailbox_detail_route", mailbox_id=new_mailbox.id
mailbox_id=new_mailbox.id,
) )
) )
@ -177,16 +127,38 @@ def mailbox_route():
"dashboard/mailbox.html", "dashboard/mailbox.html",
mailboxes=mailboxes, mailboxes=mailboxes,
new_mailbox_form=new_mailbox_form, new_mailbox_form=new_mailbox_form,
delete_mailbox_form=delete_mailbox_form,
csrf_form=csrf_form,
) )
def delete_mailbox(mailbox_id: int):
from server import create_light_app
with create_light_app().app_context():
mailbox = Mailbox.get(mailbox_id)
if not mailbox:
return
mailbox_email = mailbox.email
user = mailbox.user
Mailbox.delete(mailbox_id)
Session.commit()
LOG.d("Mailbox %s %s deleted", mailbox_id, mailbox_email)
send_email(
user.email,
f"Your mailbox {mailbox_email} has been deleted",
f"""Mailbox {mailbox_email} along with its aliases are deleted successfully.
Regards,
SimpleLogin team.
""",
)
def send_verification_email(user, mailbox): def send_verification_email(user, mailbox):
s = TimestampSigner(MAILBOX_SECRET) s = Signer(MAILBOX_SECRET)
encoded_data = json.dumps([mailbox.id, mailbox.email]).encode("utf-8") mailbox_id_signed = s.sign(str(mailbox.id)).decode()
b64_data = base64.urlsafe_b64encode(encoded_data)
mailbox_id_signed = s.sign(b64_data).decode()
verification_url = ( verification_url = (
URL + "/dashboard/mailbox_verify" + f"?mailbox_id={mailbox_id_signed}" URL + "/dashboard/mailbox_verify" + f"?mailbox_id={mailbox_id_signed}"
) )
@ -210,35 +182,23 @@ def send_verification_email(user, mailbox):
@dashboard_bp.route("/mailbox_verify") @dashboard_bp.route("/mailbox_verify")
def mailbox_verify(): def mailbox_verify():
s = TimestampSigner(MAILBOX_SECRET) s = Signer(MAILBOX_SECRET)
mailbox_verify_request = request.args.get("mailbox_id") mailbox_id = request.args.get("mailbox_id")
try: try:
mailbox_raw_data = s.unsign(mailbox_verify_request, max_age=900) r_id = int(s.unsign(mailbox_id))
except Exception: except Exception:
flash("Invalid link. Please delete and re-add your mailbox", "error") flash("Invalid link. Please delete and re-add your mailbox", "error")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
try: else:
decoded_data = base64.urlsafe_b64decode(mailbox_raw_data) mailbox = Mailbox.get(r_id)
except binascii.Error: if not mailbox:
flash("Invalid link. Please delete and re-add your mailbox", "error") flash("Invalid link", "error")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
mailbox_data = json.loads(decoded_data)
if not isinstance(mailbox_data, list) or len(mailbox_data) != 2:
flash("Invalid link. Please delete and re-add your mailbox", "error")
return redirect(url_for("dashboard.mailbox_route"))
mailbox_id = mailbox_data[0]
mailbox = Mailbox.get(mailbox_id)
if not mailbox:
flash("Invalid link", "error")
return redirect(url_for("dashboard.mailbox_route"))
mailbox_email = mailbox_data[1]
if mailbox_email != mailbox.email:
flash("Invalid link", "error")
return redirect(url_for("dashboard.mailbox_route"))
mailbox.verified = True mailbox.verified = True
Session.commit() Session.commit()
LOG.d("Mailbox %s is verified", mailbox) LOG.d("Mailbox %s is verified", mailbox)
return render_template("dashboard/mailbox_validation.html", mailbox=mailbox) return render_template("dashboard/mailbox_validation.html", mailbox=mailbox)

View file

@ -1,10 +1,9 @@
from smtplib import SMTPRecipientsRefused from smtplib import SMTPRecipientsRefused
from email_validator import validate_email, EmailNotValidError
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from itsdangerous import TimestampSigner from itsdangerous import Signer
from wtforms import validators from wtforms import validators
from wtforms.fields.html5 import EmailField from wtforms.fields.html5 import EmailField
@ -18,7 +17,7 @@ from app.log import LOG
from app.models import Alias, AuthorizedAddress from app.models import Alias, AuthorizedAddress
from app.models import Mailbox from app.models import Mailbox
from app.pgp_utils import PGPException, load_public_key_and_check from app.pgp_utils import PGPException, load_public_key_and_check
from app.utils import sanitize_email, CSRFValidationForm from app.utils import sanitize_email
class ChangeEmailForm(FlaskForm): class ChangeEmailForm(FlaskForm):
@ -30,13 +29,12 @@ class ChangeEmailForm(FlaskForm):
@dashboard_bp.route("/mailbox/<int:mailbox_id>/", methods=["GET", "POST"]) @dashboard_bp.route("/mailbox/<int:mailbox_id>/", methods=["GET", "POST"])
@login_required @login_required
def mailbox_detail_route(mailbox_id): def mailbox_detail_route(mailbox_id):
mailbox: Mailbox = Mailbox.get(mailbox_id) mailbox = Mailbox.get(mailbox_id)
if not mailbox or mailbox.user_id != current_user.id: if not mailbox or mailbox.user_id != current_user.id:
flash("You cannot see this page", "warning") flash("You cannot see this page", "warning")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
change_email_form = ChangeEmailForm() change_email_form = ChangeEmailForm()
csrf_form = CSRFValidationForm()
if mailbox.new_email: if mailbox.new_email:
pending_email = mailbox.new_email pending_email = mailbox.new_email
@ -44,9 +42,6 @@ def mailbox_detail_route(mailbox_id):
pending_email = None pending_email = None
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if ( if (
request.form.get("form-name") == "update-email" request.form.get("form-name") == "update-email"
and change_email_form.validate_on_submit() and change_email_form.validate_on_submit()
@ -99,23 +94,16 @@ def mailbox_detail_route(mailbox_id):
) )
elif request.form.get("form-name") == "add-authorized-address": elif request.form.get("form-name") == "add-authorized-address":
address = sanitize_email(request.form.get("email")) address = sanitize_email(request.form.get("email"))
try: if AuthorizedAddress.get_by(mailbox_id=mailbox.id, email=address):
validate_email( flash(f"{address} already added", "error")
address, check_deliverability=False, allow_smtputf8=False
).domain
except EmailNotValidError:
flash(f"invalid {address}", "error")
else: else:
if AuthorizedAddress.get_by(mailbox_id=mailbox.id, email=address): AuthorizedAddress.create(
flash(f"{address} already added", "error") user_id=current_user.id,
else: mailbox_id=mailbox.id,
AuthorizedAddress.create( email=address,
user_id=current_user.id, commit=True,
mailbox_id=mailbox.id, )
email=address, flash(f"{address} added as authorized address", "success")
commit=True,
)
flash(f"{address} added as authorized address", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
@ -144,15 +132,6 @@ def mailbox_detail_route(mailbox_id):
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
) )
if mailbox.is_proton():
flash(
"Enabling PGP for a Proton Mail mailbox is redundant and does not add any security benefit",
"info",
)
return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
)
mailbox.pgp_public_key = request.form.get("pgp") mailbox.pgp_public_key = request.form.get("pgp")
try: try:
mailbox.pgp_finger_print = load_public_key_and_check( mailbox.pgp_finger_print = load_public_key_and_check(
@ -191,16 +170,25 @@ def mailbox_detail_route(mailbox_id):
) )
elif request.form.get("form-name") == "generic-subject": elif request.form.get("form-name") == "generic-subject":
if request.form.get("action") == "save": if request.form.get("action") == "save":
if not mailbox.pgp_enabled():
flash(
"Generic subject can only be used on PGP-enabled mailbox",
"error",
)
return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
)
mailbox.generic_subject = request.form.get("generic-subject") mailbox.generic_subject = request.form.get("generic-subject")
Session.commit() Session.commit()
flash("Generic subject is enabled", "success") flash("Generic subject for PGP-encrypted email is enabled", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
) )
elif request.form.get("action") == "remove": elif request.form.get("action") == "remove":
mailbox.generic_subject = None mailbox.generic_subject = None
Session.commit() Session.commit()
flash("Generic subject is disabled", "success") flash("Generic subject for PGP-encrypted email is disabled", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
) )
@ -210,7 +198,7 @@ def mailbox_detail_route(mailbox_id):
def verify_mailbox_change(user, mailbox, new_email): def verify_mailbox_change(user, mailbox, new_email):
s = TimestampSigner(MAILBOX_SECRET) s = Signer(MAILBOX_SECRET)
mailbox_id_signed = s.sign(str(mailbox.id)).decode() mailbox_id_signed = s.sign(str(mailbox.id)).decode()
verification_url = ( verification_url = (
f"{URL}/dashboard/mailbox/confirm_change?mailbox_id={mailbox_id_signed}" f"{URL}/dashboard/mailbox/confirm_change?mailbox_id={mailbox_id_signed}"
@ -262,11 +250,11 @@ def cancel_mailbox_change_route(mailbox_id):
@dashboard_bp.route("/mailbox/confirm_change") @dashboard_bp.route("/mailbox/confirm_change")
def mailbox_confirm_change_route(): def mailbox_confirm_change_route():
s = TimestampSigner(MAILBOX_SECRET) s = Signer(MAILBOX_SECRET)
signed_mailbox_id = request.args.get("mailbox_id") signed_mailbox_id = request.args.get("mailbox_id")
try: try:
mailbox_id = int(s.unsign(signed_mailbox_id, max_age=900)) mailbox_id = int(s.unsign(signed_mailbox_id))
except Exception: except Exception:
flash("Invalid link", "error") flash("Invalid link", "error")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))

View file

@ -5,7 +5,6 @@ from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.models import RecoveryCode from app.models import RecoveryCode
from app.utils import CSRFValidationForm
@dashboard_bp.route("/mfa_cancel", methods=["GET", "POST"]) @dashboard_bp.route("/mfa_cancel", methods=["GET", "POST"])
@ -16,13 +15,8 @@ def mfa_cancel():
flash("you don't have MFA enabled", "warning") flash("you don't have MFA enabled", "warning")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
csrf_form = CSRFValidationForm()
# user cancels TOTP # user cancels TOTP
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
current_user.enable_otp = False current_user.enable_otp = False
current_user.otp_secret = None current_user.otp_secret = None
Session.commit() Session.commit()
@ -34,4 +28,4 @@ def mfa_cancel():
flash("TOTP is now disabled", "warning") flash("TOTP is now disabled", "warning")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
return render_template("dashboard/mfa_cancel.html", csrf_form=csrf_form) return render_template("dashboard/mfa_cancel.html")

View file

@ -8,7 +8,6 @@ from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import RecoveryCode
class OtpTokenForm(FlaskForm): class OtpTokenForm(FlaskForm):
@ -40,10 +39,8 @@ def mfa_setup():
current_user.last_otp = token current_user.last_otp = token
Session.commit() Session.commit()
flash("MFA has been activated", "success") flash("MFA has been activated", "success")
recovery_codes = RecoveryCode.generate(current_user)
return render_template( return redirect(url_for("dashboard.recovery_code_route"))
"dashboard/recovery_code.html", recovery_codes=recovery_codes
)
else: else:
flash("Incorrect token", "warning") flash("Incorrect token", "warning")

View file

@ -80,9 +80,8 @@ def pricing():
@dashboard_bp.route("/subscription_success") @dashboard_bp.route("/subscription_success")
@login_required @login_required
def subscription_success(): def subscription_success():
return render_template( flash("Thanks so much for supporting SimpleLogin!", "success")
"dashboard/thank-you.html", return redirect(url_for("dashboard.index"))
)
@dashboard_bp.route("/coinbase_checkout") @dashboard_bp.route("/coinbase_checkout")

View file

@ -0,0 +1,30 @@
from flask import render_template, flash, redirect, url_for, request
from flask_login import login_required, current_user
from app.dashboard.base import dashboard_bp
from app.log import LOG
from app.models import RecoveryCode
@dashboard_bp.route("/recovery_code", methods=["GET", "POST"])
@login_required
def recovery_code_route():
if not current_user.two_factor_authentication_enabled():
flash("you need to enable either TOTP or WebAuthn", "warning")
return redirect(url_for("dashboard.index"))
recovery_codes = RecoveryCode.filter_by(user_id=current_user.id).all()
if request.method == "GET" and not recovery_codes:
# user arrives at this page for the first time
LOG.d("%s has no recovery keys, generate", current_user)
RecoveryCode.generate(current_user)
recovery_codes = RecoveryCode.filter_by(user_id=current_user.id).all()
if request.method == "POST":
RecoveryCode.generate(current_user)
flash("New recovery codes generated", "success")
return redirect(url_for("dashboard.recovery_code_route"))
return render_template(
"dashboard/recovery_code.html", recovery_codes=recovery_codes
)

View file

@ -12,6 +12,7 @@ from flask import (
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from flask_wtf.file import FileField from flask_wtf.file import FileField
from newrelic import agent
from wtforms import StringField, validators from wtforms import StringField, validators
from wtforms.fields.html5 import EmailField from wtforms.fields.html5 import EmailField
@ -29,7 +30,6 @@ from app.email_utils import (
personal_email_already_used, personal_email_already_used,
) )
from app.errors import ProtonPartnerNotSetUp from app.errors import ProtonPartnerNotSetUp
from app.extensions import limiter
from app.image_validation import detect_image_format, ImageFormat from app.image_validation import detect_image_format, ImageFormat
from app.jobs.export_user_data_job import ExportUserDataJob from app.jobs.export_user_data_job import ExportUserDataJob
from app.log import LOG from app.log import LOG
@ -53,12 +53,8 @@ from app.models import (
PartnerSubscription, PartnerSubscription,
UnsubscribeBehaviourEnum, UnsubscribeBehaviourEnum,
) )
from app.proton.utils import get_proton_partner, perform_proton_account_unlink from app.proton.utils import get_proton_partner
from app.utils import ( from app.utils import random_string, sanitize_email
random_string,
CSRFValidationForm,
canonicalize_email,
)
class SettingForm(FlaskForm): class SettingForm(FlaskForm):
@ -105,12 +101,10 @@ def get_partner_subscription_and_name(
@dashboard_bp.route("/setting", methods=["GET", "POST"]) @dashboard_bp.route("/setting", methods=["GET", "POST"])
@login_required @login_required
@limiter.limit("5/minute", methods=["POST"])
def setting(): def setting():
form = SettingForm() form = SettingForm()
promo_form = PromoCodeForm() promo_form = PromoCodeForm()
change_email_form = ChangeEmailForm() change_email_form = ChangeEmailForm()
csrf_form = CSRFValidationForm()
email_change = EmailChange.get_by(user_id=current_user.id) email_change = EmailChange.get_by(user_id=current_user.id)
if email_change: if email_change:
@ -119,15 +113,16 @@ def setting():
pending_email = None pending_email = None
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(url_for("dashboard.setting"))
if request.form.get("form-name") == "update-email": if request.form.get("form-name") == "update-email":
if change_email_form.validate(): if change_email_form.validate():
# whether user can proceed with the email update # whether user can proceed with the email update
new_email_valid = True new_email_valid = True
new_email = canonicalize_email(change_email_form.email.data) if (
if new_email != current_user.email and not pending_email: sanitize_email(change_email_form.email.data) != current_user.email
and not pending_email
):
new_email = sanitize_email(change_email_form.email.data)
# check if this email is not already used # check if this email is not already used
if personal_email_already_used(new_email) or Alias.get_by( if personal_email_already_used(new_email) or Alias.get_by(
email=new_email email=new_email
@ -197,16 +192,6 @@ def setting():
) )
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
if current_user.profile_picture_id is not None:
current_profile_file = File.get_by(
id=current_user.profile_picture_id
)
if (
current_profile_file is not None
and current_profile_file.user_id == current_user.id
):
s3.delete(current_profile_file.path)
file_path = random_string(30) file_path = random_string(30)
file = File.create(user_id=current_user.id, path=file_path) file = File.create(user_id=current_user.id, path=file_path)
@ -411,7 +396,6 @@ def setting():
return render_template( return render_template(
"dashboard/setting.html", "dashboard/setting.html",
csrf_form=csrf_form,
form=form, form=form,
PlanEnum=PlanEnum, PlanEnum=PlanEnum,
SenderFormatEnum=SenderFormatEnum, SenderFormatEnum=SenderFormatEnum,
@ -460,13 +444,8 @@ def send_change_email_confirmation(user: User, email_change: EmailChange):
@dashboard_bp.route("/resend_email_change", methods=["GET", "POST"]) @dashboard_bp.route("/resend_email_change", methods=["GET", "POST"])
@limiter.limit("5/hour")
@login_required @login_required
def resend_email_change(): def resend_email_change():
form = CSRFValidationForm()
if not form.validate():
flash("Invalid request. Please try again", "warning")
return redirect(url_for("dashboard.setting"))
email_change = EmailChange.get_by(user_id=current_user.id) email_change = EmailChange.get_by(user_id=current_user.id)
if email_change: if email_change:
# extend email change expiration # extend email change expiration
@ -486,10 +465,6 @@ def resend_email_change():
@dashboard_bp.route("/cancel_email_change", methods=["GET", "POST"]) @dashboard_bp.route("/cancel_email_change", methods=["GET", "POST"])
@login_required @login_required
def cancel_email_change(): def cancel_email_change():
form = CSRFValidationForm()
if not form.validate():
flash("Invalid request. Please try again", "warning")
return redirect(url_for("dashboard.setting"))
email_change = EmailChange.get_by(user_id=current_user.id) email_change = EmailChange.get_by(user_id=current_user.id)
if email_change: if email_change:
EmailChange.delete(email_change.id) EmailChange.delete(email_change.id)
@ -503,14 +478,16 @@ def cancel_email_change():
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
@dashboard_bp.route("/unlink_proton_account", methods=["POST"]) @dashboard_bp.route("/unlink_proton_account", methods=["GET", "POST"])
@login_required @login_required
def unlink_proton_account(): def unlink_proton_account():
csrf_form = CSRFValidationForm() proton_partner = get_proton_partner()
if not csrf_form.validate(): partner_user = PartnerUser.get_by(
flash("Invalid request", "warning") user_id=current_user.id, partner_id=proton_partner.id
return redirect(url_for("dashboard.setting")) )
if partner_user is not None:
perform_proton_account_unlink(current_user) PartnerUser.delete(partner_user.id)
Session.commit()
flash("Your Proton account has been unlinked", "success") flash("Your Proton account has been unlinked", "success")
agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name})
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))

View file

@ -2,10 +2,7 @@ import re
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm
from wtforms import StringField, validators
from app import parallel_limiter
from app.config import MAX_NB_SUBDOMAIN from app.config import MAX_NB_SUBDOMAIN
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.errors import SubdomainInTrashError from app.errors import SubdomainInTrashError
@ -16,18 +13,8 @@ from app.models import CustomDomain, Mailbox, SLDomain
_SUBDOMAIN_PATTERN = r"[0-9a-z-]{1,}" _SUBDOMAIN_PATTERN = r"[0-9a-z-]{1,}"
class NewSubdomainForm(FlaskForm):
domain = StringField(
"domain", validators=[validators.DataRequired(), validators.Length(max=64)]
)
subdomain = StringField(
"subdomain", validators=[validators.DataRequired(), validators.Length(max=64)]
)
@dashboard_bp.route("/subdomain", methods=["GET", "POST"]) @dashboard_bp.route("/subdomain", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
def subdomain_route(): def subdomain_route():
if not current_user.subdomain_is_available(): if not current_user.subdomain_is_available():
flash("Unknown error, redirect to the home page", "error") flash("Unknown error, redirect to the home page", "error")
@ -39,13 +26,9 @@ def subdomain_route():
).all() ).all()
errors = {} errors = {}
new_subdomain_form = NewSubdomainForm()
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "create": if request.form.get("form-name") == "create":
if not new_subdomain_form.validate():
flash("Invalid new subdomain", "warning")
return redirect(url_for("dashboard.subdomain_route"))
if not current_user.is_premium(): if not current_user.is_premium():
flash("Only premium plan can add subdomain", "warning") flash("Only premium plan can add subdomain", "warning")
return redirect(request.url) return redirect(request.url)
@ -56,8 +39,8 @@ def subdomain_route():
) )
return redirect(request.url) return redirect(request.url)
subdomain = new_subdomain_form.subdomain.data.lower().strip() subdomain = request.form.get("subdomain").lower().strip()
domain = new_subdomain_form.domain.data.lower().strip() domain = request.form.get("domain").lower().strip()
if len(subdomain) < 3: if len(subdomain) < 3:
flash("Subdomain must have at least 3 characters", "error") flash("Subdomain must have at least 3 characters", "error")
@ -125,5 +108,4 @@ def subdomain_route():
sl_domains=sl_domains, sl_domains=sl_domains,
errors=errors, errors=errors,
subdomains=subdomains, subdomains=subdomains,
new_subdomain_form=new_subdomain_form,
) )

View file

@ -75,11 +75,12 @@ def block_contact(contact_id):
@dashboard_bp.route("/unsubscribe/encoded/<encoded_request>", methods=["GET"]) @dashboard_bp.route("/unsubscribe/encoded/<encoded_request>", methods=["GET"])
@login_required @login_required
def encoded_unsubscribe(encoded_request: str): def encoded_unsubscribe(encoded_request: str):
unsub_data = UnsubscribeHandler().handle_unsubscribe_from_request( unsub_data = UnsubscribeHandler().handle_unsubscribe_from_request(
current_user, encoded_request current_user, encoded_request
) )
if not unsub_data: if not unsub_data:
flash("Invalid unsubscribe request", "error") flash(f"Invalid unsubscribe request", "error")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
if unsub_data.action == UnsubscribeAction.DisableAlias: if unsub_data.action == UnsubscribeAction.DisableAlias:
alias = Alias.get(unsub_data.data) alias = Alias.get(unsub_data.data)
@ -96,14 +97,14 @@ def encoded_unsubscribe(encoded_request: str):
) )
) )
if unsub_data.action == UnsubscribeAction.UnsubscribeNewsletter: if unsub_data.action == UnsubscribeAction.UnsubscribeNewsletter:
flash("You've unsubscribed from the newsletter", "success") flash(f"You've unsubscribed from the newsletter", "success")
return redirect( return redirect(
url_for( url_for(
"dashboard.index", "dashboard.index",
) )
) )
if unsub_data.action == UnsubscribeAction.OriginalUnsubscribeMailto: if unsub_data.action == UnsubscribeAction.OriginalUnsubscribeMailto:
flash("The original unsubscribe request has been forwarded", "success") flash(f"The original unsubscribe request has been forwarded", "success")
return redirect( return redirect(
url_for( url_for(
"dashboard.index", "dashboard.index",

View file

@ -3,12 +3,9 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from app import config from app.config import DB_URI
engine = create_engine(DB_URI)
engine = create_engine(
config.DB_URI, connect_args={"application_name": config.DB_CONN_NAME}
)
connection = engine.connect() connection = engine.connect()
Session = scoped_session(sessionmaker(bind=connection)) Session = scoped_session(sessionmaker(bind=connection))

View file

@ -1,3 +1 @@
from .views import index, new_client, client_detail from .views import index, new_client, client_detail
__all__ = ["index", "new_client", "client_detail"]

View file

@ -87,7 +87,7 @@ def client_detail(client_id):
) )
flash( flash(
"Thanks for submitting, we are informed and will come back to you asap!", f"Thanks for submitting, we are informed and will come back to you asap!",
"success", "success",
) )

View file

@ -1,3 +1 @@
from .views import index from .views import index
__all__ = ["index"]

View file

@ -34,7 +34,7 @@ def get_cname_record(hostname) -> Optional[str]:
def get_mx_domains(hostname) -> [(int, str)]: def get_mx_domains(hostname) -> [(int, str)]:
"""return list of (priority, domain name) sorted by priority (lowest priority first) """return list of (priority, domain name).
domain name ends with a "." at the end. domain name ends with a "." at the end.
""" """
try: try:
@ -50,7 +50,7 @@ def get_mx_domains(hostname) -> [(int, str)]:
ret.append((int(parts[0]), parts[1])) ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda prio_domain: prio_domain[0]) return ret
_include_spf = "include:" _include_spf = "include:"

View file

@ -20,7 +20,6 @@ X_SPAM_STATUS = "X-Spam-Status"
LIST_UNSUBSCRIBE = "List-Unsubscribe" LIST_UNSUBSCRIBE = "List-Unsubscribe"
LIST_UNSUBSCRIBE_POST = "List-Unsubscribe-Post" LIST_UNSUBSCRIBE_POST = "List-Unsubscribe-Post"
RETURN_PATH = "Return-Path" RETURN_PATH = "Return-Path"
AUTHENTICATION_RESULTS = "Authentication-Results"
# headers used to DKIM sign in order of preference # headers used to DKIM sign in order of preference
DKIM_HEADERS = [ DKIM_HEADERS = [
@ -33,7 +32,6 @@ DKIM_HEADERS = [
SL_DIRECTION = "X-SimpleLogin-Type" SL_DIRECTION = "X-SimpleLogin-Type"
SL_EMAIL_LOG_ID = "X-SimpleLogin-EmailLog-ID" SL_EMAIL_LOG_ID = "X-SimpleLogin-EmailLog-ID"
SL_ENVELOPE_FROM = "X-SimpleLogin-Envelope-From" SL_ENVELOPE_FROM = "X-SimpleLogin-Envelope-From"
SL_ORIGINAL_FROM = "X-SimpleLogin-Original-From"
SL_ENVELOPE_TO = "X-SimpleLogin-Envelope-To" SL_ENVELOPE_TO = "X-SimpleLogin-Envelope-To"
SL_CLIENT_IP = "X-SimpleLogin-Client-IP" SL_CLIENT_IP = "X-SimpleLogin-Client-IP"

View file

@ -31,7 +31,11 @@ E402 = "421 SL E402 Encryption failed - Retry later"
# E403 = "421 SL E403 Retry later" # E403 = "421 SL E403 Retry later"
E404 = "421 SL E404 Unexpected error - Retry later" E404 = "421 SL E404 Unexpected error - Retry later"
E405 = "421 SL E405 Mailbox domain problem - Retry later" E405 = "421 SL E405 Mailbox domain problem - Retry later"
E406 = "421 SL E406 Retry later"
E407 = "421 SL E407 Retry later" E407 = "421 SL E407 Retry later"
E408 = "421 SL E408 Retry later"
E409 = "421 SL E409 Retry later"
E410 = "421 SL E410 Retry later"
# endregion # endregion
# region 5** errors # region 5** errors
@ -60,5 +64,4 @@ E522 = (
) )
E523 = "550 SL E523 Unknown error" E523 = "550 SL E523 Unknown error"
E524 = "550 SL E524 Wrong use of reverse-alias" E524 = "550 SL E524 Wrong use of reverse-alias"
E525 = "550 SL E525 Alias loop"
# endregion # endregion

View file

@ -14,7 +14,7 @@ from email.header import decode_header, Header
from email.message import Message, EmailMessage from email.message import Message, EmailMessage
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from email.utils import make_msgid, formatdate, formataddr from email.utils import make_msgid, formatdate
from smtplib import SMTP, SMTPException from smtplib import SMTP, SMTPException
from typing import Tuple, List, Optional, Union from typing import Tuple, List, Optional, Union
@ -34,7 +34,30 @@ from flanker.addresslib.address import EmailAddress
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from sqlalchemy import func from sqlalchemy import func
from app import config from app.config import (
ROOT_DIR,
POSTFIX_SERVER,
DKIM_SELECTOR,
DKIM_PRIVATE_KEY,
ALIAS_DOMAINS,
POSTFIX_SUBMISSION_TLS,
MAX_NB_EMAIL_FREE_PLAN,
MAX_ALERT_24H,
POSTFIX_PORT,
URL,
LANDING_PAGE_URL,
EMAIL_DOMAIN,
ALERT_DIRECTORY_DISABLED_ALIAS_CREATION,
ALERT_SPF,
ALERT_INVALID_TOTP_LOGIN,
TEMP_DIR,
ALIAS_AUTOMATIC_DISABLE,
RSPAMD_SIGN_DKIM,
NOREPLY,
VERP_PREFIX,
VERP_MESSAGE_LIFETIME,
VERP_EMAIL_SECRET,
)
from app.db import Session from app.db import Session
from app.dns_utils import get_mx_domains from app.dns_utils import get_mx_domains
from app.email import headers from app.email import headers
@ -54,7 +77,6 @@ from app.models import (
IgnoreBounceSender, IgnoreBounceSender,
InvalidMailboxDomain, InvalidMailboxDomain,
VerpType, VerpType,
available_sl_email,
) )
from app.utils import ( from app.utils import (
random_string, random_string,
@ -69,31 +91,31 @@ VERP_HMAC_ALGO = "sha3-224"
def render(template_name, **kwargs) -> str: def render(template_name, **kwargs) -> str:
templates_dir = os.path.join(config.ROOT_DIR, "templates", "emails") templates_dir = os.path.join(ROOT_DIR, "templates", "emails")
env = Environment(loader=FileSystemLoader(templates_dir)) env = Environment(loader=FileSystemLoader(templates_dir))
template = env.get_template(template_name) template = env.get_template(template_name)
return template.render( return template.render(
MAX_NB_EMAIL_FREE_PLAN=config.MAX_NB_EMAIL_FREE_PLAN, MAX_NB_EMAIL_FREE_PLAN=MAX_NB_EMAIL_FREE_PLAN,
URL=config.URL, URL=URL,
LANDING_PAGE_URL=config.LANDING_PAGE_URL, LANDING_PAGE_URL=LANDING_PAGE_URL,
YEAR=arrow.now().year, YEAR=arrow.now().year,
**kwargs, **kwargs,
) )
def send_welcome_email(user): def send_welcome_email(user):
comm_email, unsubscribe_link, via_email = user.get_communication_email() to_email, unsubscribe_link, via_email = user.get_communication_email()
if not comm_email: if not to_email:
return return
# whether this email is sent to an alias # whether this email is sent to an alias
alias = comm_email if comm_email != user.email else None alias = to_email if to_email != user.email else None
send_email( send_email(
comm_email, to_email,
"Welcome to SimpleLogin", f"Welcome to SimpleLogin",
render("com/welcome.txt", user=user, alias=alias), render("com/welcome.txt", user=user, alias=alias),
render("com/welcome.html", user=user, alias=alias), render("com/welcome.html", user=user, alias=alias),
unsubscribe_link, unsubscribe_link,
@ -104,7 +126,7 @@ def send_welcome_email(user):
def send_trial_end_soon_email(user): def send_trial_end_soon_email(user):
send_email( send_email(
user.email, user.email,
"Your trial will end soon", f"Your trial will end soon",
render("transactional/trial-end.txt.jinja2", user=user), render("transactional/trial-end.txt.jinja2", user=user),
render("transactional/trial-end.html", user=user), render("transactional/trial-end.html", user=user),
ignore_smtp_error=True, ignore_smtp_error=True,
@ -114,7 +136,7 @@ def send_trial_end_soon_email(user):
def send_activation_email(email, activation_link): def send_activation_email(email, activation_link):
send_email( send_email(
email, email,
"Just one more step to join SimpleLogin", f"Just one more step to join SimpleLogin",
render( render(
"transactional/activation.txt", "transactional/activation.txt",
activation_link=activation_link, activation_link=activation_link,
@ -165,7 +187,7 @@ def send_change_email(new_email, current_email, link):
def send_invalid_totp_login_email(user, totp_type): def send_invalid_totp_login_email(user, totp_type):
send_email_with_rate_control( send_email_with_rate_control(
user, user,
config.ALERT_INVALID_TOTP_LOGIN, ALERT_INVALID_TOTP_LOGIN,
user.email, user.email,
"Unsuccessful attempt to login to your SimpleLogin account", "Unsuccessful attempt to login to your SimpleLogin account",
render( render(
@ -223,7 +245,7 @@ def send_cannot_create_directory_alias_disabled(user, alias_address, directory_n
""" """
send_email_with_rate_control( send_email_with_rate_control(
user, user,
config.ALERT_DIRECTORY_DISABLED_ALIAS_CREATION, ALERT_DIRECTORY_DISABLED_ALIAS_CREATION,
user.email, user.email,
f"Alias {alias_address} cannot be created", f"Alias {alias_address} cannot be created",
render( render(
@ -275,9 +297,8 @@ def send_email(
LOG.d("send email to %s, subject '%s'", to_email, subject) LOG.d("send email to %s, subject '%s'", to_email, subject)
from_name = from_name or config.NOREPLY from_name = from_name or NOREPLY
from_addr = from_addr or config.NOREPLY from_addr = from_addr or NOREPLY
from_domain = get_email_domain_part(from_addr)
if html: if html:
msg = MIMEMultipart("alternative") msg = MIMEMultipart("alternative")
@ -292,14 +313,13 @@ def send_email(
msg[headers.FROM] = f'"{from_name}" <{from_addr}>' msg[headers.FROM] = f'"{from_name}" <{from_addr}>'
msg[headers.TO] = to_email msg[headers.TO] = to_email
msg_id_header = make_msgid(domain=config.EMAIL_DOMAIN) msg_id_header = make_msgid(domain=EMAIL_DOMAIN)
msg[headers.MESSAGE_ID] = msg_id_header msg[headers.MESSAGE_ID] = msg_id_header
date_header = formatdate() date_header = formatdate()
msg[headers.DATE] = date_header msg[headers.DATE] = date_header
if headers.MIME_VERSION not in msg: msg[headers.MIME_VERSION] = "1.0"
msg[headers.MIME_VERSION] = "1.0"
if unsubscribe_link: if unsubscribe_link:
add_or_replace_header(msg, headers.LIST_UNSUBSCRIBE, f"<{unsubscribe_link}>") add_or_replace_header(msg, headers.LIST_UNSUBSCRIBE, f"<{unsubscribe_link}>")
@ -316,7 +336,7 @@ def send_email(
# use a different envelope sender for each transactional email (aka VERP) # use a different envelope sender for each transactional email (aka VERP)
sl_sendmail( sl_sendmail(
generate_verp_email(VerpType.transactional, transaction.id, from_domain), generate_verp_email(VerpType.transactional, transaction.id),
to_email, to_email,
msg, msg,
retries=retries, retries=retries,
@ -331,7 +351,7 @@ def send_email_with_rate_control(
subject, subject,
plaintext, plaintext,
html=None, html=None,
max_nb_alert=config.MAX_ALERT_24H, max_nb_alert=MAX_ALERT_24H,
nb_day=1, nb_day=1,
ignore_smtp_error=False, ignore_smtp_error=False,
retries=0, retries=0,
@ -428,7 +448,7 @@ def get_email_domain_part(address):
def add_dkim_signature(msg: Message, email_domain: str): def add_dkim_signature(msg: Message, email_domain: str):
if config.RSPAMD_SIGN_DKIM: if RSPAMD_SIGN_DKIM:
LOG.d("DKIM signature will be added by rspamd") LOG.d("DKIM signature will be added by rspamd")
msg[headers.SL_WANT_SIGNING] = "yes" msg[headers.SL_WANT_SIGNING] = "yes"
return return
@ -443,9 +463,9 @@ def add_dkim_signature(msg: Message, email_domain: str):
continue continue
# To investigate why some emails can't be DKIM signed. todo: remove # To investigate why some emails can't be DKIM signed. todo: remove
if config.TEMP_DIR: if TEMP_DIR:
file_name = str(uuid.uuid4()) + ".eml" file_name = str(uuid.uuid4()) + ".eml"
with open(os.path.join(config.TEMP_DIR, file_name), "wb") as f: with open(os.path.join(TEMP_DIR, file_name), "wb") as f:
f.write(msg.as_bytes()) f.write(msg.as_bytes())
LOG.w("email saved to %s", file_name) LOG.w("email saved to %s", file_name)
@ -460,12 +480,12 @@ def add_dkim_signature_with_header(
# Specify headers in "byte" form # Specify headers in "byte" form
# Generate message signature # Generate message signature
if config.DKIM_PRIVATE_KEY: if DKIM_PRIVATE_KEY:
sig = dkim.sign( sig = dkim.sign(
message_to_bytes(msg), message_to_bytes(msg),
config.DKIM_SELECTOR, DKIM_SELECTOR,
email_domain.encode(), email_domain.encode(),
config.DKIM_PRIVATE_KEY.encode(), DKIM_PRIVATE_KEY.encode(),
include_headers=dkim_headers, include_headers=dkim_headers,
) )
sig = sig.decode() sig = sig.decode()
@ -517,7 +537,7 @@ def delete_all_headers_except(msg: Message, headers: [str]):
def can_create_directory_for_address(email_address: str) -> bool: def can_create_directory_for_address(email_address: str) -> bool:
"""return True if an email ends with one of the alias domains provided by SimpleLogin""" """return True if an email ends with one of the alias domains provided by SimpleLogin"""
# not allow creating directory with premium domain # not allow creating directory with premium domain
for domain in config.ALIAS_DOMAINS: for domain in ALIAS_DOMAINS:
if email_address.endswith("@" + domain): if email_address.endswith("@" + domain):
return True return True
@ -574,7 +594,7 @@ def email_can_be_used_as_mailbox(email_address: str) -> bool:
mx_domains = get_mx_domain_list(domain) mx_domains = get_mx_domain_list(domain)
# if no MX record, email is not valid # if no MX record, email is not valid
if not config.SKIP_MX_LOOKUP_ON_CHECK and not mx_domains: if not mx_domains:
LOG.d("No MX record for domain %s", domain) LOG.d("No MX record for domain %s", domain)
return False return False
@ -768,7 +788,7 @@ def get_header_unicode(header: Union[str, Header]) -> str:
ret = "" ret = ""
for to_decoded_str, charset in decode_header(header): for to_decoded_str, charset in decode_header(header):
if charset is None: if charset is None:
if isinstance(to_decoded_str, bytes): if type(to_decoded_str) is bytes:
decoded_str = to_decoded_str.decode() decoded_str = to_decoded_str.decode()
else: else:
decoded_str = to_decoded_str decoded_str = to_decoded_str
@ -805,13 +825,13 @@ def to_bytes(msg: Message):
for generator_policy in [None, policy.SMTP, policy.SMTPUTF8]: for generator_policy in [None, policy.SMTP, policy.SMTPUTF8]:
try: try:
return msg.as_bytes(policy=generator_policy) return msg.as_bytes(policy=generator_policy)
except Exception: except:
LOG.w("as_bytes() fails with %s policy", policy, exc_info=True) LOG.w("as_bytes() fails with %s policy", policy, exc_info=True)
msg_string = msg.as_string() msg_string = msg.as_string()
try: try:
return msg_string.encode() return msg_string.encode()
except Exception: except:
LOG.w("as_string().encode() fails", exc_info=True) LOG.w("as_string().encode() fails", exc_info=True)
return msg_string.encode(errors="replace") return msg_string.encode(errors="replace")
@ -828,6 +848,19 @@ def should_add_dkim_signature(domain: str) -> bool:
return False return False
def is_valid_email(email_address: str) -> bool:
"""
Used to check whether an email address is valid
NOT run MX check.
NOT allow unicode.
"""
try:
validate_email(email_address, check_deliverability=False, allow_smtputf8=False)
return True
except EmailNotValidError:
return False
class EmailEncoding(enum.Enum): class EmailEncoding(enum.Enum):
BASE64 = "base64" BASE64 = "base64"
QUOTED = "quoted-printable" QUOTED = "quoted-printable"
@ -898,25 +931,22 @@ def decode_text(text: str, encoding: EmailEncoding = EmailEncoding.NO) -> str:
return text return text
def add_header(msg: Message, text_header, html_header=None) -> Message: def add_header(msg: Message, text_header, html_header) -> Message:
if not html_header:
html_header = text_header.replace("\n", "<br>")
content_type = msg.get_content_type().lower() content_type = msg.get_content_type().lower()
if content_type == "text/plain": if content_type == "text/plain":
encoding = get_encoding(msg) encoding = get_encoding(msg)
payload = msg.get_payload() payload = msg.get_payload()
if isinstance(payload, str): if type(payload) is str:
clone_msg = copy(msg) clone_msg = copy(msg)
new_payload = f"""{text_header} new_payload = f"""{text_header}
------------------------------ ---
{decode_text(payload, encoding)}""" {decode_text(payload, encoding)}"""
clone_msg.set_payload(encode_text(new_payload, encoding)) clone_msg.set_payload(encode_text(new_payload, encoding))
return clone_msg return clone_msg
elif content_type == "text/html": elif content_type == "text/html":
encoding = get_encoding(msg) encoding = get_encoding(msg)
payload = msg.get_payload() payload = msg.get_payload()
if isinstance(payload, str): if type(payload) is str:
new_payload = f"""<table width="100%" style="width: 100%; -premailer-width: 100%; -premailer-cellpadding: 0; new_payload = f"""<table width="100%" style="width: 100%; -premailer-width: 100%; -premailer-cellpadding: 0;
-premailer-cellspacing: 0; margin: 0; padding: 0;"> -premailer-cellspacing: 0; margin: 0; padding: 0;">
<tr> <tr>
@ -938,8 +968,6 @@ def add_header(msg: Message, text_header, html_header=None) -> Message:
for part in msg.get_payload(): for part in msg.get_payload():
if isinstance(part, Message): if isinstance(part, Message):
new_parts.append(add_header(part, text_header, html_header)) new_parts.append(add_header(part, text_header, html_header))
elif isinstance(part, str):
new_parts.append(MIMEText(part))
else: else:
new_parts.append(part) new_parts.append(part)
clone_msg = copy(msg) clone_msg = copy(msg)
@ -948,14 +976,7 @@ def add_header(msg: Message, text_header, html_header=None) -> Message:
elif content_type in ("multipart/mixed", "multipart/signed"): elif content_type in ("multipart/mixed", "multipart/signed"):
new_parts = [] new_parts = []
payload = msg.get_payload() parts = list(msg.get_payload())
if isinstance(payload, str):
# The message is badly formatted inject as new
new_parts = [MIMEText(text_header, "plain"), MIMEText(payload, "plain")]
clone_msg = copy(msg)
clone_msg.set_payload(new_parts)
return clone_msg
parts = list(payload)
LOG.d("only add header for the first part for %s", content_type) LOG.d("only add header for the first part for %s", content_type)
for ix, part in enumerate(parts): for ix, part in enumerate(parts):
if ix == 0: if ix == 0:
@ -971,11 +992,7 @@ def add_header(msg: Message, text_header, html_header=None) -> Message:
return msg return msg
def replace(msg: Union[Message, str], old, new) -> Union[Message, str]: def replace(msg: Message, old, new) -> Message:
if isinstance(msg, str):
msg = msg.replace(old, new)
return msg
content_type = msg.get_content_type() content_type = msg.get_content_type()
if ( if (
@ -995,7 +1012,7 @@ def replace(msg: Union[Message, str], old, new) -> Union[Message, str]:
if content_type in ("text/plain", "text/html"): if content_type in ("text/plain", "text/html"):
encoding = get_encoding(msg) encoding = get_encoding(msg)
payload = msg.get_payload() payload = msg.get_payload()
if isinstance(payload, str): if type(payload) is str:
if encoding == EmailEncoding.QUOTED: if encoding == EmailEncoding.QUOTED:
LOG.d("handle quoted-printable replace %s -> %s", old, new) LOG.d("handle quoted-printable replace %s -> %s", old, new)
# first decode the payload # first decode the payload
@ -1040,7 +1057,7 @@ def replace(msg: Union[Message, str], old, new) -> Union[Message, str]:
return msg return msg
def generate_reply_email(contact_email: str, alias: Alias) -> str: def generate_reply_email(contact_email: str, user: User) -> str:
""" """
generate a reply_email (aka reverse-alias), make sure it isn't used by any contact generate a reply_email (aka reverse-alias), make sure it isn't used by any contact
""" """
@ -1051,7 +1068,6 @@ def generate_reply_email(contact_email: str, alias: Alias) -> str:
include_sender_in_reverse_alias = False include_sender_in_reverse_alias = False
user = alias.user
# user has set this option explicitly # user has set this option explicitly
if user.include_sender_in_reverse_alias is not None: if user.include_sender_in_reverse_alias is not None:
include_sender_in_reverse_alias = user.include_sender_in_reverse_alias include_sender_in_reverse_alias = user.include_sender_in_reverse_alias
@ -1066,28 +1082,22 @@ def generate_reply_email(contact_email: str, alias: Alias) -> str:
contact_email = contact_email.replace(".", "_") contact_email = contact_email.replace(".", "_")
contact_email = convert_to_alphanumeric(contact_email) contact_email = convert_to_alphanumeric(contact_email)
reply_domain = config.EMAIL_DOMAIN
alias_domain = get_email_domain_part(alias.email)
sl_domain = SLDomain.get_by(domain=alias_domain)
if sl_domain and sl_domain.use_as_reverse_alias:
reply_domain = alias_domain
# not use while to avoid infinite loop # not use while to avoid infinite loop
for _ in range(1000): for _ in range(1000):
if include_sender_in_reverse_alias and contact_email: if include_sender_in_reverse_alias and contact_email:
random_length = random.randint(5, 10) random_length = random.randint(5, 10)
reply_email = ( reply_email = (
# do not use the ra+ anymore # do not use the ra+ anymore
# f"ra+{contact_email}+{random_string(random_length)}@{config.EMAIL_DOMAIN}" # f"ra+{contact_email}+{random_string(random_length)}@{EMAIL_DOMAIN}"
f"{contact_email}_{random_string(random_length)}@{reply_domain}" f"{contact_email}_{random_string(random_length)}@{EMAIL_DOMAIN}"
) )
else: else:
random_length = random.randint(20, 50) random_length = random.randint(20, 50)
# do not use the ra+ anymore # do not use the ra+ anymore
# reply_email = f"ra+{random_string(random_length)}@{config.EMAIL_DOMAIN}" # reply_email = f"ra+{random_string(random_length)}@{EMAIL_DOMAIN}"
reply_email = f"{random_string(random_length)}@{reply_domain}" reply_email = f"{random_string(random_length)}@{EMAIL_DOMAIN}"
if available_sl_email(reply_email): if not Contact.get_by(reply_email=reply_email):
return reply_email return reply_email
raise Exception("Cannot generate reply email") raise Exception("Cannot generate reply email")
@ -1098,11 +1108,31 @@ def is_reverse_alias(address: str) -> bool:
if Contact.get_by(reply_email=address): if Contact.get_by(reply_email=address):
return True return True
return address.endswith(f"@{config.EMAIL_DOMAIN}") and ( return address.endswith(f"@{EMAIL_DOMAIN}") and (
address.startswith("reply+") or address.startswith("ra+") address.startswith("reply+") or address.startswith("ra+")
) )
# allow also + and @ that are present in a reply address
_ALLOWED_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.+@"
def normalize_reply_email(reply_email: str) -> str:
"""Handle the case where reply email contains *strange* char that was wrongly generated in the past"""
if not reply_email.isascii():
reply_email = convert_to_id(reply_email)
ret = []
# drop all control characters like shift, separator, etc
for c in reply_email:
if c not in _ALLOWED_CHARS:
ret.append("_")
else:
ret.append(c)
return "".join(ret)
def should_disable(alias: Alias) -> (bool, str): def should_disable(alias: Alias) -> (bool, str):
""" """
Return whether an alias should be disabled and if yes, the reason why Return whether an alias should be disabled and if yes, the reason why
@ -1112,7 +1142,7 @@ def should_disable(alias: Alias) -> (bool, str):
LOG.w("%s cannot be disabled", alias) LOG.w("%s cannot be disabled", alias)
return False, "" return False, ""
if not config.ALIAS_AUTOMATIC_DISABLE: if not ALIAS_AUTOMATIC_DISABLE:
return False, "" return False, ""
yesterday = arrow.now().shift(days=-1) yesterday = arrow.now().shift(days=-1)
@ -1227,14 +1257,14 @@ def spf_pass(
subject = get_header_unicode(msg[headers.SUBJECT]) subject = get_header_unicode(msg[headers.SUBJECT])
send_email_with_rate_control( send_email_with_rate_control(
user, user,
config.ALERT_SPF, ALERT_SPF,
mailbox.email, mailbox.email,
f"SimpleLogin Alert: attempt to send emails from your alias {alias.email} from unknown IP Address", f"SimpleLogin Alert: attempt to send emails from your alias {alias.email} from unknown IP Address",
render( render(
"transactional/spf-fail.txt", "transactional/spf-fail.txt",
alias=alias.email, alias=alias.email,
ip=ip, ip=ip,
mailbox_url=config.URL + f"/dashboard/mailbox/{mailbox.id}#spf", mailbox_url=URL + f"/dashboard/mailbox/{mailbox.id}#spf",
to_email=contact_email, to_email=contact_email,
subject=subject, subject=subject,
time=arrow.now(), time=arrow.now(),
@ -1242,7 +1272,7 @@ def spf_pass(
render( render(
"transactional/spf-fail.html", "transactional/spf-fail.html",
ip=ip, ip=ip,
mailbox_url=config.URL + f"/dashboard/mailbox/{mailbox.id}#spf", mailbox_url=URL + f"/dashboard/mailbox/{mailbox.id}#spf",
to_email=contact_email, to_email=contact_email,
subject=subject, subject=subject,
time=arrow.now(), time=arrow.now(),
@ -1265,11 +1295,11 @@ def spf_pass(
@cached(cache=TTLCache(maxsize=2, ttl=20)) @cached(cache=TTLCache(maxsize=2, ttl=20))
def get_smtp_server(): def get_smtp_server():
LOG.d("get a smtp server") LOG.d("get a smtp server")
if config.POSTFIX_SUBMISSION_TLS: if POSTFIX_SUBMISSION_TLS:
smtp = SMTP(config.POSTFIX_SERVER, 587) smtp = SMTP(POSTFIX_SERVER, 587)
smtp.starttls() smtp.starttls()
else: else:
smtp = SMTP(config.POSTFIX_SERVER, config.POSTFIX_PORT) smtp = SMTP(POSTFIX_SERVER, POSTFIX_PORT)
return smtp return smtp
@ -1341,12 +1371,12 @@ def save_email_for_debugging(msg: Message, file_name_prefix=None) -> str:
"""Save email for debugging to temporary location """Save email for debugging to temporary location
Return the file path Return the file path
""" """
if config.TEMP_DIR: if TEMP_DIR:
file_name = str(uuid.uuid4()) + ".eml" file_name = str(uuid.uuid4()) + ".eml"
if file_name_prefix: if file_name_prefix:
file_name = "{}-{}".format(file_name_prefix, file_name) file_name = "{}-{}".format(file_name_prefix, file_name)
with open(os.path.join(config.TEMP_DIR, file_name), "wb") as f: with open(os.path.join(TEMP_DIR, file_name), "wb") as f:
f.write(msg.as_bytes()) f.write(msg.as_bytes())
LOG.d("email saved to %s", file_name) LOG.d("email saved to %s", file_name)
@ -1359,12 +1389,12 @@ def save_envelope_for_debugging(envelope: Envelope, file_name_prefix=None) -> st
"""Save envelope for debugging to temporary location """Save envelope for debugging to temporary location
Return the file path Return the file path
""" """
if config.TEMP_DIR: if TEMP_DIR:
file_name = str(uuid.uuid4()) + ".eml" file_name = str(uuid.uuid4()) + ".eml"
if file_name_prefix: if file_name_prefix:
file_name = "{}-{}".format(file_name_prefix, file_name) file_name = "{}-{}".format(file_name_prefix, file_name)
with open(os.path.join(config.TEMP_DIR, file_name), "wb") as f: with open(os.path.join(TEMP_DIR, file_name), "wb") as f:
f.write(envelope.original_content) f.write(envelope.original_content)
LOG.d("envelope saved to %s", file_name) LOG.d("envelope saved to %s", file_name)
@ -1390,15 +1420,12 @@ def generate_verp_email(
# Signing without itsdangereous because it uses base64 that includes +/= symbols and lower and upper case letters. # Signing without itsdangereous because it uses base64 that includes +/= symbols and lower and upper case letters.
# We need to encode in base32 # We need to encode in base32
payload_hmac = hmac.new( payload_hmac = hmac.new(
config.VERP_EMAIL_SECRET.encode("utf-8"), json_payload, VERP_HMAC_ALGO VERP_EMAIL_SECRET.encode("utf-8"), json_payload, VERP_HMAC_ALGO
).digest()[:8] ).digest()[:8]
encoded_payload = base64.b32encode(json_payload).rstrip(b"=").decode("utf-8") encoded_payload = base64.b32encode(json_payload).rstrip(b"=").decode("utf-8")
encoded_signature = base64.b32encode(payload_hmac).rstrip(b"=").decode("utf-8") encoded_signature = base64.b32encode(payload_hmac).rstrip(b"=").decode("utf-8")
return "{}.{}.{}@{}".format( return "{}.{}.{}@{}".format(
config.VERP_PREFIX, VERP_PREFIX, encoded_payload, encoded_signature, sender_domain or EMAIL_DOMAIN
encoded_payload,
encoded_signature,
sender_domain or config.EMAIL_DOMAIN,
).lower() ).lower()
@ -1411,7 +1438,7 @@ def get_verp_info_from_email(email: str) -> Optional[Tuple[VerpType, int]]:
return None return None
username = email[:idx] username = email[:idx]
fields = username.split(".") fields = username.split(".")
if len(fields) != 3 or fields[0] != config.VERP_PREFIX: if len(fields) != 3 or fields[0] != VERP_PREFIX:
return None return None
try: try:
padding = (8 - (len(fields[1]) % 8)) % 8 padding = (8 - (len(fields[1]) % 8)) % 8
@ -1423,7 +1450,7 @@ def get_verp_info_from_email(email: str) -> Optional[Tuple[VerpType, int]]:
except binascii.Error: except binascii.Error:
return None return None
expected_signature = hmac.new( expected_signature = hmac.new(
config.VERP_EMAIL_SECRET.encode("utf-8"), payload, VERP_HMAC_ALGO VERP_EMAIL_SECRET.encode("utf-8"), payload, VERP_HMAC_ALGO
).digest()[:8] ).digest()[:8]
if expected_signature != signature: if expected_signature != signature:
return None return None
@ -1431,13 +1458,6 @@ def get_verp_info_from_email(email: str) -> Optional[Tuple[VerpType, int]]:
# verp type, object_id, time # verp type, object_id, time
if len(data) != 3: if len(data) != 3:
return None return None
if data[2] > (time.time() + config.VERP_MESSAGE_LIFETIME - VERP_TIME_START) / 60: if data[2] > (time.time() + VERP_MESSAGE_LIFETIME - VERP_TIME_START) / 60:
return None return None
return VerpType(data[0]), data[1] return VerpType(data[0]), data[1]
def sl_formataddr(name_address_tuple: Tuple[str, str]):
"""Same as formataddr but use utf-8 encoding by default and always return str (and never Header)"""
name, addr = name_address_tuple
# formataddr can return Header, make sure to convert to str
return str(formataddr((name, Header(addr, "utf-8"))))

View file

@ -1,38 +0,0 @@
from email_validator import (
validate_email,
EmailNotValidError,
)
from app.utils import convert_to_id
# allow also + and @ that are present in a reply address
_ALLOWED_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.+@"
def is_valid_email(email_address: str) -> bool:
"""
Used to check whether an email address is valid
NOT run MX check.
NOT allow unicode.
"""
try:
validate_email(email_address, check_deliverability=False, allow_smtputf8=False)
return True
except EmailNotValidError:
return False
def normalize_reply_email(reply_email: str) -> str:
"""Handle the case where reply email contains *strange* char that was wrongly generated in the past"""
if not reply_email.isascii():
reply_email = convert_to_id(reply_email)
ret = []
# drop all control characters like shift, separator, etc
for c in reply_email:
if c not in _ALLOWED_CHARS:
ret.append("_")
else:
ret.append(c)
return "".join(ret)

View file

@ -71,7 +71,7 @@ class ErrContactErrorUpgradeNeeded(SLException):
"""raised when user cannot create a contact because the plan doesn't allow it""" """raised when user cannot create a contact because the plan doesn't allow it"""
def error_for_user(self) -> str: def error_for_user(self) -> str:
return "Please upgrade to premium to create reverse-alias" return f"Please upgrade to premium to create reverse-alias"
class ErrAddressInvalid(SLException): class ErrAddressInvalid(SLException):
@ -84,14 +84,6 @@ class ErrAddressInvalid(SLException):
return f"{self.address} is not a valid email address" return f"{self.address} is not a valid email address"
class InvalidContactEmailError(SLException):
def __init__(self, website_email: str): # noqa: F821
self.website_email = website_email
def error_for_user(self) -> str:
return f"Cannot create contact with invalid email {self.website_email}"
class ErrContactAlreadyExists(SLException): class ErrContactAlreadyExists(SLException):
"""raised when a contact already exists""" """raised when a contact already exists"""
@ -116,15 +108,3 @@ class AccountAlreadyLinkedToAnotherPartnerException(LinkException):
class AccountAlreadyLinkedToAnotherUserException(LinkException): class AccountAlreadyLinkedToAnotherUserException(LinkException):
def __init__(self): def __init__(self):
super().__init__("This account is linked to another user") super().__init__("This account is linked to another user")
class AccountIsUsingAliasAsEmail(LinkException):
def __init__(self):
super().__init__("Your account has an alias as it's email address")
class ProtonAccountNotVerified(LinkException):
def __init__(self):
super().__init__(
"The Proton account you are trying to use has not been verified"
)

View file

@ -9,7 +9,6 @@ class LoginEvent:
failed = 1 failed = 1
disabled_login = 2 disabled_login = 2
not_activated = 3 not_activated = 3
scheduled_to_be_deleted = 4
class Source(EnumE): class Source(EnumE):
web = 0 web = 0

View file

@ -1,31 +1,12 @@
from flask_limiter import Limiter from flask_limiter import Limiter
from flask_limiter.util import get_remote_address from flask_limiter.util import get_remote_address
from flask_login import current_user, LoginManager from flask_login import LoginManager
from app import config
login_manager = LoginManager() login_manager = LoginManager()
login_manager.session_protection = "strong" login_manager.session_protection = "strong"
# We want to rate limit based on:
# - If the user is not logged in: request source IP
# - If the user is logged in: user_id
def __key_func():
if current_user.is_authenticated:
return f"userid:{current_user.id}"
else:
ip_addr = get_remote_address()
return f"ip:{ip_addr}"
# Setup rate limit facility # Setup rate limit facility
limiter = Limiter(key_func=__key_func) limiter = Limiter(key_func=get_remote_address)
@limiter.request_filter
def disable_rate_limit():
return config.DISABLE_RATE_LIMIT
# @limiter.request_filter # @limiter.request_filter

View file

@ -5,7 +5,7 @@ from typing import Optional, Tuple
from aiosmtpd.handlers import Message from aiosmtpd.handlers import Message
from aiosmtpd.smtp import Envelope from aiosmtpd.smtp import Envelope
from app import s3, config from app import s3
from app.config import ( from app.config import (
DMARC_CHECK_ENABLED, DMARC_CHECK_ENABLED,
ALERT_QUARANTINE_DMARC, ALERT_QUARANTINE_DMARC,
@ -34,37 +34,6 @@ def apply_dmarc_policy_for_forward_phase(
from_header = get_header_unicode(msg[headers.FROM]) from_header = get_header_unicode(msg[headers.FROM])
warning_plain_text = """This email failed anti-phishing checks when it was received by SimpleLogin, be careful with its content.
More info on https://simplelogin.io/docs/getting-started/anti-phishing/
"""
warning_html = """
<p style="color:red">
This email failed anti-phishing checks when it was received by SimpleLogin, be careful with its content.
More info on <a href="https://simplelogin.io/docs/getting-started/anti-phishing/">anti-phishing measure</a>
</p>
"""
# do not quarantine an email if fails DMARC but has a small rspamd score
if (
config.MIN_RSPAMD_SCORE_FOR_FAILED_DMARC is not None
and spam_result.rspamd_score < config.MIN_RSPAMD_SCORE_FOR_FAILED_DMARC
and spam_result.dmarc
in (
DmarcCheckResult.quarantine,
DmarcCheckResult.reject,
)
):
LOG.w(
f"email fails DMARC but has a small rspamd score, from contact {contact.email} to alias {alias.email}."
f"mail_from:{envelope.mail_from}, from_header: {from_header}"
)
changed_msg = add_header(
msg,
warning_plain_text,
warning_html,
)
return changed_msg, None
if spam_result.dmarc == DmarcCheckResult.soft_fail: if spam_result.dmarc == DmarcCheckResult.soft_fail:
LOG.w( LOG.w(
f"dmarc forward: soft_fail from contact {contact.email} to alias {alias.email}." f"dmarc forward: soft_fail from contact {contact.email} to alias {alias.email}."
@ -72,8 +41,15 @@ More info on https://simplelogin.io/docs/getting-started/anti-phishing/
) )
changed_msg = add_header( changed_msg = add_header(
msg, msg,
warning_plain_text, f"""This email failed anti-phishing checks when it was received by SimpleLogin, be careful with its content.
warning_html, More info on https://simplelogin.io/docs/getting-started/anti-phishing/
""",
f"""
<p style="color:red">
This email failed anti-phishing checks when it was received by SimpleLogin, be careful with its content.
More info on <a href="https://simplelogin.io/docs/getting-started/anti-phishing/">anti-phishing measure</a>
</p>
""",
) )
return changed_msg, None return changed_msg, None
@ -157,7 +133,6 @@ def apply_dmarc_policy_for_reply_phase(
DmarcCheckResult.soft_fail, DmarcCheckResult.soft_fail,
): ):
return None return None
LOG.w( LOG.w(
f"dmarc reply: Put email from {alias_from.email} to {contact_recipient} into quarantine. {spam_result.event_data()}, " f"dmarc reply: Put email from {alias_from.email} to {contact_recipient} into quarantine. {spam_result.event_data()}, "
f"mail_from:{envelope.mail_from}, from_header: {msg[headers.FROM]}" f"mail_from:{envelope.mail_from}, from_header: {msg[headers.FROM]}"

View file

@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from mailbox import Message from mailbox import Message
from typing import Optional, Union from typing import Optional
from app import s3 from app import s3
from app.config import ( from app.config import (
@ -189,7 +189,7 @@ def handle_yahoo_complaint(message: Message) -> bool:
return handle_complaint(message, ProviderComplaintYahoo()) return handle_complaint(message, ProviderComplaintYahoo())
def find_alias_with_address(address: str) -> Optional[Union[Alias, DomainDeletedAlias]]: def find_alias_with_address(address: str) -> Optional[Alias]:
return Alias.get_by(email=address) or DomainDeletedAlias.get_by(email=address) return Alias.get_by(email=address) or DomainDeletedAlias.get_by(email=address)
@ -221,7 +221,7 @@ def handle_complaint(message: Message, origin: ProviderComplaintOrigin) -> bool:
return True return True
if is_deleted_alias(msg_info.sender_address): if is_deleted_alias(msg_info.sender_address):
LOG.i("Complaint is for deleted alias. Do nothing") LOG.i(f"Complaint is for deleted alias. Do nothing")
return True return True
contact = Contact.get_by(reply_email=msg_info.sender_address) contact = Contact.get_by(reply_email=msg_info.sender_address)
@ -231,7 +231,7 @@ def handle_complaint(message: Message, origin: ProviderComplaintOrigin) -> bool:
alias = find_alias_with_address(msg_info.rcpt_address) alias = find_alias_with_address(msg_info.rcpt_address)
if is_deleted_alias(msg_info.rcpt_address): if is_deleted_alias(msg_info.rcpt_address):
LOG.i("Complaint is for deleted alias. Do nothing") LOG.i(f"Complaint is for deleted alias. Do nothing")
return True return True
if not alias: if not alias:
@ -245,22 +245,16 @@ def handle_complaint(message: Message, origin: ProviderComplaintOrigin) -> bool:
def report_complaint_to_user_in_reply_phase( def report_complaint_to_user_in_reply_phase(
alias: Union[Alias, DomainDeletedAlias], alias: Alias,
to_address: str, to_address: str,
origin: ProviderComplaintOrigin, origin: ProviderComplaintOrigin,
msg_info: OriginalMessageInformation, msg_info: OriginalMessageInformation,
): ):
capitalized_name = origin.name().capitalize() capitalized_name = origin.name().capitalize()
mailbox_email = msg_info.mailbox_address
if not mailbox_email:
if type(alias) is Alias:
mailbox_email = alias.mailbox.email
else:
mailbox_email = alias.domain.mailboxes[0].email
send_email_with_rate_control( send_email_with_rate_control(
alias.user, alias.user,
f"{ALERT_COMPLAINT_REPLY_PHASE}_{origin.name()}", f"{ALERT_COMPLAINT_REPLY_PHASE}_{origin.name()}",
mailbox_email, msg_info.mailbox_address or alias.mailbox.email,
f"Abuse report from {capitalized_name}", f"Abuse report from {capitalized_name}",
render( render(
"transactional/provider-complaint-reply-phase.txt.jinja2", "transactional/provider-complaint-reply-phase.txt.jinja2",
@ -299,19 +293,11 @@ def report_complaint_to_user_in_transactional_phase(
def report_complaint_to_user_in_forward_phase( def report_complaint_to_user_in_forward_phase(
alias: Union[Alias, DomainDeletedAlias], alias: Alias, origin: ProviderComplaintOrigin, msg_info: OriginalMessageInformation
origin: ProviderComplaintOrigin,
msg_info: OriginalMessageInformation,
): ):
capitalized_name = origin.name().capitalize() capitalized_name = origin.name().capitalize()
user = alias.user user = alias.user
mailbox_email = msg_info.mailbox_address or alias.mailbox.email
mailbox_email = msg_info.mailbox_address
if not mailbox_email:
if type(alias) is Alias:
mailbox_email = alias.mailbox.email
else:
mailbox_email = alias.domain.mailboxes[0].email
send_email_with_rate_control( send_email_with_rate_control(
user, user,
f"{ALERT_COMPLAINT_FORWARD_PHASE}_{origin.name()}", f"{ALERT_COMPLAINT_FORWARD_PHASE}_{origin.name()}",

View file

@ -4,7 +4,6 @@ from typing import Dict, Optional
import newrelic.agent import newrelic.agent
from app.email import headers from app.email import headers
from app.log import LOG
from app.models import EnumE, Phase from app.models import EnumE, Phase
from email.message import Message from email.message import Message
@ -56,7 +55,6 @@ class SpamdResult:
self.phase: Phase = phase self.phase: Phase = phase
self.dmarc: DmarcCheckResult = DmarcCheckResult.not_available self.dmarc: DmarcCheckResult = DmarcCheckResult.not_available
self.spf: SPFCheckResult = SPFCheckResult.not_available self.spf: SPFCheckResult = SPFCheckResult.not_available
self.rspamd_score = -1
def set_dmarc_result(self, dmarc_result: DmarcCheckResult): def set_dmarc_result(self, dmarc_result: DmarcCheckResult):
self.dmarc = dmarc_result self.dmarc = dmarc_result
@ -87,7 +85,6 @@ class SpamdResult:
spam_entries = [ spam_entries = [
entry.strip() for entry in str(spam_result_header[-1]).split("\n") entry.strip() for entry in str(spam_result_header[-1]).split("\n")
] ]
for entry_pos in range(len(spam_entries)): for entry_pos in range(len(spam_entries)):
sep = spam_entries[entry_pos].find("(") sep = spam_entries[entry_pos].find("(")
if sep > -1: if sep > -1:
@ -104,17 +101,6 @@ class SpamdResult:
spamd_result.set_spf_result(spf_result) spamd_result.set_spf_result(spf_result)
break break
# parse the rspamd score
try:
score_line = spam_entries[0] # e.g. "default: False [2.30 / 13.00];"
spamd_result.rspamd_score = float(
score_line[(score_line.find("[") + 1) : score_line.find("]")]
.split("/")[0]
.strip()
)
except (IndexError, ValueError):
LOG.e("cannot parse rspamd score")
cls._store_in_message(spamd_result, msg) cls._store_in_message(spamd_result, msg)
return spamd_result return spamd_result

View file

@ -42,11 +42,9 @@ class UnsubscribeLink:
class UnsubscribeEncoder: class UnsubscribeEncoder:
@staticmethod @staticmethod
def encode( def encode(
action: UnsubscribeAction, action: UnsubscribeAction, data: Union[int, UnsubscribeOriginalData]
data: Union[int, UnsubscribeOriginalData],
force_web: bool = False,
) -> UnsubscribeLink: ) -> UnsubscribeLink:
if config.UNSUBSCRIBER and not force_web: if config.UNSUBSCRIBER:
return UnsubscribeLink(UnsubscribeEncoder.encode_mailto(action, data), True) return UnsubscribeLink(UnsubscribeEncoder.encode_mailto(action, data), True)
return UnsubscribeLink(UnsubscribeEncoder.encode_url(action, data), False) return UnsubscribeLink(UnsubscribeEncoder.encode_url(action, data), False)
@ -54,8 +52,9 @@ class UnsubscribeEncoder:
def encode_subject( def encode_subject(
cls, action: UnsubscribeAction, data: Union[int, UnsubscribeOriginalData] cls, action: UnsubscribeAction, data: Union[int, UnsubscribeOriginalData]
) -> str: ) -> str:
if action != UnsubscribeAction.OriginalUnsubscribeMailto and not isinstance( if (
data, int action != UnsubscribeAction.OriginalUnsubscribeMailto
and type(data) is not int
): ):
raise ValueError(f"Data has to be an int for an action of type {action}") raise ValueError(f"Data has to be an int for an action of type {action}")
if action == UnsubscribeAction.OriginalUnsubscribeMailto: if action == UnsubscribeAction.OriginalUnsubscribeMailto:
@ -73,8 +72,8 @@ class UnsubscribeEncoder:
) )
signed_data = cls._get_signer().sign(serialized_data).decode("utf-8") signed_data = cls._get_signer().sign(serialized_data).decode("utf-8")
encoded_request = f"{UNSUB_PREFIX}.{signed_data}" encoded_request = f"{UNSUB_PREFIX}.{signed_data}"
if len(encoded_request) > 512: if len(encoded_request) > 256:
LOG.w("Encoded request is longer than 512 chars") LOG.e("Encoded request is longer than 256 chars")
return encoded_request return encoded_request
@staticmethod @staticmethod

View file

@ -1,5 +1,4 @@
import urllib import urllib
from email.header import Header
from email.message import Message from email.message import Message
from app.email import headers from app.email import headers
@ -10,7 +9,6 @@ from app.handler.unsubscribe_encoder import (
UnsubscribeData, UnsubscribeData,
UnsubscribeOriginalData, UnsubscribeOriginalData,
) )
from app.log import LOG
from app.models import Alias, Contact, UnsubscribeBehaviourEnum from app.models import Alias, Contact, UnsubscribeBehaviourEnum
@ -32,10 +30,7 @@ class UnsubscribeGenerator:
""" """
unsubscribe_data = message[headers.LIST_UNSUBSCRIBE] unsubscribe_data = message[headers.LIST_UNSUBSCRIBE]
if not unsubscribe_data: if not unsubscribe_data:
LOG.info("Email has no unsubscribe header")
return message return message
if isinstance(unsubscribe_data, Header):
unsubscribe_data = str(unsubscribe_data.encode())
raw_methods = [method.strip() for method in unsubscribe_data.split(",")] raw_methods = [method.strip() for method in unsubscribe_data.split(",")]
mailto_unsubs = None mailto_unsubs = None
other_unsubs = [] other_unsubs = []
@ -49,9 +44,7 @@ class UnsubscribeGenerator:
if url_data.scheme == "mailto": if url_data.scheme == "mailto":
query_data = urllib.parse.parse_qs(url_data.query) query_data = urllib.parse.parse_qs(url_data.query)
mailto_unsubs = (url_data.path, query_data.get("subject", [""])[0]) mailto_unsubs = (url_data.path, query_data.get("subject", [""])[0])
LOG.debug(f"Unsub is mailto to {mailto_unsubs}")
else: else:
LOG.debug(f"Unsub has {url_data.scheme} scheme")
other_unsubs.append(method) other_unsubs.append(method)
# If there are non mailto unsubscribe methods, use those in the header # If there are non mailto unsubscribe methods, use those in the header
if other_unsubs: if other_unsubs:
@ -63,19 +56,18 @@ class UnsubscribeGenerator:
add_or_replace_header( add_or_replace_header(
message, headers.LIST_UNSUBSCRIBE_POST, "List-Unsubscribe=One-Click" message, headers.LIST_UNSUBSCRIBE_POST, "List-Unsubscribe=One-Click"
) )
LOG.debug(f"Adding click unsub methods to header {other_unsubs}")
return message return message
elif not mailto_unsubs: if not mailto_unsubs:
LOG.debug("No unsubs. Deleting all unsub headers") message = delete_header(message, headers.LIST_UNSUBSCRIBE)
delete_header(message, headers.LIST_UNSUBSCRIBE) message = delete_header(message, headers.LIST_UNSUBSCRIBE_POST)
delete_header(message, headers.LIST_UNSUBSCRIBE_POST)
return message return message
unsub_data = UnsubscribeData( return self._add_unsubscribe_header(
UnsubscribeAction.OriginalUnsubscribeMailto, message,
UnsubscribeOriginalData(alias.id, mailto_unsubs[0], mailto_unsubs[1]), UnsubscribeData(
UnsubscribeAction.OriginalUnsubscribeMailto,
UnsubscribeOriginalData(alias.id, mailto_unsubs[0], mailto_unsubs[1]),
),
) )
LOG.debug(f"Adding unsub data {unsub_data}")
return self._add_unsubscribe_header(message, unsub_data)
def _add_unsubscribe_header( def _add_unsubscribe_header(
self, message: Message, unsub: UnsubscribeData self, message: Message, unsub: UnsubscribeData

View file

@ -49,7 +49,7 @@ class UnsubscribeHandler:
return status.E507 return status.E507
mailbox = Mailbox.get_by(email=envelope.mail_from) mailbox = Mailbox.get_by(email=envelope.mail_from)
if not mailbox: if not mailbox:
LOG.w("Unknown mailbox %s", envelope.mail_from) LOG.w("Unknown mailbox %s", msg[headers.SUBJECT])
return status.E507 return status.E507
if unsub_data.action == UnsubscribeAction.DisableAlias: if unsub_data.action == UnsubscribeAction.DisableAlias:

View file

@ -15,7 +15,7 @@ from app.models import (
Mailbox, Mailbox,
User, User,
) )
from app.utils import sanitize_email, canonicalize_email from app.utils import sanitize_email
from .log import LOG from .log import LOG
@ -30,7 +30,7 @@ def handle_batch_import(batch_import: BatchImport):
LOG.d("Download file %s from %s", batch_import.file, file_url) LOG.d("Download file %s from %s", batch_import.file, file_url)
r = requests.get(file_url) r = requests.get(file_url)
lines = [line.decode("utf-8") for line in r.iter_lines()] lines = [line.decode() for line in r.iter_lines()]
import_from_csv(batch_import, user, lines) import_from_csv(batch_import, user, lines)
@ -69,7 +69,7 @@ def import_from_csv(batch_import: BatchImport, user: User, lines):
if "mailboxes" in row: if "mailboxes" in row:
for mailbox_email in row["mailboxes"].split(): for mailbox_email in row["mailboxes"].split():
mailbox_email = canonicalize_email(mailbox_email) mailbox_email = sanitize_email(mailbox_email)
mailbox = Mailbox.get_by(email=mailbox_email) mailbox = Mailbox.get_by(email=mailbox_email)
if not mailbox or not mailbox.verified or mailbox.user_id != user.id: if not mailbox or not mailbox.verified or mailbox.user_id != user.id:

View file

@ -1,4 +1,2 @@
from .integrations import set_enable_proton_cookie from .integrations import set_enable_proton_cookie
from .exit_sudo import exit_sudo_mode from .exit_sudo import exit_sudo_mode
__all__ = ["set_enable_proton_cookie", "exit_sudo_mode"]

View file

@ -14,12 +14,7 @@ import sqlalchemy
from app import config from app import config
from app.db import Session from app.db import Session
from app.email import headers from app.email import headers
from app.email_utils import ( from app.email_utils import generate_verp_email, render, add_dkim_signature
generate_verp_email,
render,
add_dkim_signature,
get_email_domain_part,
)
from app.mail_sender import sl_sendmail from app.mail_sender import sl_sendmail
from app.models import ( from app.models import (
Alias, Alias,
@ -39,8 +34,9 @@ from app.models import (
class ExportUserDataJob: class ExportUserDataJob:
REMOVE_FIELDS = { REMOVE_FIELDS = {
"User": ("otp_secret", "password"), "User": ("otp_secret",),
"Alias": ("ts_vector", "transfer_token", "hibp_last_check"), "Alias": ("ts_vector", "transfer_token", "hibp_last_check"),
"CustomDomain": ("ownership_txt_token",), "CustomDomain": ("ownership_txt_token",),
} }
@ -151,11 +147,7 @@ class ExportUserDataJob:
transaction = TransactionalEmail.create(email=to_email, commit=True) transaction = TransactionalEmail.create(email=to_email, commit=True)
sl_sendmail( sl_sendmail(
generate_verp_email( generate_verp_email(VerpType.transactional, transaction.id),
VerpType.transactional,
transaction.id,
get_email_domain_part(config.NOREPLY),
),
to_email, to_email,
msg, msg,
ignore_smtp_error=False, ignore_smtp_error=False,

View file

@ -6,8 +6,8 @@ import os
import time import time
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from email.message import Message
from functools import wraps from functools import wraps
from mailbox import Message
from smtplib import SMTP, SMTPException from smtplib import SMTP, SMTPException
from typing import Optional, Dict, List, Callable from typing import Optional, Dict, List, Callable
@ -17,13 +17,11 @@ from attr import dataclass
from app import config from app import config
from app.email import headers from app.email import headers
from app.log import LOG from app.log import LOG
from app.message_utils import message_to_bytes, message_format_base64_parts from app.message_utils import message_to_bytes
@dataclass @dataclass
class SendRequest: class SendRequest:
SAVE_EXTENSION = "sendrequest"
envelope_from: str envelope_from: str
envelope_to: str envelope_to: str
msg: Message msg: Message
@ -31,7 +29,6 @@ class SendRequest:
rcpt_options: Dict = {} rcpt_options: Dict = {}
is_forward: bool = False is_forward: bool = False
ignore_smtp_errors: bool = False ignore_smtp_errors: bool = False
retries: int = 0
def to_bytes(self) -> bytes: def to_bytes(self) -> bytes:
if not config.SAVE_UNSENT_DIR: if not config.SAVE_UNSENT_DIR:
@ -45,7 +42,6 @@ class SendRequest:
"mail_options": self.mail_options, "mail_options": self.mail_options,
"rcpt_options": self.rcpt_options, "rcpt_options": self.rcpt_options,
"is_forward": self.is_forward, "is_forward": self.is_forward,
"retries": self.retries,
} }
return json.dumps(data).encode("utf-8") return json.dumps(data).encode("utf-8")
@ -66,33 +62,8 @@ class SendRequest:
mail_options=decoded_data["mail_options"], mail_options=decoded_data["mail_options"],
rcpt_options=decoded_data["rcpt_options"], rcpt_options=decoded_data["rcpt_options"],
is_forward=decoded_data["is_forward"], is_forward=decoded_data["is_forward"],
retries=decoded_data.get("retries", 1),
) )
def save_request_to_unsent_dir(self, prefix: str = "DeliveryFail"):
file_name = (
f"{prefix}-{int(time.time())}-{uuid.uuid4()}.{SendRequest.SAVE_EXTENSION}"
)
file_path = os.path.join(config.SAVE_UNSENT_DIR, file_name)
self.save_request_to_file(file_path)
@staticmethod
def save_request_to_failed_dir(self, prefix: str = "DeliveryRetryFail"):
file_name = (
f"{prefix}-{int(time.time())}-{uuid.uuid4()}.{SendRequest.SAVE_EXTENSION}"
)
dir_name = os.path.join(config.SAVE_UNSENT_DIR, "failed")
if not os.path.isdir(dir_name):
os.makedirs(dir_name)
file_path = os.path.join(dir_name, file_name)
self.save_request_to_file(file_path)
def save_request_to_file(self, file_path: str):
file_contents = self.to_bytes()
with open(file_path, "wb") as fd:
fd.write(file_contents)
LOG.i(f"Saved unsent message {file_path}")
class MailSender: class MailSender:
def __init__(self): def __init__(self):
@ -124,7 +95,7 @@ class MailSender:
def enable_background_pool(self, max_workers=10): def enable_background_pool(self, max_workers=10):
self._pool = ThreadPoolExecutor(max_workers=max_workers) self._pool = ThreadPoolExecutor(max_workers=max_workers)
def send(self, send_request: SendRequest, retries: int = 2) -> bool: def send(self, send_request: SendRequest, retries: int = 2):
"""replace smtp.sendmail""" """replace smtp.sendmail"""
if self._store_emails: if self._store_emails:
self._emails_sent.append(send_request) self._emails_sent.append(send_request)
@ -135,21 +106,21 @@ class MailSender:
send_request.msg[headers.FROM], send_request.msg[headers.FROM],
send_request.msg[headers.TO], send_request.msg[headers.TO],
) )
return True return
if not self._pool: if not self._pool:
return self._send_to_smtp(send_request, retries) self._send_to_smtp(send_request, retries)
else: else:
self._pool.submit(self._send_to_smtp, (send_request, retries)) self._pool.submit(self._send_to_smtp, (send_request, retries))
return True
def _send_to_smtp(self, send_request: SendRequest, retries: int) -> bool: def _send_to_smtp(self, send_request: SendRequest, retries: int):
try: try:
start = time.time() start = time.time()
with SMTP( if config.POSTFIX_SUBMISSION_TLS:
config.POSTFIX_SERVER, smtp_port = 587
config.POSTFIX_PORT, else:
timeout=config.POSTFIX_TIMEOUT, smtp_port = config.POSTFIX_PORT
) as smtp:
with SMTP(config.POSTFIX_SERVER, smtp_port) as smtp:
if config.POSTFIX_SUBMISSION_TLS: if config.POSTFIX_SUBMISSION_TLS:
smtp.starttls() smtp.starttls()
@ -180,94 +151,35 @@ class MailSender:
newrelic.agent.record_custom_metric( newrelic.agent.record_custom_metric(
"Custom/smtp_sending_time", time.time() - start "Custom/smtp_sending_time", time.time() - start
) )
return True
except ( except (
SMTPException, SMTPException,
ConnectionRefusedError, ConnectionRefusedError,
TimeoutError, TimeoutError,
) as e: ) as e:
if retries > 0: if retries > 0:
time.sleep(0.3 * retries) time.sleep(0.3 * send_request.retries)
return self._send_to_smtp(send_request, retries - 1) self._send_to_smtp(send_request, retries - 1)
else: else:
if send_request.ignore_smtp_errors: if send_request.ignore_smtp_errors:
LOG.e(f"Ignore smtp error {e}") LOG.e(f"Ignore smtp error {e}")
return False return
LOG.e( LOG.e(
f"Could not send message to smtp server {config.POSTFIX_SERVER}:{config.POSTFIX_PORT}" f"Could not send message to smtp server {config.POSTFIX_SERVER}:{smtp_port}"
) )
if config.SAVE_UNSENT_DIR: self._save_request_to_unsent_dir(send_request)
send_request.save_request_to_unsent_dir()
return False def _save_request_to_unsent_dir(self, send_request: SendRequest):
file_name = f"DeliveryFail-{int(time.time())}-{uuid.uuid4()}.eml"
file_path = os.path.join(config.SAVE_UNSENT_DIR, file_name)
file_contents = send_request.to_bytes()
with open(file_path, "wb") as fd:
fd.write(file_contents)
LOG.i(f"Saved unsent message {file_path}")
mail_sender = MailSender() mail_sender = MailSender()
def save_request_to_failed_dir(exception_name: str, send_request: SendRequest):
file_name = f"{exception_name}-{int(time.time())}-{uuid.uuid4()}.{SendRequest.SAVE_EXTENSION}"
failed_file_dir = os.path.join(config.SAVE_UNSENT_DIR, "failed")
try:
os.makedirs(failed_file_dir)
except FileExistsError:
pass
file_path = os.path.join(failed_file_dir, file_name)
file_contents = send_request.to_bytes()
with open(file_path, "wb") as fd:
fd.write(file_contents)
return file_path
def load_unsent_mails_from_fs_and_resend():
if not config.SAVE_UNSENT_DIR:
return
for filename in os.listdir(config.SAVE_UNSENT_DIR):
(_, extension) = os.path.splitext(filename)
if extension[1:] != SendRequest.SAVE_EXTENSION:
LOG.i(f"Skipping {filename} does not have the proper extension")
continue
full_file_path = os.path.join(config.SAVE_UNSENT_DIR, filename)
if not os.path.isfile(full_file_path):
LOG.i(f"Skipping {filename} as it's not a file")
continue
LOG.i(f"Trying to re-deliver email {filename}")
try:
send_request = SendRequest.load_from_file(full_file_path)
send_request.retries += 1
except Exception as e:
LOG.e(f"Cannot load {filename}. Error {e}")
continue
try:
send_request.ignore_smtp_errors = True
if mail_sender.send(send_request, 2):
os.unlink(full_file_path)
newrelic.agent.record_custom_event(
"DeliverUnsentEmail", {"delivered": "true"}
)
else:
if send_request.retries > 2:
os.unlink(full_file_path)
send_request.save_request_to_failed_dir()
else:
send_request.save_request_to_file(full_file_path)
newrelic.agent.record_custom_event(
"DeliverUnsentEmail", {"delivered": "false"}
)
except Exception as e:
# Unlink original file to avoid re-doing the same
os.unlink(full_file_path)
LOG.e(
"email sending failed with error:%s "
"envelope %s -> %s, mail %s -> %s saved to %s",
e,
send_request.envelope_from,
send_request.envelope_to,
send_request.msg[headers.FROM],
send_request.msg[headers.TO],
save_request_to_failed_dir(e.__class__.__name__, send_request),
)
def sl_sendmail( def sl_sendmail(
envelope_from: str, envelope_from: str,
envelope_to: str, envelope_to: str,
@ -281,7 +193,7 @@ def sl_sendmail(
send_request = SendRequest( send_request = SendRequest(
envelope_from, envelope_from,
envelope_to, envelope_to,
message_format_base64_parts(msg), msg,
mail_options, mail_options,
rcpt_options, rcpt_options,
is_forward, is_forward,

View file

@ -1,42 +1,21 @@
import re
from email import policy from email import policy
from email.message import Message from email.message import Message
from app.email import headers
from app.log import LOG from app.log import LOG
# Spam assassin might flag as spam with a different line length
BASE64_LINELENGTH = 76
def message_to_bytes(msg: Message) -> bytes: def message_to_bytes(msg: Message) -> bytes:
"""replace Message.as_bytes() method by trying different policies""" """replace Message.as_bytes() method by trying different policies"""
for generator_policy in [None, policy.SMTP, policy.SMTPUTF8]: for generator_policy in [None, policy.SMTP, policy.SMTPUTF8]:
try: try:
return msg.as_bytes(policy=generator_policy) return msg.as_bytes(policy=generator_policy)
except Exception: except:
LOG.w("as_bytes() fails with %s policy", policy, exc_info=True) LOG.w("as_bytes() fails with %s policy", policy, exc_info=True)
msg_string = msg.as_string() msg_string = msg.as_string()
try: try:
return msg_string.encode() return msg_string.encode()
except Exception: except:
LOG.w("as_string().encode() fails", exc_info=True) LOG.w("as_string().encode() fails", exc_info=True)
return msg_string.encode(errors="replace") return msg_string.encode(errors="replace")
def message_format_base64_parts(msg: Message) -> Message:
for part in msg.walk():
if part.get(
headers.CONTENT_TRANSFER_ENCODING
) == "base64" and part.get_content_type() in ("text/plain", "text/html"):
# Remove line breaks
body = re.sub("[\r\n]", "", part.get_payload())
# Split in 80 column lines
chunks = [
body[i : i + BASE64_LINELENGTH]
for i in range(0, len(body), BASE64_LINELENGTH)
]
part.set_payload("\r\n".join(chunks))
return msg

View file

@ -1,14 +1,13 @@
from __future__ import annotations from __future__ import annotations
import base64 import base64
import dataclasses
import enum import enum
import hashlib import hashlib
import hmac import hmac
import os import os
import random import random
import secrets
import uuid import uuid
from email.utils import formataddr
from typing import List, Tuple, Optional, Union from typing import List, Tuple, Optional, Union
import arrow import arrow
@ -19,7 +18,7 @@ from flanker.addresslib import address
from flask import url_for from flask import url_for
from flask_login import UserMixin from flask_login import UserMixin
from jinja2 import FileSystemLoader, Environment from jinja2 import FileSystemLoader, Environment
from sqlalchemy import orm, or_ from sqlalchemy import orm
from sqlalchemy import text, desc, CheckConstraint, Index, Column from sqlalchemy import text, desc, CheckConstraint, Index, Column
from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
@ -27,11 +26,9 @@ from sqlalchemy.orm import deferred
from sqlalchemy.sql import and_ from sqlalchemy.sql import and_
from sqlalchemy_utils import ArrowType from sqlalchemy_utils import ArrowType
from app import config
from app import s3 from app import s3
from app import config
from app.db import Session from app.db import Session
from app.dns_utils import get_mx_domains
from app.errors import ( from app.errors import (
AliasInTrashError, AliasInTrashError,
DirectoryInTrashError, DirectoryInTrashError,
@ -47,6 +44,7 @@ from app.utils import (
random_string, random_string,
random_words, random_words,
sanitize_email, sanitize_email,
random_word,
) )
Base = declarative_base() Base = declarative_base()
@ -233,8 +231,6 @@ class AuditLogActionEnum(EnumE):
logged_as_user = 6 logged_as_user = 6
extend_subscription = 7 extend_subscription = 7
download_provider_complaint = 8 download_provider_complaint = 8
disable_user = 9
enable_user = 10
class Phase(EnumE): class Phase(EnumE):
@ -276,13 +272,6 @@ class IntEnumType(sa.types.TypeDecorator):
return self._enum_type(enum_value) return self._enum_type(enum_value)
@dataclasses.dataclass
class AliasOptions:
show_sl_domains: bool = True
show_partner_domains: Optional[Partner] = None
show_partner_premium: Optional[bool] = None
class Hibp(Base, ModelMixin): class Hibp(Base, ModelMixin):
__tablename__ = "hibp" __tablename__ = "hibp"
name = sa.Column(sa.String(), nullable=False, unique=True, index=True) name = sa.Column(sa.String(), nullable=False, unique=True, index=True)
@ -301,9 +290,7 @@ class HibpNotifiedAlias(Base, ModelMixin):
""" """
__tablename__ = "hibp_notified_alias" __tablename__ = "hibp_notified_alias"
alias_id = sa.Column( alias_id = sa.Column(sa.ForeignKey("alias.id", ondelete="cascade"), nullable=False)
sa.ForeignKey("alias.id", ondelete="cascade"), nullable=False, index=True
)
user_id = sa.Column(sa.ForeignKey("users.id", ondelete="cascade"), nullable=False) user_id = sa.Column(sa.ForeignKey("users.id", ondelete="cascade"), nullable=False)
notified_at = sa.Column(ArrowType, default=arrow.utcnow, nullable=False) notified_at = sa.Column(ArrowType, default=arrow.utcnow, nullable=False)
@ -344,7 +331,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
sa.Boolean, default=True, nullable=False, server_default="1" sa.Boolean, default=True, nullable=False, server_default="1"
) )
activated = sa.Column(sa.Boolean, default=False, nullable=False, index=True) activated = sa.Column(sa.Boolean, default=False, nullable=False)
# an account can be disabled if having harmful behavior # an account can be disabled if having harmful behavior
disabled = sa.Column(sa.Boolean, default=False, nullable=False, server_default="0") disabled = sa.Column(sa.Boolean, default=False, nullable=False, server_default="0")
@ -414,10 +401,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
) )
referral_id = sa.Column( referral_id = sa.Column(
sa.ForeignKey("referral.id", ondelete="SET NULL"), sa.ForeignKey("referral.id", ondelete="SET NULL"), nullable=True, default=None
nullable=True,
default=None,
index=True,
) )
referral = orm.relationship("Referral", foreign_keys=[referral_id]) referral = orm.relationship("Referral", foreign_keys=[referral_id])
@ -434,15 +418,12 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
# newsletter is sent to this address # newsletter is sent to this address
newsletter_alias_id = sa.Column( newsletter_alias_id = sa.Column(
sa.ForeignKey("alias.id", ondelete="SET NULL"), sa.ForeignKey("alias.id", ondelete="SET NULL"), nullable=True, default=None
nullable=True,
default=None,
index=True,
) )
# whether to include the sender address in reverse-alias # whether to include the sender address in reverse-alias
include_sender_in_reverse_alias = sa.Column( include_sender_in_reverse_alias = sa.Column(
sa.Boolean, default=True, nullable=False, server_default="0" sa.Boolean, default=False, nullable=False, server_default="0"
) )
# whether to use random string or random word as suffix # whether to use random string or random word as suffix
@ -451,7 +432,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
random_alias_suffix = sa.Column( random_alias_suffix = sa.Column(
sa.Integer, sa.Integer,
nullable=False, nullable=False,
default=AliasSuffixEnum.word.value, default=AliasSuffixEnum.random_string.value,
server_default=str(AliasSuffixEnum.random_string.value), server_default=str(AliasSuffixEnum.random_string.value),
) )
@ -520,8 +501,9 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
server_default=BlockBehaviourEnum.return_2xx.name, server_default=BlockBehaviourEnum.return_2xx.name,
) )
# to keep existing behavior, the server default is TRUE whereas for new user, the default value is FALSE
include_header_email_header = sa.Column( include_header_email_header = sa.Column(
sa.Boolean, default=True, nullable=False, server_default="1" sa.Boolean, default=False, nullable=False, server_default="1"
) )
# bitwise flags. Allow for future expansion # bitwise flags. Allow for future expansion
@ -535,21 +517,11 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
# Keep original unsub behaviour # Keep original unsub behaviour
unsub_behaviour = sa.Column( unsub_behaviour = sa.Column(
IntEnumType(UnsubscribeBehaviourEnum), IntEnumType(UnsubscribeBehaviourEnum),
default=UnsubscribeBehaviourEnum.PreserveOriginal, default=UnsubscribeBehaviourEnum.DisableAlias,
server_default=str(UnsubscribeBehaviourEnum.DisableAlias.value), server_default=str(UnsubscribeBehaviourEnum.DisableAlias.value),
nullable=False, nullable=False,
) )
# Trigger hard deletion of the account at this time
delete_on = sa.Column(ArrowType, default=None)
__table_args__ = (
sa.Index(
"ix_users_activated_trial_end_lifetime", activated, trial_end, lifetime
),
sa.Index("ix_users_delete_on", delete_on),
)
@property @property
def directory_quota(self): def directory_quota(self):
return min( return min(
@ -584,8 +556,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
@classmethod @classmethod
def create(cls, email, name="", password=None, from_partner=False, **kwargs): def create(cls, email, name="", password=None, from_partner=False, **kwargs):
email = sanitize_email(email) user: User = super(User, cls).create(email=email, name=name, **kwargs)
user: User = super(User, cls).create(email=email, name=name[:100], **kwargs)
if password: if password:
user.set_password(password) user.set_password(password)
@ -596,6 +567,19 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
Session.flush() Session.flush()
user.default_mailbox_id = mb.id user.default_mailbox_id = mb.id
# create a first alias mail to show user how to use when they login
alias = Alias.create_new(
user,
prefix="simplelogin-newsletter",
mailbox_id=mb.id,
note="This is your first alias. It's used to receive SimpleLogin communications "
"like new features announcements, newsletters.",
)
Session.flush()
user.newsletter_alias_id = alias.id
Session.flush()
# generate an alternative_id if needed # generate an alternative_id if needed
if "alternative_id" not in kwargs: if "alternative_id" not in kwargs:
user.alternative_id = str(uuid.uuid4()) user.alternative_id = str(uuid.uuid4())
@ -614,19 +598,6 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
Session.flush() Session.flush()
return user return user
# create a first alias mail to show user how to use when they login
alias = Alias.create_new(
user,
prefix="simplelogin-newsletter",
mailbox_id=mb.id,
note="This is your first alias. It's used to receive SimpleLogin communications "
"like new features announcements, newsletters.",
)
Session.flush()
user.newsletter_alias_id = alias.id
Session.flush()
if config.DISABLE_ONBOARDING: if config.DISABLE_ONBOARDING:
LOG.d("Disable onboarding emails") LOG.d("Disable onboarding emails")
return user return user
@ -652,7 +623,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
return user return user
def get_active_subscription( def get_active_subscription(
self, include_partner_subscription: bool = True self,
) -> Optional[ ) -> Optional[
Union[ Union[
Subscription Subscription
@ -680,40 +651,19 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
if coinbase_subscription and coinbase_subscription.is_active(): if coinbase_subscription and coinbase_subscription.is_active():
return coinbase_subscription return coinbase_subscription
if include_partner_subscription: partner_sub: PartnerSubscription = PartnerSubscription.find_by_user_id(self.id)
partner_sub: PartnerSubscription = PartnerSubscription.find_by_user_id( if partner_sub and partner_sub.is_active():
self.id return partner_sub
)
if partner_sub and partner_sub.is_active():
return partner_sub
return None return None
def get_active_subscription_end(
self, include_partner_subscription: bool = True
) -> Optional[arrow.Arrow]:
sub = self.get_active_subscription(
include_partner_subscription=include_partner_subscription
)
if isinstance(sub, Subscription):
return arrow.get(sub.next_bill_date)
if isinstance(sub, AppleSubscription):
return sub.expires_date
if isinstance(sub, ManualSubscription):
return sub.end_at
if isinstance(sub, CoinbaseSubscription):
return sub.end_at
return None
# region Billing # region Billing
def lifetime_or_active_subscription( def lifetime_or_active_subscription(self) -> bool:
self, include_partner_subscription: bool = True
) -> bool:
"""True if user has lifetime licence or active subscription""" """True if user has lifetime licence or active subscription"""
if self.lifetime: if self.lifetime:
return True return True
return self.get_active_subscription(include_partner_subscription) is not None return self.get_active_subscription() is not None
def is_paid(self) -> bool: def is_paid(self) -> bool:
"""same as _lifetime_or_active_subscription but not include free manual subscription""" """same as _lifetime_or_active_subscription but not include free manual subscription"""
@ -742,14 +692,14 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
return True return True
def is_premium(self, include_partner_subscription: bool = True) -> bool: def is_premium(self) -> bool:
""" """
user is premium if they: user is premium if they:
- have a lifetime deal or - have a lifetime deal or
- in trial period or - in trial period or
- active subscription - active subscription
""" """
if self.lifetime_or_active_subscription(include_partner_subscription): if self.lifetime_or_active_subscription():
return True return True
if self.trial_end and arrow.now() < self.trial_end: if self.trial_end and arrow.now() < self.trial_end:
@ -769,11 +719,11 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
if sub: if sub:
if sub.cancelled: if sub.cancelled:
channels.append( channels.append(
f"""Cancelled Paddle Subscription <a href="https://vendors.paddle.com/subscriptions/customers/manage/{sub.subscription_id}">{sub.subscription_id}</a> {sub.plan_name()} ends at {sub.next_bill_date}""" f"Cancelled Paddle Subscription {sub.subscription_id} {sub.plan_name()} ends at {sub.next_bill_date}"
) )
else: else:
channels.append( channels.append(
f"""Active Paddle Subscription <a href="https://vendors.paddle.com/subscriptions/customers/manage/{sub.subscription_id}">{sub.subscription_id}</a> {sub.plan_name()}, renews at {sub.next_bill_date}""" f"Active Paddle Subscription {sub.subscription_id} {sub.plan_name()}, renews at {sub.next_bill_date}"
) )
apple_sub: AppleSubscription = AppleSubscription.get_by(user_id=self.id) apple_sub: AppleSubscription = AppleSubscription.get_by(user_id=self.id)
@ -838,17 +788,6 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
< self.max_alias_for_free_account() < self.max_alias_for_free_account()
) )
def can_send_or_receive(self) -> bool:
if self.disabled:
LOG.i(f"User {self} is disabled. Cannot receive or send emails")
return False
if self.delete_on is not None:
LOG.i(
f"User {self} is scheduled to be deleted. Cannot receive or send emails"
)
return False
return True
def profile_picture_url(self): def profile_picture_url(self):
if self.profile_picture_id: if self.profile_picture_id:
return self.profile_picture.get_url() return self.profile_picture.get_url()
@ -927,16 +866,14 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
def custom_domains(self): def custom_domains(self):
return CustomDomain.filter_by(user_id=self.id, verified=True).all() return CustomDomain.filter_by(user_id=self.id, verified=True).all()
def available_domains_for_random_alias( def available_domains_for_random_alias(self) -> List[Tuple[bool, str]]:
self, alias_options: Optional[AliasOptions] = None
) -> List[Tuple[bool, str]]:
"""Return available domains for user to create random aliases """Return available domains for user to create random aliases
Each result record contains: Each result record contains:
- whether the domain belongs to SimpleLogin - whether the domain belongs to SimpleLogin
- the domain - the domain
""" """
res = [] res = []
for domain in self.available_sl_domains(alias_options=alias_options): for domain in self.available_sl_domains():
res.append((True, domain)) res.append((True, domain))
for custom_domain in self.verified_custom_domains(): for custom_domain in self.verified_custom_domains():
@ -1006,7 +943,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
return alias.email, unsub.link, unsub.via_email return alias.email, unsub.link, unsub.via_email
# alias disabled -> user doesn't want to receive newsletter # alias disabled -> user doesn't want to receive newsletter
else: else:
return None, "", False return None, None, False
else: else:
# do not handle http POST unsubscribe # do not handle http POST unsubscribe
if config.UNSUBSCRIBER: if config.UNSUBSCRIBER:
@ -1019,67 +956,32 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
True, True,
) )
return None, "", False return None, None, False
def available_sl_domains( def available_sl_domains(self) -> [str]:
self, alias_options: Optional[AliasOptions] = None
) -> [str]:
""" """
Return all SimpleLogin domains that user can use when creating a new alias, including: Return all SimpleLogin domains that user can use when creating a new alias, including:
- SimpleLogin public domains, available for all users (ALIAS_DOMAIN) - SimpleLogin public domains, available for all users (ALIAS_DOMAIN)
- SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN) - SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN)
""" """
return [ return [sl_domain.domain for sl_domain in self.get_sl_domains()]
sl_domain.domain
for sl_domain in self.get_sl_domains(alias_options=alias_options)
]
def get_sl_domains( def get_sl_domains(self) -> List["SLDomain"]:
self, alias_options: Optional[AliasOptions] = None query = SLDomain.filter_by(hidden=False).order_by(SLDomain.order)
) -> list["SLDomain"]:
if alias_options is None:
alias_options = AliasOptions()
top_conds = [SLDomain.hidden == False] # noqa: E712
or_conds = [] # noqa:E711
if self.default_alias_public_domain_id is not None:
default_domain_conds = [SLDomain.id == self.default_alias_public_domain_id]
if not self.is_premium():
default_domain_conds.append(
SLDomain.premium_only == False # noqa: E712
)
or_conds.append(and_(*default_domain_conds).self_group())
if alias_options.show_partner_domains is not None:
partner_user = PartnerUser.filter_by(
user_id=self.id, partner_id=alias_options.show_partner_domains.id
).first()
if partner_user is not None:
partner_domain_cond = [SLDomain.partner_id == partner_user.partner_id]
if alias_options.show_partner_premium is None:
alias_options.show_partner_premium = self.is_premium()
if not alias_options.show_partner_premium:
partner_domain_cond.append(
SLDomain.premium_only == False # noqa: E712
)
or_conds.append(and_(*partner_domain_cond).self_group())
if alias_options.show_sl_domains:
sl_conds = [SLDomain.partner_id == None] # noqa: E711
if not self.is_premium():
sl_conds.append(SLDomain.premium_only == False) # noqa: E712
or_conds.append(and_(*sl_conds).self_group())
top_conds.append(or_(*or_conds))
query = Session.query(SLDomain).filter(*top_conds).order_by(SLDomain.order)
return query.all()
def available_alias_domains( if self.is_premium():
self, alias_options: Optional[AliasOptions] = None return query.all()
) -> [str]: else:
return query.filter_by(premium_only=False).all()
def available_alias_domains(self) -> [str]:
"""return all domains that user can use when creating a new alias, including: """return all domains that user can use when creating a new alias, including:
- SimpleLogin public domains, available for all users (ALIAS_DOMAIN) - SimpleLogin public domains, available for all users (ALIAS_DOMAIN)
- SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN) - SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN)
- Verified custom domains - Verified custom domains
""" """
domains = self.available_sl_domains(alias_options=alias_options) domains = self.available_sl_domains()
for custom_domain in self.verified_custom_domains(): for custom_domain in self.verified_custom_domains():
domains.append(custom_domain.domain) domains.append(custom_domain.domain)
@ -1097,21 +999,16 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
> 0 > 0
) )
def get_random_alias_suffix(self, custom_domain: Optional["CustomDomain"] = None): def get_random_alias_suffix(self):
"""Get random suffix for an alias based on user's preference. """Get random suffix for an alias based on user's preference.
Use a shorter suffix in case of custom domain
Returns: Returns:
str: the random suffix generated str: the random suffix generated
""" """
if self.random_alias_suffix == AliasSuffixEnum.random_string.value: if self.random_alias_suffix == AliasSuffixEnum.random_string.value:
return random_string(config.ALIAS_RANDOM_SUFFIX_LENGTH, include_digits=True) return random_string(config.ALIAS_RANDOM_SUFFIX_LENGTH, include_digits=True)
return random_word()
if custom_domain is None:
return random_words(1, 3)
return random_words(1)
def __repr__(self): def __repr__(self):
return f"<User {self.id} {self.name} {self.email}>" return f"<User {self.id} {self.name} {self.email}>"
@ -1356,48 +1253,34 @@ class OauthToken(Base, ModelMixin):
return self.expired < arrow.now() return self.expired < arrow.now()
def available_sl_email(email: str) -> bool: def generate_email(
if (
Alias.get_by(email=email)
or Contact.get_by(reply_email=email)
or DeletedAlias.get_by(email=email)
):
return False
return True
def generate_random_alias_email(
scheme: int = AliasGeneratorEnum.word.value, scheme: int = AliasGeneratorEnum.word.value,
in_hex: bool = False, in_hex: bool = False,
alias_domain: str = config.FIRST_ALIAS_DOMAIN, alias_domain=config.FIRST_ALIAS_DOMAIN,
retries: int = 10,
) -> str: ) -> str:
"""generate an email address that does not exist before """generate an email address that does not exist before
:param alias_domain: the domain used to generate the alias. :param alias_domain: the domain used to generate the alias.
:param scheme: int, value of AliasGeneratorEnum, indicate how the email is generated :param scheme: int, value of AliasGeneratorEnum, indicate how the email is generated
:param retries: int, How many times we can try to generate an alias in case of collision
:type in_hex: bool, if the generate scheme is uuid, is hex favorable? :type in_hex: bool, if the generate scheme is uuid, is hex favorable?
""" """
if retries <= 0:
raise Exception("Cannot generate alias after many retries")
if scheme == AliasGeneratorEnum.uuid.value: if scheme == AliasGeneratorEnum.uuid.value:
name = uuid.uuid4().hex if in_hex else uuid.uuid4().__str__() name = uuid.uuid4().hex if in_hex else uuid.uuid4().__str__()
random_email = name + "@" + alias_domain random_email = name + "@" + alias_domain
else: else:
random_email = random_words(2, 3) + "@" + alias_domain random_email = random_words() + "@" + alias_domain
random_email = random_email.lower().strip() random_email = random_email.lower().strip()
# check that the client does not exist yet # check that the client does not exist yet
if available_sl_email(random_email): if not Alias.get_by(email=random_email) and not DeletedAlias.get_by(
email=random_email
):
LOG.d("generate email %s", random_email) LOG.d("generate email %s", random_email)
return random_email return random_email
# Rerun the function # Rerun the function
LOG.w("email %s already exists, generate a new email", random_email) LOG.w("email %s already exists, generate a new email", random_email)
return generate_random_alias_email( return generate_email(scheme=scheme, in_hex=in_hex)
scheme=scheme, in_hex=in_hex, retries=retries - 1
)
class Alias(Base, ModelMixin): class Alias(Base, ModelMixin):
@ -1479,7 +1362,7 @@ class Alias(Base, ModelMixin):
) )
# have I been pwned # have I been pwned
hibp_last_check = sa.Column(ArrowType, default=None, index=True) hibp_last_check = sa.Column(ArrowType, default=None)
hibp_breaches = orm.relationship("Hibp", secondary="alias_hibp") hibp_breaches = orm.relationship("Hibp", secondary="alias_hibp")
# to use Postgres full text search. Only applied on "note" column for now # to use Postgres full text search. Only applied on "note" column for now
@ -1574,7 +1457,6 @@ class Alias(Base, ModelMixin):
new_alias.custom_domain_id = custom_domain.id new_alias.custom_domain_id = custom_domain.id
Session.add(new_alias) Session.add(new_alias)
DailyMetric.get_or_create_today_metric().nb_alias += 1
if commit: if commit:
Session.commit() Session.commit()
@ -1596,7 +1478,7 @@ class Alias(Base, ModelMixin):
suffix = user.get_random_alias_suffix() suffix = user.get_random_alias_suffix()
email = f"{prefix}.{suffix}@{config.FIRST_ALIAS_DOMAIN}" email = f"{prefix}.{suffix}@{config.FIRST_ALIAS_DOMAIN}"
if available_sl_email(email): if not cls.get_by(email=email) and not DeletedAlias.get_by(email=email):
break break
return Alias.create( return Alias.create(
@ -1625,7 +1507,7 @@ class Alias(Base, ModelMixin):
if user.default_alias_custom_domain_id: if user.default_alias_custom_domain_id:
custom_domain = CustomDomain.get(user.default_alias_custom_domain_id) custom_domain = CustomDomain.get(user.default_alias_custom_domain_id)
random_email = generate_random_alias_email( random_email = generate_email(
scheme=scheme, in_hex=in_hex, alias_domain=custom_domain.domain scheme=scheme, in_hex=in_hex, alias_domain=custom_domain.domain
) )
elif user.default_alias_public_domain_id: elif user.default_alias_public_domain_id:
@ -1633,12 +1515,12 @@ class Alias(Base, ModelMixin):
if sl_domain.premium_only and not user.is_premium(): if sl_domain.premium_only and not user.is_premium():
LOG.w("%s not premium, cannot use %s", user, sl_domain) LOG.w("%s not premium, cannot use %s", user, sl_domain)
else: else:
random_email = generate_random_alias_email( random_email = generate_email(
scheme=scheme, in_hex=in_hex, alias_domain=sl_domain.domain scheme=scheme, in_hex=in_hex, alias_domain=sl_domain.domain
) )
if not random_email: if not random_email:
random_email = generate_random_alias_email(scheme=scheme, in_hex=in_hex) random_email = generate_email(scheme=scheme, in_hex=in_hex)
alias = Alias.create( alias = Alias.create(
user_id=user.id, user_id=user.id,
@ -1672,9 +1554,7 @@ class ClientUser(Base, ModelMixin):
client_id = sa.Column(sa.ForeignKey(Client.id, ondelete="cascade"), nullable=False) client_id = sa.Column(sa.ForeignKey(Client.id, ondelete="cascade"), nullable=False)
# Null means client has access to user original email # Null means client has access to user original email
alias_id = sa.Column( alias_id = sa.Column(sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=True)
sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=True, index=True
)
# user can decide to send to client another name # user can decide to send to client another name
name = sa.Column( name = sa.Column(
@ -1758,8 +1638,6 @@ class Contact(Base, ModelMixin):
Store configuration of sender (website-email) and alias. Store configuration of sender (website-email) and alias.
""" """
MAX_NAME_LENGTH = 512
__tablename__ = "contact" __tablename__ = "contact"
__table_args__ = ( __table_args__ = (
@ -1793,7 +1671,7 @@ class Contact(Base, ModelMixin):
is_cc = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0") is_cc = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0")
pgp_public_key = sa.Column(sa.Text, nullable=True) pgp_public_key = sa.Column(sa.Text, nullable=True)
pgp_finger_print = sa.Column(sa.String(512), nullable=True, index=True) pgp_finger_print = sa.Column(sa.String(512), nullable=True)
alias = orm.relationship(Alias, backref="contacts") alias = orm.relationship(Alias, backref="contacts")
user = orm.relationship(User) user = orm.relationship(User)
@ -1928,9 +1806,7 @@ class Contact(Base, ModelMixin):
else formatted_email else formatted_email
) )
from app.email_utils import sl_formataddr new_addr = formataddr((new_name, self.reply_email)).strip()
new_addr = sl_formataddr((new_name, self.reply_email)).strip()
return new_addr.strip() return new_addr.strip()
def last_reply(self) -> "EmailLog": def last_reply(self) -> "EmailLog":
@ -1947,7 +1823,6 @@ class Contact(Base, ModelMixin):
class EmailLog(Base, ModelMixin): class EmailLog(Base, ModelMixin):
__tablename__ = "email_log" __tablename__ = "email_log"
__table_args__ = (Index("ix_email_log_created_at", "created_at"),)
user_id = sa.Column( user_id = sa.Column(
sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True
@ -2205,9 +2080,7 @@ class AliasUsedOn(Base, ModelMixin):
sa.UniqueConstraint("alias_id", "hostname", name="uq_alias_used"), sa.UniqueConstraint("alias_id", "hostname", name="uq_alias_used"),
) )
alias_id = sa.Column( alias_id = sa.Column(sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=False)
sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=False, index=True
)
user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False)
alias = orm.relationship(Alias) alias = orm.relationship(Alias)
@ -2326,7 +2199,6 @@ class CustomDomain(Base, ModelMixin):
@classmethod @classmethod
def create(cls, **kwargs): def create(cls, **kwargs):
domain = kwargs.get("domain") domain = kwargs.get("domain")
kwargs["domain"] = domain.replace("\n", "")
if DeletedSubdomain.get_by(domain=domain): if DeletedSubdomain.get_by(domain=domain):
raise SubdomainInTrashError raise SubdomainInTrashError
@ -2594,28 +2466,6 @@ class Mailbox(Base, ModelMixin):
+ Alias.filter_by(mailbox_id=self.id).count() + Alias.filter_by(mailbox_id=self.id).count()
) )
def is_proton(self) -> bool:
if (
self.email.endswith("@proton.me")
or self.email.endswith("@protonmail.com")
or self.email.endswith("@protonmail.ch")
or self.email.endswith("@proton.ch")
or self.email.endswith("@pm.me")
):
return True
from app.email_utils import get_email_local_part
mx_domains: [(int, str)] = get_mx_domains(get_email_local_part(self.email))
# Proton is the first domain
if mx_domains and mx_domains[0][1] in (
"mail.protonmail.ch.",
"mailsec.protonmail.ch.",
):
return True
return False
@classmethod @classmethod
def delete(cls, obj_id): def delete(cls, obj_id):
mailbox: Mailbox = cls.get(obj_id) mailbox: Mailbox = cls.get(obj_id)
@ -2648,12 +2498,6 @@ class Mailbox(Base, ModelMixin):
return ret return ret
@classmethod
def create(cls, **kw):
if "email" in kw:
kw["email"] = sanitize_email(kw["email"])
return super().create(**kw)
def __repr__(self): def __repr__(self):
return f"<Mailbox {self.id} {self.email}>" return f"<Mailbox {self.id} {self.email}>"
@ -2837,21 +2681,12 @@ class RecoveryCode(Base, ModelMixin):
__table_args__ = (sa.UniqueConstraint("user_id", "code", name="uq_recovery_code"),) __table_args__ = (sa.UniqueConstraint("user_id", "code", name="uq_recovery_code"),)
user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False)
code = sa.Column(sa.String(64), nullable=False) code = sa.Column(sa.String(16), nullable=False)
used = sa.Column(sa.Boolean, nullable=False, default=False) used = sa.Column(sa.Boolean, nullable=False, default=False)
used_at = sa.Column(ArrowType, nullable=True, default=None) used_at = sa.Column(ArrowType, nullable=True, default=None)
user = orm.relationship(User) user = orm.relationship(User)
@classmethod
def _hash_code(cls, code: str) -> str:
code_hmac = hmac.new(
config.RECOVERY_CODE_HMAC_SECRET.encode("utf-8"),
code.encode("utf-8"),
"sha3_224",
)
return base64.urlsafe_b64encode(code_hmac.digest()).decode("utf-8").rstrip("=")
@classmethod @classmethod
def generate(cls, user): def generate(cls, user):
"""generate recovery codes for user""" """generate recovery codes for user"""
@ -2860,27 +2695,14 @@ class RecoveryCode(Base, ModelMixin):
Session.flush() Session.flush()
nb_code = 0 nb_code = 0
raw_codes = []
while nb_code < _NB_RECOVERY_CODE: while nb_code < _NB_RECOVERY_CODE:
raw_code = random_string(_RECOVERY_CODE_LENGTH) code = random_string(_RECOVERY_CODE_LENGTH)
encoded_code = cls._hash_code(raw_code) if not cls.get_by(user_id=user.id, code=code):
if not cls.get_by(user_id=user.id, code=encoded_code): cls.create(user_id=user.id, code=code)
cls.create(user_id=user.id, code=encoded_code)
raw_codes.append(raw_code)
nb_code += 1 nb_code += 1
LOG.d("Create recovery codes for %s", user) LOG.d("Create recovery codes for %s", user)
Session.commit() Session.commit()
return raw_codes
@classmethod
def find_by_user_code(cls, user: User, code: str):
hashed_code = cls._hash_code(code)
# TODO: Only return hashed codes once there aren't unhashed codes in the db.
found_code = cls.get_by(user_id=user.id, code=hashed_code)
if found_code:
return found_code
return cls.get_by(user_id=user.id, code=code)
@classmethod @classmethod
def empty(cls, user): def empty(cls, user):
@ -2913,31 +2735,6 @@ class Notification(Base, ModelMixin):
) )
class Partner(Base, ModelMixin):
__tablename__ = "partner"
name = sa.Column(sa.String(128), unique=True, nullable=False)
contact_email = sa.Column(sa.String(128), unique=True, nullable=False)
@staticmethod
def find_by_token(token: str) -> Optional[Partner]:
hmaced = PartnerApiToken.hmac_token(token)
res = (
Session.query(Partner, PartnerApiToken)
.filter(
and_(
PartnerApiToken.token == hmaced,
Partner.id == PartnerApiToken.partner_id,
)
)
.first()
)
if res:
partner, partner_api_token = res
return partner
return None
class SLDomain(Base, ModelMixin): class SLDomain(Base, ModelMixin):
"""SimpleLogin domains""" """SimpleLogin domains"""
@ -2955,23 +2752,12 @@ class SLDomain(Base, ModelMixin):
sa.Boolean, nullable=False, default=False, server_default="0" sa.Boolean, nullable=False, default=False, server_default="0"
) )
partner_id = sa.Column(
sa.ForeignKey(Partner.id, ondelete="cascade"),
nullable=True,
default=None,
server_default="NULL",
)
# if enabled, do not show this domain when user creates a custom alias # if enabled, do not show this domain when user creates a custom alias
hidden = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0") hidden = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0")
# the order in which the domains are shown when user creates a custom alias # the order in which the domains are shown when user creates a custom alias
order = sa.Column(sa.Integer, nullable=False, default=0, server_default="0") order = sa.Column(sa.Integer, nullable=False, default=0, server_default="0")
use_as_reverse_alias = sa.Column(
sa.Boolean, nullable=False, default=False, server_default="0"
)
def __repr__(self): def __repr__(self):
return f"<SLDomain {self.domain} {'Premium' if self.premium_only else 'Free'}" return f"<SLDomain {self.domain} {'Premium' if self.premium_only else 'Free'}"
@ -2992,8 +2778,6 @@ class Monitoring(Base, ModelMixin):
active_queue = sa.Column(sa.Integer, nullable=False) active_queue = sa.Column(sa.Integer, nullable=False)
deferred_queue = sa.Column(sa.Integer, nullable=False) deferred_queue = sa.Column(sa.Integer, nullable=False)
__table_args__ = (Index("ix_monitoring_created_at", "created_at"),)
class BatchImport(Base, ModelMixin): class BatchImport(Base, ModelMixin):
__tablename__ = "batch_import" __tablename__ = "batch_import"
@ -3084,34 +2868,6 @@ class Metric2(Base, ModelMixin):
nb_app = sa.Column(sa.Float, nullable=True) nb_app = sa.Column(sa.Float, nullable=True)
class DailyMetric(Base, ModelMixin):
"""
For storing daily event-based metrics.
The difference between DailyEventMetric and Metric2 is Metric2 stores the total
whereas DailyEventMetric is reset for a new day
"""
__tablename__ = "daily_metric"
date = sa.Column(sa.Date, nullable=False, unique=True)
# users who sign up via web without using "Login with Proton"
nb_new_web_non_proton_user = sa.Column(
sa.Integer, nullable=False, server_default="0", default=0
)
nb_alias = sa.Column(sa.Integer, nullable=False, server_default="0", default=0)
@staticmethod
def get_or_create_today_metric() -> DailyMetric:
today = arrow.utcnow().date()
daily_metric = DailyMetric.get_by(date=today)
if not daily_metric:
daily_metric = DailyMetric.create(
date=today, nb_new_web_non_proton_user=0, nb_alias=0
)
return daily_metric
class Bounce(Base, ModelMixin): class Bounce(Base, ModelMixin):
"""Record all bounces. Deleted after 7 days""" """Record all bounces. Deleted after 7 days"""
@ -3119,8 +2875,6 @@ class Bounce(Base, ModelMixin):
email = sa.Column(sa.String(256), nullable=False, index=True) email = sa.Column(sa.String(256), nullable=False, index=True)
info = sa.Column(sa.Text, nullable=True) info = sa.Column(sa.Text, nullable=True)
__table_args__ = (sa.Index("ix_bounce_created_at", "created_at"),)
class TransactionalEmail(Base, ModelMixin): class TransactionalEmail(Base, ModelMixin):
"""Storing all email addresses that receive transactional emails, including account email and mailboxes. """Storing all email addresses that receive transactional emails, including account email and mailboxes.
@ -3130,8 +2884,6 @@ class TransactionalEmail(Base, ModelMixin):
__tablename__ = "transactional_email" __tablename__ = "transactional_email"
email = sa.Column(sa.String(256), nullable=False, unique=False) email = sa.Column(sa.String(256), nullable=False, unique=False)
__table_args__ = (sa.Index("ix_transactional_email_created_at", "created_at"),)
class Payout(Base, ModelMixin): class Payout(Base, ModelMixin):
"""Referral payouts""" """Referral payouts"""
@ -3184,7 +2936,7 @@ class MessageIDMatching(Base, ModelMixin):
# to track what email_log that has created this matching # to track what email_log that has created this matching
email_log_id = sa.Column( email_log_id = sa.Column(
sa.ForeignKey("email_log.id", ondelete="cascade"), nullable=True, index=True sa.ForeignKey("email_log.id", ondelete="cascade"), nullable=True
) )
email_log = orm.relationship("EmailLog") email_log = orm.relationship("EmailLog")
@ -3373,26 +3125,6 @@ class AdminAuditLog(Base):
data={}, data={},
) )
@classmethod
def disable_user(cls, admin_user_id: int, user_id: int):
cls.create(
admin_user_id=admin_user_id,
action=AuditLogActionEnum.disable_user.value,
model="User",
model_id=user_id,
data={},
)
@classmethod
def enable_user(cls, admin_user_id: int, user_id: int):
cls.create(
admin_user_id=admin_user_id,
action=AuditLogActionEnum.enable_user.value,
model="User",
model_id=user_id,
data={},
)
class ProviderComplaintState(EnumE): class ProviderComplaintState(EnumE):
new = 0 new = 0
@ -3418,6 +3150,31 @@ class ProviderComplaint(Base, ModelMixin):
refused_email = orm.relationship(RefusedEmail, foreign_keys=[refused_email_id]) refused_email = orm.relationship(RefusedEmail, foreign_keys=[refused_email_id])
class Partner(Base, ModelMixin):
__tablename__ = "partner"
name = sa.Column(sa.String(128), unique=True, nullable=False)
contact_email = sa.Column(sa.String(128), unique=True, nullable=False)
@staticmethod
def find_by_token(token: str) -> Optional[Partner]:
hmaced = PartnerApiToken.hmac_token(token)
res = (
Session.query(Partner, PartnerApiToken)
.filter(
and_(
PartnerApiToken.token == hmaced,
Partner.id == PartnerApiToken.partner_id,
)
)
.first()
)
if res:
partner, partner_api_token = res
return partner
return None
class PartnerApiToken(Base, ModelMixin): class PartnerApiToken(Base, ModelMixin):
__tablename__ = "partner_api_token" __tablename__ = "partner_api_token"
@ -3487,7 +3244,7 @@ class PartnerSubscription(Base, ModelMixin):
) )
# when the partner subscription ends # when the partner subscription ends
end_at = sa.Column(ArrowType, nullable=False, index=True) end_at = sa.Column(ArrowType, nullable=False)
partner_user = orm.relationship(PartnerUser) partner_user = orm.relationship(PartnerUser)
@ -3517,7 +3274,7 @@ class PartnerSubscription(Base, ModelMixin):
class Newsletter(Base, ModelMixin): class Newsletter(Base, ModelMixin):
__tablename__ = "newsletter" __tablename__ = "newsletter"
subject = sa.Column(sa.String(), nullable=False, index=True) subject = sa.Column(sa.String(), nullable=False, unique=True, index=True)
html = sa.Column(sa.Text) html = sa.Column(sa.Text)
plain_text = sa.Column(sa.Text) plain_text = sa.Column(sa.Text)
@ -3539,19 +3296,3 @@ class NewsletterUser(Base, ModelMixin):
user = orm.relationship(User) user = orm.relationship(User)
newsletter = orm.relationship(Newsletter) newsletter = orm.relationship(Newsletter)
class ApiToCookieToken(Base, ModelMixin):
__tablename__ = "api_cookie_token"
code = sa.Column(sa.String(128), unique=True, nullable=False)
user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False)
api_key_id = sa.Column(sa.ForeignKey(ApiKey.id, ondelete="cascade"), nullable=False)
user = orm.relationship(User)
api_key = orm.relationship(ApiKey)
@classmethod
def create(cls, **kwargs):
code = secrets.token_urlsafe(32)
return super().create(code=code, **kwargs)

View file

@ -1,3 +1 @@
from . import views from . import views
__all__ = ["views"]

View file

@ -4,9 +4,8 @@ from jinja2 import Environment, FileSystemLoader
from app.config import ROOT_DIR, URL from app.config import ROOT_DIR, URL
from app.email_utils import send_email from app.email_utils import send_email
from app.handler.unsubscribe_encoder import UnsubscribeEncoder, UnsubscribeAction
from app.log import LOG from app.log import LOG
from app.models import NewsletterUser, Alias from app.models import NewsletterUser
def send_newsletter_to_user(newsletter, user) -> (bool, str): def send_newsletter_to_user(newsletter, user) -> (bool, str):
@ -17,25 +16,12 @@ def send_newsletter_to_user(newsletter, user) -> (bool, str):
html_template = env.from_string(newsletter.html) html_template = env.from_string(newsletter.html)
text_template = env.from_string(newsletter.plain_text) text_template = env.from_string(newsletter.plain_text)
comm_email, unsubscribe_link, via_email = user.get_communication_email() to_email, unsubscribe_link, via_email = user.get_communication_email()
if not comm_email: if not to_email:
return False, f"{user} not subscribed to newsletter" return False, f"{user} not subscribed to newsletter"
comm_alias = Alias.get_by(email=comm_email)
comm_alias_id = -1
if comm_alias:
comm_alias_id = comm_alias.id
unsubscribe_oneclick = unsubscribe_link
if via_email and comm_alias_id > -1:
unsubscribe_oneclick = UnsubscribeEncoder.encode(
UnsubscribeAction.DisableAlias,
comm_alias_id,
force_web=True,
).link
send_email( send_email(
comm_email, to_email,
newsletter.subject, newsletter.subject,
text_template.render( text_template.render(
user=user, user=user,
@ -44,10 +30,7 @@ def send_newsletter_to_user(newsletter, user) -> (bool, str):
html_template.render( html_template.render(
user=user, user=user,
URL=URL, URL=URL,
unsubscribe_oneclick=unsubscribe_oneclick,
), ),
unsubscribe_link=unsubscribe_link,
unsubscribe_via_email=via_email,
) )
NewsletterUser.create(newsletter_id=newsletter.id, user_id=user.id, commit=True) NewsletterUser.create(newsletter_id=newsletter.id, user_id=user.id, commit=True)

View file

@ -1,3 +1 @@
from .views import authorize, token, user_info from .views import authorize, token, user_info
__all__ = ["authorize", "token", "user_info"]

View file

@ -64,7 +64,7 @@ def _split_arg(arg_input: Union[str, list]) -> Set[str]:
- the response_type/scope passed as a list ?scope=scope_1&scope=scope_2 - the response_type/scope passed as a list ?scope=scope_1&scope=scope_2
""" """
res = set() res = set()
if isinstance(arg_input, str): if type(arg_input) is str:
if " " in arg_input: if " " in arg_input:
for x in arg_input.split(" "): for x in arg_input.split(" "):
if x: if x:

View file

@ -5,11 +5,3 @@ from .views import (
account_activated, account_activated,
extension_redirect, extension_redirect,
) )
__all__ = [
"index",
"final",
"setup_done",
"account_activated",
"extension_redirect",
]

View file

@ -1,32 +0,0 @@
import arrow
from app.db import Session
from app.email_utils import send_email, render
from app.log import LOG
from app.models import Subscription
from app import paddle_utils
def failed_payment(sub: Subscription, subscription_id: str):
LOG.w(
"Subscription failed payment %s for %s (sub %s)",
subscription_id,
sub.user,
sub.id,
)
sub.cancelled = True
Session.commit()
user = sub.user
paddle_utils.cancel_subscription(subscription_id)
send_email(
user.email,
"SimpleLogin - your subscription has failed to be renewed",
render(
"transactional/subscription-cancel.txt",
end_date=arrow.arrow.datetime.utcnow(),
),
)

View file

@ -1,73 +0,0 @@
import uuid
from datetime import timedelta
from functools import wraps
from typing import Callable, Any, Optional
from flask import request
from flask_login import current_user
from limits.storage import RedisStorage
from werkzeug import exceptions
lock_redis: Optional[RedisStorage] = None
def set_redis_concurrent_lock(redis: RedisStorage):
global lock_redis
lock_redis = redis
class _InnerLock:
def __init__(
self,
lock_suffix: Optional[str] = None,
max_wait_secs: int = 5,
only_when: Optional[Callable[..., bool]] = None,
):
self.lock_suffix = lock_suffix
self.max_wait_secs = max_wait_secs
self.only_when = only_when
def acquire_lock(self, lock_name: str, lock_value: str):
if not lock_redis.storage.set(
lock_name, lock_value, ex=timedelta(seconds=self.max_wait_secs), nx=True
):
raise exceptions.TooManyRequests()
def release_lock(self, lock_name: str, lock_value: str):
current_lock_value = lock_redis.storage.get(lock_name)
if current_lock_value == lock_value.encode("utf-8"):
lock_redis.storage.delete(lock_name)
def __call__(self, f: Callable[..., Any]):
if self.lock_suffix is None:
lock_suffix = f.__name__
else:
lock_suffix = self.lock_suffix
@wraps(f)
def decorated(*args, **kwargs):
if self.only_when and not self.only_when():
return f(*args, **kwargs)
if not lock_redis:
return f(*args, **kwargs)
lock_value = str(uuid.uuid4())[:10]
if "id" in dir(current_user):
lock_name = f"cl:{current_user.id}:{lock_suffix}"
else:
lock_name = f"cl:{request.remote_addr}:{lock_suffix}"
self.acquire_lock(lock_name, lock_value)
try:
return f(*args, **kwargs)
finally:
self.release_lock(lock_name, lock_value)
return decorated
def lock(
name: Optional[str] = None,
max_wait_secs: int = 5,
only_when: Optional[Callable[..., bool]] = None,
):
return _InnerLock(name, max_wait_secs, only_when)

View file

@ -5,11 +5,3 @@ from .views import (
provider1_callback, provider1_callback,
provider2_callback, provider2_callback,
) )
__all__ = [
"index",
"phone_reservation",
"twilio_callback",
"provider1_callback",
"provider2_callback",
]

View file

@ -64,9 +64,7 @@ class ProtonCallbackHandler:
) )
def handle_link( def handle_link(
self, self, current_user: Optional[User], partner: Partner
current_user: Optional[User],
partner: Partner,
) -> ProtonCallbackResult: ) -> ProtonCallbackResult:
if current_user is None: if current_user is None:
raise Exception("Cannot link account with current_user being None") raise Exception("Cannot link account with current_user being None")

View file

@ -7,12 +7,11 @@ from typing import Optional
from app.account_linking import SLPlan, SLPlanType from app.account_linking import SLPlan, SLPlanType
from app.config import PROTON_EXTRA_HEADER_NAME, PROTON_EXTRA_HEADER_VALUE from app.config import PROTON_EXTRA_HEADER_NAME, PROTON_EXTRA_HEADER_VALUE
from app.errors import ProtonAccountNotVerified
from app.log import LOG from app.log import LOG
_APP_VERSION = "OauthClient_1.0.0" _APP_VERSION = "OauthClient_1.0.0"
PROTON_ERROR_CODE_HV_NEEDED = 9001 PROTON_ERROR_CODE_NOT_EXISTS = 2501
PLAN_FREE = 1 PLAN_FREE = 1
PLAN_PREMIUM = 2 PLAN_PREMIUM = 2
@ -58,15 +57,6 @@ def convert_access_token(access_token_response: str) -> AccessCredentials:
) )
def handle_response_not_ok(status: int, body: dict, text: str) -> Exception:
if status == HTTPStatus.UNPROCESSABLE_ENTITY:
res_code = body.get("Code")
if res_code == PROTON_ERROR_CODE_HV_NEEDED:
return ProtonAccountNotVerified()
return Exception(f"Unexpected status code. Wanted 200 and got {status}: " + text)
class ProtonClient(ABC): class ProtonClient(ABC):
@abstractmethod @abstractmethod
def get_user(self) -> Optional[UserInformation]: def get_user(self) -> Optional[UserInformation]:
@ -134,11 +124,11 @@ class HttpProtonClient(ProtonClient):
@staticmethod @staticmethod
def __validate_response(res: Response) -> dict: def __validate_response(res: Response) -> dict:
status = res.status_code status = res.status_code
as_json = res.json()
if status != HTTPStatus.OK: if status != HTTPStatus.OK:
raise HttpProtonClient.__handle_response_not_ok( raise Exception(
status=status, body=as_json, text=res.text f"Unexpected status code. Wanted 200 and got {status}: " + res.text
) )
as_json = res.json()
res_code = as_json.get("Code") res_code = as_json.get("Code")
if not res_code or res_code != 1000: if not res_code or res_code != 1000:
raise Exception( raise Exception(

View file

@ -1,9 +1,8 @@
from newrelic import agent
from typing import Optional from typing import Optional
from app.db import Session from app.db import Session
from app.errors import ProtonPartnerNotSetUp from app.errors import ProtonPartnerNotSetUp
from app.models import Partner, PartnerUser, User from app.models import Partner
PROTON_PARTNER_NAME = "Proton" PROTON_PARTNER_NAME = "Proton"
_PROTON_PARTNER: Optional[Partner] = None _PROTON_PARTNER: Optional[Partner] = None
@ -22,14 +21,3 @@ def get_proton_partner() -> Partner:
def is_proton_partner(partner: Partner) -> bool: def is_proton_partner(partner: Partner) -> bool:
return partner.name == PROTON_PARTNER_NAME return partner.name == PROTON_PARTNER_NAME
def perform_proton_account_unlink(current_user: User):
proton_partner = get_proton_partner()
partner_user = PartnerUser.get_by(
user_id=current_user.id, partner_id=proton_partner.id
)
if partner_user is not None:
PartnerUser.delete(partner_user.id)
Session.commit()
agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name})

View file

@ -1,22 +0,0 @@
import flask
import limits.storage
from app.parallel_limiter import set_redis_concurrent_lock
from app.session import RedisSessionStore
def initialize_redis_services(app: flask.Flask, redis_url: str):
if redis_url.startswith("redis://") or redis_url.startswith("rediss://"):
storage = limits.storage.RedisStorage(redis_url)
app.session_interface = RedisSessionStore(storage.storage, storage.storage, app)
set_redis_concurrent_lock(storage)
elif redis_url.startswith("redis+sentinel://"):
storage = limits.storage.RedisSentinelStorage(redis_url)
app.session_interface = RedisSessionStore(
storage.storage, storage.storage_slave, app
)
set_redis_concurrent_lock(storage)
else:
raise RuntimeError(
f"Tried to set_redis_session with an invalid redis url: ${redis_url}"
)

Some files were not shown because too many files have changed in this diff Show more