Source code for algorithms.dataset.utils

import numpy as np
import collections
import h5py


[docs]def pad_tensor(x, max_len, axis): ''' :param x: given x :param max_len: max length :param axis: the axis on which to be padded :return: padded x ''' pad_widths = [(0,0) for _ in range(len(x.shape))] pad_widths[axis] = (0, max_len - x.shape[axis]) return np.pad(x, (pad_widths), mode='constant')
[docs]def compute_n_batches(n_samples, batch_size): ''' :param n_samples: how many samples :param batch_size: how many samples in one batch :return: how many batches we need ''' n_batches = int(n_samples / batch_size) if n_samples % batch_size != 0: n_batches += 1 return n_batches
[docs]class KeyValueReplayMemory(object): def __init__(self, maxsize=None): self.maxsize = maxsize self.mem = collections.defaultdict(list)
[docs] def add(self, keys, values): ''' Adds keys from values to memory Args: - keys: the keys to add, list of hashable - values: dict containing each key in keys ''' n_samples = len(values[keys[0]]) for key in keys: assert len(values[key]) == n_samples, 'n_samples from each key must match' self.mem[key].extend(values[key]) if self.maxsize: self.mem[key] = self.mem[key][-self.maxsize:]
[docs] def sample(self, keys, size): ''' Sample a batch of size for each key and return as a dict Args: - keys: list of keys - size: number of samples to select ''' sample = dict() n_samples = len(self.mem[keys[0]]) idxs = np.random.randint(0, n_samples, size) for key in keys: sample[key] = np.take(self.mem[key], idxs, axis=0) return sample
[docs]def load_dataset(filepath, maxsize=None): ''' :param filepath: file path of the data set :param maxsize: max size default is None :return: loaded data set ''' f = h5py.File(filepath, 'r') d = dict() for key in f.keys(): if maxsize is None: d[key] = f[key].value else: d[key] = f[key].value[:maxsize] return d