import logging
from typing import Dict, Iterable, List, Tuple, Union
import judo
from judo.functions import random_state
from judo.judo_backend import Backend
from fragile.core.api_classes import Callback
from fragile.core.typing import Tensor
[docs]class ReplayMemory(Callback):
"""Replay buffer that contains data collected from algorithm runs."""
name = "memory"
_log = logging.getLogger("Memory")
def __init__(
self,
max_size: int,
names: Union[List[str], Tuple[str]] = None,
min_size: int = None,
**kwargs,
):
"""
Initialize a :class:`ReplayMemory`.
Args:
max_size: Maximum number of experiences that will be stored.
names: Names of the replay data attributes that will be stored.
min_size: Minimum number of samples that need to be stored before the \
replay memory is considered ready. If ``None`` it will be equal \
to max_size.
"""
super(ReplayMemory, self).__init__(**kwargs)
self.max_size = max_size
self.min_size = 1.0 if min_size is None else min_size
self.names = names
[docs] def __len__(self) -> int:
first_attr = getattr(self, self.names[0])
return 0 if first_attr is None else len(first_attr)
[docs] def __repr__(self) -> str:
text = "Memory with min_size %s max_size %s and length %s" % (
self.min_size,
self.max_size,
len(self),
)
return text
[docs] def setup(self, swarm):
super(ReplayMemory, self).setup(swarm)
if self.names is None:
self.names = self.swarm.state.names
self.reset()
[docs] def reset(self, *args, **kwargs):
"""Delete all the data previously stored in the memory."""
super(ReplayMemory, self).reset(*args, **kwargs)
for name in self.names:
setattr(self, name, None)
[docs] def after_env(self):
self.append(**dict(self.swarm.state))
[docs] def get_value(self, name):
"""Get attributes of the memory."""
if name == "len":
return len(self)
return getattr(self, name)
[docs] def is_ready(self) -> bool:
"""
Return ``True`` if the number of experiences in the memory is greater than ``min_size``.
"""
return len(self) >= self.min_size
[docs] def get_values(self) -> Tuple[Tensor, ...]:
"""Return a tuple containing the memorized data for all the saved data attributes."""
return tuple([getattr(self, val) for val in self.names])
[docs] def as_dict(self) -> Dict[str, Tensor]:
return dict(zip(self.names, self.get_values()))
[docs] def iterate_batches(self, batch_size: int, as_dict: bool = True):
with Backend.use_backend("numpy"):
indexes = random_state.permutation(range(len(self)))
for i in range(0, len(self), batch_size):
batch_ix = indexes[i : i + batch_size] # noqa: E203
data = tuple([getattr(self, val)[batch_ix] for val in self.names])
if as_dict:
yield dict(zip(self.names, data))
else:
yield data
[docs] def iterate_values(self, randomize: bool = False) -> Iterable[Tuple[Tensor]]:
"""
Return a generator that yields a tuple containing the data of each state \
stored in the memory.
"""
indexes = range(len(self))
if randomize:
with Backend.use_backend("numpy"):
indexes = random_state.permutation(indexes)
for i in indexes:
yield tuple([getattr(self, val)[i] for val in self.names])
[docs] def append(self, **kwargs):
for name, val in kwargs.items():
if name not in self.names:
raise KeyError("%s not in self.names: %s" % (name, self.names))
# Scalar vectors are transformed to columns
val = judo.to_backend(val)
if len(val.shape) == 0:
val = judo.unsqueeze(val)
if len(val.shape) == 1:
val = val.reshape(-1, 1)
try:
processed = (
val
if getattr(self, name) is None
else judo.concatenate([getattr(self, name), val])
)
if len(processed) > self.max_size:
processed = processed[: self.max_size]
except Exception as e:
print(name, val.shape, getattr(self, name).shape)
raise e
setattr(self, name, processed)
self._log.info("Memory now contains %s samples" % len(self))