Source code for algorithms.policy.GRUNetwork
import torch.nn as nn
import numpy as np
from algorithms.policy.GRUCell import GRUCell
[docs]class GRUNetwork(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim, gru_layer=GRUCell,
output_nonlinearity=None):
'''
:param input_dim: input feature dimension
:param output_dim: output dimension
:param hidden_dim: hidden layer dimension
:param gru_layer: GRU Cell model
:param output_nonlinearity: output activation function
'''
super(GRUNetwork, self).__init__()
self.gru = gru_layer(input_size=input_dim, hidden_size=hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
self.output_activation_fn = output_nonlinearity
[docs] def forward(self, x, h=None):
if h is not None:
h = self.gru.forward(x, h)
else:
h = self.gru.forward(x)
x = self.fc(h)
if self.output_activation_fn is not None:
x = self.output_activation_fn(x)
return x, h