Source code for fragile.callbacks.data_tracking

import copy

import einops
import judo
from judo.data_types import dtype

from fragile.core.api_classes import Callback
from fragile.core.typing import StateDict


[docs]class StoreInitAction(Callback): name = "store_init_action" default_inputs = {"init_actions": {"clone": True}} default_outputs = ("init_actions",) @property def param_dict(self) -> StateDict: return {"init_actions": dict(self.swarm.param_dict["actions"])}
[docs] def before_env(self): if self.swarm.epoch == 0: self.update(init_actions=judo.copy(self.get("actions")))
[docs]class TrackWalkersId(Callback): default_inputs = {"id_walkers": {"clone": True}, "parent_ids": {"clone": True}} default_param_dict = { "id_walkers": {"dtype": dtype.hash_type}, "parent_ids": {"dtype": dtype.hash_type}, }
[docs] def update_ids(self, inactives: bool = True): with judo.Backend.use_backend("numpy"): name = "states" if "states" in self.swarm.state.names else "observs" actives = judo.to_numpy(self.swarm.state.actives) new_ids_all = self.swarm.state.hash_batch(name) parent_ids = judo.copy(judo.to_numpy(self.get("parent_ids", inactives=inactives))) new_ids = judo.copy(judo.to_numpy(self.get("id_walkers", inactives=True))) parent_ids[actives] = judo.copy(new_ids[actives]) new_ids[actives] = new_ids_all[actives] self.update( parent_ids=judo.to_backend(parent_ids), id_walkers=judo.to_backend(new_ids), inactives=inactives, )
[docs] def after_env(self): self.update_ids()
[docs]class KeepBestN(Callback): name = "keep_best_n" def __init__( self, n_keep: int, always_update: bool = False, fix_walkers: bool = True, **kwargs, ): super(KeepBestN, self).__init__(**kwargs) self.minimize = None self.always_update = always_update self.n_keep = n_keep self.fix_walkers = fix_walkers self.top_walkers = [] self.top_scores = []
[docs] def setup(self, swarm): super(KeepBestN, self).setup(swarm) self.minimize = swarm.minimize
[docs] def reset(self, *args, **kwargs): self.top_walkers = [] self.top_scores = []
[docs] def before_walkers(self): self.update_top_walkers()
[docs] def get_topk_walkers(self): scores, oobs, terminals = self.get("scores"), self.get("oobs"), self.get("terminals") index = judo.arange(len(scores)) bool_ix = ~oobs if terminals is None else judo.logical_or(~oobs, terminals) alive_scores = judo.copy(scores[bool_ix]) if len(alive_scores) == 0: return None, None # print(alive_scores) _, topk = judo.topk(alive_scores, k=self.n_keep, largest=not self.minimize) # topk = judo.astype(judo.clip(topk, 0, judo.inf), judo.int) top_indexes = judo.copy(index[bool_ix][topk]) top_scores = judo.copy(scores[bool_ix][topk]) top_walkers = [self.swarm.state.export_walker(i, copy=True) for i in top_indexes] try: return top_walkers, top_scores except Exception as e: print(topk, bool_ix) raise e
[docs] def update_top_walkers(self): last_top_walkers, last_top_scores = self.get_topk_walkers() if last_top_walkers is None: return elif self.always_update: new_top_walkers = last_top_walkers new_top_scores = last_top_scores else: all_top_walkers = self.top_walkers + last_top_walkers if judo.is_tensor(self.top_scores): curr_top_list = einops.asnumpy(self.top_scores).tolist() else: curr_top_list = self.top_scores last_top_list = einops.asnumpy(last_top_scores).tolist() top_scores = judo.tensor(curr_top_list + last_top_list) new_top_scores, topk = judo.topk(top_scores, k=self.n_keep, largest=not self.minimize) new_top_walkers = [all_top_walkers[i] for i in einops.asnumpy(topk).astype(int)] self.top_walkers = new_top_walkers self.top_scores = new_top_scores return # if not judo.Backend.is_numpy():# and judo.dtype.is_tensor(self.score): # scores = self._data["scores"] # score = scores[0] if isinstance(scores, list) else scores.item() # else: score = self.score score_improves = (best_score < score) if self.minimize else (best_score > score) if self.always_update or score_improves: # or numpy.isinf(score): # new_best = {k: copy.deepcopy(v) for k, v in best.items()} self._data = copy.deepcopy(best)
[docs] def fix_topk_walkers(self): if self.fix_walkers: terminals = self.swarm.get("terminals") for i, walker in enumerate(self.top_walkers): self.swarm.state.import_walker(copy.deepcopy(walker), index=i) if not self.swarm.state.actives[i] and terminals is not None and not terminals[i]: self.swarm.state.actives[i] = True self.swarm.state._n_actives += 1
[docs] def after_walkers(self): self.fix_topk_walkers()