import os
import copy
import heapq
import asyncio
import logging
import aiohttp
import synapse.exc as s_exc
import synapse.common as s_common
import synapse.lib.coro as s_coro
import synapse.lib.nexus as s_nexus
import synapse.lib.schemas as s_schemas
import synapse.lib.lmdbslab as s_lmdbslab
logger = logging.getLogger(__name__)
KEY_LEN = 32 # length of a provider/user iden in a key
REFRESH_WINDOW = 0.5 # refresh in REFRESH_WINDOW * expires_in
DEFAULT_TIMEOUT = 10 # secs
[docs]
def normOAuthTokenData(issued_at, data):
'''
Normalize timestamps to be in epoch millis and set expires_at/refresh_at.
'''
s_schemas.reqValidOauth2TokenResponse(data)
expires_in = data['expires_in']
return {
'access_token': data['access_token'],
'expires_in': expires_in,
'expires_at': issued_at + expires_in * 1000,
'refresh_at': issued_at + (expires_in * REFRESH_WINDOW) * 1000,
'refresh_token': data.get('refresh_token'),
}
az_tfile_envar = 'AZURE_FEDERATED_TOKEN_FILE'
def _getAzureTokenFile() -> tuple[bool, str]:
fp = os.getenv(az_tfile_envar, None)
if fp is None:
return False, f'{az_tfile_envar} environment variable is not set.'
if os.path.exists(fp):
with open(fp, 'r') as fd:
assertion = fd.read()
return True, assertion
else:
return False, f'{az_tfile_envar} file does not exist {fp}'
az_clientid_envar = 'AZURE_CLIENT_ID'
def _getAzureClientId() -> tuple[bool, str]:
valu = os.getenv(az_clientid_envar, None)
if valu is None:
return False, f'{az_clientid_envar} environment variable is not set.'
if valu:
return True, valu
else:
return False, f'{az_clientid_envar} is set to an empty string.'
[docs]
class OAuthMixin(s_nexus.Pusher):
'''
Mixin for Cells to organize and execute OAuth token refreshes.
'''
async def _initOAuthManager(self):
slab = self.slab
self._oauth_clients = s_lmdbslab.SlabDict(slab, db=slab.initdb('oauth:v2:clients')) # key=<provider><user>
self._oauth_providers = s_lmdbslab.SlabDict(slab, db=slab.initdb('oauth:v2:providers')) # key=<provider>
self._oauth_sched_map = {}
self._oauth_sched_heap = []
self._oauth_sched_wake = asyncio.Event()
self.onfini(self._oauth_sched_wake.set)
self._oauth_actviden = self.addActiveCoro(self._runOAuthRefreshLoop)
# For testing
self._oauth_sched_ran = asyncio.Event()
self._oauth_sched_empty = asyncio.Event()
async def _runOAuthRefreshLoop(self):
self._oauth_sched_map.clear()
self._oauth_sched_heap.clear()
self._oauth_sched_wake.clear()
for provideriden, useriden, clientconf in self.listOAuthClients():
self._scheduleOAuthItem(provideriden, useriden, clientconf)
await self._oauthRefreshLoop()
def _scheduleOAuthItem(self, provideriden, useriden, clientconf):
if not self.isactive:
return
if clientconf.get('error'):
return
if not clientconf.get('refresh_token'):
logger.warning(f'OAuth V2 client missing token to schedule refresh provider={provideriden} user={useriden}')
return
refresh_at = clientconf['refresh_at']
newitem = (refresh_at, provideriden, useriden)
old_refresh_at = self._oauth_sched_map.get(newitem[1:])
if old_refresh_at == refresh_at:
return
if old_refresh_at is not None:
# there's an old item for this client in the refresh queue to remove
self._oauth_sched_heap.remove((old_refresh_at, *newitem[1:]))
heapq.heapify(self._oauth_sched_heap)
self._oauth_sched_map[newitem[1:]] = refresh_at
heapq.heappush(self._oauth_sched_heap, newitem)
if self._oauth_sched_heap[0] == newitem:
# the new item is at the front of the line so wake up the loop if its waiting
self._oauth_sched_wake.set()
async def _oauthRefreshLoop(self):
while not self.isfini:
while self._oauth_sched_heap:
refresh_at, provideriden, useriden = self._oauth_sched_heap[0]
refresh_in = int(max(0, refresh_at - s_common.now()) / 1000)
if await s_coro.event_wait(self._oauth_sched_wake, timeout=refresh_in):
self._oauth_sched_wake.clear()
continue
if self.isfini: # pragma: no cover
break
_, provideriden, useriden = heapq.heappop(self._oauth_sched_heap)
self._oauth_sched_map.pop((provideriden, useriden), None)
logger.debug(f'Refreshing OAuth V2 token for provider={provideriden} user={useriden}')
providerconf = self._getOAuthProvider(provideriden)
if providerconf is None:
logger.debug(f'OAuth V2 provider does not exist for provider={provideriden}')
continue
user = self.auth.user(useriden)
if user is None:
await self._setOAuthTokenData(provideriden, useriden, {'error': 'User does not exist'})
continue
if user.isLocked():
await self._setOAuthTokenData(provideriden, useriden, {'error': 'User is locked'})
continue
clientconf = self._oauth_clients.get(provideriden + useriden)
if clientconf is None:
logger.debug(f'OAuth V2 client does not exist for provider={provideriden} user={useriden}')
continue
ok, data = await self._refreshOAuthAccessToken(providerconf, clientconf, useriden)
if not ok:
logger.warning(f'OAuth V2 token refresh failed provider={provideriden} user={useriden} data={data}')
await self._setOAuthTokenData(provideriden, useriden, data)
self._oauth_sched_ran.set()
self._oauth_sched_empty.set()
await s_coro.event_wait(self._oauth_sched_wake)
self._oauth_sched_wake.clear()
self._oauth_sched_ran.clear()
async def _getOAuthAccessToken(self, providerconf, useriden, authcode, code_verifier=None):
ok, data = await self._getAuthData(providerconf, useriden)
if not ok:
return ok, data
token_uri = providerconf['token_uri']
ssl_verify = providerconf['ssl_verify']
auth, formdata = self._unpackAuthData(data)
formdata.add_field('grant_type', 'authorization_code')
formdata.add_field('scope', providerconf['scope'])
formdata.add_field('redirect_uri', providerconf['redirect_uri'])
formdata.add_field('code', authcode)
if code_verifier is not None:
formdata.add_field('code_verifier', code_verifier)
return await self._fetchOAuthToken(token_uri, auth, formdata, ssl_verify=ssl_verify)
async def _refreshOAuthAccessToken(self, providerconf, clientconf, useriden):
ok, data = await self._getAuthData(providerconf, useriden)
if not ok:
return ok, data
token_uri = providerconf['token_uri']
ssl_verify = providerconf['ssl_verify']
refresh_token = clientconf['refresh_token']
auth, formdata = self._unpackAuthData(data)
formdata.add_field('grant_type', 'refresh_token')
formdata.add_field('refresh_token', refresh_token)
ok, data = await self._fetchOAuthToken(token_uri, auth, formdata, ssl_verify=ssl_verify, retries=3)
if ok and not data.get('refresh_token'):
# if a refresh_token is not provided in the response persist the existing token
data['refresh_token'] = refresh_token
return ok, data
async def _getAuthData(self, providerconf, useriden):
isok = False
ret = {}
auth_scheme = providerconf['auth_scheme']
if auth_scheme == 'basic':
ret['auth'] = {'login': providerconf['client_id'], 'password': providerconf['client_secret']}
ret['formdata'] = {}
isok = True
elif auth_scheme == 'client_assertion':
assertion = None
client_id = providerconf.get('client_id', None)
client_assertion = providerconf['client_assertion']
if (info := client_assertion.get('cortex:callstorm')):
opts = {
'view': info['view'],
'vars': info.get('vars', {}),
'user': useriden,
}
try:
ok, info = await self.callStorm(info['query'], opts=opts)
except Exception as e:
isok = False
ret['error'] = f'Error executing callStorm: {e}'
else:
if not ok:
return ok, info
assertion = info.get('token')
elif (info := client_assertion.get('msft:azure:workloadidentity')):
ok, valu = _getAzureTokenFile()
if not ok:
return ok, {'error': valu}
assertion = valu
if info.get('client_id'):
ok, valu = _getAzureClientId()
if not ok:
return ok, {'error': valu}
client_id = valu
else:
isok = False
ret['error'] = f'Unknown client_assertions data: {client_assertion}'
if assertion:
formdata = {
'client_id': client_id,
'client_assertion': assertion,
'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
}
ret['formdata'] = formdata
isok = True
else:
isok = False
ret['error'] = f'Unknown authorization scheme: {auth_scheme}'
return isok, ret
@staticmethod
def _unpackAuthData(data: dict) -> tuple[aiohttp.BasicAuth | None, aiohttp.FormData]:
auth = data.get('auth', None) # type: dict | None
if auth:
auth = aiohttp.BasicAuth(auth.get('login'), password=auth.get('password'))
formdata = aiohttp.FormData()
for k, v in data.get('formdata', {}).items():
formdata.add_field(k, v)
return auth, formdata
async def _fetchOAuthToken(self, url, auth, formdata, ssl_verify=True, retries=1):
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json',
}
attempts = 0
issued_at = s_common.now()
timeout = aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT)
ssl = self.getCachedSslCtx(verify=ssl_verify)
async with aiohttp.ClientSession(timeout=timeout) as sess:
while True:
attempts += 1
try:
async with sess.post(url, auth=auth, headers=headers, data=formdata, ssl=ssl) as resp:
if resp.status == 200:
data = await resp.json()
return True, normOAuthTokenData(issued_at, data)
if resp.status < 500:
data = await resp.json()
errmesg = data.get('error', 'unknown error')
if 'error_description' in data:
errmesg += f': {data["error_description"]}'
return False, {'error': errmesg + f' (HTTP code {resp.status})'}
retn = False, {'error': f'Token API returned HTTP code {resp.status}'}
except asyncio.TimeoutError:
retn = False, {'error': f'Token API request timed out'}
except Exception as e:
logger.exception(f'Error fetching token data from {url}')
return False, {'error': str(e)}
if attempts <= retries:
await self.waitfini(2 ** (attempts - 1))
continue
return retn
[docs]
async def addOAuthProvider(self, conf):
conf = s_schemas.reqValidOauth2Provider(conf)
iden = conf['iden']
if self._getOAuthProvider(iden) is not None:
raise s_exc.DupIden(mesg=f'Duplicate OAuth V2 client iden ({iden})', iden=iden)
# N.B. The schema ensures that the possible values in the conf are valid
# when they are provided. Since writing multi-path schemas in draft07 is
# overly complicated, some of the mutual exclusion values and logical
# "is this meaningful?" type checks are made here before pushing the
# nexus event to create the provider.
client_secret = conf.get('client_secret')
client_assertion = conf.get('client_assertion', {})
if client_assertion and client_secret:
mesg = 'client_assertion and client_secret provided. These are mutually exclusive options.'
raise s_exc.BadArg(mesg=mesg)
if not client_assertion and not client_secret:
mesg = 'client_assertion and client_secret missing. These are mutually exclusive options and one must be provided.'
raise s_exc.BadArg(mesg=mesg)
auth_scheme = conf.get('auth_scheme')
client_id = conf.get('client_id')
if auth_scheme == 'basic':
if not client_id:
raise s_exc.BadArg(mesg='Must provide client_id for auth_scheme=basic')
if not client_secret:
raise s_exc.BadArg(mesg='Must provide client_secret for auth_scheme=basic')
elif auth_scheme == 'client_assertion':
if (info := client_assertion.get('cortex:callstorm')) is not None:
if not hasattr(self, 'callStorm'):
mesg = f'cortex:callstorm client assertion not supported by {self.__class__.__name__}'
raise s_exc.BadArg(mesg=mesg)
if not client_id:
raise s_exc.BadArg(mesg='Must provide client_id for with cortex:callstorm provider.')
text = info['query']
# Validate the query text
try:
await self.reqValidStorm(text)
except s_exc.BadSyntax as e:
raise s_exc.BadArg(mesg=f'Bad storm query: {e.get("mesg")}') from None
view = self.getView(info['view'])
if view is None:
raise s_exc.BadArg(mesg=f'View {info["view"]} does not exist.')
elif (info := client_assertion.get('msft:azure:workloadidentity')) is not None:
if not info.get('token'):
raise s_exc.BadArg(mesg='msft:azure:workloadidentity token key must be true')
ok, tknkvalu = _getAzureTokenFile()
if not ok:
raise s_exc.BadArg(mesg=f'Failed to get the client_assertion data: {tknkvalu}')
if info.get('client_id'):
if client_id:
raise s_exc.BadArg(mesg='Cannot specify a fixed client_id and a dynamic client_id value.')
ok, idvalu = _getAzureClientId()
if not ok:
raise s_exc.BadArg(mesg=f'Failed to get the client_id data: {idvalu}')
else: # pragma: no cover
raise s_exc.BadArg(mesg=f'Unknown auth_scheme={auth_scheme}')
await self._push('oauth:provider:add', conf)
@s_nexus.Pusher.onPush('oauth:provider:add')
async def _addOAuthProvider(self, conf):
iden = conf['iden']
if self._getOAuthProvider(iden) is None:
self._oauth_providers.set(iden, conf)
def _getOAuthProvider(self, iden):
conf = self._oauth_providers.get(iden)
if conf is not None:
return copy.deepcopy(conf)
[docs]
async def getOAuthProvider(self, iden):
conf = self._getOAuthProvider(iden)
if conf is not None:
conf.pop('client_secret', None)
return conf
[docs]
async def listOAuthProviders(self):
return [(iden, await self.getOAuthProvider(iden)) for iden in self._oauth_providers.keys()]
[docs]
async def delOAuthProvider(self, iden):
if self._getOAuthProvider(iden) is not None:
return await self._push('oauth:provider:del', iden)
@s_nexus.Pusher.onPush('oauth:provider:del')
async def _delOAuthProvider(self, iden):
for clientiden in list(self._oauth_clients.keys()):
if clientiden.startswith(iden):
self._oauth_clients.pop(clientiden)
conf = self._oauth_providers.pop(iden)
if conf is not None:
conf.pop('client_secret', None)
return conf
[docs]
async def getOAuthClient(self, provideriden, useriden):
conf = self._oauth_clients.get(provideriden + useriden)
if conf is not None:
return copy.deepcopy(conf)
return None
[docs]
def listOAuthClients(self):
'''
Returns:
list: List of (provideriden, useriden, conf) for each client.
'''
return [(iden[:KEY_LEN], iden[KEY_LEN:], copy.deepcopy(conf)) for iden, conf in self._oauth_clients.items()]
[docs]
async def getOAuthAccessToken(self, provideriden, useriden):
if self._getOAuthProvider(provideriden) is None:
raise s_exc.BadArg(mesg=f'OAuth V2 provider has not been configured ({provideriden})', iden=provideriden)
clientconf = await self.getOAuthClient(provideriden, useriden)
if clientconf is None:
return False, 'Auth code has not been set'
# if the client has an error return None so caller can start oauth flow again
err = clientconf.get('error')
if err is not None:
logger.debug(f'OAuth V2 client token unavailable provider={provideriden} user={useriden} err={err}')
return False, err
# never return an expired token
expires_at = clientconf['expires_at']
if expires_at < s_common.now():
logger.debug(f'OAuth V2 token is expired ({expires_at}) for provider={provideriden} user={useriden}')
return False, 'Token is expired'
return True, clientconf.get('access_token')
[docs]
async def clearOAuthAccessToken(self, provideriden, useriden):
'''
Remove a client access token by clearing the configuration.
This will prevent further refreshes (if scheduled),
and a new auth code will be required the next time an access token is requested.
'''
if self._oauth_clients.get(provideriden + useriden) is not None:
return await self._push('oauth:client:data:clear', provideriden, useriden)
@s_nexus.Pusher.onPush('oauth:client:data:clear')
async def _clearOAuthAccessToken(self, provideriden, useriden):
return self._oauth_clients.pop(provideriden + useriden)
[docs]
async def setOAuthAuthCode(self, provideriden, useriden, authcode, code_verifier=None):
'''
Typically set as the end result of a successful OAuth flow.
An initial access token and refresh token will be immediately requested,
and the client will be loaded into the schedule to be background refreshed.
'''
providerconf = self._getOAuthProvider(provideriden)
if providerconf is None:
raise s_exc.BadArg(mesg=f'OAuth V2 provider has not been configured ({provideriden})', iden=provideriden)
await self.clearOAuthAccessToken(provideriden, useriden)
ok, data = await self._getOAuthAccessToken(providerconf, useriden, authcode, code_verifier=code_verifier)
if not ok:
raise s_exc.SynErr(mesg=f'Failed to get OAuth v2 token: {data["error"]}')
await self._setOAuthTokenData(provideriden, useriden, data)
@s_nexus.Pusher.onPushAuto('oauth:client:data:set')
async def _setOAuthTokenData(self, provideriden, useriden, data):
iden = provideriden + useriden
self._oauth_clients.set(iden, data)
self._scheduleOAuthItem(provideriden, useriden, data)