algorithms.AGen.critic package

Submodules

algorithms.AGen.critic.base module

class algorithms.AGen.critic.base.Critic(network, dataset, obs_dim, act_dim, optimizer=None, lr=0.0001, n_train_epochs=5, grad_norm_rescale=10000.0, grad_norm_clip=10000.0, summary_writer=None, debug_nan=False, verbose=0)[source]

Bases: object

Critic base class

critique(itr, paths)[source]
Compute and return rewards based on the (obs, action) pairs in paths

where rewards are a list of numpy arrays of equal length as the corresponding path rewards

Args:

itr: iteration count paths: list of dictionaries {‘observations’: obs(list), ‘actions’: act(list)}

train(itr, samples_data)[source]

Train the critic using real and sampled data

Args:

itr: iteration count samples_data: dictionary containing generated data

algorithms.AGen.critic.model module

class algorithms.AGen.critic.model.Block(input_size, hidden_layer_dims, activation_fn, drop_out_fn)[source]

Bases: torch.nn.modules.module.Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class algorithms.AGen.critic.model.ObservationActionMLP(hidden_layer_dims, obs_size, act_size, output_dim=1, obs_hidden_layer_dims=[], act_hidden_layer_dims=[], activation_fn=ReLU(), dropout_keep_prob=1.0, l2_reg=0.0, return_features=False)[source]

Bases: torch.nn.modules.module.Module

forward(obs, act)[source]
Parameters
  • obs – batch of observations

  • act – batch of actions

Returns

rewards for the batched observation and action pairs

algorithms.AGen.critic.utils module

algorithms.AGen.critic.utils.batch_to_path_rewards(rewards, path_lengths)[source]
Args:
  • rewards: numpy array of shape (batch size, reward_dim)

  • path_lengths: list of lengths to be selected in groups from the row of rewards

Module contents