Source code for algorithms.RL_Algorithm.optimizers.utils.torch

import torch
import numpy as np

tensor = torch.tensor
DoubleTensor = torch.DoubleTensor
FloatTensor = torch.FloatTensor
LongTensor = torch.LongTensor
ByteTensor = torch.ByteTensor
ones = torch.ones
zeros = torch.zeros


[docs]def to_device(device, *args): return [x.to(device) for x in args]
[docs]def get_flat_params_from(model): ''' :param model: model :return: the flattened param extracted from the model ''' params = [] for param in model.parameters(): params.append(param.view(-1)) flat_params = torch.cat(params) return flat_params
[docs]def set_flat_params_to(model, flat_params): ''' :param model: model to load the param :param flat_params: param to pass :return: no return, pass the given param to the model ''' prev_ind = 0 for param in model.parameters(): flat_size = int(np.prod(list(param.size()))) param.data.copy_( flat_params[prev_ind:prev_ind + flat_size].view(param.size())) prev_ind += flat_size
[docs]def get_flat_grad_from(inputs, grad_grad=False): grads = [] for param in inputs: if grad_grad: grads.append(param.grad.grad.view(-1)) else: if param.grad is None: grads.append(zeros(param.view(-1).shape)) else: grads.append(param.grad.view(-1)) flat_grad = torch.cat(grads) return flat_grad
[docs]def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=False, create_graph=False): if create_graph: retain_graph = True inputs = list(inputs) params = [] for i, param in enumerate(inputs): if i not in filter_input_ids: params.append(param) grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph) grads = [grad.contiguous() for grad in grads] j = 0 out_grads = [] for i, param in enumerate(inputs): if i in filter_input_ids: out_grads.append(zeros(param.view(-1).shape, device=param.device, dtype=param.dtype)) else: out_grads.append(grads[j].view(-1)) j += 1 grads = torch.cat(out_grads) for param in params: param.grad = None return grads