:py:mod:`fragile.core.state` ============================ .. py:module:: fragile.core.state Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: fragile.core.state.State fragile.core.state.SwarmState .. py:class:: State(n_walkers, param_dict) Data structure that handles the data defining a population of walkers. Each population attribute will be stored as a tensor with its first dimension (batch size) representing each walker. In order to define a tensor attribute, a `param_dict` dictionary needs to be specified using the following structure: .. rubric:: Example >>> attr_dict = {'name': {'shape': Optional[tuple|None], 'dtype': dtype}, # doctest: +SKIP ... 'biases': {'shape': (10,), 'dtype': float}, ... 'vector': {'dtype': 'float32'}, ... 'sequence': {'shape': None, 'dtype': 'float32'} ... } Where tuple is a tuple indicating the shape of the desired tensor. The created arrays will be accessible through the `name_1` attribute of the class, or by indexing the class with `states["name_1"]`. If `size` is not defined, the attribute will be considered a vector of length `n_walkers`. :param n_walkers: The number of items in the first dimension of the tensors. :type n_walkers: int :param param_dict: Dictionary defining the attributes of the tensors. :type param_dict: StateDict .. attribute:: n_walkers The number of walkers that this instance represents. :type: int .. attribute:: param_dict A dictionary containing the shape and type of each walker's attribute. :type: StateDict .. attribute:: names The name of the walker's attributes tracked by this instance. :type: Tuple[str] .. attribute:: tensor_names The name of the walker's attributes that correspond to tensors. :type: Set[str] .. attribute:: list_names The name of the walker's attributes that correspond to lists of objects. :type: Set[str] .. attribute:: vector_names The name of the walker's attributes that correspond to vectors of scalars. :type: Set[str] .. py:method:: n_walkers() :property: Return the number of walkers that this instance represents. .. py:method:: param_dict() :property: Return a dictionary containing the shape and type of each walker's attribute. .. py:method:: names() :property: Return the name of the walker's attributes tracked by this instance. .. py:method:: tensor_names() :property: Return the name of the walker's attributes that correspond to tensors. .. py:method:: list_names() :property: Return the name of the walker's attributes that correspond to lists of objects. .. py:method:: vector_names() :property: Return the name of the walker's attributes that correspond to vectors of scalars. .. py:method:: __len__() Length is equal to n_walkers. .. py:method:: __setitem__(key, value) Allow the class to set its attributes as if it was a dict. :param key: Attribute to be set. :param value: Value of the target attribute. :returns: None. .. py:method:: __getitem__(item) Query an attribute of the class as if it was a dictionary. :param item: Name of the attribute to be selected. :returns: The corresponding item. .. py:method:: __repr__() Return a string that provides a nice representation of this instance attributes. .. py:method:: __hash__() Return an integer that represents the hash of the current instance. .. py:method:: hash_attribute(name) Return a unique id for a given attribute. .. py:method:: hash_batch(name) Return a unique id for each walker attribute. .. py:method:: get(name, default=None, raise_error = True) Get an attribute by key and return the default value if it does not exist. :param name: Attribute to be recovered. :param default: Value returned in case the attribute is not part of state. :param raise_error: If True, raise AttributeError if name is not present in states. :returns: Target attribute if found in the instance, otherwise returns the default value. .. py:method:: keys() Return a generator for the values of the stored data. .. py:method:: values() Return a generator for the values of the stored data. .. py:method:: items() Return a generator for the attribute names and the values of the stored data. .. py:method:: itervals() Iterate the states attributes by walker. :returns: Tuple containing all the names of the attributes, and the values that correspond to a given walker. .. py:method:: iteritems() Iterate the states attributes by walker. :returns: Tuple containing all the names of the attributes, and the values that correspond to a given walker. .. py:method:: update(other = None, **kwargs) Modify the data stored in this instance. Existing attributes will be updated, and no new attributes can be created. :param other: Other SwarmState instance to copy upon update. Defaults to None. :type other: Union[SwarmState, Dict[str, Tensor]], optional :param \*\*kwargs: Extra key-value pairs of attributes to add to or modify the current state. .. rubric:: Example >>> s = State(2, {'name': {'shape': (3, 4), "dtype": bool}}) >>> s.update({'name': np.ones((3, 4))}) >>> len(s.names) 1 >>> s['name'] array([[ True, True, True, True], [ True, True, True, True], [ True, True, True, True]]) .. py:method:: to_dict() Return the stored data as a dictionary of arrays. .. py:method:: copy() Crete a copy of the current instance. .. py:method:: reset() Reset the values of the class .. py:method:: params_to_arrays(param_dict, n_walkers) Create a dictionary containing arrays specified by ``param_dict``, the attribute dictionary. This method creates a dictionary containing arrays specified by the attribute dictionary. The attribute dictionary defines the attributes of the tensors, and ``n_walkers`` is the number of items in the first dimension of the data tensors. The method returns a dictionary with the same keys as ``param_dict``, containing arrays specified by the values in the attribute dictionary. The method achieves this by iterating through each item in the attribute dictionary, creating a copy of the value, and checking if the shape of the value is specified. If the shape is specified, the method initializes the tensor to zeros using the shape and dtype specified in the attribute dictionary. If the shape is not specified or if the key is in ``self.list_names``, the method initializes the value as a list of `None` with ``n_walkers`` items. If the key is ``'label'``, the method initializes the value as a list of empty strings with ``n_walkers`` items. :param param_dict: Dictionary defining the attributes of the tensors. :param n_walkers: Number of items in the first dimension of the data tensors. :returns: Dictionary with the same names as the attribute dictionary, containing arrays specified by the values in the attribute dictionary. .. rubric:: Example >>> attr_dict = {'weights': {'shape': (10, 5), 'dtype': 'float32'}, ... 'biases': {'shape': (10,), 'dtype': 'float32'}, ... 'label': {'shape': None, 'dtype': 'str'}, ... 'vector': {'dtype': 'float32'}} >>> n_walkers = 3 >>> state = State(param_dict=attr_dict, n_walkers=n_walkers) >>> tensor_dict = state.params_to_arrays(attr_dict, n_walkers) >>> tensor_dict.keys() dict_keys(['weights', 'biases', 'label', 'vector']) >>> tensor_dict['weights'].shape (3, 10, 5) >>> tensor_dict['biases'].shape (3, 10) >>> tensor_dict['label'] [None, None, None] >>> tensor_dict['vector'].shape (3,) .. py:method:: _parse_value(name, value) Ensure that the input value has correct dimensions and shape for a given attribute. :param name: Name of the attribute. :type name: str :param value: New value to set to the attribute. :type value: Any :returns: Parsed and validated value of the new state element. :rtype: Any .. rubric:: Example >>> s = State(2, {'name': {'shape': (3, 4), "dtype": int}}) >>> parsed_val = s._parse_value('name', [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) >>> parsed_val array([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12]]) .. py:class:: SwarmState(n_walkers, param_dict) Bases: :py:obj:`State` A dictionary-style container for storing the current state of the swarm. It allows you to update the status and metadata of the walkers in the swarm. The keys of the instance must be its attributes. The attribute value can be of any type (tensor, list or any python object). Lists and Tensors can have different len than ``n_walkers`` if necessary, but tensors should have the same number of rows as walkers (whether active or not). :param n_walkers: Number of walkers :type n_walkers: int :param param_dict: Dictionary defining the attributes of the tensors. :type param_dict: StateDict .. rubric:: Example >>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=100, param_dict=param_dict) .. py:method:: clone_names() :property: Return the name of the attributes that will be copied when a walker clones. .. py:method:: actives() :property: Get the active walkers indices. .. py:method:: n_actives() :property: Get the number of active walkers. .. py:method:: clone(will_clone, compas_clone, clone_names) Clone all the stored data according to the provided index. .. py:method:: export_walker(index, names = None, copy = False) Export the data of a walker at index `index` as a dictionary. :param index: The index of the target walker. :type index: int :param names: The list of attribute names to be included in the output. If None, all attributes will be included. Defaults to None. :type names: Optional[List[str]], optional :param copy: If True, the returned dictionary will be a copy of the original data. Defaults to False. :type copy: bool, optional :returns: A dictionary containing the requested attributes and their corresponding values for the specified walker. :rtype: Dict[str, Union[Tensor, numpy.ndarray]] .. rubric:: Examples >>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=10, param_dict=param_dict) >>> s.reset() >>> walker_dict = s.export_walker(0, names=["x"]) >>> print(walker_dict) # doctest: +NORMALIZE_WHITESPACE {'x': array([[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]])} .. py:method:: import_walker(data, index = 0) Takes data dictionary and imports it into state at indice `index`. :param data: Dictionary containing the data to be imported. :type data: Dict :param index: Walker index to receive the data. Defaults to 0. :type index: int, optional .. rubric:: Examples >>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=10, param_dict=param_dict) >>> s.reset() >>> data = {"x": judo.ones((3, 4), dtype=int)} >>> s.import_walker(data, index=0) >>> s.get("x")[0, 0, :3] array([1, 1, 1]) .. py:method:: reset(root_walker = None) Completely resets both current and history data that have been held in state. Optionally can take a root value to reset individual attributes. :param root_walker: The initial state when resetting. :type root_walker: Optional[Dict[str, Tensor]], optional .. rubric:: Examples >>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=10, param_dict=param_dict) >>> s.reset() >>> walker_dict = s.export_walker(0, names=["x"]) >>> print(walker_dict["x"].shape) # doctest: +NORMALIZE_WHITESPACE (1, 3, 4) .. py:method:: get(name, default=None, raise_error = True, inactives = False) Get an attribute by key and return the default value if it does not exist. :param name: Attribute to be recovered. :param default: Value returned in case the attribute is not part of state. :param raise_error: If True, raise AttributeError if name is not present in states. :param inactives: Whether to update the walkers marked as inactive. :returns: Target attribute if found in the instance, otherwise returns the default value. .. rubric:: Examples >>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=10, param_dict=param_dict) >>> s.reset() >>> print(s.get("x").shape) (10, 3, 4) .. py:method:: _update_active_list(current_vals, new_vals) Update the list of active walkers. .. py:method:: update(other = None, inactives = False, **kwargs) Modify the data stored in the SwarmState instance. Existing attributes will be updated, and new attributes will be created if needed. :param other: State class that will be copied upon update. :param inactives: Whether to update the walkers marked as inactive. :param \*\*kwargs: It is possible to specify the update as name value attributes, where name is the name of the attribute to be updated, and value is the new value for the attribute. :returns: None. .. rubric:: Examples >>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=10, param_dict=param_dict) >>> s.reset() >>> s.update(x=judo.ones((10, 3, 4), dtype=int)) >>> print(s.get("x")[0,0,0]) 1 .. py:method:: update_actives(actives) Set the walkers marked as active. .. rubric:: Example To set the first 10 walkers as active, call the function with a tensor of size equal to the number of walkers, where the first ten elements are `True` and the remaining elements are `False`: >>> param_dict = {"vector":{"dtype":int}} >>> s = SwarmState(n_walkers=20, param_dict=param_dict) >>> active_walkers = np.concatenate([np.ones(10), np.zeros(10)]).astype(bool) >>> s.update_actives(active_walkers) This will mark those walkers as active, and any attribute updated with inactives=False (this is the default) will only modify the data from those walkers.