import torch
import torch.nn as nn
import numpy as np
from algorithms.policy.GRUNetwork import GRUNetwork
from algorithms.distribution.recurrent_diagonal_gaussian import RecurrentDiagonalGaussian
from algorithms.RL_Algorithm.optimizers.utils.math import normal_log_density
from algorithms.policy.GRUCell import GRUCell
[docs]class GaussianGRUPolicy(nn.Module):
def __init__(self,
env_spec,
hidden_dim=32,
feature_network=None,
state_include_action=True,
gru_layer=GRUCell,
output_nonlinearity=None,
mode: int=0,
log_std=0,
cuda_enable=True):
'''
:param env_spec: env args
:param hidden_dim: hidden layer dimension
:param feature_network: feature network model
:param state_include_action: boolean var indicating whether action is included in state
:param gru_layer: GRU Cell model
:param output_nonlinearity: output activation function
:param mode: 0 for training, 1 for testing
:param log_std: log of output std
:param cuda_enable: if enable cuda
'''
super().__init__()
obs_dim = env_spec.observation_space.flat_dim
action_dim = env_spec.action_space.flat_dim
if state_include_action:
input_dim = obs_dim + action_dim
else:
input_dim = obs_dim
# if feature_network is None:
feature_dim = input_dim
self._env_spec = env_spec
self.mean_network = GRUNetwork(
input_dim=feature_dim,
output_dim=action_dim,
hidden_dim=hidden_dim,
gru_layer=gru_layer,
output_nonlinearity=output_nonlinearity
)
self.feature_network = feature_network
# self.fc_std = nn.Linear(hidden_dim, action_dim).double()
# self.fc_std.weight.data.fill_(np.log(init_std))
self.action_log_std = nn.Parameter(torch.ones(1, action_dim) * log_std)
# TODO: check if need to initialize bias
self.input_dim = input_dim
self.action_dim = action_dim
self.hidden_dim = hidden_dim
self.prev_actions = None
self.prev_hiddens = None
self.dist = RecurrentDiagonalGaussian(action_dim)
self.state_include_action = state_include_action
self.mode = mode
self.cuda_enable = cuda_enable and torch.cuda.is_available()
self.is_disc_action = False
[docs] def forward(self, x, h=None):
'''
:param x: input feature
:param h: hidden layer
:return: output mean, log std for action and hidden layer for next round
'''
action_mean, h = self.mean_network.forward(x, h)
# action_log_std = self.fc_std(h)
action_log_std = self.action_log_std.expand_as(action_mean)
return action_mean, action_log_std, h
[docs] def load_param(self, param_path: str):
'''
:param param_path: saved parameter file path
:return: no return, load the parameter into the current model
'''
self.load_state_dict(torch.load(param_path))
[docs] def dist_info_sym(self, obs_var, state_info_vars):
n_batches = np.array(obs_var).shape[0]
n_steps = np.array(obs_var).shape[1]
obs_var = torch.tensor(obs_var)
obs_var = torch.reshape(obs_var, (n_batches, n_steps, -1))
if self.state_include_action:
prev_action_var = state_info_vars["prev_action"]
all_input_var = torch.cat((obs_var, prev_action_var), dim=2)
else:
all_input_var = obs_var
if self.feature_network is None:
means, log_stds, _ = self.forward(all_input_var)
else:
flat_input_var = torch.reshape(all_input_var, (-1, self.input_dim))
feature_batch = self.feature_network(flat_input_var)
means, log_stds, _ = self.forward(feature_batch)
return dict(mean=means, log_std=log_stds)
[docs] def get_kl(self, x, actions, h=None):
'''
:param x: input feature
:param actions: actions
:param h: hidden layer
:return: KL divergence of updated policy and the old one
'''
if self.state_include_action:
prev_act = np.concatenate([np.zeros((actions.shape[0], 1, actions.shape[2])), actions], axis=1)[:, :-1, :]
x = np.concatenate([x, prev_act], axis=-1)
mean1, log_std1, std1 = self.forward(x, h)
mean0 = mean1.detach()
log_std0 = log_std1.detach()
std0 = std1.detach()
kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
return kl.sum(1, keepdim=True)
[docs] def get_log_prob(self, x, actions):
'''
:param x: input obs feature
:param actions: input actions
:return: log likelihood of the actions given the distribution output by the network
'''
if self.state_include_action:
prev_act = np.concatenate([np.zeros((actions.shape[0], 1, actions.shape[2])), actions], axis=1)[:, :-1, :]
x = np.concatenate([x, prev_act], axis=-1)
x = x.reshape((-1, self.input_dim))
actions = actions.reshape((-1, self.action_dim))
if torch.cuda.is_available():
x = torch.tensor(x).cuda()
actions = torch.tensor(actions).float().cuda()
else:
x = torch.tensor(x)
actions = torch.tensor(actions).float()
action_mean, action_log_std, hidden_vec = self.forward(x)
action_log_std = action_log_std
action_std = torch.exp(action_log_std)
return normal_log_density(actions, action_mean, action_log_std, action_std)
[docs] def get_fim(self, x, actions):
'''
:param x: input observation feature
:param actions: input actions
:return: get fisher information matrix
'''
if self.state_include_action:
prev_act = np.concatenate([np.zeros((actions.shape[0], 1, actions.shape[2])), actions], axis=1)[:, :-1, :]
x = np.concatenate([x, prev_act], axis=-1)
if torch.cuda.is_available():
x = torch.tensor(x).reshape((-1, self.input_dim)).cuda()
else:
x = torch.tensor(x).reshape((-1, self.input_dim))
mean, action_log_std, _ = self.forward(x)
cov_inv = self.action_log_std.exp().pow(-2).squeeze(0).repeat(x.size(0))
param_count = 0
std_index = 0
id = 0
for name, param in self.named_parameters():
if name == "action_log_std":
std_id = id
std_index = param_count
param_count += param.view(-1).shape[0]
id += 1
return cov_inv.detach(), mean, {'std_id': std_id, 'std_index': std_index}
@property
def vectorized(self):
return True
[docs] def reset(self, dones=None):
'''
:param dones: indicators of whether all the agent have finished their episode or not
:return: no return, update some information according the given list of dones
'''
if dones is None:
dones = [True]
dones = np.asarray(dones)
if self.prev_actions is None or len(dones) != len(self.prev_actions):
self.prev_actions = np.zeros((len(dones), self.action_dim))
self.prev_hiddens = np.zeros((len(dones), self.hidden_dim))
self.prev_actions[dones] = 0.
if all(dones):
self.prev_hiddens = None
elif any(dones):
self.prev_hiddens[dones] = None
[docs] def get_action(self, observation):
'''
:param observation: input observation
:return: get actions from the given observation
'''
actions, agent_infos, _ = self.get_actions([observation])
return actions[0], {k: v[0] for k, v in agent_infos.items()}
[docs] def get_actions(self, observations):
'''
:param observations: a batch of observations
:return: get the corresponding batch of actions
'''
# mode: 0 stand for training, 1 for testing
flat_obs = self.observation_space.flatten_n(observations)
# self.prev_actions.shape = np.zeros([1,2], dtype=float)
if self.state_include_action:
assert self.prev_actions is not None
all_input = np.concatenate([
flat_obs,
self.prev_actions
], axis=-1)
else:
all_input = flat_obs
all_input = torch.tensor(all_input)
if self.cuda_enable:
all_input = all_input.cuda()
if self.prev_hiddens is not None:
self.prev_hiddens = self.prev_hiddens.cuda()
means, log_stds, hidden_vec = self.forward(all_input, self.prev_hiddens)
rnd = np.random.normal(size=means.shape)
means = means.cpu().detach().numpy()
log_stds = log_stds.cpu().detach().numpy()
actions = rnd * np.exp(log_stds) + means
prev_actions = self.prev_actions
self.prev_actions = self.action_space.flatten_n(actions)
self.prev_hiddens = hidden_vec
agent_info = dict(mean=means, log_std=log_stds)
if self.state_include_action:
agent_info["prev_action"] = np.copy(prev_actions)
if self.mode == 1:
return actions, agent_info, hidden_vec.cpu().detach().numpy()
elif self.mode == 0:
return actions, agent_info
else:
raise NotImplementedError
[docs] def get_actions_with_prev(self, observations, prev_actions, prev_hiddens):
'''
:param observations: input batch of observations
:param prev_actions: previous batch of actions
:param prev_hiddens: previous hidden layer
:return: actions for the current batch of observations
'''
# for getting back to hidden vector and action prediction before prediction
if prev_actions is None or prev_hiddens is None:
return self.get_actions(observations)
flat_obs = self.observation_space.flatten_n(observations)
# print(flat_obs.shape, prev_actions.shape)
if self.state_include_action:
h, w = flat_obs.shape
all_input = np.concatenate([
flat_obs,
np.reshape(prev_actions, [h, 2])
# np.zeros([h,2], dtype=float)
], axis=-1)
else:
all_input = flat_obs
if torch.cuda.is_available():
all_input = torch.tensor(all_input).cuda()
if not torch.is_tensor(prev_hiddens):
prev_hiddens = torch.tensor(prev_hiddens).float().cuda()
else:
all_input = torch.tensor(all_input)
if not torch.is_tensor(prev_hiddens):
prev_hiddens = torch.tensor(prev_hiddens).float()
means, log_stds, hidden_vec = self.forward(all_input, prev_hiddens)
means = means.cpu().detach().numpy()
log_stds = log_stds.cpu().detach().numpy()
rnd = np.random.normal(size=means.shape)
actions = rnd * np.exp(log_stds) + means
self.prev_actions = self.action_space.flatten_n(actions)
self.prev_hiddens = hidden_vec
agent_info = dict(mean=means, log_std=log_stds)
if self.state_include_action:
agent_info["prev_action"] = np.copy(prev_actions)
if self.mode == 1:
return actions, agent_info, hidden_vec.cpu().detach().numpy()
elif self.mode == 0:
return actions, agent_info
else:
raise NotImplementedError
@property
def recurrent(self):
return True
@property
def distribution(self):
return self.dist
@property
def state_info_specs(self):
if self.state_include_action:
return [
("prev_action", (self.action_dim,)),
]
else:
return []
@property
def observation_space(self):
return self._env_spec.observation_space
@property
def action_space(self):
return self._env_spec.action_space