import tensorflow as tf import numpy as np from baselines.ppo2.model import Model class MicrobatchedModel(Model): """ Model that does training one microbatch at a time - when gradient computation on the entire minibatch causes some overflow """ def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train, nsteps, ent_coef, vf_coef, max_grad_norm, mpi_rank_weight, comm, microbatch_size): self.nmicrobatches = nbatch_train // microbatch_size self.microbatch_size = microbatch_size assert nbatch_train % microbatch_size == 0, 'microbatch_size ({}) should divide nbatch_train ({}) evenly'.format(microbatch_size, nbatch_train) super().__init__( policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nbatch_act, nbatch_train=microbatch_size, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, mpi_rank_weight=mpi_rank_weight, comm=comm) self.grads_ph = [tf.compat.v1.placeholder(dtype=g.dtype, shape=g.shape) for g in self.grads] grads_ph_and_vars = list(zip(self.grads_ph, self.var)) self._apply_gradients_op = self.trainer.apply_gradients(grads_ph_and_vars) def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None): assert states is None, "microbatches with recurrent models are not supported yet" # Here we calculate advantage A(s,a) = R + yV(s') - V(s) # Returns = R + yV(s') advs = returns - values # Normalize the advantages advs = (advs - advs.mean()) / (advs.std() + 1e-8) # Initialize empty list for per-microbatch stats like pg_loss, vf_loss, entropy, approxkl (whatever is in self.stats_list) stats_vs = [] for microbatch_idx in range(self.nmicrobatches): _sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx+1) * self.microbatch_size) td_map = { self.train_model.X: obs[_sli], self.A:actions[_sli], self.ADV:advs[_sli], self.R:returns[_sli], self.CLIPRANGE:cliprange, self.OLDNEGLOGPAC:neglogpacs[_sli], self.OLDVPRED:values[_sli] } # Compute gradient on a microbatch (note that variables do not change here) ... grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map) if microbatch_idx == 0: sum_grad_v = grad_v else: # .. and add to the total of the gradients for i, g in enumerate(grad_v): sum_grad_v[i] += g stats_vs.append(stats_v) feed_dict = {ph: sum_g / self.nmicrobatches for ph, sum_g in zip(self.grads_ph, sum_grad_v)} feed_dict[self.LR] = lr # Update variables using average of the gradients self.sess.run(self._apply_gradients_op, feed_dict) # Return average of the stats return np.mean(np.array(stats_vs), axis=0).tolist()