fragile.core.state#

Module Contents#

Classes#

State

Data structure that handles the data defining a population of walkers.

SwarmState

A dictionary-style container for storing the current state of the swarm. It allows you

class fragile.core.state.State(n_walkers, param_dict)[source]#

Data structure that handles the data defining a population of walkers.

Each population attribute will be stored as a tensor with its first dimension (batch size) representing each walker.

In order to define a tensor attribute, a param_dict dictionary needs to be specified using the following structure:

Example

>>> attr_dict = {'name': {'shape': Optional[tuple|None], 'dtype': dtype},  
...             'biases': {'shape': (10,), 'dtype': float},
...             'vector': {'dtype': 'float32'},
...             'sequence': {'shape': None, 'dtype': 'float32'}
...             }

Where tuple is a tuple indicating the shape of the desired tensor. The created arrays will be accessible through the name_1 attribute of the class, or by indexing the class with states[“name_1”].

If size is not defined, the attribute will be considered a vector of length n_walkers.

Parameters
  • n_walkers (int) – The number of items in the first dimension of the tensors.

  • param_dict (StateDict) – Dictionary defining the attributes of the tensors.

n_walkers#

The number of walkers that this instance represents.

Type

int

Return type

int

param_dict#

A dictionary containing the shape and type of each walker’s attribute.

Type

StateDict

Return type

fragile.core.typing.StateDict

names#

The name of the walker’s attributes tracked by this instance.

Type

Tuple[str]

Return type

Tuple[str]

tensor_names#

The name of the walker’s attributes that correspond to tensors.

Type

Set[str]

Return type

Set[str]

list_names#

The name of the walker’s attributes that correspond to lists of objects.

Type

Set[str]

Return type

Set[str]

vector_names#

The name of the walker’s attributes that correspond to vectors of scalars.

Type

Set[str]

Return type

Set[str]

property n_walkers#

Return the number of walkers that this instance represents.

Return type

int

property param_dict#

Return a dictionary containing the shape and type of each walker’s attribute.

Return type

fragile.core.typing.StateDict

property names#

Return the name of the walker’s attributes tracked by this instance.

Return type

Tuple[str]

property tensor_names#

Return the name of the walker’s attributes that correspond to tensors.

Return type

Set[str]

property list_names#

Return the name of the walker’s attributes that correspond to lists of objects.

Return type

Set[str]

property vector_names#

Return the name of the walker’s attributes that correspond to vectors of scalars.

Return type

Set[str]

__len__()[source]#

Length is equal to n_walkers.

Return type

int

__setitem__(key, value)[source]#

Allow the class to set its attributes as if it was a dict.

Parameters
  • key – Attribute to be set.

  • value (Union[Tuple, List, fragile.core.typing.Tensor]) – Value of the target attribute.

Returns

None.

__getitem__(item)[source]#

Query an attribute of the class as if it was a dictionary.

Parameters

item (str) – Name of the attribute to be selected.

Returns

The corresponding item.

Return type

Union[fragile.core.typing.Tensor, List[fragile.core.typing.Tensor], SwarmState]

__repr__()[source]#

Return a string that provides a nice representation of this instance attributes.

Return type

str

__hash__()[source]#

Return an integer that represents the hash of the current instance.

Return type

int

hash_attribute(name)[source]#

Return a unique id for a given attribute.

Parameters

name (str) –

Return type

int

hash_batch(name)[source]#

Return a unique id for each walker attribute.

Parameters

name (str) –

Return type

List[int]

get(name, default=None, raise_error=True)[source]#

Get an attribute by key and return the default value if it does not exist.

Parameters
  • name (str) – Attribute to be recovered.

  • default – Value returned in case the attribute is not part of state.

  • raise_error (bool) – If True, raise AttributeError if name is not present in states.

Returns

Target attribute if found in the instance, otherwise returns the default value.

keys()[source]#

Return a generator for the values of the stored data.

Return type

_dict_keys[str, Dict[str, Any]]

values()[source]#

Return a generator for the values of the stored data.

Return type

Generator

items()[source]#

Return a generator for the attribute names and the values of the stored data.

Return type

Generator

itervals()[source]#

Iterate the states attributes by walker.

Returns

Tuple containing all the names of the attributes, and the values that correspond to a given walker.

iteritems()[source]#

Iterate the states attributes by walker.

Returns

Tuple containing all the names of the attributes, and the values that correspond to a given walker.

update(other=None, **kwargs)[source]#

Modify the data stored in this instance.

Existing attributes will be updated, and no new attributes can be created.

Parameters
  • other (Union[SwarmState, Dict[str, Tensor]], optional) – Other SwarmState instance to copy upon update. Defaults to None.

  • **kwargs – Extra key-value pairs of attributes to add to or modify the current state.

Return type

None

Example

>>> s = State(2, {'name': {'shape': (3, 4), "dtype": bool}})
>>> s.update({'name': np.ones((3, 4))})
>>> len(s.names)
1
>>> s['name']
array([[ True,  True,  True,  True],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True]])
to_dict()[source]#

Return the stored data as a dictionary of arrays.

Return type

Dict[str, Union[fragile.core.typing.Tensor, list]]

copy()[source]#

Crete a copy of the current instance.

Return type

SwarmState

reset()[source]#

Reset the values of the class

Return type

None

params_to_arrays(param_dict, n_walkers)[source]#

Create a dictionary containing arrays specified by param_dict, the attribute dictionary.

This method creates a dictionary containing arrays specified by the attribute dictionary. The attribute dictionary defines the attributes of the tensors, and n_walkers is the number of items in the first dimension of the data tensors.

The method returns a dictionary with the same keys as param_dict, containing arrays specified by the values in the attribute dictionary. The method achieves this by iterating through each item in the attribute dictionary, creating a copy of the value, and checking if the shape of the value is specified.

If the shape is specified, the method initializes the tensor to zeros using the shape and dtype specified in the attribute dictionary. If the shape is not specified or if the key is in self.list_names, the method initializes the value as a list of None with n_walkers items. If the key is 'label', the method initializes the value as a list of empty strings with n_walkers items.

Parameters
  • param_dict (fragile.core.typing.StateDict) – Dictionary defining the attributes of the tensors.

  • n_walkers (int) – Number of items in the first dimension of the data tensors.

Returns

Dictionary with the same names as the attribute dictionary, containing arrays specified by the values in the attribute dictionary.

Return type

Dict[str, fragile.core.typing.Tensor]

Example

>>> attr_dict = {'weights': {'shape': (10, 5), 'dtype': 'float32'},
...              'biases': {'shape': (10,), 'dtype': 'float32'},
...              'label': {'shape': None, 'dtype': 'str'},
...              'vector': {'dtype': 'float32'}}
>>> n_walkers = 3
>>> state = State(param_dict=attr_dict, n_walkers=n_walkers)
>>> tensor_dict = state.params_to_arrays(attr_dict, n_walkers)
>>> tensor_dict.keys()
dict_keys(['weights', 'biases', 'label', 'vector'])
>>> tensor_dict['weights'].shape
(3, 10, 5)
>>> tensor_dict['biases'].shape
(3, 10)
>>> tensor_dict['label']
[None, None, None]
>>> tensor_dict['vector'].shape
(3,)
_parse_value(name, value)[source]#

Ensure that the input value has correct dimensions and shape for a given attribute.

Parameters
  • name (str) – Name of the attribute.

  • value (Any) – New value to set to the attribute.

Returns

Parsed and validated value of the new state element.

Return type

Any

Example

>>> s = State(2, {'name': {'shape': (3, 4), "dtype": int}})
>>> parsed_val = s._parse_value('name', [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
>>> parsed_val
array([[ 1,  2,  3,  4],
       [ 5,  6,  7,  8],
       [ 9, 10, 11, 12]])
class fragile.core.state.SwarmState(n_walkers, param_dict)[source]#

Bases: State

A dictionary-style container for storing the current state of the swarm. It allows you to update the status and metadata of the walkers in the swarm.

The keys of the instance must be its attributes. The attribute value can be of any type (tensor, list or any python object). Lists and Tensors can have different len than n_walkers if necessary, but tensors should have the same number of rows as walkers (whether active or not).

Parameters
  • n_walkers (int) – Number of walkers

  • param_dict (StateDict) – Dictionary defining the attributes of the tensors.

Example

>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}}
>>> s = SwarmState(n_walkers=100, param_dict=param_dict)
property clone_names#

Return the name of the attributes that will be copied when a walker clones.

Return type

Set[str]

property actives#

Get the active walkers indices.

Return type

fragile.core.typing.Tensor

property n_actives#

Get the number of active walkers.

Return type

int

clone(will_clone, compas_clone, clone_names)[source]#

Clone all the stored data according to the provided index.

export_walker(index, names=None, copy=False)[source]#

Export the data of a walker at index index as a dictionary.

Parameters
  • index (int) – The index of the target walker.

  • names (Optional[List[str]], optional) – The list of attribute names to be included in the output. If None, all attributes will be included. Defaults to None.

  • copy (bool, optional) – If True, the returned dictionary will be a copy of the original data. Defaults to False.

Returns

A dictionary containing the requested attributes and their corresponding values for the specified walker.

Return type

Dict[str, Union[Tensor, numpy.ndarray]]

Examples

>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}}
>>> s = SwarmState(n_walkers=10, param_dict=param_dict)
>>> s.reset()
>>> walker_dict = s.export_walker(0, names=["x"])
>>> print(walker_dict)  
{'x': array([[[0, 0, 0, 0],
              [0, 0, 0, 0],
              [0, 0, 0, 0]]])}
import_walker(data, index=0)[source]#

Takes data dictionary and imports it into state at indice index.

Parameters
  • data (Dict) – Dictionary containing the data to be imported.

  • index (int, optional) – Walker index to receive the data. Defaults to 0.

Return type

None

Examples

>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}}
>>> s = SwarmState(n_walkers=10, param_dict=param_dict)
>>> s.reset()
>>> data = {"x": judo.ones((3, 4), dtype=int)}
>>> s.import_walker(data, index=0)
>>> s.get("x")[0, 0, :3]
array([1, 1, 1])
reset(root_walker=None)[source]#

Completely resets both current and history data that have been held in state.

Optionally can take a root value to reset individual attributes.

Parameters

root_walker (Optional[Dict[str, Tensor]], optional) – The initial state when resetting.

Return type

None

Examples

>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}}
>>> s = SwarmState(n_walkers=10, param_dict=param_dict)
>>> s.reset()
>>> walker_dict = s.export_walker(0, names=["x"])
>>> print(walker_dict["x"].shape)  
(1, 3, 4)
get(name, default=None, raise_error=True, inactives=False)[source]#

Get an attribute by key and return the default value if it does not exist.

Parameters
  • name (str) – Attribute to be recovered.

  • default – Value returned in case the attribute is not part of state.

  • raise_error (bool) – If True, raise AttributeError if name is not present in states.

  • inactives (bool) – Whether to update the walkers marked as inactive.

Returns

Target attribute if found in the instance, otherwise returns the default value.

Examples

>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}}
>>> s = SwarmState(n_walkers=10, param_dict=param_dict)
>>> s.reset()
>>> print(s.get("x").shape)
(10, 3, 4)
_update_active_list(current_vals, new_vals)[source]#

Update the list of active walkers.

update(other=None, inactives=False, **kwargs)[source]#

Modify the data stored in the SwarmState instance.

Existing attributes will be updated, and new attributes will be created if needed.

Parameters
  • other (Union[SwarmState, Dict[str, fragile.core.typing.Tensor]]) – State class that will be copied upon update.

  • inactives (bool) – Whether to update the walkers marked as inactive.

  • **kwargs – It is possible to specify the update as name value attributes, where name is the name of the attribute to be updated, and value is the new value for the attribute.

Returns

None.

Return type

None

Examples

>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}}
>>> s = SwarmState(n_walkers=10, param_dict=param_dict)
>>> s.reset()
>>> s.update(x=judo.ones((10, 3, 4), dtype=int))
>>> print(s.get("x")[0,0,0])
1
update_actives(actives)[source]#

Set the walkers marked as active.

Example

To set the first 10 walkers as active, call the function with a tensor of size equal to the number of walkers, where the first ten elements are True and the remaining elements are False:

>>> param_dict = {"vector":{"dtype":int}}
>>> s = SwarmState(n_walkers=20, param_dict=param_dict)
>>> active_walkers = np.concatenate([np.ones(10), np.zeros(10)]).astype(bool)
>>> s.update_actives(active_walkers)

This will mark those walkers as active, and any attribute updated with inactives=False (this is the default) will only modify the data from those walkers.

Return type

None