jose_utils.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import base64
  2. import hashlib
  3. from typing import Optional
  4. import arrow
  5. from jwcrypto import jwk, jwt
  6. from app.config import OPENID_PRIVATE_KEY_PATH, URL
  7. from app.log import LOG
  8. from app.models import ClientUser
  9. with open(OPENID_PRIVATE_KEY_PATH, "rb") as f:
  10. _key = jwk.JWK.from_pem(f.read())
  11. def get_jwk_key() -> dict:
  12. return _key._public_params()
  13. def make_id_token(
  14. client_user: ClientUser,
  15. nonce: Optional[str] = None,
  16. access_token: Optional[str] = None,
  17. code: Optional[str] = None,
  18. ):
  19. """Make id_token for OpenID Connect
  20. According to RFC 7519, these claims are mandatory:
  21. - iss
  22. - sub
  23. - aud
  24. - exp
  25. - iat
  26. """
  27. claims = {
  28. "iss": URL,
  29. "sub": str(client_user.id),
  30. "aud": client_user.client.oauth_client_id,
  31. "exp": arrow.now().shift(hours=1).timestamp,
  32. "iat": arrow.now().timestamp,
  33. "auth_time": arrow.now().timestamp,
  34. }
  35. if nonce:
  36. claims["nonce"] = nonce
  37. if access_token:
  38. claims["at_hash"] = id_token_hash(access_token)
  39. if code:
  40. claims["c_hash"] = id_token_hash(code)
  41. claims = {**claims, **client_user.get_user_info()}
  42. jwt_token = jwt.JWT(
  43. header={"alg": "RS256", "kid": _key._public_params()["kid"]}, claims=claims
  44. )
  45. jwt_token.make_signed_token(_key)
  46. return jwt_token.serialize()
  47. def verify_id_token(id_token) -> bool:
  48. try:
  49. jwt.JWT(key=_key, jwt=id_token)
  50. except Exception:
  51. LOG.exception("id token not verified")
  52. return False
  53. else:
  54. return True
  55. def decode_id_token(id_token) -> jwt.JWT:
  56. return jwt.JWT(key=_key, jwt=id_token)
  57. def id_token_hash(value, hashfunc=hashlib.sha256):
  58. """
  59. Inspired from oauthlib
  60. """
  61. digest = hashfunc(value.encode()).digest()
  62. left_most = len(digest) // 2
  63. return base64.urlsafe_b64encode(digest[:left_most]).decode().rstrip("=")