Source code for algorithms.RL_Algorithm.optimizers.utils.torch
import torch
import numpy as np
tensor = torch.tensor
DoubleTensor = torch.DoubleTensor
FloatTensor = torch.FloatTensor
LongTensor = torch.LongTensor
ByteTensor = torch.ByteTensor
ones = torch.ones
zeros = torch.zeros
[docs]def to_device(device, *args):
return [x.to(device) for x in args]
[docs]def get_flat_params_from(model):
'''
:param model: model
:return: the flattened param extracted from the model
'''
params = []
for param in model.parameters():
params.append(param.view(-1))
flat_params = torch.cat(params)
return flat_params
[docs]def set_flat_params_to(model, flat_params):
'''
:param model: model to load the param
:param flat_params: param to pass
:return: no return, pass the given param to the model
'''
prev_ind = 0
for param in model.parameters():
flat_size = int(np.prod(list(param.size())))
param.data.copy_(
flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
prev_ind += flat_size
[docs]def get_flat_grad_from(inputs, grad_grad=False):
grads = []
for param in inputs:
if grad_grad:
grads.append(param.grad.grad.view(-1))
else:
if param.grad is None:
grads.append(zeros(param.view(-1).shape))
else:
grads.append(param.grad.view(-1))
flat_grad = torch.cat(grads)
return flat_grad
[docs]def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=False, create_graph=False):
if create_graph:
retain_graph = True
inputs = list(inputs)
params = []
for i, param in enumerate(inputs):
if i not in filter_input_ids:
params.append(param)
grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph)
grads = [grad.contiguous() for grad in grads]
j = 0
out_grads = []
for i, param in enumerate(inputs):
if i in filter_input_ids:
out_grads.append(zeros(param.view(-1).shape, device=param.device, dtype=param.dtype))
else:
out_grads.append(grads[j].view(-1))
j += 1
grads = torch.cat(out_grads)
for param in params:
param.grad = None
return grads