Source code for algorithms.policy.GRUCell
import torch
import torch.nn as nn
[docs]class GRUCell(nn.Module):
def __init__(self, input_size, hidden_size):
"""
A gated recurrent unit implements the following update mechanism:
Reset gate: r(t) = f_r(x(t) @ W_xr + h(t-1) @ W_hr + b_r)
Update gate: u(t) = f_u(x(t) @ W_xu + h(t-1) @ W_hu + b_u)
Cell gate: c(t) = f_c(x(t) @ W_xc + r(t) * (h(t-1) @ W_hc) + b_c)
New hidden state: h(t) = (1 - u(t)) * h(t-1) + u_t * c(t)
Note that the reset, update, and cell vectors must have the same dimension as the hidden state
"""
super(GRUCell, self).__init__()
# Weights for the reset gate
self.W_xr = nn.Parameter(torch.Tensor(input_size, hidden_size))
self.W_hr = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_r = nn.Parameter(torch.Tensor(hidden_size, 1))
# Weights for the update gate
self.W_xu = nn.Parameter(torch.Tensor(input_size, hidden_size))
self.W_hu = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_u = nn.Parameter(torch.Tensor(hidden_size, 1))
# Weights for the cell gate
self.W_xc = nn.Parameter(torch.Tensor(input_size, hidden_size))
self.W_hc = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_c = nn.Parameter(torch.Tensor(hidden_size, 1))
self.input_size = input_size
self.hidden_size = hidden_size
self.gate_nonlinearity = nn.Sigmoid()
self.nonlinearity = nn.Tanh()
[docs] def forward(self, x, h=None):
x = x.float()
if h is None:
h = x.new_zeros(x.size(0), self.hidden_size, requires_grad=False)
W_x_ruc = torch.cat([self.W_xr, self.W_xu, self.W_xc], dim=1)
W_h_ruc = torch.cat([self.W_hr, self.W_hu, self.W_hc], dim=1)
b_ruc = torch.cat([self.b_r, self.b_u, self.b_c], dim=1)
if torch.cuda.is_available():
xb_ruc = torch.matmul(x, W_x_ruc.cuda()) + torch.reshape(b_ruc.cuda(), (1, -1))
h_ruc = torch.matmul(h, W_h_ruc.cuda())
else:
xb_ruc = torch.matmul(x, W_x_ruc) + torch.reshape(b_ruc, (1, -1))
h_ruc = torch.matmul(h, W_h_ruc)
xb_r, xb_u, xb_c = torch.split(dim=1, split_size_or_sections=int(xb_ruc.shape[1]/3), tensor=xb_ruc)
h_r, h_u, h_c = torch.split(dim=1, split_size_or_sections=int(h_ruc.shape[1]/3), tensor=h_ruc)
r = self.gate_nonlinearity(xb_r + h_r)
u = self.gate_nonlinearity(xb_u + h_u)
c = self.nonlinearity(xb_c + r * h_c)
h = (1 - u) * h + u * c
return h