123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405 |
- import enum
- import hashlib
- import arrow
- import bcrypt
- import stripe
- from arrow import Arrow
- from flask_login import UserMixin
- from sqlalchemy_utils import ArrowType
- from app import s3
- from app.config import URL, MAX_NB_EMAIL_FREE_PLAN, EMAIL_DOMAIN
- from app.extensions import db
- from app.log import LOG
- from app.oauth_models import ScopeE
- from app.utils import convert_to_id, random_string
- class ModelMixin(object):
- id = db.Column(db.Integer, primary_key=True, autoincrement=True)
- created_at = db.Column(ArrowType, default=arrow.utcnow, nullable=False)
- updated_at = db.Column(ArrowType, default=None, onupdate=arrow.utcnow)
- _repr_hide = ["created_at", "updated_at"]
- @classmethod
- def query(cls):
- return db.session.query(cls)
- @classmethod
- def get(cls, id):
- return cls.query.get(id)
- @classmethod
- def get_by(cls, **kw):
- return cls.query.filter_by(**kw).first()
- @classmethod
- def filter_by(cls, **kw):
- return cls.query.filter_by(**kw)
- @classmethod
- def get_or_create(cls, **kw):
- r = cls.get_by(**kw)
- if not r:
- r = cls(**kw)
- db.session.add(r)
- return r
- @classmethod
- def create(cls, **kw):
- r = cls(**kw)
- db.session.add(r)
- return r
- def save(self):
- db.session.add(self)
- def delete(self):
- db.session.delete(self)
- def __repr__(self):
- values = ", ".join(
- "%s=%r" % (n, getattr(self, n))
- for n in self.__table__.c.keys()
- if n not in self._repr_hide
- )
- return "%s(%s)" % (self.__class__.__name__, values)
- class File(db.Model, ModelMixin):
- path = db.Column(db.String(128), unique=True, nullable=False)
- def get_url(self):
- return s3.get_url(self.path)
- class PlanEnum(enum.Enum):
- free = 0
- trial = 1
- monthly = 2
- yearly = 3
- class User(db.Model, ModelMixin, UserMixin):
- __tablename__ = "users"
- email = db.Column(db.String(128), unique=True, nullable=False)
- salt = db.Column(db.String(128), nullable=False)
- password = db.Column(db.String(128), nullable=False)
- name = db.Column(db.String(128), nullable=False)
- is_admin = db.Column(db.Boolean, nullable=False, default=False)
- activated = db.Column(db.Boolean, default=False, nullable=False)
- plan = db.Column(
- db.Enum(PlanEnum),
- nullable=False,
- default=PlanEnum.free,
- server_default=PlanEnum.free.name,
- )
- # only relevant for trial period
- plan_expiration = db.Column(ArrowType)
- stripe_customer_id = db.Column(db.String(128), unique=True)
- stripe_card_token = db.Column(db.String(128), unique=True)
- stripe_subscription_id = db.Column(db.String(128), unique=True)
- profile_picture_id = db.Column(db.ForeignKey(File.id), nullable=True)
- is_developer = db.Column(db.Boolean, nullable=False, server_default="0")
- # contain the list of promo codes user has used. Promo codes are separated by ","
- promo_codes = db.Column(db.Text, nullable=True)
- profile_picture = db.relationship(File)
- def should_upgrade(self):
- """User is invited to upgrade if they are in free plan or their trial ends soon"""
- if self.plan == PlanEnum.free:
- return True
- elif self.plan == PlanEnum.trial and self.plan_expiration < arrow.now().shift(
- weeks=1
- ):
- return True
- return False
- def is_premium(self):
- return self.plan in (PlanEnum.monthly, PlanEnum.yearly)
- def can_create_new_email(self):
- if self.is_premium():
- return True
- # plan not expired yet
- elif self.plan == PlanEnum.trial and self.plan_expiration > arrow.now():
- return True
- else: # free or trial expired
- return GenEmail.filter_by(user_id=self.id).count() < MAX_NB_EMAIL_FREE_PLAN
- def set_password(self, password):
- salt = bcrypt.gensalt()
- password_hash = bcrypt.hashpw(password.encode(), salt).decode()
- self.salt = salt.decode()
- self.password = password_hash
- def check_password(self, password) -> bool:
- password_hash = bcrypt.hashpw(password.encode(), self.salt.encode())
- return self.password.encode() == password_hash
- def profile_picture_url(self):
- if self.profile_picture_id:
- return self.profile_picture.get_url()
- else: # use gravatar
- hash_email = hashlib.md5(self.email.encode("utf-8")).hexdigest()
- return f"https://www.gravatar.com/avatar/{hash_email}"
- def plan_current_period_end(self) -> Arrow:
- if not self.stripe_subscription_id:
- LOG.error(
- "plan_current_period_end should not be called with empty stripe_subscription_id"
- )
- return None
- current_period_end_ts = stripe.Subscription.retrieve(
- self.stripe_subscription_id
- )["current_period_end"]
- return arrow.get(current_period_end_ts)
- def get_promo_codes(self) -> [str]:
- if not self.promo_codes:
- return []
- return self.promo_codes.split(",")
- def save_new_promo_code(self, promo_code):
- current_promo_codes = self.get_promo_codes()
- current_promo_codes.append(promo_code)
- self.promo_codes = ",".join(current_promo_codes)
- class ActivationCode(db.Model, ModelMixin):
- """For activate user account"""
- user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
- code = db.Column(db.String(128), unique=True, nullable=False)
- user = db.relationship(User)
- # the activation code is valid for 1h
- expired = db.Column(ArrowType, default=arrow.now().shift(hours=1))
- class ResetPasswordCode(db.Model, ModelMixin):
- """For resetting password"""
- user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
- code = db.Column(db.String(128), unique=True, nullable=False)
- user = db.relationship(User)
- # the activation code is valid for 1h
- expired = db.Column(ArrowType, default=arrow.now().shift(hours=1), nullable=False)
- class Partner(db.Model, ModelMixin):
- email = db.Column(db.String(128))
- name = db.Column(db.String(128))
- website = db.Column(db.String(1024))
- additional_information = db.Column(db.Text)
- # If apply from a authenticated user, set user_id to the user who has applied for partnership
- user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=True)
- # <<< OAUTH models >>>
- client_scope = db.Table(
- "client_scope",
- db.Column(
- "client_id",
- db.Integer,
- db.ForeignKey("client.id", ondelete="cascade"),
- primary_key=True,
- nullable=False,
- ),
- db.Column(
- "scope_id",
- db.Integer,
- db.ForeignKey("scope.id", ondelete="cascade"),
- primary_key=True,
- nullable=False,
- ),
- )
- def generate_oauth_client_id(client_name) -> str:
- oauth_client_id = convert_to_id(client_name) + "-" + random_string()
- # check that the client does not exist yet
- if not Client.get_by(oauth_client_id=oauth_client_id):
- LOG.debug("generate oauth_client_id %s", oauth_client_id)
- return oauth_client_id
- # Rerun the function
- LOG.warning(
- "client_id %s already exists, generate a new client_id", oauth_client_id
- )
- return generate_oauth_client_id(client_name)
- class Client(db.Model, ModelMixin):
- oauth_client_id = db.Column(db.String(128), unique=True, nullable=False)
- oauth_client_secret = db.Column(db.String(128), nullable=False)
- name = db.Column(db.String(128), nullable=False)
- home_url = db.Column(db.String(1024))
- published = db.Column(db.Boolean, default=False, nullable=False)
- # user who created this client
- user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
- icon_id = db.Column(db.ForeignKey(File.id), nullable=True)
- scopes = db.relationship("Scope", secondary=client_scope, lazy="subquery")
- icon = db.relationship(File)
- def nb_user(self):
- return ClientUser.filter_by(client_id=self.id).count()
- @classmethod
- def create_new(cls, name, user_id) -> "Client":
- # generate a client-id
- oauth_client_id = generate_oauth_client_id(name)
- oauth_client_secret = random_string(40)
- client = Client.create(
- name=name,
- oauth_client_id=oauth_client_id,
- oauth_client_secret=oauth_client_secret,
- user_id=user_id,
- )
- # By default, add email and name scope
- client.scopes.append(Scope.get_by(name=ScopeE.NAME.value))
- client.scopes.append(Scope.get_by(name=ScopeE.EMAIL.value))
- return client
- def get_icon_url(self):
- if self.icon_id:
- return self.icon.get_url()
- else:
- return URL + "/static/default-icon.svg"
- class RedirectUri(db.Model, ModelMixin):
- """Valid redirect uris for a client"""
- client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False)
- uri = db.Column(db.String(1024), nullable=False)
- client = db.relationship(Client, backref="redirect_uris")
- class AuthorizationCode(db.Model, ModelMixin):
- code = db.Column(db.String(128), unique=True, nullable=False)
- client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False)
- user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
- scope = db.Column(db.String(128))
- redirect_uri = db.Column(db.String(1024))
- user = db.relationship(User, lazy=False)
- client = db.relationship(Client, lazy=False)
- class OauthToken(db.Model, ModelMixin):
- access_token = db.Column(db.String(128), unique=True)
- client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False)
- user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
- scope = db.Column(db.String(128))
- redirect_uri = db.Column(db.String(1024))
- user = db.relationship(User)
- client = db.relationship(Client)
- class Scope(db.Model, ModelMixin):
- name = db.Column(db.String(128), unique=True, nullable=False)
- def generate_email() -> str:
- """generate an email address that does not exist before"""
- random_email = random_string(40) + "@" + EMAIL_DOMAIN
- # check that the client does not exist yet
- if not GenEmail.get_by(email=random_email):
- LOG.debug("generate email %s", random_email)
- return random_email
- # Rerun the function
- LOG.warning("email %s already exists, generate a new email", random_email)
- return generate_email()
- class GenEmail(db.Model, ModelMixin):
- """Generated email"""
- user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
- email = db.Column(db.String(128), unique=True, nullable=False)
- enabled = db.Column(db.Boolean(), default=True, nullable=False)
- @classmethod
- def create_new_gen_email(cls, user_id):
- random_email = generate_email()
- return GenEmail.create(user_id=user_id, email=random_email)
- def __repr__(self):
- return f"<GenEmail {self.id} {self.email}>"
- class ClientUser(db.Model, ModelMixin):
- __table_args__ = (
- db.UniqueConstraint("user_id", "client_id", name="uq_client_user"),
- )
- user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
- client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False)
- # Null means client has access to user original email
- gen_email_id = db.Column(
- db.ForeignKey(GenEmail.id, ondelete="cascade"), nullable=True
- )
- gen_email = db.relationship(GenEmail, backref="client_users")
- user = db.relationship(User)
- client = db.relationship(Client)
- def get_email(self):
- return self.gen_email.email if self.gen_email_id else self.user.email
- def get_user_info(self) -> dict:
- """return user info according to client scope
- Return dict with key being scope name
- """
- res = {"id": self.id, "client": self.client.name, "email_verified": True}
- for scope in self.client.scopes:
- if scope.name == ScopeE.NAME.value:
- res[ScopeE.NAME.value] = self.user.name
- elif scope.name == ScopeE.EMAIL.value:
- # Use generated email
- if self.gen_email_id:
- LOG.debug(
- "Use gen email for user %s, client %s", self.user, self.client
- )
- res[ScopeE.EMAIL.value] = self.gen_email.email
- # Use user original email
- else:
- res[ScopeE.EMAIL.value] = self.user.email
- return res
|