Source code for fragile.core.api_classes

import copy
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union

import judo
from judo import dtype, random_state

from fragile.core.state import SwarmState
from fragile.core.typing import InputDict, StateData, StateDict, Tensor


[docs]class SwarmComponent: """ A component of a swarm simulation. Every class that stores its data in :class:`SwarmState` inherits from this class. Args: swarm: Reference to the :class:`Swarm` that incorporates this :class:`SwarmComponent`. Defaults to None. param_dict: Describes the multi-dimensional arrays that define the internal state of the component. Defaults to None. inputs: Describes the data that this components processes as inputs. This data is present in the internal state of the :class:`Swarm`, and generated by other components. Defaults to None. outputs: Name of the data attributes that this component makes available as part of its internal state. Defaults to None. Attributes: swarm (Swarm): The parent :class:`Swarm` instance. n_walkers (int): Number of walkers in the :class:`Swarm`. inputs (Dict[str, Dict[str, Any]]): A dictionary mapping input names to their corresponding type and default values. outputs (Tuple[str]): A tuple containing the names of the data attributes. param_dict (Dict[str, Dict[str, Any]]): A dictionary mapping parameter names. Methods: get(...) get_input_data(...) update(...) reset(...) _prepare_tensors(...) """ default_inputs = {} default_outputs = tuple() default_param_dict = {} def __init__( self, swarm: Optional["SwarmAPI"] = None, param_dict: Optional[StateDict] = None, inputs: Optional[InputDict] = None, outputs: Optional[Tuple[str]] = None, ): """ Initialize a :class:`SwarmComponent`. Args: swarm (Optional["SwarmAPI";], optional): Reference to the :class:`Swarm` that incorporates this :class:`SwarmComponent`. Defaults to None. param_dict (Optional[StateDict], optional): Describes the multi-dimensional arrays that define the internal state of the component. Defaults to None. inputs (Optional[InputDict], optional): Describes the data that this components processes as inputs. This data is present in the internal state of the :class:`Swarm`, and generated by other components. Defaults to None. outputs (Optional[Tuple[str]], optional): Name of the data attributes that this component makes available as part of its internal state. Defaults to None. """ param_dict = param_dict or {} param_dict = {**self.default_param_dict, **param_dict} inputs = inputs or {} inputs = {**self.default_inputs, **inputs} outputs = tuple(outputs) if outputs is not None else tuple() outputs = tuple(set(self.default_outputs + outputs)) self._swarm = None self._param_dict = param_dict self._inputs = inputs self._outputs = outputs if swarm is not None: # This way you can run side effects on child classes self.setup(swarm) @property def swarm(self) -> "SwarmAPI": """Return a reference to the :class:`Swarm` that includes the current component.""" return self._swarm @property def n_walkers(self) -> int: """Return the number of walkers of the :class:`Swarm`.""" return self.swarm.n_walkers @property def inputs(self) -> InputDict: """Return a dictionary containing the data that this component needs to function.""" return dict(self._inputs) @property def outputs(self) -> Tuple[str, ...]: """Return a tuple containing the names of the data attribute that the component outputs.""" return tuple(self._outputs) @property def param_dict(self) -> StateDict: """Return the dictionary defining all the data attributes that the component requires.""" return dict(self._param_dict)
[docs] def setup(self, swarm): """Prepare the component during the setup phase of the :class:`Swarm`.""" self._swarm = swarm
[docs] def get( self, name: str, default: Any = None, raise_error: bool = False, inactives: bool = False, ) -> Any: """Access attributes of the :class:`Swarm` and its children.""" return self.swarm.state.get( name=name, default=default, raise_error=raise_error, inactives=inactives, )
[docs] def get_input_data(self) -> StateData: """ Return a dictionary with all the data that this component requires as inputs. Returns: Dictionary containing all the data required by the :class:`SwarmComponent`, where each key is the name of a data attribute and its value a multi-dimensional array. """ def get_one_input(name, values): raise_error = not values.get("optional", False) return self.get(name=name, default=values.get("default"), raise_error=raise_error) return {k: get_one_input(k, v) for k, v in self.inputs.items()}
[docs] def update( self, other: Union["SwarmState", Dict[str, Tensor]] = None, inactives: bool = False, **kwargs, ) -> None: """ Modify the data stored in the SwarmState instance. Existing attributes will be updated, and new attributes will be created if needed. Args: other: State class that will be copied upon update. inactives: Whether to update the walkers marked as inactive. **kwargs: It is possible to specify the update as name value attributes, where name is the name of the attribute to be updated, and value is the new value for the attribute. Returns: None """ return self.swarm.state.update(other=other, inactives=inactives, **kwargs)
[docs] def _prepare_tensors(self, **kwargs): if kwargs: step_data = kwargs else: step_data = self.get_input_data() return step_data
[docs] def reset( self, inplace: bool = True, root_walker: Optional[StateData] = None, states: Optional[StateData] = None, **kwargs, ): """ Reset the internal state of the :class:`SwarmComponent`. Args: inplace (bool, optional): Unused. Defaults to True. root_walker (Optional[StateData], optional): Set the internal state of the :class:`SwarmComponent` to this value. Defaults to None. states (Optional[StateData], optional): Set the internal state of the :class:`SwarmComponent` to this value. Defaults to None. kwargs: Other parameters required to reset the component. """ pass
[docs]class EnvironmentAPI(SwarmComponent): """ The Environment is in charge of stepping the walkers, acting as a state \ transition function. For every different problem, a new Environment needs to be implemented following the :class:`EnvironmentAPI` interface. """ default_inputs = {"actions": {}} default_outputs = ("observs", "rewards", "oobs") def __init__( self, action_shape: tuple, action_dtype, observs_shape: tuple, observs_dtype, swarm: "SwarmAPI" = None, ): """ Initialize the :class:`EnvironmentAPI`. Args: action_shape (tuple): _description_ action_dtype (_type_): _description_ observs_shape (_type_): _description_ observs_dtype (_type_): _description_ swarm (SwarmAPI, optional): _description_. Defaults to None. """ self._action_shape = action_shape self._action_dtype = action_dtype self._observs_shape = observs_shape self._observs_dtype = observs_dtype super(EnvironmentAPI, self).__init__( swarm=swarm, param_dict=self.param_dict, ) @property def observs_shape(self) -> tuple: """Return the shape of the observations.""" return self._observs_dtype @property def observs_dtype(self): """Return the dtype of the observations.""" return self._observs_dtype @property def action_shape(self) -> tuple: """Return the shape of the actions.""" return self._action_shape @property def action_dtype(self): """Return the dtype of the actions.""" return self._action_dtype @property def param_dict(self) -> StateDict: """Return the dictionary defining all the data attributes that the component requires.""" param_dict = { "observs": {"shape": self._observs_shape, "dtype": self._observs_dtype}, "rewards": {"dtype": dtype.float32}, "oobs": {"dtype": dtype.bool}, "actions": {"shape": self._action_shape, "dtype": self._action_dtype}, } return param_dict
[docs] def step(self, **kwargs) -> StateData: """ Return the data corresponding to the new state of the environment after \ using the input data to make the corresponding state transition. """ raise NotImplementedError()
[docs] def make_transitions( self, inplace: bool = True, inactives: bool = False, **kwargs, ) -> Union[None, StateData]: """ Return the data corresponding to the new state of the environment after \ using the input data to make the corresponding state transition. Args: inplace: If ``False`` return the new data. If ``True``, update the state of the Swarm. inactives: Whether to update the walkers marked as inactive. **kwargs: Keyword arguments passed if the returned value from the ``states_to_data`` function of the class was a dictionary. Returns: Dictionary containing the data representing the state of the environment after the state transition. The keys of the dictionary are the names of the data attributes and its values are arrays representing a batch of new values for that attribute. The :class:`StatesEnv` returned by ``step`` will contain the returned data. """ input_data = self._prepare_tensors(**kwargs) out_data = self.step(**input_data) if inplace: self.update(**out_data, inactives=inactives) return out_data
[docs] def reset( self, inplace: bool = True, root_walker: Optional[StateData] = None, states: Optional[StateData] = None, inactives: bool = True, **kwargs, ) -> Union[None, StateData]: """ Reset the internal state of the :class:`SwarmComponent`. Args: inplace (bool, optional): If ``False`` return the new data. If ``True``, update the state of the Swarm. Defaults to ``True``. root_walker (Optional[StateData], optional): _description_. Defaults to None. states (Optional[StateData], optional): _description_. Defaults to None. inactives (bool, optional): Whether to update the walkers marked as inactive. Defaults to ``True``. kwargs: Other arguments passed to make_transitions. Returns: Union[None, StateData]: Dictionary containing the data representing the state of the environment after the state transition. The keys of the dictionary are the names of the data attributes and its values are arrays representing a batch of new values for that attribute. The :class:`StatesEnv` returned by ``step`` will contain the returned data. """ return self.make_transitions(inplace=inplace, inactives=inactives, **kwargs)
[docs]class PolicyAPI(SwarmComponent): """ The policy is in charge of calculating the interactions with the :class:`Environment`. The PolicyAPI class is responsible for defining the policy that determines the actions for interacting with the :class:`Environment` in a swarm simulation. This is an abstract base class, and specific policy implementations should inherit from this class and implement the 'select_actions' method. """ default_outputs = tuple(["actions"])
[docs] def select_actions(self, **kwargs) -> Union[Tensor, StateData]: """ Select actions for each walker in the swarm based on the current state. This method must be implemented by subclasses. Args: **kwargs: Additional keyword arguments required for selecting actions. Returns: Union[Tensor, StateData]: The selected actions as a Tensor or a StateData dictionary. """ raise NotImplementedError
[docs] def act(self, inplace: bool = True, **kwargs) -> Union[None, StateData]: """ Calculate SwarmState containing the data needed to interact with the environment. Args: inplace (bool, optional): If True, updates the swarm state with the selected actions. If False, returns the selected actions. Defaults to True. **kwargs: Additional keyword arguments required for selecting actions. Returns: Union[None, StateData]: None if inplace is True. Otherwise, a StateData dictionary containing the selected actions. """ action_input = self._prepare_tensors(**kwargs) actions_data = self.select_actions(**action_input) if not isinstance(actions_data, dict): actions_data = {"actions": actions_data} if inplace: self.update(**actions_data) else: return actions_data
[docs] def reset( self, inplace: bool = True, root_walker: Optional[StateData] = None, states: Optional[StateData] = None, **kwargs, ) -> Union[None, StateData]: """ Reset the internal state of the :class:`PolicyAPI`. Args: inplace (bool, optional): If True, updates the swarm state with the selected actions. If False, returns the selected actions. Defaults to True. root_walker (Optional[StateData], optional): Set the internal state of the PolicyAPI to this value. Defaults to None. states (Optional[StateData], optional): Set the internal state of the PolicyAPI to this value. Defaults to None. **kwargs: Other parameters required to reset the component. Returns: Union[None, StateData]: None if inplace is True. Otherwise, a StateData dictionary containing the selected actions. """ # TODO: only run act when inputs are not present in root_walker/states # if root_walker is None and states is None: return self.act(inplace=inplace, **kwargs)
[docs]class Callback(SwarmComponent): """ The :class:`Walkers` is a data structure that takes care of all the data involved \ in making a Swarm evolve. """ name = None
[docs] def before_reset(self): pass
[docs] def after_reset(self): pass
[docs] def before_evolve(self): pass
[docs] def after_evolve(self): pass
[docs] def before_policy(self): pass
[docs] def after_policy(self): pass
[docs] def before_env(self): pass
[docs] def after_env(self): pass
[docs] def before_walkers(self): pass
[docs] def after_walkers(self): pass
[docs] def reset( self, inplace: bool = True, root_walker: Optional[StateData] = None, states: Optional[StateData] = None, **kwargs, ): pass
[docs] def evolution_end(self) -> bool: return False
[docs] def run_end(self): pass
[docs]class WalkersMetric(SwarmComponent):
[docs] def __call__(self, inplace: bool = True, **kwargs) -> Tensor: input_data = self._prepare_tensors(**kwargs) out_data = self.calculate(**input_data) if inplace: self.update(**out_data) return out_data
[docs] def calculate(self, **kwargs): raise NotImplementedError()
[docs] def reset( self, inplace: bool = True, root_walker: Optional[StateData] = None, states: Optional[StateData] = None, **kwargs, ): pass
[docs]class WalkersAPI(SwarmComponent): """ The WalkersAPI class defines the base functionality for managing walkers in a swarm simulation. This class inherits from the SwarmComponent class. """
[docs] def get_input_data(self) -> StateData: """ Return a dictionary with all the data that this component requires as inputs, including the data for inactive walkers. Returns: StateData: A dictionary containing all the required data for the WalkersAPI component, where each key is the name of a data attribute and its value is a multi-dimensional array. """ def get_one_input(name, values): return self.get( name=name, default=values.get("default"), raise_error=not values.get("optional", False), inactives=True, ) return {k: get_one_input(k, v) for k, v in self.inputs.items()}
[docs] def update( self, other: Union["SwarmState", Dict[str, Tensor]] = None, inactives: bool = True, **kwargs, ) -> None: return super(WalkersAPI, self).update(other=other, inactives=inactives, **kwargs)
[docs] def balance(self, inplace: bool = True, **kwargs) -> Union[None, StateData]: """ Perform a balance operation on the swarm state. Args: inplace (bool, optional): If True, updates the swarm state with the balanced data. If False, returns the balanced data. Defaults to True. **kwargs: Additional keyword arguments required for running the balance operation. Returns: Union[None, StateData]: None if inplace is True. Otherwise, a StateData dictionary containing the balanced data. """ input_data = self._prepare_tensors(**kwargs) out_data = self.run_epoch(inplace=inplace, **input_data) if inplace: self.update(**out_data) return out_data
[docs] def run_epoch(self, inplace: bool = True, **kwargs) -> StateData: """ Implement the functionality for running an epoch in the derived class. This method is \ called during the balance operation. Args: inplace (bool, optional): If `True`, updates the swarm state with the data generated during the epoch. If False, returns the data. Defaults to `True`. **kwargs: Additional keyword arguments required for running the epoch. Returns: StateData: A dictionary containing the data generated during the epoch if inplace is False. Otherwise, returns None. """ raise NotImplementedError()
[docs] def reset(self, inplace: bool = True, **kwargs): """ Reset the internal state of the Walkers class. This method should be implemented in the derived class. Args: inplace (bool, optional): Unused. Defaults to True. **kwargs: Additional keyword arguments required for resetting the component. Returns: None """ pass
[docs] def get_in_bounds_compas(self, oobs=None) -> Tensor: """ Return an array of indexes corresponding to alive walkers chosen at random. Args: oobs (Optional[np.ndarray], optional): An optional boolean array indicating out-of-bounds walkers. If not provided, all walkers are considered alive. Defaults to None. Returns: Tensor: An array of indexes corresponding to randomly chosen alive walkers. """ n_walkers = len(oobs) if oobs is not None else self.swarm.n_walkers indexes = judo.arange(n_walkers, dtype=int) # No need to sample if all walkers are dead or terminal. if oobs is None or oobs.all(): return indexes alive_indexes = indexes[~oobs] compas_ix = random_state.permutation(alive_indexes) compas = random_state.choice(compas_ix, len(oobs), replace=True) compas[: len(compas_ix)] = compas_ix return compas
[docs] def clone_walkers(self, will_clone=None, compas_clone=None, **kwargs) -> None: """ Sample the clone probability distribution and clone the walkers accordingly. Args: will_clone (Optional[np.ndarray], optional): A boolean array indicating which walkers will be cloned. If not provided, the clone operation will be performed on all walkers. Defaults to None. compas_clone (Optional[np.ndarray], optional): An array of indexes indicating the walkers to be cloned. If not provided, random alive walkers will be chosen. Defaults to None. **kwargs: Additional keyword arguments required for cloning the walkers. Returns: None """ self.swarm.state.clone( will_clone=will_clone, compas_clone=compas_clone, clone_names=self.swarm.clone_names, )
[docs]class SwarmAPI: """ The Swarm implements the iteration logic to make the :class:`Walkers` evolve. It contains the necessary logic to use an Environment, a Model, and a \ Walkers instance to create the algorithm execution loop. This class defines a method called run() that receives two optional arguments, \ root_walker and state, and has no return value. This method runs the fractal AI Swarm \ evolution process until a stop condition is met. In its implementation, it calls several other methods: \ (before_reset(), reset(), after_reset(), evolve(), before_evolve(), after_evolve(), \ evolution_end(), before_env(), after_env(), before_policy(), after_policy(), \ before_walkers(), after_walkers(), run_end()) defined within the same class, \ which are mainly used to manage different aspects of the search process or to \ invoke user-defined callbacks. The evolve() method updates the states of the search environment and model, \ makes the walkers undergo a perturbation process, and balances them. It also \ invokes several other methods that trigger user-defined callbacks. The evolution_end() method returns ``True`` if any of the following conditions is met: 1. The current epoch exceeds the maximum allowed epochs. 2. All walkers are out of the problem domain. 3. Any callback of the class has set the ``evolution_end`` flag to ``True``. Attributes: walkers_last (bool): If `True` indicates that the :class:`~fragile.Walkers` \ class runs after acting on the environment. If `Fase`, the walkers run before \ acting on the environment. Args: n_walkers (int): The number of walkers in the swarm. env (:class:`EnvironmentAPI`): An environment that simulates the objective function. policy (:class:`PolicyAPI`): A policy that defines how the individuals evolve. walkers (:class:`WalkersAPI`): A set of motion rules to control a population of walkers. callbacks (Optional[Iterable[Callback]]): A list of functions to call at each iteration. minimize (bool): If ``True``, take the minimum value of fitness, else take the maximum. max_epochs (int): Maximum number of epochs allowed before the swarm search is stopped. """ walkers_last = True def __init__( self, n_walkers: int, env: EnvironmentAPI, policy: PolicyAPI, walkers: WalkersAPI, callbacks: Optional[Iterable[Callback]] = None, minimize: bool = False, max_epochs: int = 1e100, ): """Initialize a :class:`SwarmAPI`.""" self.minimize = minimize self.max_epochs = int(max_epochs) self._n_walkers = n_walkers self._epoch = 0 self._env = env self._policy = policy self._walkers = walkers self._state = None self._inputs = {} self._clone_names = set() self._callbacks = {} callbacks = callbacks if callbacks is not None else [] for callback in callbacks: self.register_callback(callback, setup=False) self.setup() @property def n_walkers(self) -> int: """Return the number of walkers in the swarm.""" return self._n_walkers @property def n_actives(self) -> int: """Returns the number of active walkers in the swarm.""" return self.state.n_actives @property def epoch(self) -> int: """Return the current epoch of the search algorithm.""" return self._epoch @property def state(self) -> SwarmState: """Returns the state instance describing the walkers of the swarm.""" return self._state @property def env(self) -> EnvironmentAPI: """All the simulation code (problem specific) will be handled here.""" return self._env @property def policy(self) -> PolicyAPI: """ All the policy and random perturbation code (problem specific) will \ be handled here. """ return self._policy @property def walkers(self) -> WalkersAPI: """Access the :class:`Walkers` in charge of implementing the FAI evolution process.""" return self._walkers @property def callbacks(self) -> Dict[str, Callback]: """Return the dictionary containing all the user-defined callbacks.""" return self._callbacks @property def param_dict(self) -> StateDict: """Return the copy of parameters dictionary describing the attributes of the walkers.""" return copy.deepcopy(self.state.param_dict) @property def clone_names(self) -> Set[str]: """Return the set of all the attributes that are cloned when iterating the Swarm.""" return self._clone_names @property def inputs(self) -> dict: """Return the dictionary containing all the inputs of the search algorithm.""" return self._inputs
[docs] def __len__(self) -> int: """Return the the number of walkers in the swarm.""" return self.n_walkers
[docs] def __getattr__(self, item): if item in self.callbacks: return self.callbacks[item] return super(SwarmAPI, self).__getattribute__(item)
[docs] def setup_state(self, param_dict: StateDict, n_walkers: Optional[int] = None): """ Set up :class:`SwarmState` instance for the search. Args: param_dict: Initial dictionary of parameters. n_walkers: Number of :class:`Walker` instances. If not set, uses the previous number. Returns: None. """ if n_walkers is not None: self._n_walkers = n_walkers self._state = SwarmState(n_walkers=self.n_walkers, param_dict=param_dict) self._state.reset()
[docs] def setup(self) -> None: """ Prepare :class:`Swarm` and internal components. Returns: None. """ self._setup_components() self._setup_clone_names() self._setup_inputs()
[docs] def register_callback(self, callback: Callback, setup: bool = True) -> None: """ Register a ``Callback`` object with the :class:`Swarm`. When `setup=True`, calls setup() on the callback. Adds the Callback's attribute dictionary to the instance's :class:`State`. This allows the state to manage all the data the callback needs. Args: callback: An instance of :class:`Callback`. setup: Indicates whether we should call `setup()` on the given ``callback``. Returns: None. """ if setup: callback.setup(self) new_param_dict = {**self.param_dict, **callback.param_dict} self.setup_state(n_walkers=self.n_walkers, param_dict=new_param_dict) self.callbacks[callback.name] = callback clone_names = [k for k, v in callback.inputs.items() if v.get("clone")] self._clone_names = set(list(self.clone_names) + clone_names)
[docs] def get(self, name: str, default: Any = None, raise_error: bool = False) -> Any: """ Access attributes of the :class:`Swarm` and its children. Args: name: Name of the attribute default: Default value to return if attribute doesn't exist. raise_error: Raise an error if the named attribute cannot be found. Returns: The attribute if it exists, else the default. """ return self.state.get(name=name, default=default, raise_error=raise_error)
[docs] def reset( self, root_walker: Optional["OneWalker"] = None, state: Optional[SwarmState] = None, ) -> None: """ Reset a :class:`~fragile.Swarm` and clear the internal data to start a \ new search process. Upon being called, ``reset`` clears all internal data from previous runs and sets up the Swarm to begin a new search process. Args: root_walker (:class:`~fragile.OneWalker`, optional): A walker representing the initial state of the search. The walkers will be reset to thiswalker and it will be added to the root of the :class:`~fragile.StateTree`. state (:class:`~fragile.SwarmState`): Defines the initial state of the `Swarm`. Defaults are loaded from the current setup given during instantiation of the `Swarm()` object. """ self.state.reset(root_walker=root_walker) if not self.walkers_last: self.walkers.reset(root_walker=root_walker) self.env.reset(root_walker=root_walker) self.policy.reset(root_walker=root_walker) if self.walkers_last: self.walkers.reset(root_walker=root_walker) for callback in self.callbacks.values(): callback.reset(root_walker=root_walker) self._epoch = 0
[docs] def run( self, root_walker: Optional[StateData] = None, state: Optional[StateData] = None, ) -> None: """ Run a new search process until the stop condition is met. This method runs a new search process by resetting the walkers to the initial state given in the ``root_walker`` argument. If not provided, it will use the previously set \ root walker. The initial swarm state can be defined using the `state` dictionary. During each epoch of the search process, the swarm undergoes an evolution process in the environment, and the search continues until the stop condition is met. The stop condition is defined by either of the following conditions: - Maximum number of epochs are exceeded (defined in ``max_epochs`` attribute). - All the walkers are out of bounds (defined in ``oobs`` attribute of the ``walkers``). - Any user-defined callback returns True for its ``evolution_end`` method. After the stop condition is met, the ``run_end`` method is called and all callbacks' ``run_end`` methods are executed. Args: root_walker (Optional[StateData]): Walker representing the initial state of the search. The walkers will be reset to this walker, and it will be added to the root of the :class:`StateTree` if any. state (Optional[StateData]): StateData dictionary that define the initial state of the Swarm. Returns: None. """ self.before_reset() self.reset(root_walker=root_walker) self.after_reset() while not self.evolution_end(): self.before_evolve() self.evolve() self.after_evolve() self.run_end()
[docs] def evolution_end(self) -> bool: """ Check if the :class:`Swarm`'s evolution process should stop. Checks whether the maximum number of epochs has been reached, all walkers are out of bounds or any of the callbacks have ended their evolution. Returns: bool: A boolean value indicating whether the evolution process has to stop. """ oobs = self.get("oobs") return ( (self.epoch >= self.max_epochs) or oobs.all() or any(c.evolution_end() for c in self.callbacks.values()) )
[docs] def evolve(self) -> None: """ Make the walkers undergo a perturbation process in the swarm \ :class:`Environment`. The evolution algorithm function undergoes the following steps: 2. If the ``self.walkers_last`` is ``False`` then it calls the callback functions \ declared as 'before_walkers'. 3. If the ``self.walkers_last`` is ``False`` It balances the walkers according \ to the clone probabilities by calling the balance function on the \ :class:`Walkers` instance that is a component of the :class:`Swarm` class. 4. If the ``self.walkers_last`` is ``False`` After the balancing is complete, \ the function calls more callback functions declared as 'after_walkers'. 5. It then calls more callback functions defined under the 'before_policy' tag. 6. Makes each walker select an action by calling the act() method of the policy object. 7. Performs transition based on the new state created after taking the given action. 8. Calls more callback functions defined under the 'after_env' tag. 9. If the ``self.walkers_last`` is ``True`` runs the walkers balance process and \ related callbacks. This function updates the :class:`StatesEnv` and the :class:`StatesModel` after each step. """ if not self.walkers_last: self.before_walkers() self.walkers.balance() self.after_walkers() self.before_policy() self.policy.act() self.after_policy() self.before_env() self.env.make_transitions() self.after_env() if self.walkers_last: self.before_walkers() self.walkers.balance() self.after_walkers() self._epoch += 1
[docs] def before_reset(self) -> None: """ "Called before resetting the search process back to its initial state.""" for callback in self.callbacks.values(): callback.before_reset()
[docs] def after_reset(self) -> None: """Called after resetting the search process back to its initial state.""" for callback in self.callbacks.values(): callback.after_reset()
[docs] def before_policy(self) -> None: """Called before using the policy to sample an action.""" for callback in self.callbacks.values(): callback.before_policy()
[docs] def after_policy(self) -> None: """Called after using the policy to sample an action.""" for callback in self.callbacks.values(): callback.after_policy()
[docs] def before_env(self) -> None: """Called before the environment has been updated.""" for callback in self.callbacks.values(): callback.before_env()
[docs] def after_env(self) -> None: """Called after the environment has been updated.""" for callback in self.callbacks.values(): callback.after_env()
[docs] def before_walkers(self) -> None: """Called before the walkers have been balanced.""" for callback in self.callbacks.values(): callback.before_walkers()
[docs] def after_walkers(self) -> None: """Called after the walkers have been balanced.""" for callback in self.callbacks.values(): callback.after_walkers()
[docs] def before_evolve(self) -> None: """Called before the evolve step has started.""" for callback in self.callbacks.values(): callback.before_evolve()
[docs] def after_evolve(self) -> None: """Called after the evolve step has ended.""" for callback in self.callbacks.values(): callback.after_evolve()
[docs] def run_end(self) -> None: """Called after the evolution process has ended.""" for callback in self.callbacks.values(): callback.run_end()