Coverage for zombie_nomnom_api/rest_app/authentication.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-07 04:25 +0000

1from functools import cache 

2from fastapi.security import HTTPBearer 

3import jwt 

4from jwt.exceptions import PyJWKClientError, DecodeError 

5from zombie_nomnom_api import configs 

6 

7 

8token_auth_scheme = HTTPBearer() 

9 

10 

11def create_error_json(message: str) -> dict[str, str]: 

12 return {"status": "error", "message": message} 

13 

14 

15class VerifyToken: 

16 def __init__(self, configs=configs): 

17 """ 

18 Args: 

19 configs (zombie_nomnom_api.Configs): The Configs that hold the domain for the oauth server. 

20 

21 Attributes: 

22 config (zombie_nomnom_api.Configs): The Configs that hold the domain for the oauth server. 

23 jwks_client (jwt.PyJWKClient): A jwt client that is used to verify the tokens. 

24 """ 

25 self.config = configs 

26 

27 jwks_url = f"https://{self.config.oauth_domain}/.well-known/jwks.json" 

28 self.jwks_client = jwt.PyJWKClient(jwks_url) 

29 

30 def verify(self, token: str, permissions: list = None, scopes: list | str = None): 

31 # This gets the 'kid' from the passed token 

32 try: 

33 signing_key = self.jwks_client.get_signing_key_from_jwt(token).key 

34 except PyJWKClientError as error: 

35 return create_error_json(str(error)) 

36 except DecodeError as error: 

37 return create_error_json(str(error)) 

38 

39 try: 

40 payload = jwt.decode( 

41 token, 

42 signing_key, 

43 algorithms=self.config.oauth_algorithms, 

44 audience=self.config.oauth_audience, 

45 issuer=self.config.oauth_issuer, 

46 ) 

47 except Exception as e: 

48 return {"status": "error", "message": str(e)} 

49 

50 if scopes: 

51 result = self._check_claims( 

52 payload, 

53 "scope", 

54 str, 

55 scopes if isinstance(scopes, list) else scopes.split(" "), 

56 ) 

57 if result.get("status") == "error": 

58 return result 

59 

60 if permissions: 

61 result = self._check_claims(payload, "permissions", list, permissions) 

62 if result.get("status") == "error": 

63 return result 

64 

65 return payload 

66 

67 def _check_claims( 

68 self, payload: dict, claim_name: str, claim_type: type, expected_value: list 

69 ): 

70 

71 payload_claim = payload.get(claim_name) 

72 if payload_claim is None or not isinstance(payload[claim_name], claim_type): 

73 return create_error_json( 

74 f"User does not have the required '{claim_name}' claim." 

75 ) 

76 result = {"status": "success", "status_code": 200} 

77 

78 if claim_name == "scope": 

79 payload_claim = payload_claim.split(" ") 

80 

81 for value in expected_value: 

82 if value not in payload_claim: 

83 return create_error_json( 

84 f"User does not have the required '{claim_name}' claim." 

85 ) 

86 return result 

87 

88 

89@cache 

90def get_verifier() -> VerifyToken: 

91 return VerifyToken()