diff --git a/batchflow/config.py b/batchflow/config.py index b4a606dbe..cb8430fe1 100644 --- a/batchflow/config.py +++ b/batchflow/config.py @@ -1,324 +1,267 @@ """ Config class""" from pprint import pformat -import numpy as np +class Config(dict): + """ Class for configs that can be represented as nested dicts with easy indexing by slashes. """ -class Config: - """ Class for configs that can be represented as nested dicts with easy indexing by slashes """ - + # Should be defined temporarily for the already pickled configs class IAddDict(dict): - """ dict that supports update via += """ - def __iadd__(self, other): - if isinstance(other, dict): - self.update(other) - else: - raise TypeError(f"unsupported operand type(s) for +=: 'IAddDict' and '{type(other)}'") - return self + pass def __init__(self, config=None, **kwargs): - """ Create Config + """ Create Config. Parameters ---------- - config : dict, Config or None - an object to initialize Config - if dict, all keys and values slashes will be parsed into nested structure of dicts - and the resulting dictionary will be saved into self.config - if an instance on Config, config.config will be saved to self.config (not a copy!) - if None, empty dictionary will be created + config : dict, Config, list, tuple or None + An object to initialize Config. + + If dict, all keys with slashes and values are parsed into nested structure of dicts, + and the resulting dictionary is saved to self.config. + For example, `{'a/b': 1, 'c/d/e': 2}` will be parsed into `{'a': {'b': 1}, 'c': {'d': {'e': 2}}}`. + + If list or tuples, should contain key-value pairs with the length of 2. + For example, `[('a/b', 1), ('c/d/e', 2)]` will be parsed into `{'a': {'b': 1}, 'c': {'d': {'e': 2}}}`. + + If an instance of Config, config is saved to self.config. + + If None, empty dictionary is created. kwargs : - parameters from kwargs also will be parsed and saved into self.config + Parameters from kwargs also are parsed and saved to self.config. """ + # pylint: disable=super-init-not-called + self.config = {} + if config is None: - self.config = Config.IAddDict() - elif isinstance(config, (dict, list)): - self.config = self.parse(config) + pass elif isinstance(config, Config): - self.config = config.config + self.parse(config.config) + elif isinstance(config, (dict, list, tuple)): + self.parse(config) else: - raise TypeError(f'config must be dict, Config or list but {type(config)} was given') + raise TypeError(f'Config must be dict, Config, list or tuple but {type(config)} was given') for key, value in kwargs.items(): self.put(key, value) - def pop(self, variables, config=None, **kwargs): - """ Returns variables and remove them from config + def parse(self, config): + """ Parses flatten config with slashes. Parameters ---------- - variables : str or list of strs - names of variables. '/' is used to get value from nested dict - config : dict, Config or None - if None, variables will be getted from self.config else from config + config : dict, Config, list or tuple Returns ------- - single value or a tuple + self : Config + """ if isinstance(config, Config): - value = config.pop(variables, None, **kwargs) + # suppose we have config = {'a': {'b': {'c': 1}}}, + # and we try to update config with other = {'a': {'b': {'d': 3}}}, + # and expect to see config = {'a': {'b': {'c': 1, 'd': 3}}} + items = config.items(flatten=True) + elif isinstance(config, dict): + items = config.items() else: - value = self._get(variables, config, pop=True, **kwargs) - return value + items = dict(config).items() - def get(self, variables, default=None, config=None): - """ Returns variables from config + for key, value in items: + if isinstance(key, str): # if key contains multiple consecutive '/' + key = '/'.join(s for s in key.split('/') if s) + self.put(key, value) + + return self + + def put(self, key, value): + """ Put a new key into config recursively. Parameters ---------- - variables : str or list of str or tuple of str - names of variables. '/' is used to get value from nested dict. - config : dict, Config or None - if None variables will be getted from self.config else from config - default : masc - default value if variable doesn't exist in config + key : hashable object + Key to add. '/' is used to put value into nested dict. + value : misc - Returns - ------- - single value or a tuple """ - if isinstance(config, Config): - val = config.get(variables, default=default) - else: - val = self._get(variables, config, default=default, pop=False) - return val + if isinstance(value, dict): # for example, value = {'a/b': 3}, and we need to parse it before put + value = Config(value).config + + if isinstance(key, str): - def _get(self, variables, config=None, **kwargs): - if config is None: config = self.config - pop = kwargs.get('pop', False) - has_default = 'default' in kwargs - default = kwargs.get('default') + levels = key.split('/') + last_level = levels[-1] + + for level in levels[:-1]: + prev_config = config + if level not in config: + config[level] = {} + config = config[level] + + if isinstance(value, dict) and last_level in config and isinstance(config[last_level], dict): + config[last_level].update(value) + else: + if isinstance(config, dict): + config[last_level] = value + # for example, we try to set my_config['a/b/c'] = 3, + # where my_config = Config({'a/b': 1}) and don't want error here + else: + prev_config[level] = {last_level: value} # pylint: disable=undefined-loop-variable + else: + self.config[key] = value + + def _get(self, key, default=None, has_default=False, pop=False): + """ Consecutively retrieve values for a given key if the key contains '/'. + This method supports the `default` to be unique for each variable in key. + """ + method = 'get' if not pop else 'pop' + method = getattr(self.config, method) unpack = False - if not isinstance(variables, (list, tuple)): - variables = list([variables]) + if not isinstance(key, list): + key = [key] unpack = True + n = len(key) + if n > 1: + default = [default] * n if not isinstance(default, list) else default + if len(default) != n: + raise ValueError('The length of `default` must be equal to the length of `key`') + else: + default = [default] + ret_vars = [] - for variable in variables: - _config = config - if '/' in variable: - var = variable.split('/') - prefix = var[:-1] - var_name = var[-1] - else: - prefix = [] - var_name = variable + for ix, variable in enumerate(key): + + if isinstance(variable, str) and '/' in variable: + + value = self.config + levels = variable.split('/') + values = [] + + for level in levels: + + if not isinstance(value, dict): + if not has_default: + raise KeyError(level) + value = default[ix] + values.append(value) + break + + if level not in value: + if not has_default: + raise KeyError(level) + value = default[ix] + values.append(value) + break + + value = value[level] + values.append(value) - for p in prefix: - if p in _config: - _config = _config[p] - else: - _config = None - break - if isinstance(_config, dict): if pop: - if has_default: - val = _config.pop(var_name, default) - else: - val = _config.pop(var_name) - else: - if has_default: - val = _config.get(var_name, default) - else: - val = _config[var_name] + # delete the last level from the parent dict + values[-2].pop(level, default[ix]) # pylint: disable=undefined-loop-variable + else: - if has_default: - val = default + + if variable not in self.config: + if not has_default: + raise KeyError(variable) + value = default[ix] + else: - raise KeyError(f"Key '{variable}' not found") + value = method(variable) + + if isinstance(value, dict): + value = Config(value) + ret_vars.append(value) - val = Config(val) if isinstance(val, (dict, Config.IAddDict)) else val - ret_vars.append(val) + ret_vars = ret_vars[0] if unpack else tuple(ret_vars) - if unpack: - ret_vars = ret_vars[0] - else: - ret_vars = tuple(ret_vars) return ret_vars - def put(self, variable, value, config=None): - """ Put a new variable into config + def get(self, key, default=None): + """ Returns the value or tuple of values for key in the config. + If not found, returns a default value. Parameters ---------- - variable : str - variable to add. '/' is used to put value into nested dict - value : masc - config : dict, Config or None - if None value will be putted into self.config else from config + key : str or list of hashable objects + A key in the dictionary. '/' is used to get value from nested dict. + default : misc + Default value if key doesn't exist in config. + By default None, so this method never raises a KeyError. + If key has several variables, `default` can be a list with defaults for each variable. + + Returns + ------- + value : misc + Single value or a tuple. """ - if config is None: - config = self.config - elif isinstance(config, Config): - config = config.config - if isinstance(value, dict): - value = Config(value) - variable = variable.strip('/') - if '/' in variable: - var = variable.split('/') - prefix = var[:-1] - var_name = var[-1] - else: - prefix = [] - var_name = variable - - for i, p in enumerate(prefix): - if p not in config: - config[p] = Config.IAddDict() - if isinstance(config[p], dict): - config = config[p] - else: # for example, we put value with key 'a/b' into `{a: c}` - value = Config({'/'.join(prefix[i+1:] + [var_name]): value}) - var_name = p - break - if var_name in config and isinstance(config[var_name], dict) and isinstance(value, Config): - config[var_name] = Config(config[var_name]) - config[var_name].update(value) - config[var_name] = config[var_name].config - else: - if isinstance(value, Config): - config[var_name] = value.config - else: - config[var_name] = value + value = self._get(key, default=default, has_default=True) - def parse(self, config): - """ Parses flatten config with slashes + return value + + def pop(self, key, **kwargs): + """ Returns the value or tuple of values for key in the config. + If not found, returns a default value. Parameters ---------- - config : dict, Config or list + key : str or list of hashable objects + A key in the dictionary. '/' is used to get value from nested dict. + default : misc + Default value if key doesn't exist in config. + If key has several variables, `default` can be a list with defaults for each variable. Returns ------- - new_config : dict + value : misc + Single value or a tuple. """ - if isinstance(config, Config): - return config.config - if isinstance(config, dict): - items = config.items() - elif isinstance(config, list): - items = config - if np.any([len(item) != 2 for item in items]): - raise ValueError('tuples in list should represent pairs key-value' - ', and therefore must be always the length of 2') - else: - raise TypeError(f'config must be dict, Config or list but {type(config)} was given') - new_config = Config.IAddDict() - for key, value in items: - if isinstance(value, dict): - value = self.parse(value) - if not isinstance(key, str): - raise TypeError(f'only str keys are supported, "{str(key)}" is of {type(key)} type') - key = '/'.join(filter(None, key.split('/'))) #merge multiple consecutive slashes '/' to one - self.put(key, value, new_config) - return new_config + has_default = 'default' in kwargs + default = kwargs.get('default') + value = self._get(key, has_default=has_default, default=default, pop=True) + + return value + + def update(self, other=None, **kwargs): + other = other or {} + if not isinstance(other, (dict, tuple, list)): + raise TypeError(f'{type(other)} object is not iterable') + + self.parse(Config(other)) + + for key, value in kwargs.items(): + self.put(key, value) def flatten(self, config=None): - """ Transforms nested dict into flatten dict + """ Transforms nested dict into flatten dict. Parameters ---------- config : dict, Config or None - if None self.config will be parsed else config + If None `self.config` will be parsed else config. Returns ------- new_config : dict + """ - if config is None: - config = self.config - elif isinstance(config, Config): - config = config.config - new_config = Config.IAddDict() + config = self.config if config is None else config + new_config = {} for key, value in config.items(): - if isinstance(value, Config): - value = value.config if isinstance(value, dict) and len(value) > 0: value = self.flatten(value) for _key, _value in value.items(): - new_config[key+'/'+_key] = _value + if isinstance(_key, str): + new_config[key + '/' + _key] = _value + else: + new_config[key] = {_key: _value} else: new_config[key] = value - return new_config - - def __add__(self, other): - if isinstance(other, dict): - other = Config(other) - if isinstance(other, Config): - return Config([*self.flatten().items(), *other.flatten().items()]) - return NotImplemented - - def __radd__(self, other): - if isinstance(other, dict): - other = Config(other) - return other.__add__(self) - - def __getitem__(self, key): - value = self._get(key) - return value - - def __setitem__(self, key, value): - self.pop(key, default=None) - self.put(key, value) - - def __delitem__(self, key): - self.pop(key) - - def __getattr__(self, key): - if key in self.config: - value = self._get(key) - value = Config(value) if isinstance(value, dict) else value - return value - raise AttributeError(key) - - def __getstate__(self): - """ Must be explicitly defined for pickling to work. """ - return vars(self) - def __setstate__(self, state): - """ Must be explicitly defined for pickling to work. """ - vars(self).update(state) - - def __len__(self): - return len(self.config) - - def __rshift__(self, other): - """ - Parameters - ---------- - other : Pipeline - - Returns - ------- - Pipeline - Pipeline object with an updated config - """ - return other << self - - def __eq__(self, other): - self_ = self.flatten() if isinstance(self, Config) else self - other_ = Config(other).flatten() if isinstance(other, (dict, Config)) else other - return self_.__eq__(other_) - - def items(self, flatten=False): - """ Returns config items - - Parameters - ---------- - flatten : bool - if False, keys and values will be getted from first level of nested dict, else from the last - - Returns - ------- - dict_items - """ - if flatten: - items = self.flatten().items() - else: - items = self.config.items() - return items + return new_config def keys(self, flatten=False): """ Returns config keys @@ -326,11 +269,12 @@ def keys(self, flatten=False): Parameters ---------- flatten : bool - if False, keys will be getted from first level of nested dict, else from the last + If False, keys will be got from first level of nested dict, else from the last. Returns ------- - dict_keys + keys : dict_keys + """ if flatten: keys = self.flatten().keys() @@ -344,11 +288,12 @@ def values(self, flatten=False): Parameters ---------- flatten : bool - if False, values will be getted from first level of nested dict, else from the last + If False, values will be got from first level of nested dict, else from the last. Returns ------- - dict_values + values : dict_values + """ if flatten: values = self.flatten().values() @@ -356,31 +301,98 @@ def values(self, flatten=False): values = self.config.values() return values - def update(self, other=None, **kwargs): - """ Update config with values from other + def items(self, flatten=False): + """ Returns config items Parameters ---------- - other : dict or Config + flatten : bool + If False, keys and values will be got from first level of nested dict, else from the last. + + Returns + ------- + items : dict_items - kwargs : - parameters from kwargs also will be included into the resulting config """ - other = {} if other is None else other - if isinstance(other, (dict, Config)): - for key, value in other.items(): - self.put(key, value) + if flatten: + items = self.flatten().items() else: - for key, value in kwargs.items(): - self.put(key, value) + items = self.config.items() + return items def copy(self): """ Create a shallow copy of the instance. """ return Config(self.config.copy()) + def __getitem__(self, key): + value = self._get(key) + return value + + def __setitem__(self, key, value): + if key in self.config: + self.pop(key, default=None) + self.put(key, value) + + def __delitem__(self, key): + self.pop(key) + + def __getattr__(self, key): + if key in self.config: + value = self.config.get(key) + value = Config(value) if isinstance(value, dict) else value + return value + raise AttributeError(key) + + def __add__(self, other): + if isinstance(other, dict) and not isinstance(other, Config): + other = Config(other) + if isinstance(other, Config): + return Config([*self.flatten().items(), *other.flatten().items()]) + return NotImplemented + + def __iadd__(self, other): + if isinstance(other, dict): + self.update(other) + else: + raise TypeError(f"unsupported operand type(s) for +=: 'IAddDict' and '{type(other)}'") + return self + + def __radd__(self, other): + if isinstance(other, dict): + other = Config(other) + return other.__add__(self) + + def __eq__(self, other): + self_ = self.flatten() + other_ = Config(other).flatten() if isinstance(other, dict) else other + return self_.__eq__(other_) + + def __len__(self): + return len(self.config) + def __iter__(self): return iter(self.config) def __repr__(self): lines = ['\n' + 4 * ' ' + line for line in pformat(self.config).split('\n')] return f"Config({''.join(lines)})" + + def __rshift__(self, other): + """ Parameters + ---------- + other : Pipeline + + Returns + ------- + Pipeline + Pipeline object with an updated config. + """ + return other << self + + def __getstate__(self): + """ Must be explicitly defined for pickling to work. """ + return vars(self) + + def __setstate__(self, state): + """ Must be explicitly defined for pickling to work. """ + vars(self).update(state)