semantic_rl / train_procgen /single_graph_util.py
leonepson's picture
Upload 254 files
5960497 verified
import csv
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from math import ceil
from constants import ENV_NAMES
import seaborn # sets some style parameters automatically
np.random.seed(1024)
COLORS = [(np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)) for x in range(20)]
def switch_to_outer_plot(fig):
ax0 = fig.add_subplot(111, frame_on=False)
ax0.set_xticks([])
ax0.set_yticks([])
return ax0
def ema(data_in, smoothing=0):
data_out = np.zeros_like(data_in)
curr = np.nan
for i in range(len(data_in)):
x = data_in[i]
if np.isnan(curr):
curr = x
else:
curr = (1 - smoothing) * x + smoothing * curr
data_out[i] = curr
return data_out
def plot_data_mean_std(ax, data_y, color_idx=0, data_x=None, x_scale=1, smoothing=0, first_valid=0, label=None):
color = COLORS[color_idx]
hexcolor = '#%02x%02x%02x' % color
data_y = data_y[:, first_valid:]
nx, num_datapoint = np.shape(data_y)
if smoothing > 0:
for i in range(nx):
data_y[i, ...] = ema(data_y[i, ...], smoothing)
if data_x is None:
data_x = (np.array(range(num_datapoint)) + first_valid) * x_scale
data_mean = np.mean(data_y, axis=0)
data_std = np.std(data_y, axis=0, ddof=1)
ax.plot(data_x, data_mean, color=hexcolor, label=label, linestyle='solid', alpha=1, rasterized=True)
ax.fill_between(data_x, data_mean - data_std, data_mean + data_std, color=hexcolor, alpha=.25, linewidth=0.0,
rasterized=True)
def read_csv(filename, key_name):
with open(filename) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=',')
key_index = -1
values = []
for line_num, row in enumerate(csv_reader):
row = [x.lower() for x in row]
if line_num == 0:
idxs = [i for i, val in enumerate(row) if val == key_name]
key_index = idxs[0]
else:
values.append(row[key_index])
return np.array(values, dtype=np.float32)
def plot_values(ax, all_values, title=None, max_x=0, label=None, **kwargs):
if max_x > 0:
all_values = all_values[..., :max_x]
if ax is not None:
plot_data_mean_std(ax, all_values, label=label, **kwargs)
ax.set_title(title, fontsize=20)
return all_values
def plot_experiment(env_name, run_directory_prefix, titles=None, suffixes=[''], normalization_ranges=None,
key_name='eprewmean', **kwargs):
run_folders = [f'{run_directory_prefix}_{0}_{2020 + x}' for x in range(3)]
sppo_run_folders = ['sppo-' + rf for rf in run_folders]
ppo_run_folders = ['ppo-' + rf for rf in run_folders]
run_folders = [sppo_run_folders, ppo_run_folders]
run_names = ['SPPO', 'PPO']
num_envs = 1
will_normalize_and_reduce = normalization_ranges is not None
if will_normalize_and_reduce:
num_visible_plots = 1
f, axarr = plt.subplots()
else:
num_visible_plots = num_envs
dimx = dimy = ceil(np.sqrt(num_visible_plots))
f, axarr = plt.subplots(dimx, dimy, sharex=True)
color_idx = 0
for rf in range(len(run_folders)):
for suffix in suffixes:
all_values = []
game_weights = [1] * num_envs
if len(suffixes) == 1:
label = run_names[rf]
else:
if suffix == '':
label = run_names[rf] + ' train'
else:
label = run_names[rf] + ' test'
print(f'loading results from {env_name}...')
if num_visible_plots == 1:
ax = axarr
else:
dimy = len(axarr[0])
ax = axarr[0 // dimy][0 % dimy]
csv_files = [f"checkpoints/{resid}/progress{'-' if len(suffix) > 0 else ''}{suffix}.csv" for resid in
run_folders[rf]]
curr_ax = None if will_normalize_and_reduce else ax
raw_data = np.array([read_csv(file, key_name) for file in csv_files])
values = plot_values(curr_ax, raw_data, title=env_name, color_idx=color_idx, label=label, **kwargs)
if will_normalize_and_reduce:
game_range = normalization_ranges[env_name]
game_min = game_range[0]
game_max = game_range[1]
game_delta = game_max - game_min
sub_values = game_weights[0] * (np.array(values) - game_min) / (game_delta)
all_values.append(sub_values)
if will_normalize_and_reduce:
normalized_data = np.sum(all_values, axis=0)
normalized_data = normalized_data / np.sum(game_weights)
title = 'Mean Normalized Score'
plot_values(ax, normalized_data, title=None, color_idx=color_idx, label=suffix, **kwargs)
color_idx += 1
if num_visible_plots == 1:
ax.legend(loc='lower right')
else:
f.legend(loc='lower right', bbox_to_anchor=(.5, 0, .5, 1))
return f, axarr