models.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. import enum
  2. import hashlib
  3. import arrow
  4. import bcrypt
  5. import stripe
  6. from arrow import Arrow
  7. from flask_login import UserMixin
  8. from sqlalchemy_utils import ArrowType
  9. from app import s3
  10. from app.config import URL, MAX_NB_EMAIL_FREE_PLAN, EMAIL_DOMAIN
  11. from app.extensions import db
  12. from app.log import LOG
  13. from app.oauth_models import ScopeE
  14. from app.utils import convert_to_id, random_string
  15. class ModelMixin(object):
  16. id = db.Column(db.Integer, primary_key=True, autoincrement=True)
  17. created_at = db.Column(ArrowType, default=arrow.utcnow, nullable=False)
  18. updated_at = db.Column(ArrowType, default=None, onupdate=arrow.utcnow)
  19. _repr_hide = ["created_at", "updated_at"]
  20. @classmethod
  21. def query(cls):
  22. return db.session.query(cls)
  23. @classmethod
  24. def get(cls, id):
  25. return cls.query.get(id)
  26. @classmethod
  27. def get_by(cls, **kw):
  28. return cls.query.filter_by(**kw).first()
  29. @classmethod
  30. def filter_by(cls, **kw):
  31. return cls.query.filter_by(**kw)
  32. @classmethod
  33. def get_or_create(cls, **kw):
  34. r = cls.get_by(**kw)
  35. if not r:
  36. r = cls(**kw)
  37. db.session.add(r)
  38. return r
  39. @classmethod
  40. def create(cls, **kw):
  41. r = cls(**kw)
  42. db.session.add(r)
  43. return r
  44. def save(self):
  45. db.session.add(self)
  46. def delete(self):
  47. db.session.delete(self)
  48. def __repr__(self):
  49. values = ", ".join(
  50. "%s=%r" % (n, getattr(self, n))
  51. for n in self.__table__.c.keys()
  52. if n not in self._repr_hide
  53. )
  54. return "%s(%s)" % (self.__class__.__name__, values)
  55. class File(db.Model, ModelMixin):
  56. path = db.Column(db.String(128), unique=True, nullable=False)
  57. def get_url(self):
  58. return s3.get_url(self.path)
  59. class PlanEnum(enum.Enum):
  60. free = 0
  61. trial = 1
  62. monthly = 2
  63. yearly = 3
  64. class User(db.Model, ModelMixin, UserMixin):
  65. __tablename__ = "users"
  66. email = db.Column(db.String(128), unique=True, nullable=False)
  67. salt = db.Column(db.String(128), nullable=False)
  68. password = db.Column(db.String(128), nullable=False)
  69. name = db.Column(db.String(128), nullable=False)
  70. is_admin = db.Column(db.Boolean, nullable=False, default=False)
  71. activated = db.Column(db.Boolean, default=False, nullable=False)
  72. plan = db.Column(
  73. db.Enum(PlanEnum),
  74. nullable=False,
  75. default=PlanEnum.free,
  76. server_default=PlanEnum.free.name,
  77. )
  78. # only relevant for trial period
  79. plan_expiration = db.Column(ArrowType)
  80. stripe_customer_id = db.Column(db.String(128), unique=True)
  81. stripe_card_token = db.Column(db.String(128), unique=True)
  82. stripe_subscription_id = db.Column(db.String(128), unique=True)
  83. profile_picture_id = db.Column(db.ForeignKey(File.id), nullable=True)
  84. is_developer = db.Column(db.Boolean, nullable=False, server_default="0")
  85. # contain the list of promo codes user has used. Promo codes are separated by ","
  86. promo_codes = db.Column(db.Text, nullable=True)
  87. profile_picture = db.relationship(File)
  88. def should_upgrade(self):
  89. """User is invited to upgrade if they are in free plan or their trial ends soon"""
  90. if self.plan == PlanEnum.free:
  91. return True
  92. elif self.plan == PlanEnum.trial and self.plan_expiration < arrow.now().shift(
  93. weeks=1
  94. ):
  95. return True
  96. return False
  97. def is_premium(self):
  98. return self.plan in (PlanEnum.monthly, PlanEnum.yearly)
  99. def can_create_new_email(self):
  100. if self.is_premium():
  101. return True
  102. # plan not expired yet
  103. elif self.plan == PlanEnum.trial and self.plan_expiration > arrow.now():
  104. return True
  105. else: # free or trial expired
  106. return GenEmail.filter_by(user_id=self.id).count() < MAX_NB_EMAIL_FREE_PLAN
  107. def set_password(self, password):
  108. salt = bcrypt.gensalt()
  109. password_hash = bcrypt.hashpw(password.encode(), salt).decode()
  110. self.salt = salt.decode()
  111. self.password = password_hash
  112. def check_password(self, password) -> bool:
  113. password_hash = bcrypt.hashpw(password.encode(), self.salt.encode())
  114. return self.password.encode() == password_hash
  115. def profile_picture_url(self):
  116. if self.profile_picture_id:
  117. return self.profile_picture.get_url()
  118. else: # use gravatar
  119. hash_email = hashlib.md5(self.email.encode("utf-8")).hexdigest()
  120. return f"https://www.gravatar.com/avatar/{hash_email}"
  121. def plan_current_period_end(self) -> Arrow:
  122. if not self.stripe_subscription_id:
  123. LOG.error(
  124. "plan_current_period_end should not be called with empty stripe_subscription_id"
  125. )
  126. return None
  127. current_period_end_ts = stripe.Subscription.retrieve(
  128. self.stripe_subscription_id
  129. )["current_period_end"]
  130. return arrow.get(current_period_end_ts)
  131. def get_promo_codes(self) -> [str]:
  132. if not self.promo_codes:
  133. return []
  134. return self.promo_codes.split(",")
  135. def save_new_promo_code(self, promo_code):
  136. current_promo_codes = self.get_promo_codes()
  137. current_promo_codes.append(promo_code)
  138. self.promo_codes = ",".join(current_promo_codes)
  139. class ActivationCode(db.Model, ModelMixin):
  140. """For activate user account"""
  141. user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
  142. code = db.Column(db.String(128), unique=True, nullable=False)
  143. user = db.relationship(User)
  144. # the activation code is valid for 1h
  145. expired = db.Column(ArrowType, default=arrow.now().shift(hours=1))
  146. class ResetPasswordCode(db.Model, ModelMixin):
  147. """For resetting password"""
  148. user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
  149. code = db.Column(db.String(128), unique=True, nullable=False)
  150. user = db.relationship(User)
  151. # the activation code is valid for 1h
  152. expired = db.Column(ArrowType, default=arrow.now().shift(hours=1), nullable=False)
  153. class Partner(db.Model, ModelMixin):
  154. email = db.Column(db.String(128))
  155. name = db.Column(db.String(128))
  156. website = db.Column(db.String(1024))
  157. additional_information = db.Column(db.Text)
  158. # If apply from a authenticated user, set user_id to the user who has applied for partnership
  159. user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=True)
  160. # <<< OAUTH models >>>
  161. client_scope = db.Table(
  162. "client_scope",
  163. db.Column(
  164. "client_id",
  165. db.Integer,
  166. db.ForeignKey("client.id", ondelete="cascade"),
  167. primary_key=True,
  168. nullable=False,
  169. ),
  170. db.Column(
  171. "scope_id",
  172. db.Integer,
  173. db.ForeignKey("scope.id", ondelete="cascade"),
  174. primary_key=True,
  175. nullable=False,
  176. ),
  177. )
  178. def generate_oauth_client_id(client_name) -> str:
  179. oauth_client_id = convert_to_id(client_name) + "-" + random_string()
  180. # check that the client does not exist yet
  181. if not Client.get_by(oauth_client_id=oauth_client_id):
  182. LOG.debug("generate oauth_client_id %s", oauth_client_id)
  183. return oauth_client_id
  184. # Rerun the function
  185. LOG.warning(
  186. "client_id %s already exists, generate a new client_id", oauth_client_id
  187. )
  188. return generate_oauth_client_id(client_name)
  189. class Client(db.Model, ModelMixin):
  190. oauth_client_id = db.Column(db.String(128), unique=True, nullable=False)
  191. oauth_client_secret = db.Column(db.String(128), nullable=False)
  192. name = db.Column(db.String(128), nullable=False)
  193. home_url = db.Column(db.String(1024))
  194. published = db.Column(db.Boolean, default=False, nullable=False)
  195. # user who created this client
  196. user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
  197. icon_id = db.Column(db.ForeignKey(File.id), nullable=True)
  198. scopes = db.relationship("Scope", secondary=client_scope, lazy="subquery")
  199. icon = db.relationship(File)
  200. def nb_user(self):
  201. return ClientUser.filter_by(client_id=self.id).count()
  202. @classmethod
  203. def create_new(cls, name, user_id) -> "Client":
  204. # generate a client-id
  205. oauth_client_id = generate_oauth_client_id(name)
  206. oauth_client_secret = random_string(40)
  207. client = Client.create(
  208. name=name,
  209. oauth_client_id=oauth_client_id,
  210. oauth_client_secret=oauth_client_secret,
  211. user_id=user_id,
  212. )
  213. # By default, add email and name scope
  214. client.scopes.append(Scope.get_by(name=ScopeE.NAME.value))
  215. client.scopes.append(Scope.get_by(name=ScopeE.EMAIL.value))
  216. return client
  217. def get_icon_url(self):
  218. if self.icon_id:
  219. return self.icon.get_url()
  220. else:
  221. return URL + "/static/default-icon.svg"
  222. class RedirectUri(db.Model, ModelMixin):
  223. """Valid redirect uris for a client"""
  224. client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False)
  225. uri = db.Column(db.String(1024), nullable=False)
  226. client = db.relationship(Client, backref="redirect_uris")
  227. class AuthorizationCode(db.Model, ModelMixin):
  228. code = db.Column(db.String(128), unique=True, nullable=False)
  229. client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False)
  230. user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
  231. scope = db.Column(db.String(128))
  232. redirect_uri = db.Column(db.String(1024))
  233. user = db.relationship(User, lazy=False)
  234. client = db.relationship(Client, lazy=False)
  235. class OauthToken(db.Model, ModelMixin):
  236. access_token = db.Column(db.String(128), unique=True)
  237. client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False)
  238. user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
  239. scope = db.Column(db.String(128))
  240. redirect_uri = db.Column(db.String(1024))
  241. user = db.relationship(User)
  242. client = db.relationship(Client)
  243. class Scope(db.Model, ModelMixin):
  244. name = db.Column(db.String(128), unique=True, nullable=False)
  245. def generate_email() -> str:
  246. """generate an email address that does not exist before"""
  247. random_email = random_string(40) + "@" + EMAIL_DOMAIN
  248. # check that the client does not exist yet
  249. if not GenEmail.get_by(email=random_email):
  250. LOG.debug("generate email %s", random_email)
  251. return random_email
  252. # Rerun the function
  253. LOG.warning("email %s already exists, generate a new email", random_email)
  254. return generate_email()
  255. class GenEmail(db.Model, ModelMixin):
  256. """Generated email"""
  257. user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
  258. email = db.Column(db.String(128), unique=True, nullable=False)
  259. enabled = db.Column(db.Boolean(), default=True, nullable=False)
  260. @classmethod
  261. def create_new_gen_email(cls, user_id):
  262. random_email = generate_email()
  263. return GenEmail.create(user_id=user_id, email=random_email)
  264. def __repr__(self):
  265. return f"<GenEmail {self.id} {self.email}>"
  266. class ClientUser(db.Model, ModelMixin):
  267. __table_args__ = (
  268. db.UniqueConstraint("user_id", "client_id", name="uq_client_user"),
  269. )
  270. user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False)
  271. client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False)
  272. # Null means client has access to user original email
  273. gen_email_id = db.Column(
  274. db.ForeignKey(GenEmail.id, ondelete="cascade"), nullable=True
  275. )
  276. gen_email = db.relationship(GenEmail, backref="client_users")
  277. user = db.relationship(User)
  278. client = db.relationship(Client)
  279. def get_email(self):
  280. return self.gen_email.email if self.gen_email_id else self.user.email
  281. def get_user_info(self) -> dict:
  282. """return user info according to client scope
  283. Return dict with key being scope name
  284. """
  285. res = {"id": self.id, "client": self.client.name, "email_verified": True}
  286. for scope in self.client.scopes:
  287. if scope.name == ScopeE.NAME.value:
  288. res[ScopeE.NAME.value] = self.user.name
  289. elif scope.name == ScopeE.EMAIL.value:
  290. # Use generated email
  291. if self.gen_email_id:
  292. LOG.debug(
  293. "Use gen email for user %s, client %s", self.user, self.client
  294. )
  295. res[ScopeE.EMAIL.value] = self.gen_email.email
  296. # Use user original email
  297. else:
  298. res[ScopeE.EMAIL.value] = self.user.email
  299. return res