Source code for synapse.lib.oauth

import os
import copy
import heapq
import asyncio
import logging

import aiohttp

import synapse.exc as s_exc
import synapse.common as s_common

import synapse.lib.coro as s_coro
import synapse.lib.nexus as s_nexus
import synapse.lib.schemas as s_schemas
import synapse.lib.lmdbslab as s_lmdbslab

logger = logging.getLogger(__name__)

KEY_LEN = 32          # length of a provider/user iden in a key
REFRESH_WINDOW = 0.5  # refresh in REFRESH_WINDOW * expires_in
DEFAULT_TIMEOUT = 10  # secs

[docs] def normOAuthTokenData(issued_at, data): ''' Normalize timestamps to be in epoch millis and set expires_at/refresh_at. ''' s_schemas.reqValidOauth2TokenResponse(data) expires_in = data['expires_in'] return { 'access_token': data['access_token'], 'expires_in': expires_in, 'expires_at': issued_at + expires_in * 1000, 'refresh_at': issued_at + (expires_in * REFRESH_WINDOW) * 1000, 'refresh_token': data.get('refresh_token'), }
az_tfile_envar = 'AZURE_FEDERATED_TOKEN_FILE' def _getAzureTokenFile() -> tuple[bool, str]: fp = os.getenv(az_tfile_envar, None) if fp is None: return False, f'{az_tfile_envar} environment variable is not set.' if os.path.exists(fp): with open(fp, 'r') as fd: assertion = fd.read() return True, assertion else: return False, f'{az_tfile_envar} file does not exist {fp}' az_clientid_envar = 'AZURE_CLIENT_ID' def _getAzureClientId() -> tuple[bool, str]: valu = os.getenv(az_clientid_envar, None) if valu is None: return False, f'{az_clientid_envar} environment variable is not set.' if valu: return True, valu else: return False, f'{az_clientid_envar} is set to an empty string.'
[docs] class OAuthMixin(s_nexus.Pusher): ''' Mixin for Cells to organize and execute OAuth token refreshes. ''' async def _initOAuthManager(self): slab = self.slab self._oauth_clients = s_lmdbslab.SlabDict(slab, db=slab.initdb('oauth:v2:clients')) # key=<provider><user> self._oauth_providers = s_lmdbslab.SlabDict(slab, db=slab.initdb('oauth:v2:providers')) # key=<provider> self._oauth_sched_map = {} self._oauth_sched_heap = [] self._oauth_sched_wake = asyncio.Event() self.onfini(self._oauth_sched_wake.set) self._oauth_actviden = self.addActiveCoro(self._runOAuthRefreshLoop) # For testing self._oauth_sched_ran = asyncio.Event() self._oauth_sched_empty = asyncio.Event() async def _runOAuthRefreshLoop(self): self._oauth_sched_map.clear() self._oauth_sched_heap.clear() self._oauth_sched_wake.clear() for provideriden, useriden, clientconf in self.listOAuthClients(): self._scheduleOAuthItem(provideriden, useriden, clientconf) await self._oauthRefreshLoop() def _scheduleOAuthItem(self, provideriden, useriden, clientconf): if not self.isactive: return if clientconf.get('error'): return if not clientconf.get('refresh_token'): logger.warning(f'OAuth V2 client missing token to schedule refresh provider={provideriden} user={useriden}') return refresh_at = clientconf['refresh_at'] newitem = (refresh_at, provideriden, useriden) old_refresh_at = self._oauth_sched_map.get(newitem[1:]) if old_refresh_at == refresh_at: return if old_refresh_at is not None: # there's an old item for this client in the refresh queue to remove self._oauth_sched_heap.remove((old_refresh_at, *newitem[1:])) heapq.heapify(self._oauth_sched_heap) self._oauth_sched_map[newitem[1:]] = refresh_at heapq.heappush(self._oauth_sched_heap, newitem) if self._oauth_sched_heap[0] == newitem: # the new item is at the front of the line so wake up the loop if its waiting self._oauth_sched_wake.set() async def _oauthRefreshLoop(self): while not self.isfini: while self._oauth_sched_heap: refresh_at, provideriden, useriden = self._oauth_sched_heap[0] refresh_in = int(max(0, refresh_at - s_common.now()) / 1000) if await s_coro.event_wait(self._oauth_sched_wake, timeout=refresh_in): self._oauth_sched_wake.clear() continue if self.isfini: # pragma: no cover break _, provideriden, useriden = heapq.heappop(self._oauth_sched_heap) self._oauth_sched_map.pop((provideriden, useriden), None) logger.debug(f'Refreshing OAuth V2 token for provider={provideriden} user={useriden}') providerconf = self._getOAuthProvider(provideriden) if providerconf is None: logger.debug(f'OAuth V2 provider does not exist for provider={provideriden}') continue user = self.auth.user(useriden) if user is None: await self._setOAuthTokenData(provideriden, useriden, {'error': 'User does not exist'}) continue if user.isLocked(): await self._setOAuthTokenData(provideriden, useriden, {'error': 'User is locked'}) continue clientconf = self._oauth_clients.get(provideriden + useriden) if clientconf is None: logger.debug(f'OAuth V2 client does not exist for provider={provideriden} user={useriden}') continue ok, data = await self._refreshOAuthAccessToken(providerconf, clientconf, useriden) if not ok: logger.warning(f'OAuth V2 token refresh failed provider={provideriden} user={useriden} data={data}') await self._setOAuthTokenData(provideriden, useriden, data) self._oauth_sched_ran.set() self._oauth_sched_empty.set() await s_coro.event_wait(self._oauth_sched_wake) self._oauth_sched_wake.clear() self._oauth_sched_ran.clear() async def _getOAuthAccessToken(self, providerconf, useriden, authcode, code_verifier=None): ok, data = await self._getAuthData(providerconf, useriden) if not ok: return ok, data token_uri = providerconf['token_uri'] ssl_verify = providerconf['ssl_verify'] auth, formdata = self._unpackAuthData(data) formdata.add_field('grant_type', 'authorization_code') formdata.add_field('scope', providerconf['scope']) formdata.add_field('redirect_uri', providerconf['redirect_uri']) formdata.add_field('code', authcode) if code_verifier is not None: formdata.add_field('code_verifier', code_verifier) return await self._fetchOAuthToken(token_uri, auth, formdata, ssl_verify=ssl_verify) async def _refreshOAuthAccessToken(self, providerconf, clientconf, useriden): ok, data = await self._getAuthData(providerconf, useriden) if not ok: return ok, data token_uri = providerconf['token_uri'] ssl_verify = providerconf['ssl_verify'] refresh_token = clientconf['refresh_token'] auth, formdata = self._unpackAuthData(data) formdata.add_field('grant_type', 'refresh_token') formdata.add_field('refresh_token', refresh_token) ok, data = await self._fetchOAuthToken(token_uri, auth, formdata, ssl_verify=ssl_verify, retries=3) if ok and not data.get('refresh_token'): # if a refresh_token is not provided in the response persist the existing token data['refresh_token'] = refresh_token return ok, data async def _getAuthData(self, providerconf, useriden): isok = False ret = {} auth_scheme = providerconf['auth_scheme'] if auth_scheme == 'basic': ret['auth'] = {'login': providerconf['client_id'], 'password': providerconf['client_secret']} ret['formdata'] = {} isok = True elif auth_scheme == 'client_assertion': assertion = None client_id = providerconf.get('client_id', None) client_assertion = providerconf['client_assertion'] if (info := client_assertion.get('cortex:callstorm')): opts = { 'view': info['view'], 'vars': info.get('vars', {}), 'user': useriden, } try: ok, info = await self.callStorm(info['query'], opts=opts) except Exception as e: isok = False ret['error'] = f'Error executing callStorm: {e}' else: if not ok: return ok, info assertion = info.get('token') elif (info := client_assertion.get('msft:azure:workloadidentity')): ok, valu = _getAzureTokenFile() if not ok: return ok, {'error': valu} assertion = valu if info.get('client_id'): ok, valu = _getAzureClientId() if not ok: return ok, {'error': valu} client_id = valu else: isok = False ret['error'] = f'Unknown client_assertions data: {client_assertion}' if assertion: formdata = { 'client_id': client_id, 'client_assertion': assertion, 'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer', } ret['formdata'] = formdata isok = True else: isok = False ret['error'] = f'Unknown authorization scheme: {auth_scheme}' return isok, ret @staticmethod def _unpackAuthData(data: dict) -> tuple[aiohttp.BasicAuth | None, aiohttp.FormData]: auth = data.get('auth', None) # type: dict | None if auth: auth = aiohttp.BasicAuth(auth.get('login'), password=auth.get('password')) formdata = aiohttp.FormData() for k, v in data.get('formdata', {}).items(): formdata.add_field(k, v) return auth, formdata async def _fetchOAuthToken(self, url, auth, formdata, ssl_verify=True, retries=1): headers = { 'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json', } attempts = 0 issued_at = s_common.now() timeout = aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT) ssl = self.getCachedSslCtx(verify=ssl_verify) async with aiohttp.ClientSession(timeout=timeout) as sess: while True: attempts += 1 try: async with sess.post(url, auth=auth, headers=headers, data=formdata, ssl=ssl) as resp: if resp.status == 200: data = await resp.json() return True, normOAuthTokenData(issued_at, data) if resp.status < 500: data = await resp.json() errmesg = data.get('error', 'unknown error') if 'error_description' in data: errmesg += f': {data["error_description"]}' return False, {'error': errmesg + f' (HTTP code {resp.status})'} retn = False, {'error': f'Token API returned HTTP code {resp.status}'} except asyncio.TimeoutError: retn = False, {'error': f'Token API request timed out'} except Exception as e: logger.exception(f'Error fetching token data from {url}') return False, {'error': str(e)} if attempts <= retries: await self.waitfini(2 ** (attempts - 1)) continue return retn
[docs] async def addOAuthProvider(self, conf): conf = s_schemas.reqValidOauth2Provider(conf) iden = conf['iden'] if self._getOAuthProvider(iden) is not None: raise s_exc.DupIden(mesg=f'Duplicate OAuth V2 client iden ({iden})', iden=iden) # N.B. The schema ensures that the possible values in the conf are valid # when they are provided. Since writing multi-path schemas in draft07 is # overly complicated, some of the mutual exclusion values and logical # "is this meaningful?" type checks are made here before pushing the # nexus event to create the provider. client_secret = conf.get('client_secret') client_assertion = conf.get('client_assertion', {}) if client_assertion and client_secret: mesg = 'client_assertion and client_secret provided. These are mutually exclusive options.' raise s_exc.BadArg(mesg=mesg) if not client_assertion and not client_secret: mesg = 'client_assertion and client_secret missing. These are mutually exclusive options and one must be provided.' raise s_exc.BadArg(mesg=mesg) auth_scheme = conf.get('auth_scheme') client_id = conf.get('client_id') if auth_scheme == 'basic': if not client_id: raise s_exc.BadArg(mesg='Must provide client_id for auth_scheme=basic') if not client_secret: raise s_exc.BadArg(mesg='Must provide client_secret for auth_scheme=basic') elif auth_scheme == 'client_assertion': if (info := client_assertion.get('cortex:callstorm')) is not None: if not hasattr(self, 'callStorm'): mesg = f'cortex:callstorm client assertion not supported by {self.__class__.__name__}' raise s_exc.BadArg(mesg=mesg) if not client_id: raise s_exc.BadArg(mesg='Must provide client_id for with cortex:callstorm provider.') text = info['query'] # Validate the query text try: await self.reqValidStorm(text) except s_exc.BadSyntax as e: raise s_exc.BadArg(mesg=f'Bad storm query: {e.get("mesg")}') from None view = self.getView(info['view']) if view is None: raise s_exc.BadArg(mesg=f'View {info["view"]} does not exist.') elif (info := client_assertion.get('msft:azure:workloadidentity')) is not None: if not info.get('token'): raise s_exc.BadArg(mesg='msft:azure:workloadidentity token key must be true') ok, tknkvalu = _getAzureTokenFile() if not ok: raise s_exc.BadArg(mesg=f'Failed to get the client_assertion data: {tknkvalu}') if info.get('client_id'): if client_id: raise s_exc.BadArg(mesg='Cannot specify a fixed client_id and a dynamic client_id value.') ok, idvalu = _getAzureClientId() if not ok: raise s_exc.BadArg(mesg=f'Failed to get the client_id data: {idvalu}') else: # pragma: no cover raise s_exc.BadArg(mesg=f'Unknown auth_scheme={auth_scheme}') await self._push('oauth:provider:add', conf)
@s_nexus.Pusher.onPush('oauth:provider:add') async def _addOAuthProvider(self, conf): iden = conf['iden'] if self._getOAuthProvider(iden) is None: self._oauth_providers.set(iden, conf) def _getOAuthProvider(self, iden): conf = self._oauth_providers.get(iden) if conf is not None: return copy.deepcopy(conf)
[docs] async def getOAuthProvider(self, iden): conf = self._getOAuthProvider(iden) if conf is not None: conf.pop('client_secret', None) return conf
[docs] async def listOAuthProviders(self): return [(iden, await self.getOAuthProvider(iden)) for iden in self._oauth_providers.keys()]
[docs] async def delOAuthProvider(self, iden): if self._getOAuthProvider(iden) is not None: return await self._push('oauth:provider:del', iden)
@s_nexus.Pusher.onPush('oauth:provider:del') async def _delOAuthProvider(self, iden): for clientiden in list(self._oauth_clients.keys()): if clientiden.startswith(iden): self._oauth_clients.pop(clientiden) conf = self._oauth_providers.pop(iden) if conf is not None: conf.pop('client_secret', None) return conf
[docs] async def getOAuthClient(self, provideriden, useriden): conf = self._oauth_clients.get(provideriden + useriden) if conf is not None: return copy.deepcopy(conf) return None
[docs] def listOAuthClients(self): ''' Returns: list: List of (provideriden, useriden, conf) for each client. ''' return [(iden[:KEY_LEN], iden[KEY_LEN:], copy.deepcopy(conf)) for iden, conf in self._oauth_clients.items()]
[docs] async def getOAuthAccessToken(self, provideriden, useriden): if self._getOAuthProvider(provideriden) is None: raise s_exc.BadArg(mesg=f'OAuth V2 provider has not been configured ({provideriden})', iden=provideriden) clientconf = await self.getOAuthClient(provideriden, useriden) if clientconf is None: return False, 'Auth code has not been set' # if the client has an error return None so caller can start oauth flow again err = clientconf.get('error') if err is not None: logger.debug(f'OAuth V2 client token unavailable provider={provideriden} user={useriden} err={err}') return False, err # never return an expired token expires_at = clientconf['expires_at'] if expires_at < s_common.now(): logger.debug(f'OAuth V2 token is expired ({expires_at}) for provider={provideriden} user={useriden}') return False, 'Token is expired' return True, clientconf.get('access_token')
[docs] async def clearOAuthAccessToken(self, provideriden, useriden): ''' Remove a client access token by clearing the configuration. This will prevent further refreshes (if scheduled), and a new auth code will be required the next time an access token is requested. ''' if self._oauth_clients.get(provideriden + useriden) is not None: return await self._push('oauth:client:data:clear', provideriden, useriden)
@s_nexus.Pusher.onPush('oauth:client:data:clear') async def _clearOAuthAccessToken(self, provideriden, useriden): return self._oauth_clients.pop(provideriden + useriden)
[docs] async def setOAuthAuthCode(self, provideriden, useriden, authcode, code_verifier=None): ''' Typically set as the end result of a successful OAuth flow. An initial access token and refresh token will be immediately requested, and the client will be loaded into the schedule to be background refreshed. ''' providerconf = self._getOAuthProvider(provideriden) if providerconf is None: raise s_exc.BadArg(mesg=f'OAuth V2 provider has not been configured ({provideriden})', iden=provideriden) await self.clearOAuthAccessToken(provideriden, useriden) ok, data = await self._getOAuthAccessToken(providerconf, useriden, authcode, code_verifier=code_verifier) if not ok: raise s_exc.SynErr(mesg=f'Failed to get OAuth v2 token: {data["error"]}') await self._setOAuthTokenData(provideriden, useriden, data)
@s_nexus.Pusher.onPushAuto('oauth:client:data:set') async def _setOAuthTokenData(self, provideriden, useriden, data): iden = provideriden + useriden self._oauth_clients.set(iden, data) self._scheduleOAuthItem(provideriden, useriden, data)