Source code for algorithms.dataset.utils
import numpy as np
import collections
import h5py
[docs]def pad_tensor(x, max_len, axis):
'''
:param x: given x
:param max_len: max length
:param axis: the axis on which to be padded
:return: padded x
'''
pad_widths = [(0,0) for _ in range(len(x.shape))]
pad_widths[axis] = (0, max_len - x.shape[axis])
return np.pad(x, (pad_widths), mode='constant')
[docs]def compute_n_batches(n_samples, batch_size):
'''
:param n_samples: how many samples
:param batch_size: how many samples in one batch
:return: how many batches we need
'''
n_batches = int(n_samples / batch_size)
if n_samples % batch_size != 0:
n_batches += 1
return n_batches
[docs]class KeyValueReplayMemory(object):
def __init__(self, maxsize=None):
self.maxsize = maxsize
self.mem = collections.defaultdict(list)
[docs] def add(self, keys, values):
'''
Adds keys from values to memory
Args:
- keys: the keys to add, list of hashable
- values: dict containing each key in keys
'''
n_samples = len(values[keys[0]])
for key in keys:
assert len(values[key]) == n_samples, 'n_samples from each key must match'
self.mem[key].extend(values[key])
if self.maxsize:
self.mem[key] = self.mem[key][-self.maxsize:]
[docs] def sample(self, keys, size):
'''
Sample a batch of size for each key and return as a dict
Args:
- keys: list of keys
- size: number of samples to select
'''
sample = dict()
n_samples = len(self.mem[keys[0]])
idxs = np.random.randint(0, n_samples, size)
for key in keys:
sample[key] = np.take(self.mem[key], idxs, axis=0)
return sample
[docs]def load_dataset(filepath, maxsize=None):
'''
:param filepath: file path of the data set
:param maxsize: max size default is None
:return: loaded data set
'''
f = h5py.File(filepath, 'r')
d = dict()
for key in f.keys():
if maxsize is None:
d[key] = f[key].value
else:
d[key] = f[key].value[:maxsize]
return d