from typing import Callable
import judo
from judo.data_types import dtype
from judo.functions.random import random_state
from judo.judo_tensor import tensor
from judo.typing import Tensor
import numpy
AVAILABLE_FUNCTIONS = {
"l2_norm",
"relativize",
"get_alive_indexes",
"calculate_virtual_reward",
"calculate_clone",
"calculate_distance",
"fai_iteration",
"cross_virtual_reward",
"cross_clone",
"cross_fai_iteration",
}
[docs]def l2_norm(x: Tensor, y: Tensor) -> Tensor:
"""Euclidean distance between two batches of points stacked across the first dimension."""
return judo.norm(x - y, axis=1)
[docs]def relativize(x: Tensor) -> Tensor:
"""Normalize the data using a custom smoothing technique."""
orig = x
x = judo.astype(x, dtype.float)
std = x.std()
fstd = float(std)
if fstd == 0 or numpy.isnan(fstd) or numpy.isinf(fstd):
return judo.ones(len(x), dtype=orig.dtype)
standard = (x - x.mean()) / std
with numpy.errstate(invalid="ignore", divide="ignore"):
res = judo.where(standard > 0.0, judo.log(1.0 + standard) + 1.0, judo.exp(standard))
return res
[docs]def get_alive_indexes(oobs: Tensor):
"""Get indexes representing random alive walkers given a vector of death conditions."""
if judo.all(oobs):
return judo.arange(len(oobs))
ix = judo.logical_not(oobs).flatten()
return random_state.choice(judo.arange(len(ix))[ix], size=len(ix), replace=ix.sum() < len(ix))
[docs]def calculate_distance(
observs: Tensor,
distance_function: Callable = l2_norm,
return_compas: bool = False,
oobs: Tensor = None,
compas: Tensor = None,
):
"""Calculate a distance metric for each walker with respect to a random companion."""
if compas is None:
compas = get_alive_indexes(oobs) if oobs is not None else judo.arange(observs.shape[0])
compas = random_state.permutation(compas)
flattened_observs = observs.view(observs.shape[0], -1)
distance = distance_function(flattened_observs, flattened_observs[compas])
distance_norm = relativize(distance.flatten())
return distance_norm if not return_compas else (distance_norm, compas)
[docs]def calculate_virtual_reward(
observs: Tensor,
rewards: Tensor,
oobs: Tensor = None,
dist_coef: float = 1.0,
reward_coef: float = 1.0,
other_reward: Tensor = 1.0,
return_compas: bool = False,
return_distance: bool = False,
distance_function: Callable = l2_norm,
):
"""Calculate the virtual rewards given the required data."""
compas = get_alive_indexes(oobs) if oobs is not None else judo.arange(len(rewards))
compas = random_state.permutation(compas)
flattened_observs = observs.reshape(len(compas), -1)
other_reward = other_reward.flatten() if dtype.is_tensor(other_reward) else other_reward
distance = distance_function(flattened_observs, flattened_observs[compas])
distance_norm = relativize(distance.flatten())
rewards_norm = relativize(rewards.flatten())
virtual_reward = distance_norm**dist_coef * rewards_norm**reward_coef * other_reward
return_data = tuple([virtual_reward])
if return_compas:
return_data = return_data + tuple([compas])
if return_distance:
return_data = return_data + tuple([distance])
return return_data[0] if len(return_data) == 1 else return_data
[docs]def calculate_clone(virtual_rewards: Tensor, oobs: Tensor = None, eps=1e-8):
"""Calculate the clone indexes and masks from the virtual rewards."""
compas_ix = get_alive_indexes(oobs) if oobs is not None else judo.arange(len(virtual_rewards))
compas_ix = random_state.permutation(compas_ix)
vir_rew = virtual_rewards.flatten()
clone_probs = (vir_rew[compas_ix] - vir_rew) / judo.where(vir_rew > eps, vir_rew, tensor(eps))
will_clone = clone_probs.flatten() > random_state.random(len(clone_probs))
return compas_ix, will_clone
[docs]def fai_iteration(
observs: Tensor,
rewards: Tensor,
oobs: Tensor = None,
dist_coef: float = 1.0,
reward_coef: float = 1.0,
eps=1e-8,
other_reward: Tensor = 1.0,
return_compas_dist: bool = False,
return_distance: bool = False,
):
"""Perform a FAI iteration."""
oobs = oobs if oobs is not None else judo.zeros(rewards.shape, dtype=dtype.bool)
virtual_reward = calculate_virtual_reward(
observs,
rewards,
oobs,
dist_coef=dist_coef,
reward_coef=reward_coef,
other_reward=other_reward,
return_distance=return_distance,
return_compas=return_compas_dist,
)
if isinstance(virtual_reward, tuple):
virtual_reward, *rest_data = virtual_reward
else:
rest_data = tuple()
compas_ix, will_clone = calculate_clone(virtual_rewards=virtual_reward, oobs=oobs, eps=eps)
return (compas_ix, will_clone, *rest_data)
[docs]def clone_tensor(x, compas_ix, will_clone):
x[will_clone] = x[compas_ix][will_clone]
return x
[docs]def cross_virtual_reward(
host_observs: Tensor,
host_rewards: Tensor,
ext_observs: Tensor,
ext_rewards: Tensor,
dist_coef: float = 1.0,
reward_coef: float = 1.0,
return_compas: bool = False,
distance_function: Callable = l2_norm,
):
"""Calculate the virtual rewards between two cloud of points."""
host_observs = host_observs.reshape(len(host_rewards), -1)
ext_observs = ext_observs.reshape(len(ext_rewards), -1)
compas_host = random_state.permutation(judo.arange(len(host_rewards)))
compas_ext = random_state.permutation(judo.arange(len(ext_rewards)))
# TODO: check if it's better for the distances to be the same for host and ext
h_dist = distance_function(host_observs, ext_observs[compas_host])
e_dist = distance_function(ext_observs, host_observs[compas_ext])
host_distance = relativize(h_dist.flatten())
ext_distance = relativize(e_dist.flatten())
host_rewards = relativize(host_rewards)
ext_rewards = relativize(ext_rewards)
host_vr = host_distance**dist_coef * host_rewards**reward_coef
ext_vr = ext_distance**dist_coef * ext_rewards**reward_coef
if return_compas:
return (host_vr, compas_host), (ext_vr, compas_ext)
return host_vr, ext_vr
[docs]def cross_clone(
host_virtual_rewards: Tensor,
ext_virtual_rewards: Tensor,
host_oobs: Tensor = None,
eps=1e-3,
):
"""Perform a clone operation between two different groups of points."""
compas_ix = random_state.permutation(judo.arange(len(ext_virtual_rewards)))
host_vr = judo.astype(host_virtual_rewards.flatten(), dtype=dtype.float32)
ext_vr = judo.astype(ext_virtual_rewards.flatten(), dtype=dtype.float32)
clone_probs = (ext_vr[compas_ix] - host_vr) / judo.where(
ext_vr > eps,
ext_vr,
tensor(eps, dtype=dtype.float32),
)
will_clone = clone_probs.flatten() > random_state.random(len(clone_probs))
if host_oobs is not None:
will_clone[host_oobs] = True
return compas_ix, will_clone
[docs]def cross_fai_iteration(
host_observs: Tensor,
host_rewards: Tensor,
ext_observs: Tensor,
ext_rewards: Tensor,
host_oobs: Tensor = None,
dist_coef: float = 1.0,
reward_coef: float = 1.0,
distance_function: Callable = l2_norm,
eps: float = 1e-8,
):
"""Perform a FractalAI cloning process between two clouds of points."""
host_vr, ext_vr = cross_virtual_reward(
host_observs=host_observs,
host_rewards=host_rewards,
ext_observs=ext_observs,
ext_rewards=ext_rewards,
dist_coef=dist_coef,
reward_coef=reward_coef,
distance_function=distance_function,
return_compas=False,
)
compas_ix, will_clone = cross_clone(
host_virtual_rewards=host_vr,
ext_virtual_rewards=ext_vr,
host_oobs=host_oobs,
eps=eps,
)
return compas_ix, will_clone