import copy
import numpy as np
from algorithms.dataset.utils import pad_tensor, compute_n_batches
[docs]def select_batch_idxs(start_idx, batch_size, min_idx, max_idx):
'''
:param start_idx: starting index
:param batch_size: batch size
:param min_idx: minimum index
:param max_idx: maximum index
:return: the selected list of indexes
'''
end_idx = start_idx + batch_size
end_idx = min(end_idx, max_idx)
idxs = np.arange(start_idx, end_idx, dtype=int)
# if too few samples selected, then randomly select the rest from the full range
if len(idxs) < batch_size:
n_additional = batch_size - len(idxs)
additional_idxs = np.random.randint(low=min_idx, high=max_idx, size=n_additional)
idxs = np.hstack((idxs, additional_idxs))
return idxs, end_idx
[docs]class Dataset(object):
def __init__(
self,
data,
batch_size,
action_normalizer=None,
observation_normalizer=None,
replay_memory=None,
recurrent=False,
flat_recurrent=False,
use_random_scaling=False,
random_scale_factor=.2,
use_random_noise=False,
random_noise_factor=.003):
'''
:param data: data
:param batch_size: batch size
:param action_normalizer: action normalizer
:param observation_normalizer: observation normalizer
:param replay_memory: replay memory buffer
:param recurrent: boolean recurrent indicator of the model
:param flat_recurrent: boolean flat recurrent indicator of the model
:param use_random_scaling: boolean use random scaling indicator
:param random_scale_factor: if use random scaling, the factor
:param use_random_noise: boolean use random noise indicator
:param random_noise_factor: if use random noise, the factor
'''
assert 'observations' in data.keys()
assert 'actions' in data.keys()
assert not (flat_recurrent and recurrent)
if recurrent:
assert 'valids' in data.keys()
# expert data has already been padded to the max sequence length
self.max_seq_len = data['observations'].shape[1]
self.data = data
self.batch_size = batch_size
self.action_normalizer = action_normalizer
self.observation_normalizer = observation_normalizer
self.replay_memory = replay_memory
self.recurrent = recurrent
self.flat_recurrent = flat_recurrent
self.use_random_scaling = use_random_scaling
self.random_scale_factor = random_scale_factor
self.use_random_noise = use_random_noise
self.random_noise_factor = random_noise_factor
self.n_samples = len(data['observations']) # number of real samples
self.next_idx = 0
def _normalize(self, data):
# normalize actions in the dataset to ensure consistency with generated actions
if self.action_normalizer:
data['actions'] = self.action_normalizer(data['actions'])
# typically obs will be normalized through an environment wrapper, but in some
# cases it is more convenient to do it in the dataset
if self.observation_normalizer:
data['observations'] = self.observation_normalizer(data['observations'])
def _apply_random_scale(self, x):
'''
randomly scales the time dimensions
i.e., x is assumed to be shape (batch_size, timesteps, input_dim), and
the same rescale factor is applied across timesteps for a given input_dim
'''
random = np.random.uniform(size=(x.shape[0], 1, x.shape[-1]))
scales = (random - .5) * 2 * self.random_scale_factor + 1.
x *= scales
return x
def _apply_random_scale_to_batch(self, batch):
if self.use_random_scaling and self.recurrent:
batch['rx'][..., :2] = self._apply_random_scale(batch['rx'][..., :2])
batch['ra'][..., :2] = self._apply_random_scale(batch['ra'][..., :2])
return batch
def _apply_random_noise(self, x):
noise = np.random.randn(*x.shape) * self.random_noise_factor
x += noise
return x
def _apply_random_noise_to_batch(self, batch):
if self.use_random_noise:
batch['rx'][..., :2] = self._apply_random_noise(batch['rx'][..., :2])
batch['ra'][..., :2] = self._apply_random_noise(batch['ra'][..., :2])
return batch
def _apply_randomness_to_batch(self, batch):
batch = self._apply_random_scale_to_batch(batch)
batch = self._apply_random_noise_to_batch(batch)
return batch
def _format(self, data):
if self.recurrent:
assert 'valids' in data.keys()
# pad to max sequence length
data['actions'] = pad_tensor(data['actions'], self.max_seq_len, axis=1)
data['observations'] = pad_tensor(data['observations'], self.max_seq_len, axis=1)
data['valids'] = pad_tensor(data['valids'], self.max_seq_len, axis=1)
elif self.flat_recurrent:
act_dim = data['actions'].shape[-1]
data['actions'] = np.reshape(data['actions'], (-1, act_dim))
obs_dim = data['observations'].shape[-1]
data['observations'] = np.reshape(data['observations'], (-1, obs_dim))
[docs] def batches(self, samples_data, store=True):
raise NotImplementedError()
[docs]class CriticDataset(Dataset):
def __init__(self, data, shuffle=True, **kwargs):
'''
:param data: data set
:param shuffle: shuffle or not
:param kwargs: other kwargs
'''
super(CriticDataset, self).__init__(data, **kwargs)
self.shuffle = shuffle
def _shuffle(self):
'''
shuffle data set
'''
# optionally shuffle when wrapping around
if self.shuffle:
idxs = np.random.permutation(self.n_samples)
self.data['observations'] = self.data['observations'][idxs]
self.data['actions'] = self.data['actions'][idxs]
if self.recurrent:
self.data['valids'] = self.data['valids'][idxs]
[docs] def batches(self, samples_data, store=True):
'''
:param samples_data: batches of sampled observation action pairs
:param store: indicator of whether store sampled data in replay buffer or not
:return: yielded batch
'''
assert 'observations' in samples_data.keys()
assert 'actions' in samples_data.keys()
# copy in order to avoid mutating data used elsewhere
sd = copy.deepcopy(samples_data)
# format incoming data if necessary
self._format(sd)
# normalize
self._normalize(sd)
# n_samples will determine the total number of samples on which to train
n_samples = len(sd['observations'])
# if using replay memory, store info from this samples_data
# and then sample a batch from the previously stored data
if self.replay_memory:
keys = ['observations', 'actions']
keys += ['valids'] if self.recurrent else []
if store: self.replay_memory.add(keys, sd)
sd = self.replay_memory.sample(keys, n_samples)
# compute and yield batches
n_batches = compute_n_batches(n_samples, self.batch_size)
for bidx in range(n_batches):
batch = dict()
# batch of generated data
gidxs, _ = select_batch_idxs(bidx * self.batch_size, self.batch_size, 0, n_samples)
gx = sd['observations'][gidxs]
ga = sd['actions'][gidxs]
# batch of real data
ridxs, self.next_idx = select_batch_idxs(self.next_idx, self.batch_size, 0, self.n_samples)
rx = self.data['observations'][ridxs]
ra = self.data['actions'][ridxs]
# build batch
batch = dict(rx=rx, ra=ra, gx=gx, ga=ga)
# valids if recurrent critic
if self.recurrent:
batch['g_valids'] = sd['valids'][gidxs]
batch['r_valids'] = self.data['valids'][ridxs]
# wrap around real data if reached the end of it
if self.next_idx >= self.n_samples:
# optionally shuffle when wrapping around
self._shuffle()
self.next_idx = 0
# optional random scaling
batch = self._apply_randomness_to_batch(batch)
# yield a batch of data
yield batch