fragile.core.state
Contents
fragile.core.state
#
Module Contents#
Classes#
Data structure that handles the data defining a population of walkers. |
|
A dictionary-style container for storing the current state of the swarm. It allows you |
- class fragile.core.state.State(n_walkers, param_dict)[source]#
Data structure that handles the data defining a population of walkers.
Each population attribute will be stored as a tensor with its first dimension (batch size) representing each walker.
In order to define a tensor attribute, a param_dict dictionary needs to be specified using the following structure:
Example
>>> attr_dict = {'name': {'shape': Optional[tuple|None], 'dtype': dtype}, ... 'biases': {'shape': (10,), 'dtype': float}, ... 'vector': {'dtype': 'float32'}, ... 'sequence': {'shape': None, 'dtype': 'float32'} ... }
Where tuple is a tuple indicating the shape of the desired tensor. The created arrays will be accessible through the name_1 attribute of the class, or by indexing the class with states[“name_1”].
If size is not defined, the attribute will be considered a vector of length n_walkers.
- Parameters
n_walkers (int) – The number of items in the first dimension of the tensors.
param_dict (StateDict) – Dictionary defining the attributes of the tensors.
- param_dict#
A dictionary containing the shape and type of each walker’s attribute.
- Type
StateDict
- Return type
fragile.core.typing.StateDict
- names#
The name of the walker’s attributes tracked by this instance.
- tensor_names#
The name of the walker’s attributes that correspond to tensors.
- list_names#
The name of the walker’s attributes that correspond to lists of objects.
- vector_names#
The name of the walker’s attributes that correspond to vectors of scalars.
- property param_dict#
Return a dictionary containing the shape and type of each walker’s attribute.
- Return type
fragile.core.typing.StateDict
- property names#
Return the name of the walker’s attributes tracked by this instance.
- Return type
Tuple[str]
- property tensor_names#
Return the name of the walker’s attributes that correspond to tensors.
- Return type
Set[str]
- property list_names#
Return the name of the walker’s attributes that correspond to lists of objects.
- Return type
Set[str]
- property vector_names#
Return the name of the walker’s attributes that correspond to vectors of scalars.
- Return type
Set[str]
- __setitem__(key, value)[source]#
Allow the class to set its attributes as if it was a dict.
- Parameters
key – Attribute to be set.
value (Union[Tuple, List, fragile.core.typing.Tensor]) – Value of the target attribute.
- Returns
None.
- __getitem__(item)[source]#
Query an attribute of the class as if it was a dictionary.
- Parameters
item (str) – Name of the attribute to be selected.
- Returns
The corresponding item.
- Return type
Union[fragile.core.typing.Tensor, List[fragile.core.typing.Tensor], SwarmState]
- __repr__()[source]#
Return a string that provides a nice representation of this instance attributes.
- Return type
- get(name, default=None, raise_error=True)[source]#
Get an attribute by key and return the default value if it does not exist.
- items()[source]#
Return a generator for the attribute names and the values of the stored data.
- Return type
Generator
- itervals()[source]#
Iterate the states attributes by walker.
- Returns
Tuple containing all the names of the attributes, and the values that correspond to a given walker.
- iteritems()[source]#
Iterate the states attributes by walker.
- Returns
Tuple containing all the names of the attributes, and the values that correspond to a given walker.
- update(other=None, **kwargs)[source]#
Modify the data stored in this instance.
Existing attributes will be updated, and no new attributes can be created.
- Parameters
other (Union[SwarmState, Dict[str, Tensor]], optional) – Other SwarmState instance to copy upon update. Defaults to None.
**kwargs – Extra key-value pairs of attributes to add to or modify the current state.
- Return type
None
Example
>>> s = State(2, {'name': {'shape': (3, 4), "dtype": bool}}) >>> s.update({'name': np.ones((3, 4))}) >>> len(s.names) 1 >>> s['name'] array([[ True, True, True, True], [ True, True, True, True], [ True, True, True, True]])
- params_to_arrays(param_dict, n_walkers)[source]#
Create a dictionary containing arrays specified by
param_dict
, the attribute dictionary.This method creates a dictionary containing arrays specified by the attribute dictionary. The attribute dictionary defines the attributes of the tensors, and
n_walkers
is the number of items in the first dimension of the data tensors.The method returns a dictionary with the same keys as
param_dict
, containing arrays specified by the values in the attribute dictionary. The method achieves this by iterating through each item in the attribute dictionary, creating a copy of the value, and checking if the shape of the value is specified.If the shape is specified, the method initializes the tensor to zeros using the shape and dtype specified in the attribute dictionary. If the shape is not specified or if the key is in
self.list_names
, the method initializes the value as a list of None withn_walkers
items. If the key is'label'
, the method initializes the value as a list of empty strings withn_walkers
items.- Parameters
param_dict (fragile.core.typing.StateDict) – Dictionary defining the attributes of the tensors.
n_walkers (int) – Number of items in the first dimension of the data tensors.
- Returns
Dictionary with the same names as the attribute dictionary, containing arrays specified by the values in the attribute dictionary.
- Return type
Dict[str, fragile.core.typing.Tensor]
Example
>>> attr_dict = {'weights': {'shape': (10, 5), 'dtype': 'float32'}, ... 'biases': {'shape': (10,), 'dtype': 'float32'}, ... 'label': {'shape': None, 'dtype': 'str'}, ... 'vector': {'dtype': 'float32'}} >>> n_walkers = 3 >>> state = State(param_dict=attr_dict, n_walkers=n_walkers) >>> tensor_dict = state.params_to_arrays(attr_dict, n_walkers) >>> tensor_dict.keys() dict_keys(['weights', 'biases', 'label', 'vector']) >>> tensor_dict['weights'].shape (3, 10, 5) >>> tensor_dict['biases'].shape (3, 10) >>> tensor_dict['label'] [None, None, None] >>> tensor_dict['vector'].shape (3,)
- _parse_value(name, value)[source]#
Ensure that the input value has correct dimensions and shape for a given attribute.
- Parameters
name (str) – Name of the attribute.
value (Any) – New value to set to the attribute.
- Returns
Parsed and validated value of the new state element.
- Return type
Any
Example
>>> s = State(2, {'name': {'shape': (3, 4), "dtype": int}}) >>> parsed_val = s._parse_value('name', [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) >>> parsed_val array([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12]])
- class fragile.core.state.SwarmState(n_walkers, param_dict)[source]#
Bases:
State
A dictionary-style container for storing the current state of the swarm. It allows you to update the status and metadata of the walkers in the swarm.
The keys of the instance must be its attributes. The attribute value can be of any type (tensor, list or any python object). Lists and Tensors can have different len than
n_walkers
if necessary, but tensors should have the same number of rows as walkers (whether active or not).- Parameters
n_walkers (int) – Number of walkers
param_dict (StateDict) – Dictionary defining the attributes of the tensors.
Example
>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=100, param_dict=param_dict)
- property clone_names#
Return the name of the attributes that will be copied when a walker clones.
- Return type
Set[str]
- property actives#
Get the active walkers indices.
- Return type
fragile.core.typing.Tensor
- clone(will_clone, compas_clone, clone_names)[source]#
Clone all the stored data according to the provided index.
- export_walker(index, names=None, copy=False)[source]#
Export the data of a walker at index index as a dictionary.
- Parameters
index (int) – The index of the target walker.
names (Optional[List[str]], optional) – The list of attribute names to be included in the output. If None, all attributes will be included. Defaults to None.
copy (bool, optional) – If True, the returned dictionary will be a copy of the original data. Defaults to False.
- Returns
A dictionary containing the requested attributes and their corresponding values for the specified walker.
- Return type
Dict[str, Union[Tensor, numpy.ndarray]]
Examples
>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=10, param_dict=param_dict) >>> s.reset() >>> walker_dict = s.export_walker(0, names=["x"]) >>> print(walker_dict) {'x': array([[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]])}
- import_walker(data, index=0)[source]#
Takes data dictionary and imports it into state at indice index.
- Parameters
data (Dict) – Dictionary containing the data to be imported.
index (int, optional) – Walker index to receive the data. Defaults to 0.
- Return type
None
Examples
>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=10, param_dict=param_dict) >>> s.reset() >>> data = {"x": judo.ones((3, 4), dtype=int)} >>> s.import_walker(data, index=0) >>> s.get("x")[0, 0, :3] array([1, 1, 1])
- reset(root_walker=None)[source]#
Completely resets both current and history data that have been held in state.
Optionally can take a root value to reset individual attributes.
- Parameters
root_walker (Optional[Dict[str, Tensor]], optional) – The initial state when resetting.
- Return type
None
Examples
>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=10, param_dict=param_dict) >>> s.reset() >>> walker_dict = s.export_walker(0, names=["x"]) >>> print(walker_dict["x"].shape) (1, 3, 4)
- get(name, default=None, raise_error=True, inactives=False)[source]#
Get an attribute by key and return the default value if it does not exist.
- Parameters
- Returns
Target attribute if found in the instance, otherwise returns the default value.
Examples
>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=10, param_dict=param_dict) >>> s.reset() >>> print(s.get("x").shape) (10, 3, 4)
- update(other=None, inactives=False, **kwargs)[source]#
Modify the data stored in the SwarmState instance.
Existing attributes will be updated, and new attributes will be created if needed.
- Parameters
other (Union[SwarmState, Dict[str, fragile.core.typing.Tensor]]) – State class that will be copied upon update.
inactives (bool) – Whether to update the walkers marked as inactive.
**kwargs – It is possible to specify the update as name value attributes, where name is the name of the attribute to be updated, and value is the new value for the attribute.
- Returns
None.
- Return type
None
Examples
>>> param_dict = {"x": {"shape": (3, 4), "dtype": int}} >>> s = SwarmState(n_walkers=10, param_dict=param_dict) >>> s.reset() >>> s.update(x=judo.ones((10, 3, 4), dtype=int)) >>> print(s.get("x")[0,0,0]) 1
- update_actives(actives)[source]#
Set the walkers marked as active.
Example
To set the first 10 walkers as active, call the function with a tensor of size equal to the number of walkers, where the first ten elements are True and the remaining elements are False:
>>> param_dict = {"vector":{"dtype":int}} >>> s = SwarmState(n_walkers=20, param_dict=param_dict) >>> active_walkers = np.concatenate([np.ones(10), np.zeros(10)]).astype(bool) >>> s.update_actives(active_walkers)
This will mark those walkers as active, and any attribute updated with inactives=False (this is the default) will only modify the data from those walkers.
- Return type
None