Source code for markovchain.util

from itertools import islice, repeat
from copy import deepcopy
from custom_inherit import DocInheritMeta


DOC_INHERIT = DocInheritMeta(
    style='numpy',
    abstract_base_class=False
)

DOC_INHERIT_ABSTRACT = DocInheritMeta(
    style='numpy',
    abstract_base_class=True
)


[docs]class SaveLoad(metaclass=DOC_INHERIT_ABSTRACT): """Base class for converting to/from JSON. Attributes ---------- classes : `dict` Class group. Examples -------- >>> class SaveLoadGroup(SaveLoad): ... classes = {} ... >>> class SaveLoadObject(SaveLoadGroup): ... def __init__(self, attr=None): ... self.attr = attr ... def save(self): ... data = super().save() ... data['attr'] = self.attr ... return data ... >>> SaveLoadGroup.add_class(SaveLoadObject) >>> SaveLoadGroup.classes {'SaveLoadObject': <class '__main__.SaveLoadObject'>} >>> obj = SaveLoadObject(0) >>> data = obj.save() >>> data {'attr': 0, '__class__': 'SaveLoadObject'} >>> obj2 = SaveLoadGroup.load(data) >>> type(obj2) <class '__main__.SaveLoadObject'> >>> obj2.attr 0 """ classes = {}
[docs] @classmethod def add_class(cls, *args): """Add classes to the group. Parameters ---------- *args : `type` Classes to add. """ for cls2 in args: cls.classes[cls2.__name__] = cls2
[docs] @classmethod def remove_class(cls, *args): """Remove classes from the group. Parameters ---------- *args : `type` Classes to remove. """ for cls2 in args: try: del cls.classes[cls2.__name__] except KeyError: pass
[docs] @classmethod def load(cls, data): """Create an object from JSON data. Parameters ---------- data : `dict` JSON data. Returns ---------- `object` Created object. Raises ------ KeyError If `data` does not have the '__class__' key or the necessary class is not in the class group. """ ret = cls.classes[data['__class__']] data_cls = data['__class__'] del data['__class__'] try: ret = ret(**data) finally: data['__class__'] = data_cls return ret
[docs] def save(self): """Convert an object to JSON. Returns ---------- `dict` JSON data. """ return { '__class__': self.__class__.__name__ }
[docs]class ObjectWrapper: # pylint:disable=too-few-public-methods """Base class for wrapping objects. Example ------- >>> class Object: ... def method(self): ... return 2 ... >>> class Wrapper(ObjectWrapper): ... def method(self): ... return super().method() * 2 ... >>> obj = Object() >>> wrapped = Wrapper(obj) >>> wrapped.method() 4 """
[docs] def __init__(self, obj): self.__class__ = type( obj.__class__.__name__, (self.__class__, obj.__class__), {} ) self.__dict__ = obj.__dict__
[docs]def const(x): """Return a function that takes any arguments and returns the specified value. Parameters ---------- x Value to return. Returns ------- `function` """ return lambda *args, **kwargs: x
[docs]def to_list(x): """Convert a value to a list. Parameters ---------- x Value. Returns ------- `list` Examples -------- >>> to_list(0) [0] >>> to_list({'x': 0}) [{'x': 0}] >>> to_list(x ** 2 for x in range(3)) [0, 1, 4] >>> x = [1, 2, 3] >>> to_list(x) [1, 2, 3] >>> _ is x True """ if isinstance(x, list): return x if not isinstance(x, dict): try: return list(x) except TypeError: pass return [x]
[docs]def fill(xs, length, copy=False): """Convert a value to a list of specified length. If the input is too short, fill it with its last element. Parameters ---------- xs Input list or value. length : `int` Output list length. copy : `bool`, optional Deep copy the last element to fill the list (default: False). Returns ------- `list` Raises ------ ValueError If `xs` is empty and `length` > 0 Examples -------- >>> fill(0, 3) [0, 0, 0] >>> fill((x ** 2 for x in range(3)), 1) [0] >>> x = [{'x': 0}, {'x': 1}] >>> y = fill(x, 4) >>> y [{'x': 0}, {'x': 1}, {'x': 1}, {'x': 1}] >>> y[2] is y[1] True >>> y[3] is y[2] True >>> y = fill(x, 4, True) >>> y [{'x': 0}, {'x': 1}, {'x': 1}, {'x': 1}] >>> y[2] is y[1] False >>> y[3] is y[2] False """ if isinstance(xs, list) and len(xs) == length: return xs if length <= 0: return [] try: xs = list(islice(xs, 0, length)) if not xs: raise ValueError('empty input') except TypeError: xs = [xs] if len(xs) < length: if copy: last = xs[-1] xs.extend(deepcopy(last) for _ in range(length - len(xs))) else: xs.extend(islice(repeat(xs[-1]), 0, length - len(xs))) return xs
[docs]def int_enum(cls, val): """Get int enum value. Parameters ---------- cls : `type` Int enum class. val : `int` or `str` Name or value. Returns ------- `IntEnum` Raises ------ ValueError """ if isinstance(val, str): val = val.upper() try: return getattr(cls, val) except AttributeError: raise ValueError('{0}.{1}'.format(cls, val)) return cls(val)
[docs]def load(obj, cls, default_factory): """Create or load an object if necessary. Parameters ---------- obj : `object` or `dict` or `None` cls : `type` default_factory : `function` Returns ------- `object` """ if obj is None: return default_factory() if isinstance(obj, dict): return cls.load(obj) return obj
def _extend(dst, src): for key, val in src.items(): if isinstance(val, dict): try: old = dst[key] if isinstance(old, dict): _extend(old, val) else: dst[key] = val except KeyError: dst[key] = val else: dst[key] = val
[docs]def extend(dst, *args): """Recursively update a dictionary. Parameters ---------- dst : `dict` Dictionary to update. *args : `dict` Returns ------- `dict` Updated dictionary. Examples -------- >>> extend({'x': {'y': 0}}, {'x': {'z': 1}}) {'x': {'y': 0, 'z': 1}} """ for src in args: _extend(dst, src) return dst
[docs]def truncate(string, maxlen, end=True): """Truncate a string. Parameters ---------- string : `str` String to truncate. maxlen : `int` Maximum string length. end : `boolean`, optional Remove characters from the end (default: `True`). Raises ------ ValueError If `maxlen` <= 3. Returns ------- `str` Truncated string. Examples -------- >>> truncate('str', 6) 'str' >>> truncate('long string', 8) 'long ...' >>> truncate('long string', 8, False) '...tring' """ if maxlen <= 3: raise ValueError('maxlen <= 3') if len(string) <= maxlen: return string if end: return string[:maxlen - 3] + '...' return '...' + string[3 - maxlen:]
[docs]def state_size_dataset(sz): """Get dataset key part for state size. Parameters ---------- sz : `int` State size. Returns ------- `str` Dataset key part. """ return '_ss%d' % sz
[docs]def level_dataset(lv): """Get dataset key part for level. Parameters ---------- lv : `int` Level. Returns ------- `str` Dataset key part. """ return '_lv%d' % lv