Source code for algorithms.RL_Algorithm.GAIL.gail

import os
import time
import random
import julia

import algorithms.RL_Algorithm.utils
from algorithms.utils import save_params, extract_normalizing_env, load_params
from sandbox.rocky.tf.samplers.batch_sampler import BatchSampler
from sandbox.rocky.tf.samplers.vectorized_sampler import VectorizedSampler
from algorithms.RL_Algorithm.optimizers.trpo import trpo_step
from algorithms.RL_Algorithm.optimizers.utils import *
from algorithms import utils
from envs.utils import load_data
from preprocessing.clean_holo import clean_data, csv2txt, create_lane
from src.trajdata import convert_raw_ngsim_to_trajdatas
from preprocessing.extract_feature import extract_ngsim_features


[docs]class GAIL(object): def __init__(self, env, policy, baseline, critic=None, recognition=None, step_size=0.01, reward_handler=algorithms.RL_Algorithm.utils.RewardHandler(), saver=None, saver_filepath=None, validator=None, snapshot_env=True, scope=None, n_itr=500, start_itr=0, batch_size=5000, max_path_length=500, discount=0.99, gae_lambda=1, plot=False, pause_for_plot=False, center_adv=True, positive_adv=False, store_paths=False, whole_paths=True, fixed_horizon=False, sampler_cls=None, sampler_args=None, force_batch_sampler=False, max_kl=None, damping=None, l2_reg=None, policy_filepath=None, critic_filepath=None, env_filepath=None, cuda_enable=True, args=None ): """ :param env: Environment :param policy: Policy :type policy: Policy :param baseline: Baseline :param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms simultaneously, each using different environments and policies :param n_itr: Number of iterations. :param start_itr: Starting iteration. :param batch_size: Number of samples per iteration. :param max_path_length: Maximum length of a single rollout. :param discount: Discount. :param gae_lambda: Lambda used for generalized advantage estimation. :param plot: Plot evaluation run after each iteration. :param pause_for_plot: Whether to pause before contiuing when plotting. :param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1. :param positive_adv: Whether to shift the advantages so that they are always positive. When used in conjunction with center_adv the advantages will be standardized before shifting. :param store_paths: Whether to save all paths data to the snapshot. :return: """ self.env = env self.policy = policy self.baseline = baseline self.scope = scope self.n_itr = n_itr self.start_itr = start_itr self.batch_size = batch_size self.max_path_length = max_path_length self.discount = discount self.gae_lambda = gae_lambda self.plot = plot self.pause_for_plot = pause_for_plot self.center_adv = center_adv self.positive_adv = positive_adv self.store_paths = store_paths self.whole_paths = whole_paths self.fixed_horizon = fixed_horizon if sampler_cls is None: if self.policy.vectorized and not force_batch_sampler: sampler_cls = VectorizedSampler else: sampler_cls = BatchSampler self.sampler_cls = sampler_cls self.sampler = sampler_cls(self, **sampler_args) self.sampler_args = sampler_args self.step_size = step_size self.critic = critic self.recognition = recognition self.reward_handler = reward_handler self.saver = saver self.saver_filepath = saver_filepath self.validator = validator self.snapshot_env = snapshot_env self.max_kl = max_kl self.damping = damping self.l2_reg = l2_reg self.critic_filepath = critic_filepath self.policy_filepath = policy_filepath self.env_filepath = env_filepath self.cuda_enable = cuda_enable and torch.cuda.is_available() if self.cuda_enable: self.policy = self.policy.cuda() self.baseline.set_cuda() if self.critic: self.critic.network = self.critic.network.cuda() self.file_set = set() self.j = julia.Julia() self.j.using("NGSIM") self.args = args self.prev_itr = 650
[docs] def start_worker(self): self.sampler.start_worker()
[docs] def shutdown_worker(self): self.sampler.shutdown_worker()
[docs] def obtain_samples(self, itr): return self.sampler.obtain_samples(itr)
[docs] def process_samples(self, itr, paths): """ Augment path rewards with critic and recognition model rewards Args: itr: iteration counter paths: list of dictionaries each containing info for a single trajectory each with keys 'observations', 'actions', 'agent_infos', 'env_infos', 'rewards' """ # compute critic and recognition rewards and combine them with the path rewards critic_rewards = self.critic.critique(itr, paths) if self.critic else None recognition_rewards = self.recognition.recognize(itr, paths) if self.recognition else None paths = self.reward_handler.merge(paths, critic_rewards, recognition_rewards) return self.sampler.process_samples(itr, paths)
def _save(self, itr): """ Save a tf checkpoint of the session. """ # using keep_checkpoint_every_n_hours as proxy for iterations between saves id = itr + 1 + self.prev_itr if (itr + 1) % 50 == 0: # collect params (or stuff to keep in general) params = dict() if self.critic: critic_save_path = os.path.join(self.saver_filepath, "critic_{}.pkl".format(id)) torch.save(self.critic.network.state_dict(), critic_save_path) print("critic params has been saved to: {}".format(critic_save_path)) policy_save_path = os.path.join(self.saver_filepath, "policy_{}.pkl".format(id)) torch.save(self.policy.state_dict(), policy_save_path) print("policy params has been saved to: {}".format(policy_save_path)) # if the environment is wrapped in a normalizing env, save those stats normalized_env = extract_normalizing_env(self.env) if normalized_env is not None: params['normalzing'] = dict( obs_mean=normalized_env._obs_mean, obs_var=normalized_env._obs_var ) # save params save_dir = os.path.split(self.saver_filepath)[0] save_params(save_dir, params, id, max_to_keep=50)
[docs] def load(self): ''' Load parameters from a filepath. Symmetric to _save. This is not ideal, but it's easier than keeping track of everything separately. ''' params = load_params(self.env_filepath) print("load env normalization param from: {}".format(self.env_filepath)) if self.policy is not None: self.load_policy(self.policy_filepath) if self.critic is not None: self.load_critic(self.critic_filepath) # self.policy.set_param_values(params['policy']) normalized_env = extract_normalizing_env(self.env) if normalized_env is not None: normalized_env._obs_mean = params['normalzing']['obs_mean'] normalized_env._obs_var = params['normalzing']['obs_var']
[docs] def load_critic(self, critic_param_path): print("critic loading params from: {}".format(critic_param_path)) self.critic.network.load_state_dict(torch.load(critic_param_path))
[docs] def load_policy(self, policy_param_path): print("policy loading params from: {}".format(policy_param_path)) self.policy.load_state_dict(torch.load(policy_param_path)) self.policy = self.policy.float()
def _validate(self, itr, samples_data): """ Run validation functions. """ if self.validator: objs = dict( policy=self.policy, critic=self.critic, samples_data=samples_data, env=self.env) self.validator.validate(itr, objs)
[docs] def log_diagnostics(self, paths): self.env.log_diagnostics(paths) self.policy.log_diagnostics(paths) self.baseline.log_diagnostics(paths)
[docs] def get_itr_snapshot(self, itr, samples_data): """ Snapshot critic and recognition model as well """ self._save(itr) self._validate(itr, samples_data) snapshot = dict( itr=itr, policy=self.policy, baseline=self.baseline, ) if self.snapshot_env: snapshot['env'] = self.env if samples_data is not None: snapshot['samples_data'] = dict() if 'actions' in samples_data.keys(): snapshot['samples_data']['actions'] = samples_data['actions'][:10] if 'mean' in samples_data.keys(): snapshot['samples_data']['mean'] = samples_data['mean'][:10] return snapshot
[docs] def optimize_policy(self, itr, samples_data): """ Update the critic and recognition model in addition to the policy Args: itr: iteration counter samples_data: dictionary resulting from process_samples keys: 'rewards', 'observations', 'agent_infos', 'env_infos', 'returns', 'actions', 'advantages', 'paths' the values in the infos dicts can be accessed for example as: samples_data['agent_infos']['prob'] and the returned value will be an array of shape (batch_size, prob_dim) """ obes = samples_data['observations'] actions = samples_data['actions'] advantages = samples_data['advantages'] print("obs shape: {}, action shape: {}, return shape: {}, advantages shape: {}".format( samples_data['observations'].shape, samples_data['actions'].shape, samples_data['returns'].shape, samples_data['advantages'].shape )) trpo_step( self.policy, obes, actions, advantages, self.max_kl, self.damping, ) # train critic if self.critic is not None: self.critic.train(itr, samples_data) if self.recognition is not None: self.recognition.train(itr, samples_data) return dict()
[docs] def init_env(self, itr): if len(self.file_set) == 0: data_base_dir = "./preprocessing/data" file_list = os.listdir(data_base_dir) dir_name = random.choice(file_list) while not os.path.isdir(os.path.join(data_base_dir, dir_name, "processed")): dir_name = random.choice(file_list) print("Sample from directory: {}".format(dir_name)) paths = [] for file_name in os.listdir(os.path.join(data_base_dir, dir_name, "processed")): if "section" in file_name: orig_traj_file = os.path.join(dir_name, "processed", file_name) paths.append(orig_traj_file) lane_file = os.path.join(dir_name, "processed", '{}_lane'.format(dir_name[:19])) create_lane(lane_file) base_dir = os.path.expanduser('~/Autoenv/data/') self.j.write_roadways_to_dxf(base_dir) self.j.write_roadways_from_dxf(base_dir) print("Finish generating roadway") self.file_set.update(paths) if len(self.file_set) == 0: return False trajectory_file = random.choice(list(self.file_set)) processed_data_path = 'holo_{}_perfect_cleaned.csv'.format(trajectory_file[5:19]) self.file_set.remove(trajectory_file) df_len = clean_data(trajectory_file) if df_len == 0: print("Invalid file, skipping") return False csv2txt(processed_data_path) convert_raw_ngsim_to_trajdatas() extract_ngsim_features(output_filename="ngsim_holo_new.h5", n_expert_files=1) print("Finish converting and feature extraction") args = self.args env, trajinfos, act_low, act_high = utils.build_ngsim_env(args) data, veh_2_index = load_data( args.expert_filepath, act_low=act_low, act_high=act_high, min_length=args.env_H + args.env_primesteps, clip_std_multiple=args.normalize_clip_std_multiple, ngsim_filename=args.ngsim_filename ) if data is None: return False critic_param_cache = './data/experiments/NGSIM-gail/imitate/model/critic_cache.pkl' policy_param_cache = './data/experiments/NGSIM-gail/imitate/model/policy_cache.pkl' if self.critic: torch.save(self.critic.network.state_dict(), critic_param_cache) torch.save(self.policy.state_dict(), policy_param_cache) critic = utils.build_critic(args, data, env) self.env = env self.critic = critic if os.path.isfile(critic_param_cache): self.load_critic(critic_param_cache) else: self.load_critic(self.critic_filepath) # self.shutdown_worker() self.sampler = self.sampler_cls(self, **self.sampler_args) self.start_worker() if self.cuda_enable and self.critic: self.critic.network = self.critic.network.cuda() return True
[docs] def train(self): self.start_worker() start_time = time.time() self.env_filepath = "./data/experiments/NGSIM-gail/imitate/itr_600.npz" self.critic_filepath = "./data/experiments/NGSIM-gail/imitate/model/critic_{}.pkl".format(self.prev_itr) self.policy_filepath = "./data/experiments/NGSIM-gail/imitate/model/policy_{}.pkl".format(self.prev_itr) print("loading critic and policy params from file") self.load() for itr in range(self.start_itr, self.n_itr): try: itr_start_time = time.time() print("Initializing AutoEnv...") while not self.init_env(itr): print("Invalid data, initialize again!") print("Obtaining samples...") paths = self.obtain_samples(itr) print("Processing samples...") samples_data = self.process_samples(itr, paths) print("Logging diagnostics...") # self.log_diagnostics(paths) print("Optimizing policy...") self.optimize_policy(itr, samples_data) print("Saving snapshot...") params = self.get_itr_snapshot(itr, samples_data) if self.store_paths: params["paths"] = samples_data["paths"] # logger.save_itr_params(itr, params) print("Saved") print('Time', time.time() - start_time) print('ItrTime', time.time() - itr_start_time) except BaseException as e: print("***************************************\n" * 10) print("Some error occurred, which is {}".format(e)) print("skip to next iteration") continue self.shutdown_worker()