Source code for algorithms.AGen.critic.utils

[docs]def batch_to_path_rewards(rewards, path_lengths): ''' Args: - rewards: numpy array of shape (batch size, reward_dim) - path_lengths: list of lengths to be selected in groups from the row of rewards ''' assert len(rewards) == sum(path_lengths) path_rewards = [] s = 0 for path_length in path_lengths: e = s + path_length path_rewards.append(rewards[s:e]) s = e return path_rewards