Source code for synapse.common

import io
import os
import ssl
import sys
import enum
import json
import http
import stat
import time
import heapq
import types
import base64
import shutil
import struct
import typing
import asyncio
import decimal
import fnmatch
import hashlib
import logging
import binascii
import builtins
import tempfile
import warnings
import functools
import itertools
import threading
import traceback
import contextlib
import collections

import http.cookies

import yaml
import regex

import synapse.exc as s_exc
import synapse.lib.const as s_const
import synapse.lib.msgpack as s_msgpack
import synapse.lib.structlog as s_structlog

import synapse.vendor.cpython.lib.ipaddress as ipaddress
import synapse.vendor.cpython.lib.http.cookies as v_cookies


try:
    from yaml import CSafeLoader as Loader
    from yaml import CSafeDumper as Dumper
except ImportError:  # pragma: no cover
    from yaml import SafeLoader as Loader
    from yaml import SafeDumper as Dumper

[docs] class NoValu: pass
major = sys.version_info.major minor = sys.version_info.minor micro = sys.version_info.micro majmin = (major, minor) version = (major, minor, micro) guidre = regex.compile('^[0-9a-f]{32}$') buidre = regex.compile('^[0-9a-f]{64}$') novalu = NoValu() logger = logging.getLogger(__name__) if Loader == yaml.SafeLoader: # pragma: no cover logger.warning('*****************************************************************************************************') logger.warning('* PyYAML is using the pure python fallback implementation. This will impact performance negatively. *') logger.warning('* See PyYAML docs (https://pyyaml.org/wiki/PyYAMLDocumentation) for tips on resolving this issue. *') logger.warning('*****************************************************************************************************')
[docs] def now(): ''' Get the current epoch time in milliseconds. This relies on time.time_ns(), which is system-dependent in terms of resolution. Returns: int: Epoch time in milliseconds. ''' return time.time_ns() // 1000000
[docs] def mononow(): ''' Get the current monotonic clock time in milliseconds. This relies on time.monotonic_ns(), which is a relative time. Returns: int: Monotonic clock time in milliseconds. ''' return time.monotonic_ns() // 1000000
[docs] def guid(valu=None): ''' Get a 16 byte guid value. By default, this is a random guid value. Args: valu: Object used to construct the guid valu from. This must be able to be msgpack'd. Returns: str: 32 character, lowercase ascii string. ''' if valu is None: return binascii.hexlify(os.urandom(16)).decode('utf8') # Generate a "stable" guid from the given item byts = s_msgpack.en(valu) return hashlib.md5(byts, usedforsecurity=False).hexdigest()
[docs] def buid(valu=None): ''' A binary GUID like sequence of 32 bytes. Args: valu (object): Optional, if provided, the hash of the msgpack encoded form of the object is returned. This can be used to create stable buids. Notes: By default, this returns a random 32 byte value. Returns: bytes: A 32 byte value. ''' if valu is None: return os.urandom(32) byts = s_msgpack.en(valu) return hashlib.sha256(byts).digest()
[docs] def flatten(item): ''' Normalize a primitive object for cryptographic signing. Args: item: The python primitive object to normalize. Notes: Only None, bool, int, bytes, strings, lists, tuples and dictionaries are acceptable input. List objects will be converted to tuples. Dictionary objects must have keys which can be sorted. Returns: A new copy of the object. ''' if item is None: return None if isinstance(item, (str, int, bytes)): return item if isinstance(item, (tuple, list)): return tuple([flatten(i) for i in item]) if isinstance(item, dict): return {flatten(k): flatten(item[k]) for k in sorted(item.keys())} raise s_exc.BadDataValu(mesg=f'Unknown type: {type(item)}')
[docs] def ehex(byts): ''' Encode a bytes variable to a string using binascii.hexlify. Args: byts (bytes): Bytes to encode. Returns: str: A string representing the bytes. ''' return binascii.hexlify(byts).decode('utf8')
[docs] def uhex(text): ''' Decode a hex string into bytes. Args: text (str): Text to decode. Returns: bytes: The decoded bytes. ''' return binascii.unhexlify(text)
[docs] def isguid(text): return guidre.match(text) is not None
[docs] def isbuidhex(text): return buidre.match(text) is not None
[docs] def intify(x): ''' Ensure ( or coerce ) a value into being an integer or None. Args: x (obj): An object to intify Returns: (int): The int value ( or None ) ''' if isinstance(x, int): return x try: return int(x, 0) except (TypeError, ValueError): return None
hugectx = decimal.Context(prec=49)
[docs] def hugenum(valu): ''' Return a decimal.Decimal with proper precision for use as a synapse hugenum. ''' if isinstance(valu, float): valu = str(valu) if isinstance(valu, str) and valu.startswith('0x'): valu = int(valu, 0) return decimal.Decimal(valu, context=hugectx)
[docs] def hugeadd(x, y): ''' Add two decimal.Decimal with proper precision to support synapse hugenums. ''' return hugectx.add(x, y)
[docs] def hugesub(x, y): ''' Subtract two decimal.Decimal with proper precision to support synapse hugenums. ''' return hugectx.subtract(x, y)
[docs] def hugemul(x, y): ''' Multiply two decimal.Decimal with proper precision to support synapse hugenums. ''' return hugectx.multiply(x, y)
[docs] def hugediv(x, y): ''' Divide two decimal.Decimal with proper precision to support synapse hugenums. ''' return hugectx.divide(x, y)
[docs] def hugepow(x, y): ''' Return the first operand to the power of the second operand. ''' return hugectx.power(x, y)
[docs] def hugescaleb(x, y): ''' Return the first operand with its exponent adjusted by the second operand. ''' return hugectx.scaleb(x, y)
hugeexp = decimal.Decimal('1E-24')
[docs] def hugeround(x): ''' Round a decimal.Decimal with proper precision for synapse hugenums. ''' return hugectx.quantize(x, hugeexp)
[docs] def hugemod(x, y): return hugectx.divmod(x, y)
[docs] def vertup(vstr): ''' Convert a version string to a tuple. Example: ver = vertup('1.3.30') ''' return tuple([int(x) for x in vstr.split('.')])
[docs] def todo(_todoname, *args, **kwargs): ''' Construct and return a todo tuple of (name, args, kwargs). Note: the odd name for the first parameter is to avoid collision with keys in kwargs. ''' return (_todoname, args, kwargs)
[docs] def tuplify(obj): ''' Convert a nested set of python primitives into tupleized forms via msgpack. ''' return s_msgpack.un(s_msgpack.en(obj))
[docs] def genpath(*paths): ''' Return an absolute path of the joining of the arguments as path elements Performs home directory(``~``) and environment variable expansion on the joined path Args: *paths ([str,...]): A list of path elements Note: All paths used by Synapse operations (i.e. everything but the data) shall use this function or one of its callers before storing as object properties. ''' path = os.path.join(*paths) path = os.path.expanduser(path) path = os.path.expandvars(path) return os.path.abspath(path)
[docs] def switchext(*paths, ext): ''' Return an absolute path of the joining of the arguments with the extension replaced. If an extension does not exist, it will be added. Args: *paths ([str,...]): A list of path elements ext (str): A file extension (e.g. '.txt'). It should begin with a period. ''' return os.path.splitext(genpath(*paths))[0] + ext
[docs] def reqpath(*paths): ''' Return the absolute path of the joining of the arguments, raising an exception if a file doesn't exist at resulting path Args: *paths ([str,...]): A list of path elements ''' path = genpath(*paths) if not os.path.isfile(path): raise s_exc.NoSuchFile(mesg=f'No such path {path}', path=path) return path
[docs] def reqfile(*paths, **opts): ''' Return a file at the path resulting from joining of the arguments, raising an exception if the file does not exist. Args: *paths ([str,...]): A list of path elements **opts: arguments as kwargs to io.open Returns: io.BufferedRandom: A file-object which can be read/written too. ''' path = genpath(*paths) if not os.path.isfile(path): raise s_exc.NoSuchFile(mesg=f'No such file {path}', path=path) opts.setdefault('mode', 'rb') return io.open(path, **opts)
[docs] def getfile(*paths, **opts): ''' Return a file at the path resulting from joining of the arguments, or None if the file does not exist. Args: *paths ([str,...]): A list of path elements **opts: arguments as kwargs to io.open Returns: io.BufferedRandom: A file-object which can be read/written too. ''' path = genpath(*paths) if not os.path.isfile(path): return None opts.setdefault('mode', 'rb') return io.open(path, **opts)
[docs] def getbytes(*paths, **opts): fd = getfile(*paths, **opts) if fd is None: return None with fd: return fd.read()
[docs] def reqbytes(*paths): with reqfile(*paths) as fd: return fd.read()
[docs] def genfile(*paths) -> typing.BinaryIO: ''' Create or open (for read/write) a file path join. Args: *paths: A list of paths to join together to make the file. Notes: If the file already exists, the fd returned is opened in ``r+b`` mode. Otherwise, the fd is opened in ``w+b`` mode. The file position is set to the start of the file. The user is responsible for truncating (``fd.truncate()``) if the existing file contents are not desired, or seeking to the end (``fd.seek(0, 2)``) to append. Returns: A file-object which can be read/written too. ''' path = genpath(*paths) gendir(os.path.dirname(path)) if not os.path.isfile(path): return io.open(path, 'w+b') return io.open(path, 'r+b')
[docs] @contextlib.contextmanager def getTempDir(dirn=None): tempdir = tempfile.mkdtemp(dir=dirn) try: yield tempdir finally: shutil.rmtree(tempdir, ignore_errors=True)
[docs] def listdir(*paths, glob=None): ''' List the (optionally glob filtered) full paths from a dir. Args: *paths ([str,...]): A list of path elements glob (str): An optional fnmatch glob str ''' path = genpath(*paths) names = os.listdir(path) if glob is not None: names = fnmatch.filter(names, glob) retn = [os.path.join(path, name) for name in names] return retn
[docs] def gendir(*paths, **opts): ''' Return the absolute path of the joining of the arguments, creating a directory at the resulting path if one does not exist. Performs home directory(~) and environment variable expansion. Args: *paths ([str,...]): A list of path elements **opts: arguments as kwargs to os.makedirs ''' mode = opts.get('mode', 0o700) path = genpath(*paths) if os.path.islink(path): path = os.readlink(path) if not os.path.isdir(path): os.makedirs(path, mode=mode, exist_ok=True) return path
[docs] def reqdir(*paths): ''' Return the absolute path of the joining of the arguments, raising an exception if a directory does not exist at the resulting path. Performs home directory(~) and environment variable expansion. Args: *paths ([str,...]): A list of path elements ''' path = genpath(*paths) if not os.path.isdir(path): raise s_exc.NoSuchDir(path=path) return path
[docs] def getDirSize(*paths): ''' Get the size of a directory. Args: *paths (str): A list of path elements. Notes: This is equivalent to ``du -B 1 -s`` and ``du -bs``. Returns: tuple: Tuple of total real and total apparent size of all normal files and directories underneath ``*paths`` plus ``*paths`` itself. ''' def getsize(path): try: status = os.lstat(path) except OSError: # pragma: no cover return 0, 0 mode = status.st_mode if not (stat.S_ISREG(mode) or stat.S_ISDIR(mode)): return 0, 0 return status.st_blocks * 512, status.st_size realsum, apprsum = getsize(genpath(*paths)) for fpath, dirnames, fnames in os.walk(reqdir(*paths)): for fname in itertools.chain(fnames, dirnames): fp = genpath(fpath, fname) real, appr = getsize(fp) realsum += real apprsum += appr return realsum, apprsum
[docs] def jsload(*paths): with genfile(*paths) as fd: byts = fd.read() if not byts: return None return json.loads(byts.decode('utf8'))
[docs] def jslines(*paths): with genfile(*paths) as fd: for line in fd: yield json.loads(line)
[docs] def jssave(js, *paths): path = genpath(*paths) with io.open(path, 'wb') as fd: fd.write(json.dumps(js, sort_keys=True, indent=2).encode('utf8'))
[docs] def yamlloads(data): return yaml.load(data, Loader)
[docs] def yamlload(*paths): path = genpath(*paths) if not os.path.isfile(path): return None with io.open(path, 'rb') as fd: return yamlloads(fd)
[docs] def yamldump(obj, stream: typing.Optional[typing.BinaryIO] =None) -> bytes: ''' Dump a object to yaml. Args: obj: The object to serialize. stream: The optional stream to write the stream too. Returns: The raw yaml bytes if stream is not provided. ''' return yaml.dump(obj, allow_unicode=True, default_flow_style=False, default_style='', explicit_start=True, explicit_end=True, encoding='utf8', stream=stream, Dumper=Dumper)
[docs] def yamlsave(obj, *paths): path = genpath(*paths) with genfile(path) as fd: fd.truncate(0) yamldump(obj, stream=fd)
[docs] def yamlmod(obj, *paths): ''' Combines/creates a yaml file and combines with obj. obj and file must be maps/dict or empty. ''' oldobj = yamlload(*paths) if obj is not None: if oldobj: yamlsave({**oldobj, **obj}, *paths) else: yamlsave(obj, *paths)
[docs] def yamlpop(key, *paths): ''' Pop a key out of a yaml file. Args: key (str): Name of the key to remove. *paths: Path to a yaml file. The file must be a map / dictionary. Returns: None ''' obj = yamlload(*paths) if obj is not None: obj.pop(key, None) yamlsave(obj, *paths)
[docs] def verstr(vtup): ''' Convert a version tuple to a string. ''' return '.'.join([str(v) for v in vtup])
[docs] def excinfo(e): ''' Populate err,errmsg,errtrace info from exc. ''' tb = e.__traceback__ path, line, name, sorc = traceback.extract_tb(tb)[-1] ret = { 'err': e.__class__.__name__, 'errmsg': str(e), 'errfile': path, 'errline': line, } if isinstance(e, s_exc.SynErr): ret['errinfo'] = e.errinfo return ret
[docs] def errinfo(name, mesg): return { 'err': name, 'errmsg': mesg, }
[docs] def chunks(item, size): ''' Divide an iterable into chunks. Args: item: Item to slice size (int): Maximum chunk size. Notes: This supports Generator objects and objects which support calling the __getitem__() method with a slice object. Yields: Slices of the item containing up to "size" number of items. ''' # use islice if it's a generator if isinstance(item, types.GeneratorType): while True: chunk = tuple(itertools.islice(item, size)) if not chunk: return yield chunk # The sequence item is empty, yield a empty slice from it. # This will also catch mapping objects since a slice should # be an unhashable type for a mapping and the __getitem__ # method would not be present on a set object if not item: yield item[0:0] return # otherwise, use normal slicing off = 0 while True: chunk = item[off:off + size] if not chunk: return yield chunk off += size
[docs] def iterfd(fd, size=10000000): ''' Generator which yields bytes from a file descriptor. Args: fd (file): A file-like object to read bytes from. size (int): Size, in bytes, of the number of bytes to read from the fd at a given time. Notes: If the first read call on the file descriptor is a empty bytestring, that zero length bytestring will be yielded and the generator will then be exhausted. This behavior is intended to allow the yielding of contents of a zero byte file. Yields: bytes: Bytes from the file descriptor. ''' fd.seek(0) byts = fd.read(size) # Fast path to yield b'' if len(byts) == 0: yield byts return while byts: yield byts byts = fd.read(size)
[docs] def spin(genr): ''' Crank through a generator but discard the yielded values. Args: genr: Any generator or iterable valu. Notes: This generator is exhausted via the ``collections.dequeue()`` constructor with a ``maxlen=0``, which will quickly exhaust an iterator staying in C code as much as possible. Returns: None ''' collections.deque(genr, 0)
[docs] async def aspin(genr): ''' Async version of spin ''' async for _ in genr: pass
[docs] async def agen(*items): for item in items: yield item
[docs] def firethread(f): ''' A decorator for making a function fire a thread. ''' @functools.wraps(f) def callmeth(*args, **kwargs): thr = worker(f, *args, **kwargs) return thr return callmeth
[docs] def worker(meth, *args, **kwargs): thr = threading.Thread(target=meth, args=args, kwargs=kwargs, daemon=True) thr.start() return thr
sockerrs = (builtins.ConnectionError, builtins.FileNotFoundError) _Int64be = struct.Struct('>Q')
[docs] def int64en(i): ''' Encode an unsigned 64-bit int into 8 byte big-endian bytes ''' return _Int64be.pack(i)
[docs] def int64un(b): ''' Decode an unsigned 64-bit int from 8 byte big-endian ''' return _Int64be.unpack(b)[0]
_SignedInt64be = struct.Struct('>q')
[docs] def signedint64en(i): ''' Encode a signed 64-bit int into 8 byte big-endian bytes ''' return _SignedInt64be.pack(i)
[docs] def signedint64un(b): ''' Decode a signed 64-bit int from 8 byte big-endian ''' return _SignedInt64be.unpack(b)[0]
[docs] def enbase64(b): return base64.b64encode(b).decode('utf8')
[docs] def debase64(b): return base64.b64decode(b.encode('utf8'))
[docs] def makedirs(path, mode=0o777): os.makedirs(path, mode=mode, exist_ok=True)
[docs] def iterzip(*args, fillvalue=None): return itertools.zip_longest(*args, fillvalue=fillvalue)
def _getLogConfFromEnv(defval=None, structlog=None, datefmt=None): if structlog: structlog = 'true' else: structlog = 'false' defval = os.getenv('SYN_LOG_LEVEL', defval) datefmt = os.getenv('SYN_LOG_DATEFORMAT', datefmt) structlog = envbool('SYN_LOG_STRUCT', structlog) ret = {'defval': defval, 'structlog': structlog, 'datefmt': datefmt} return ret
[docs] def normLogLevel(valu): ''' Norm a log level value to a integer. Args: valu: The value to norm ( a string or integer ). Returns: int: A valid Logging log level. ''' if isinstance(valu, int): if valu not in s_const.LOG_LEVEL_INVERSE_CHOICES: raise s_exc.BadArg(mesg=f'Invalid log level provided: {valu}', valu=valu) return valu if isinstance(valu, str): valu = valu.strip() try: valu = int(valu) except ValueError: valu = valu.upper() ret = s_const.LOG_LEVEL_CHOICES.get(valu) if ret is None: raise s_exc.BadArg(mesg=f'Invalid log level provided: {valu}', valu=valu) from None return ret else: return normLogLevel(valu) raise s_exc.BadArg(mesg=f'Unknown log level type: {type(valu)} {valu}', valu=valu)
[docs] def setlogging(mlogger, defval=None, structlog=None, log_setup=True, datefmt=None): ''' Configure synapse logging. Args: mlogger (logging.Logger): Reference to a logging.Logger() defval (str): Default log level. May be an integer. structlog (bool): Enabled structured (jsonl) logging output. datefmt (str): Optional strftime format string. Notes: This calls logging.basicConfig and should only be called once per process. Returns: None ''' ret = _getLogConfFromEnv(defval, structlog, datefmt) datefmt = ret.get('datefmt') log_level = ret.get('defval') log_struct = ret.get('structlog') if log_level: # pragma: no cover log_level = normLogLevel(log_level) if log_struct: handler = logging.StreamHandler() formatter = s_structlog.JsonFormatter(datefmt=datefmt) handler.setFormatter(formatter) logging.basicConfig(level=log_level, handlers=(handler,)) else: logging.basicConfig(level=log_level, format=s_const.LOG_FORMAT, datefmt=datefmt) if log_setup: mlogger.info('log level set to %s', s_const.LOG_LEVEL_INVERSE_CHOICES.get(log_level)) return ret
syndir_default = '~/.syn' syndir = os.getenv('SYN_DIR') if syndir is None: syndir = syndir_default
[docs] def envbool(name, defval='false'): ''' Resolve an environment variable to a boolean value. Args: name (str): Environment variable to resolve. defval (str): Default string value to resolve as. Notes: False values will be consider strings "0" or "false" after lower casing. Returns: boolean: True if the envar is set, false if it is set to a false value. ''' return os.getenv(name, defval).lower() not in ('0', 'false')
[docs] def getSynPath(*paths): return genpath(syndir, *paths)
[docs] def getSynDir(*paths): return gendir(syndir, *paths)
[docs] def result(retn): ''' Return a value or raise an exception from a retn tuple. ''' ok, valu = retn if ok: return valu name, info = valu ctor = getattr(s_exc, name, None) if ctor is not None: raise ctor(**info) info['errx'] = name raise s_exc.SynErr(**info)
[docs] def err(e, fulltb=False): name = e.__class__.__name__ info = {} tb = sys.exc_info()[2] tbinfo = traceback.extract_tb(tb) if tbinfo: path, line, tbname, src = tbinfo[-1] path = os.path.basename(path) info = { 'efile': path, 'eline': line, 'esrc': src, 'ename': tbname, } if isinstance(e, s_exc.SynErr): info.update(e.items()) else: info['mesg'] = str(e) if fulltb: s = traceback.format_exc() if s[-1:] == "\n": s = s[:-1] info['etb'] = s return (name, info)
[docs] def retnexc(e): ''' Construct a retn tuple for the given exception. ''' return (False, err(e))
[docs] def config(conf, confdefs): ''' Initialize a config dict using the given confdef tuples. ''' conf = conf.copy() # for now just populate defval for name, info in confdefs: conf.setdefault(name, info.get('defval')) return conf
[docs] def deprecated(name, curv='2.x', eolv='3.0.0'): mesg = f'"{name}" is deprecated in {curv} and will be removed in {eolv}' warnings.warn(mesg, DeprecationWarning) return mesg
[docs] def deprdate(name, date): # pragma: no cover mesg = f'{name} is deprecated and will be removed on {date}.' warnings.warn(mesg, DeprecationWarning)
[docs] def reqjsonsafe(item): ''' Returns None if item is json serializable, otherwise raises an exception. Uses default type coercion from built-in json.dumps. ''' try: json.dumps(item) except TypeError as e: raise s_exc.MustBeJsonSafe(mesg=str(e)) from None
[docs] def jsonsafe_nodeedits(nodeedits): ''' Hexlify the buid of each node:edits ''' retn = [] for nodeedit in nodeedits: newedit = (ehex(nodeedit[0]), *nodeedit[1:]) retn.append(newedit) return retn
[docs] def unjsonsafe_nodeedits(nodeedits): retn = [] for nodeedit in nodeedits: buid = nodeedit[0] if isinstance(buid, str): newedit = (uhex(buid), *nodeedit[1:]) else: newedit = nodeedit retn.append(newedit) return retn
[docs] def reprauthrule(rule): text = '.'.join(rule[1]) if not rule[0]: text = '!' + text return text
[docs] def reqJsonSafeStrict(item): ''' Require the item to be safe to serialize to JSON without type coercion issues. Args: item: The python primitive to check. Returns: None Raise: s_exc.BadArg: If the item contains invalid data. ''' if item is None: return if isinstance(item, (str, int,)): return if isinstance(item, (list, tuple)): for valu in item: reqJsonSafeStrict(valu) return if isinstance(item, dict): for key, valu in item.items(): if not isinstance(key, str): raise s_exc.BadArg(mesg='Non-string keys are not valid json', key=key) reqJsonSafeStrict(valu) return raise s_exc.BadArg(mesg=f'Invalid item type encountered: {item.__class__.__name__}')
[docs] async def merggenr(genrs, cmprkey): ''' Iterate multiple sorted async generators and yield their results in order. Args: genrs (Sequence[AsyncGenerator[T]]): a sequence of async generator that each yield sorted items cmprkey(Callable[T, T, bool]): a comparison function over the items yielded Note: If the genrs yield increasing items, cmprkey should return True if the first parameter is less than the second parameter, e.g lambda x, y: x < y. ''' size = len(genrs) genrs = list(genrs) indxs = list(range(size)) async def genrnext(g): try: ret = await g.__anext__() return ret except StopAsyncIteration: return novalu curvs = [await genrnext(g) for g in genrs] while True: nextindx = None nextvalu = novalu toremove = [] for i in indxs: curv = curvs[i] if curv is novalu: toremove.append(i) continue # in the case where we're the first, initialize... if nextvalu is novalu: nextindx = i nextvalu = curv continue if cmprkey(curv, nextvalu): nextindx = i nextvalu = curv # check if we're done if nextvalu is novalu: return # Remove spent genrs for i in toremove: indxs.remove(i) yield nextvalu curvs[nextindx] = await genrnext(genrs[nextindx])
[docs] async def merggenr2(genrs, cmprkey=None, reverse=False): ''' Optimized version of merggenr based on heapq.merge ''' h = [] h_append = h.append if reverse: _heapify = heapq._heapify_max _heappop = heapq._heappop_max _heapreplace = heapq._heapreplace_max direction = -1 else: _heapify = heapq.heapify _heappop = heapq.heappop _heapreplace = heapq.heapreplace direction = 1 if cmprkey is None: for order, genr in enumerate(genrs): try: nxt = genr.__anext__ h_append([await nxt(), order * direction, nxt]) except StopAsyncIteration: pass _heapify(h) while len(h) > 1: try: while True: valu, _, nxt = s = h[0] yield valu s[0] = await nxt() _heapreplace(h, s) except StopAsyncIteration: _heappop(h) if h: valu, order, _ = h[0] yield valu async for valu in genrs[abs(order)]: yield valu return for order, genr in enumerate(genrs): try: nxt = genr.__anext__ valu = await nxt() h_append([cmprkey(valu), order * direction, valu, nxt]) except StopAsyncIteration: pass _heapify(h) while len(h) > 1: try: while True: _, _, valu, nxt = s = h[0] yield valu valu = await nxt() s[0] = cmprkey(valu) s[2] = valu _heapreplace(h, s) except StopAsyncIteration: _heappop(h) if h: _, order, valu, _ = h[0] yield valu async for valu in genrs[abs(order)]: yield valu
[docs] def getSslCtx(cadir, purpose=ssl.Purpose.SERVER_AUTH): ''' Create as SSL Context and load certificates from a given directory. Args: cadir (str): Path to load certificates from. purpose: SSLContext purposes flags. Returns: ssl.SSLContext: A SSL Context object. ''' sslctx = ssl.create_default_context(purpose=purpose) for name in os.listdir(cadir): certpath = os.path.join(cadir, name) if not os.path.isfile(certpath): continue try: sslctx.load_verify_locations(cafile=certpath) except Exception: # pragma: no cover logger.exception(f'Error loading {certpath}') return sslctx
[docs] def httpcodereason(code): ''' Get the reason for an HTTP status code. Args: code (int): The code. Note: If the status code is unknown, a string indicating it is unknown is returned. Returns: str: A string describing the status code. ''' try: return http.HTTPStatus(code).phrase except ValueError: return f'Unknown HTTP status code {code}'
[docs] def trimText(text: str, n: int = 256, placeholder: str = '...') -> str: ''' Trim a text string larger than n characters and add a placeholder at the end. Args: text: String to trim. n: Number of characters to allow. placeholder: Placeholder text. Returns: The original string or the trimmed string. ''' if len(text) <= n: return text plen = len(placeholder) mlen = n - plen assert plen > 0 assert n > plen return f'{text[:mlen]}{placeholder}'
def _patch_http_cookies(): ''' Patch stdlib http.cookies._unquote from the 3.11.10 implementation if the interpreter we are using is not patched for CVE-2024-7592. ''' if not hasattr(http.cookies, '_QuotePatt'): return http.cookies._unquote = v_cookies._unquote _patch_http_cookies() # TODO: Switch back to using asyncio.wait_for when we are using py 3.12+ # This is a workaround for a race where asyncio.wait_for can end up # ignoring cancellation https://github.com/python/cpython/issues/86296
[docs] async def wait_for(fut, timeout): if timeout is not None and timeout <= 0: fut = asyncio.ensure_future(fut) if fut.done(): return fut.result() await _cancel_and_wait(fut) try: return fut.result() except asyncio.CancelledError as exc: raise TimeoutError from exc async with _timeout(timeout): return await fut
def _release_waiter(waiter, *args): if not waiter.done(): waiter.set_result(None) async def _cancel_and_wait(fut): """Cancel the *fut* future or task and wait until it completes.""" loop = asyncio.get_running_loop() waiter = loop.create_future() cb = functools.partial(_release_waiter, waiter) fut.add_done_callback(cb) try: fut.cancel() # We cannot wait on *fut* directly to make # sure _cancel_and_wait itself is reliably cancellable. await waiter finally: fut.remove_done_callback(cb) class _State(enum.Enum): CREATED = "created" ENTERED = "active" EXPIRING = "expiring" EXPIRED = "expired" EXITED = "finished" class _Timeout: """Asynchronous context manager for cancelling overdue coroutines. Use `timeout()` or `timeout_at()` rather than instantiating this class directly. """ def __init__(self, when): """Schedule a timeout that will trigger at a given loop time. - If `when` is `None`, the timeout will never trigger. - If `when < loop.time()`, the timeout will trigger on the next iteration of the event loop. """ self._state = _State.CREATED self._timeout_handler = None self._task = None self._when = when def when(self): # pragma: no cover """Return the current deadline.""" return self._when def reschedule(self, when): """Reschedule the timeout.""" assert self._state is not _State.CREATED if self._state is not _State.ENTERED: # pragma: no cover raise RuntimeError( f"Cannot change state of {self._state.value} Timeout", ) self._when = when if self._timeout_handler is not None: # pragma: no cover self._timeout_handler.cancel() if when is None: self._timeout_handler = None else: loop = asyncio.get_running_loop() if when <= loop.time(): # pragma: no cover self._timeout_handler = loop.call_soon(self._on_timeout) else: self._timeout_handler = loop.call_at(when, self._on_timeout) def expired(self): # pragma: no cover """Is timeout expired during execution?""" return self._state in (_State.EXPIRING, _State.EXPIRED) def __repr__(self): # pragma: no cover info = [''] if self._state is _State.ENTERED: when = round(self._when, 3) if self._when is not None else None info.append(f"when={when}") info_str = ' '.join(info) return f"<Timeout [{self._state.value}]{info_str}>" async def __aenter__(self): self._state = _State.ENTERED self._task = asyncio.current_task() self._cancelling = self._task.cancelling() if self._task is None: # pragma: no cover raise RuntimeError("Timeout should be used inside a task") self.reschedule(self._when) return self async def __aexit__(self, exc_type, exc_val, exc_tb): assert self._state in (_State.ENTERED, _State.EXPIRING) if self._timeout_handler is not None: self._timeout_handler.cancel() self._timeout_handler = None if self._state is _State.EXPIRING: self._state = _State.EXPIRED if self._task.uncancel() <= self._cancelling and exc_type is asyncio.CancelledError: # Since there are no new cancel requests, we're # handling this. raise TimeoutError from exc_val elif self._state is _State.ENTERED: self._state = _State.EXITED return None def _on_timeout(self): assert self._state is _State.ENTERED self._task.cancel() self._state = _State.EXPIRING # drop the reference early self._timeout_handler = None def _timeout(delay): """Timeout async context manager. Useful in cases when you want to apply timeout logic around block of code or in cases when asyncio.wait_for is not suitable. For example: >>> async with asyncio.timeout(10): # 10 seconds timeout ... await long_running_task() delay - value in seconds or None to disable timeout logic long_running_task() is interrupted by raising asyncio.CancelledError, the top-most affected timeout() context manager converts CancelledError into TimeoutError. """ loop = asyncio.get_running_loop() return _Timeout(loop.time() + delay if delay is not None else None)
[docs] def format(text, **kwargs): ''' Similar to python str.format() but treats tokens as opaque. ''' for name, valu in kwargs.items(): tokn = '{' + name + '}' text = text.replace(tokn, valu) return text