Coverage for zombie_nomnom_api/rest_app/authentication.py: 100%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-07 04:25 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-07 04:25 +0000
1from functools import cache
2from fastapi.security import HTTPBearer
3import jwt
4from jwt.exceptions import PyJWKClientError, DecodeError
5from zombie_nomnom_api import configs
8token_auth_scheme = HTTPBearer()
11def create_error_json(message: str) -> dict[str, str]:
12 return {"status": "error", "message": message}
15class VerifyToken:
16 def __init__(self, configs=configs):
17 """
18 Args:
19 configs (zombie_nomnom_api.Configs): The Configs that hold the domain for the oauth server.
21 Attributes:
22 config (zombie_nomnom_api.Configs): The Configs that hold the domain for the oauth server.
23 jwks_client (jwt.PyJWKClient): A jwt client that is used to verify the tokens.
24 """
25 self.config = configs
27 jwks_url = f"https://{self.config.oauth_domain}/.well-known/jwks.json"
28 self.jwks_client = jwt.PyJWKClient(jwks_url)
30 def verify(self, token: str, permissions: list = None, scopes: list | str = None):
31 # This gets the 'kid' from the passed token
32 try:
33 signing_key = self.jwks_client.get_signing_key_from_jwt(token).key
34 except PyJWKClientError as error:
35 return create_error_json(str(error))
36 except DecodeError as error:
37 return create_error_json(str(error))
39 try:
40 payload = jwt.decode(
41 token,
42 signing_key,
43 algorithms=self.config.oauth_algorithms,
44 audience=self.config.oauth_audience,
45 issuer=self.config.oauth_issuer,
46 )
47 except Exception as e:
48 return {"status": "error", "message": str(e)}
50 if scopes:
51 result = self._check_claims(
52 payload,
53 "scope",
54 str,
55 scopes if isinstance(scopes, list) else scopes.split(" "),
56 )
57 if result.get("status") == "error":
58 return result
60 if permissions:
61 result = self._check_claims(payload, "permissions", list, permissions)
62 if result.get("status") == "error":
63 return result
65 return payload
67 def _check_claims(
68 self, payload: dict, claim_name: str, claim_type: type, expected_value: list
69 ):
71 payload_claim = payload.get(claim_name)
72 if payload_claim is None or not isinstance(payload[claim_name], claim_type):
73 return create_error_json(
74 f"User does not have the required '{claim_name}' claim."
75 )
76 result = {"status": "success", "status_code": 200}
78 if claim_name == "scope":
79 payload_claim = payload_claim.split(" ")
81 for value in expected_value:
82 if value not in payload_claim:
83 return create_error_json(
84 f"User does not have the required '{claim_name}' claim."
85 )
86 return result
89@cache
90def get_verifier() -> VerifyToken:
91 return VerifyToken()