Source code for algorithms.AGen.critic.base

import torch.optim as optim
from torch.autograd import Variable
import torch

import numpy as np
import algorithms.AGen.critic.utils


[docs]class Critic(object): """ Critic base class """ def __init__( self, network, dataset, obs_dim, act_dim, optimizer=None, lr=0.0001, n_train_epochs=5, grad_norm_rescale=10000., grad_norm_clip=10000., summary_writer=None, debug_nan=False, verbose=0): self.network = network self.dataset = dataset self.obs_dim = obs_dim self.act_dim = act_dim if optimizer is None: self.optimizer = optim.RMSprop(network.parameters(), lr=lr) else: self.optimizer = optimizer self.lr = lr self.n_train_epochs = n_train_epochs self.grad_norm_rescale = grad_norm_rescale self.grad_norm_clip = grad_norm_clip self.summary_writer = summary_writer self.debug_nan = debug_nan self.verbose = verbose
[docs] def critique(self, itr, paths): """ Compute and return rewards based on the (obs, action) pairs in paths where rewards are a list of numpy arrays of equal length as the corresponding path rewards Args: itr: iteration count paths: list of dictionaries {'observations': obs(list), 'actions': act(list)} """ # convert to batch and use network to critique obs = np.concatenate([d['observations'] for d in paths], axis=0) acts = np.concatenate([d['actions'] for d in paths], axis=0) # normalize if self.dataset.observation_normalizer: obs = self.dataset.observation_normalizer(obs) if self.dataset.action_normalizer: acts = self.dataset.action_normalizer(acts) # compute rewards rewards = self.network.forward(obs, acts) rewards = rewards.cpu().detach().numpy() if np.any(np.isnan(rewards)) and self.debug_nan: import ipdb ipdb.set_trace() # output as a list of numpy arrays, each of len equal to the rewards of # the corresponding trajectory path_lengths = [len(d['rewards']) for d in paths] path_rewards = algorithms.AGen.critic.utils.batch_to_path_rewards(rewards, path_lengths) self._log_critique(itr, paths, rewards) return path_rewards
def _log_critique(self, itr, paths, critic_rewards): """ Log information about the critique and paths Args: itr: algorithm batch iteration paths: list of dictionaries containing trajectory information critic_rewards: critic rewards """ # only write summaries if have a summary writer print("reward shape: ", critic_rewards.shape) print("Wait to complete")
[docs] def train(self, itr, samples_data): """ Train the critic using real and sampled data Args: itr: iteration count samples_data: dictionary containing generated data """ for train_itr in range(self.n_train_epochs): for batch in self.dataset.batches(samples_data, store=train_itr == 0): self._train_batch(batch)
def _train_batch(self, batch): """ Runs a single training batch Args: batch: dictionary with values needed for training network class member """ self.rx = batch['rx'] self.ra = batch['ra'] self.gx = batch['gx'] self.ga = batch['ga'] self.eps = np.random.uniform(0, 1, len(batch['rx'])).reshape(-1, 1) rx, ra, gx, ga, eps = self.rx, self.ra, self.gx, self.ga, self.eps gp_loss = 0 # gradient penalty # loss and train op self.optimizer.zero_grad() self.real_loss = real_loss = -torch.mean(self.network(rx, ra)) self.gen_loss = gen_loss = torch.mean(self.network(gx, ga)) self.loss = loss = real_loss + gen_loss + gp_loss loss.backward() self.optimizer.step() def _build_summaries( self, loss, real_loss, gen_loss, gradients, clipped_gradients, gradient_penalty=None, batch_size=None): summaries = None return summaries def _build_input_summaries(self, rx, ra, gx, ga): summaries = None return summaries