Browse Source

add get_response_types_from_str, response_types_to_str

Son NK 6 years ago
parent
commit
b2d4ac8e65
2 changed files with 40 additions and 2 deletions
  1. 14 1
      app/oauth_models.py
  2. 26 1
      tests/test_oauth_models.py

+ 14 - 1
app/oauth_models.py

@@ -26,7 +26,20 @@ def get_scopes(request: flask.Request) -> Set[Scope]:
 def get_response_types(request: flask.Request) -> Set[ResponseType]:
     response_type_strs = _split_arg(request.args.getlist("response_type"))
 
-    return set([ResponseType(r) for r in response_type_strs])
+    return set([ResponseType(r) for r in response_type_strs if r])
+
+
+def get_response_types_from_str(response_type_str) -> Set[ResponseType]:
+    response_type_strs = _split_arg(response_type_str)
+
+    return set([ResponseType(r) for r in response_type_strs if r])
+
+
+def response_types_to_str(response_types: [ResponseType]) -> str:
+    """return a string representing a list of response type, for ex
+    *code*, *id_token,token*,...
+    """
+    return ",".join([r.value for r in response_types])
 
 
 def _split_arg(arg_input: Union[str, list]) -> Set[str]:

+ 26 - 1
tests/test_oauth_models.py

@@ -1,7 +1,14 @@
 import flask
 import pytest
 
-from app.oauth_models import get_scopes, Scope, get_response_types, ResponseType
+from app.oauth_models import (
+    get_scopes,
+    Scope,
+    get_response_types,
+    ResponseType,
+    response_types_to_str,
+    get_response_types_from_str,
+)
 
 
 def test_get_scopes(flask_app):
@@ -52,3 +59,21 @@ def test_get_response_types(flask_app):
     with flask_app.test_request_context("/?response_type=abcd"):
         with pytest.raises(ValueError):
             get_response_types(flask.request)
+
+
+def test_response_types_to_str():
+    assert response_types_to_str([]) == ""
+    assert response_types_to_str([ResponseType.CODE]) == "code"
+    assert (
+        response_types_to_str([ResponseType.CODE, ResponseType.ID_TOKEN])
+        == "code,id_token"
+    )
+
+
+def test_get_response_types_from_str():
+    assert get_response_types_from_str("") == set()
+    assert get_response_types_from_str("token") == {ResponseType.TOKEN}
+    assert get_response_types_from_str("token id_token") == {
+        ResponseType.TOKEN,
+        ResponseType.ID_TOKEN,
+    }