|
|
import tensorflow as tf |
|
|
import numpy as np |
|
|
import re |
|
|
|
|
|
|
|
|
from baselines.acktr.kfac_utils import * |
|
|
from functools import reduce |
|
|
|
|
|
KFAC_OPS = ['MatMul', 'Conv2D', 'BiasAdd'] |
|
|
KFAC_DEBUG = False |
|
|
|
|
|
|
|
|
class KfacOptimizer(): |
|
|
|
|
|
def __init__(self, learning_rate=0.01, momentum=0.9, clip_kl=0.01, kfac_update=2, stats_accum_iter=60, full_stats_init=False, cold_iter=100, cold_lr=None, is_async=False, async_stats=False, epsilon=1e-2, stats_decay=0.95, blockdiag_bias=False, channel_fac=False, factored_damping=False, approxT2=False, use_float64=False, weight_decay_dict={},max_grad_norm=0.5): |
|
|
self.max_grad_norm = max_grad_norm |
|
|
self._lr = learning_rate |
|
|
self._momentum = momentum |
|
|
self._clip_kl = clip_kl |
|
|
self._channel_fac = channel_fac |
|
|
self._kfac_update = kfac_update |
|
|
self._async = is_async |
|
|
self._async_stats = async_stats |
|
|
self._epsilon = epsilon |
|
|
self._stats_decay = stats_decay |
|
|
self._blockdiag_bias = blockdiag_bias |
|
|
self._approxT2 = approxT2 |
|
|
self._use_float64 = use_float64 |
|
|
self._factored_damping = factored_damping |
|
|
self._cold_iter = cold_iter |
|
|
if cold_lr == None: |
|
|
|
|
|
self._cold_lr = self._lr |
|
|
else: |
|
|
self._cold_lr = cold_lr |
|
|
self._stats_accum_iter = stats_accum_iter |
|
|
self._weight_decay_dict = weight_decay_dict |
|
|
self._diag_init_coeff = 0. |
|
|
self._full_stats_init = full_stats_init |
|
|
if not self._full_stats_init: |
|
|
self._stats_accum_iter = self._cold_iter |
|
|
|
|
|
self.sgd_step = tf.Variable(0, name='KFAC/sgd_step', trainable=False) |
|
|
self.global_step = tf.Variable( |
|
|
0, name='KFAC/global_step', trainable=False) |
|
|
self.cold_step = tf.Variable(0, name='KFAC/cold_step', trainable=False) |
|
|
self.factor_step = tf.Variable( |
|
|
0, name='KFAC/factor_step', trainable=False) |
|
|
self.stats_step = tf.Variable( |
|
|
0, name='KFAC/stats_step', trainable=False) |
|
|
self.vFv = tf.Variable(0., name='KFAC/vFv', trainable=False) |
|
|
|
|
|
self.factors = {} |
|
|
self.param_vars = [] |
|
|
self.stats = {} |
|
|
self.stats_eigen = {} |
|
|
|
|
|
def getFactors(self, g, varlist): |
|
|
graph = tf.compat.v1.get_default_graph() |
|
|
factorTensors = {} |
|
|
fpropTensors = [] |
|
|
bpropTensors = [] |
|
|
opTypes = [] |
|
|
fops = [] |
|
|
|
|
|
def searchFactors(gradient, graph): |
|
|
|
|
|
bpropOp = gradient.op |
|
|
bpropOp_name = bpropOp.name |
|
|
|
|
|
bTensors = [] |
|
|
fTensors = [] |
|
|
|
|
|
|
|
|
|
|
|
if 'AddN' in bpropOp_name: |
|
|
factors = [] |
|
|
for g in gradient.op.inputs: |
|
|
factors.append(searchFactors(g, graph)) |
|
|
op_names = [item['opName'] for item in factors] |
|
|
|
|
|
print (gradient.name) |
|
|
print (op_names) |
|
|
print (len(np.unique(op_names))) |
|
|
assert len(np.unique(op_names)) == 1, gradient.name + \ |
|
|
' is shared among different computation OPs' |
|
|
|
|
|
bTensors = reduce(lambda x, y: x + y, |
|
|
[item['bpropFactors'] for item in factors]) |
|
|
if len(factors[0]['fpropFactors']) > 0: |
|
|
fTensors = reduce( |
|
|
lambda x, y: x + y, [item['fpropFactors'] for item in factors]) |
|
|
fpropOp_name = op_names[0] |
|
|
fpropOp = factors[0]['op'] |
|
|
else: |
|
|
fpropOp_name = re.search( |
|
|
'gradientsSampled(_[0-9]+|)/(.+?)_grad', bpropOp_name).group(2) |
|
|
fpropOp = graph.get_operation_by_name(fpropOp_name) |
|
|
if fpropOp.op_def.name in KFAC_OPS: |
|
|
|
|
|
|
|
|
bTensor = [ |
|
|
i for i in bpropOp.inputs if 'gradientsSampled' in i.name][-1] |
|
|
bTensorShape = fpropOp.outputs[0].get_shape() |
|
|
if bTensor.get_shape()[0].value == None: |
|
|
bTensor.set_shape(bTensorShape) |
|
|
bTensors.append(bTensor) |
|
|
|
|
|
if fpropOp.op_def.name == 'BiasAdd': |
|
|
fTensors = [] |
|
|
else: |
|
|
fTensors.append( |
|
|
[i for i in fpropOp.inputs if param.op.name not in i.name][0]) |
|
|
fpropOp_name = fpropOp.op_def.name |
|
|
else: |
|
|
|
|
|
bInputsList = [i for i in bpropOp.inputs[ |
|
|
0].op.inputs if 'gradientsSampled' in i.name if 'Shape' not in i.name] |
|
|
if len(bInputsList) > 0: |
|
|
bTensor = bInputsList[0] |
|
|
bTensorShape = fpropOp.outputs[0].get_shape() |
|
|
if len(bTensor.get_shape()) > 0 and bTensor.get_shape()[0].value == None: |
|
|
bTensor.set_shape(bTensorShape) |
|
|
bTensors.append(bTensor) |
|
|
fpropOp_name = opTypes.append('UNK-' + fpropOp.op_def.name) |
|
|
|
|
|
return {'opName': fpropOp_name, 'op': fpropOp, 'fpropFactors': fTensors, 'bpropFactors': bTensors} |
|
|
|
|
|
for t, param in zip(g, varlist): |
|
|
if KFAC_DEBUG: |
|
|
print(('get factor for '+param.name)) |
|
|
factors = searchFactors(t, graph) |
|
|
factorTensors[param] = factors |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for param in varlist: |
|
|
factorTensors[param]['assnWeights'] = None |
|
|
factorTensors[param]['assnBias'] = None |
|
|
for param in varlist: |
|
|
if factorTensors[param]['opName'] == 'BiasAdd': |
|
|
factorTensors[param]['assnWeights'] = None |
|
|
for item in varlist: |
|
|
if len(factorTensors[item]['bpropFactors']) > 0: |
|
|
if (set(factorTensors[item]['bpropFactors']) == set(factorTensors[param]['bpropFactors'])) and (len(factorTensors[item]['fpropFactors']) > 0): |
|
|
factorTensors[param]['assnWeights'] = item |
|
|
factorTensors[item]['assnBias'] = param |
|
|
factorTensors[param]['bpropFactors'] = factorTensors[ |
|
|
item]['bpropFactors'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for key in ['fpropFactors', 'bpropFactors']: |
|
|
for i, param in enumerate(varlist): |
|
|
if len(factorTensors[param][key]) > 0: |
|
|
if (key + '_concat') not in factorTensors[param]: |
|
|
name_scope = factorTensors[param][key][0].name.split(':')[ |
|
|
0] |
|
|
with tf.compat.v1.name_scope(name_scope): |
|
|
factorTensors[param][ |
|
|
key + '_concat'] = tf.concat(factorTensors[param][key], 0) |
|
|
else: |
|
|
factorTensors[param][key + '_concat'] = None |
|
|
for j, param2 in enumerate(varlist[(i + 1):]): |
|
|
if (len(factorTensors[param][key]) > 0) and (set(factorTensors[param2][key]) == set(factorTensors[param][key])): |
|
|
factorTensors[param2][key] = factorTensors[param][key] |
|
|
factorTensors[param2][ |
|
|
key + '_concat'] = factorTensors[param][key + '_concat'] |
|
|
|
|
|
|
|
|
if KFAC_DEBUG: |
|
|
for items in zip(varlist, fpropTensors, bpropTensors, opTypes): |
|
|
print((items[0].name, factorTensors[item])) |
|
|
self.factors = factorTensors |
|
|
return factorTensors |
|
|
|
|
|
def getStats(self, factors, varlist): |
|
|
if len(self.stats) == 0: |
|
|
|
|
|
|
|
|
with tf.device('/cpu'): |
|
|
tmpStatsCache = {} |
|
|
|
|
|
|
|
|
|
|
|
for var in varlist: |
|
|
fpropFactor = factors[var]['fpropFactors_concat'] |
|
|
bpropFactor = factors[var]['bpropFactors_concat'] |
|
|
opType = factors[var]['opName'] |
|
|
if opType == 'Conv2D': |
|
|
Kh = var.get_shape()[0] |
|
|
Kw = var.get_shape()[1] |
|
|
C = fpropFactor.get_shape()[-1] |
|
|
|
|
|
Oh = bpropFactor.get_shape()[1] |
|
|
Ow = bpropFactor.get_shape()[2] |
|
|
if Oh == 1 and Ow == 1 and self._channel_fac: |
|
|
|
|
|
|
|
|
var_assnBias = factors[var]['assnBias'] |
|
|
if var_assnBias: |
|
|
factors[var]['assnBias'] = None |
|
|
factors[var_assnBias]['assnWeights'] = None |
|
|
|
|
|
|
|
|
for var in varlist: |
|
|
fpropFactor = factors[var]['fpropFactors_concat'] |
|
|
bpropFactor = factors[var]['bpropFactors_concat'] |
|
|
opType = factors[var]['opName'] |
|
|
self.stats[var] = {'opName': opType, |
|
|
'fprop_concat_stats': [], |
|
|
'bprop_concat_stats': [], |
|
|
'assnWeights': factors[var]['assnWeights'], |
|
|
'assnBias': factors[var]['assnBias'], |
|
|
} |
|
|
if fpropFactor is not None: |
|
|
if fpropFactor not in tmpStatsCache: |
|
|
if opType == 'Conv2D': |
|
|
Kh = var.get_shape()[0] |
|
|
Kw = var.get_shape()[1] |
|
|
C = fpropFactor.get_shape()[-1] |
|
|
|
|
|
Oh = bpropFactor.get_shape()[1] |
|
|
Ow = bpropFactor.get_shape()[2] |
|
|
if Oh == 1 and Ow == 1 and self._channel_fac: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fpropFactor2_size = Kh * Kw |
|
|
slot_fpropFactor_stats2 = tf.Variable(tf.linalg.tensor_diag(tf.ones( |
|
|
[fpropFactor2_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False) |
|
|
self.stats[var]['fprop_concat_stats'].append( |
|
|
slot_fpropFactor_stats2) |
|
|
|
|
|
fpropFactor_size = C |
|
|
else: |
|
|
|
|
|
|
|
|
fpropFactor_size = Kh * Kw * C |
|
|
else: |
|
|
|
|
|
fpropFactor_size = fpropFactor.get_shape()[-1] |
|
|
|
|
|
|
|
|
if not self._blockdiag_bias and self.stats[var]['assnBias']: |
|
|
fpropFactor_size += 1 |
|
|
|
|
|
slot_fpropFactor_stats = tf.Variable(tf.linalg.tensor_diag(tf.ones( |
|
|
[fpropFactor_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False) |
|
|
self.stats[var]['fprop_concat_stats'].append( |
|
|
slot_fpropFactor_stats) |
|
|
if opType != 'Conv2D': |
|
|
tmpStatsCache[fpropFactor] = self.stats[ |
|
|
var]['fprop_concat_stats'] |
|
|
else: |
|
|
self.stats[var][ |
|
|
'fprop_concat_stats'] = tmpStatsCache[fpropFactor] |
|
|
|
|
|
if bpropFactor is not None: |
|
|
|
|
|
|
|
|
if not((not self._blockdiag_bias) and self.stats[var]['assnWeights']): |
|
|
if bpropFactor not in tmpStatsCache: |
|
|
slot_bpropFactor_stats = tf.Variable(tf.linalg.tensor_diag(tf.ones([bpropFactor.get_shape( |
|
|
)[-1]])) * self._diag_init_coeff, name='KFAC_STATS/' + bpropFactor.op.name, trainable=False) |
|
|
self.stats[var]['bprop_concat_stats'].append( |
|
|
slot_bpropFactor_stats) |
|
|
tmpStatsCache[bpropFactor] = self.stats[ |
|
|
var]['bprop_concat_stats'] |
|
|
else: |
|
|
self.stats[var][ |
|
|
'bprop_concat_stats'] = tmpStatsCache[bpropFactor] |
|
|
|
|
|
return self.stats |
|
|
|
|
|
def compute_and_apply_stats(self, loss_sampled, var_list=None): |
|
|
varlist = var_list |
|
|
if varlist is None: |
|
|
varlist = tf.compat.v1.trainable_variables() |
|
|
|
|
|
stats = self.compute_stats(loss_sampled, var_list=varlist) |
|
|
return self.apply_stats(stats) |
|
|
|
|
|
def compute_stats(self, loss_sampled, var_list=None): |
|
|
varlist = var_list |
|
|
if varlist is None: |
|
|
varlist = tf.compat.v1.trainable_variables() |
|
|
|
|
|
gs = tf.gradients(ys=loss_sampled, xs=varlist, name='gradientsSampled') |
|
|
self.gs = gs |
|
|
factors = self.getFactors(gs, varlist) |
|
|
stats = self.getStats(factors, varlist) |
|
|
|
|
|
updateOps = [] |
|
|
statsUpdates = {} |
|
|
statsUpdates_cache = {} |
|
|
for var in varlist: |
|
|
opType = factors[var]['opName'] |
|
|
fops = factors[var]['op'] |
|
|
fpropFactor = factors[var]['fpropFactors_concat'] |
|
|
fpropStats_vars = stats[var]['fprop_concat_stats'] |
|
|
bpropFactor = factors[var]['bpropFactors_concat'] |
|
|
bpropStats_vars = stats[var]['bprop_concat_stats'] |
|
|
SVD_factors = {} |
|
|
for stats_var in fpropStats_vars: |
|
|
stats_var_dim = int(stats_var.get_shape()[0]) |
|
|
if stats_var not in statsUpdates_cache: |
|
|
old_fpropFactor = fpropFactor |
|
|
B = (tf.shape(input=fpropFactor)[0]) |
|
|
if opType == 'Conv2D': |
|
|
strides = fops.get_attr("strides") |
|
|
padding = fops.get_attr("padding") |
|
|
convkernel_size = var.get_shape()[0:3] |
|
|
|
|
|
KH = int(convkernel_size[0]) |
|
|
KW = int(convkernel_size[1]) |
|
|
C = int(convkernel_size[2]) |
|
|
flatten_size = int(KH * KW * C) |
|
|
|
|
|
Oh = int(bpropFactor.get_shape()[1]) |
|
|
Ow = int(bpropFactor.get_shape()[2]) |
|
|
|
|
|
if Oh == 1 and Ow == 1 and self._channel_fac: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(SVD_factors) == 0: |
|
|
if KFAC_DEBUG: |
|
|
print(('approx %s act factor with rank-1 SVD factors' % (var.name))) |
|
|
|
|
|
S, U, V = tf.batch_svd(tf.reshape( |
|
|
fpropFactor, [-1, KH * KW, C])) |
|
|
|
|
|
sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1) |
|
|
patches_k = U[:, :, 0] * sqrtS1 |
|
|
full_factor_shape = fpropFactor.get_shape() |
|
|
patches_k.set_shape( |
|
|
[full_factor_shape[0], KH * KW]) |
|
|
patches_c = V[:, :, 0] * sqrtS1 |
|
|
patches_c.set_shape([full_factor_shape[0], C]) |
|
|
SVD_factors[C] = patches_c |
|
|
SVD_factors[KH * KW] = patches_k |
|
|
fpropFactor = SVD_factors[stats_var_dim] |
|
|
|
|
|
else: |
|
|
|
|
|
patches = tf.image.extract_patches(fpropFactor, sizes=[1, convkernel_size[ |
|
|
0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding) |
|
|
|
|
|
if self._approxT2: |
|
|
if KFAC_DEBUG: |
|
|
print(('approxT2 act fisher for %s' % (var.name))) |
|
|
|
|
|
fpropFactor = tf.reduce_mean(input_tensor=patches, axis=[1, 2]) |
|
|
else: |
|
|
|
|
|
fpropFactor = tf.reshape( |
|
|
patches, [-1, flatten_size]) / Oh / Ow |
|
|
fpropFactor_size = int(fpropFactor.get_shape()[-1]) |
|
|
if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias: |
|
|
if opType == 'Conv2D' and not self._approxT2: |
|
|
|
|
|
|
|
|
fpropFactor = tf.concat([fpropFactor, tf.ones( |
|
|
[tf.shape(input=fpropFactor)[0], 1]) / Oh / Ow], 1) |
|
|
else: |
|
|
|
|
|
fpropFactor = tf.concat( |
|
|
[fpropFactor, tf.ones([tf.shape(input=fpropFactor)[0], 1])], 1) |
|
|
|
|
|
|
|
|
|
|
|
cov = tf.matmul(fpropFactor, fpropFactor, |
|
|
transpose_a=True) / tf.cast(B, tf.float32) |
|
|
updateOps.append(cov) |
|
|
statsUpdates[stats_var] = cov |
|
|
if opType != 'Conv2D': |
|
|
|
|
|
|
|
|
statsUpdates_cache[stats_var] = cov |
|
|
|
|
|
for stats_var in bpropStats_vars: |
|
|
stats_var_dim = int(stats_var.get_shape()[0]) |
|
|
if stats_var not in statsUpdates_cache: |
|
|
old_bpropFactor = bpropFactor |
|
|
bpropFactor_shape = bpropFactor.get_shape() |
|
|
B = tf.shape(input=bpropFactor)[0] |
|
|
C = int(bpropFactor_shape[-1]) |
|
|
if opType == 'Conv2D' or len(bpropFactor_shape) == 4: |
|
|
if fpropFactor is not None: |
|
|
if self._approxT2: |
|
|
if KFAC_DEBUG: |
|
|
print(('approxT2 grad fisher for %s' % (var.name))) |
|
|
bpropFactor = tf.reduce_sum( |
|
|
input_tensor=bpropFactor, axis=[1, 2]) |
|
|
else: |
|
|
bpropFactor = tf.reshape( |
|
|
bpropFactor, [-1, C]) * Oh * Ow |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
if KFAC_DEBUG: |
|
|
print(('block diag approx fisher for %s' % (var.name))) |
|
|
bpropFactor = tf.reduce_sum(input_tensor=bpropFactor, axis=[1, 2]) |
|
|
|
|
|
|
|
|
|
|
|
bpropFactor *= tf.cast(B, dtype=tf.float32) |
|
|
|
|
|
|
|
|
cov_b = tf.matmul( |
|
|
bpropFactor, bpropFactor, transpose_a=True) / tf.cast(tf.shape(input=bpropFactor)[0], dtype=tf.float32) |
|
|
|
|
|
updateOps.append(cov_b) |
|
|
statsUpdates[stats_var] = cov_b |
|
|
statsUpdates_cache[stats_var] = cov_b |
|
|
|
|
|
if KFAC_DEBUG: |
|
|
aKey = list(statsUpdates.keys())[0] |
|
|
statsUpdates[aKey] = tf.compat.v1.Print(statsUpdates[aKey], |
|
|
[tf.convert_to_tensor(value='step:'), |
|
|
self.global_step, |
|
|
tf.convert_to_tensor( |
|
|
value='computing stats'), |
|
|
]) |
|
|
self.statsUpdates = statsUpdates |
|
|
return statsUpdates |
|
|
|
|
|
def apply_stats(self, statsUpdates): |
|
|
""" compute stats and update/apply the new stats to the running average |
|
|
""" |
|
|
|
|
|
def updateAccumStats(): |
|
|
if self._full_stats_init: |
|
|
return tf.cond(pred=tf.greater(self.sgd_step, self._cold_iter), true_fn=lambda: tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter)), false_fn=tf.no_op) |
|
|
else: |
|
|
return tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter)) |
|
|
|
|
|
def updateRunningAvgStats(statsUpdates, fac_iter=1): |
|
|
|
|
|
|
|
|
|
|
|
return tf.group(*self._apply_stats(statsUpdates)) |
|
|
|
|
|
if self._async_stats: |
|
|
|
|
|
update_stats = self._apply_stats(statsUpdates) |
|
|
|
|
|
queue = tf.queue.FIFOQueue(1, [item.dtype for item in update_stats], shapes=[ |
|
|
item.get_shape() for item in update_stats]) |
|
|
enqueue_op = queue.enqueue(update_stats) |
|
|
|
|
|
def dequeue_stats_op(): |
|
|
return queue.dequeue() |
|
|
self.qr_stats = tf.compat.v1.train.QueueRunner(queue, [enqueue_op]) |
|
|
update_stats_op = tf.cond(pred=tf.equal(queue.size(), tf.convert_to_tensor( |
|
|
value=0)), true_fn=tf.no_op, false_fn=lambda: tf.group(*[dequeue_stats_op(), ])) |
|
|
else: |
|
|
|
|
|
update_stats_op = tf.cond(pred=tf.greater_equal( |
|
|
self.stats_step, self._stats_accum_iter), true_fn=lambda: updateRunningAvgStats(statsUpdates), false_fn=updateAccumStats) |
|
|
self._update_stats_op = update_stats_op |
|
|
return update_stats_op |
|
|
|
|
|
def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.): |
|
|
updateOps = [] |
|
|
|
|
|
for stats_var in statsUpdates: |
|
|
stats_new = statsUpdates[stats_var] |
|
|
if accumulate: |
|
|
|
|
|
update_op = tf.compat.v1.assign_add( |
|
|
stats_var, accumulateCoeff * stats_new, use_locking=True) |
|
|
else: |
|
|
|
|
|
update_op = tf.compat.v1.assign( |
|
|
stats_var, stats_var * self._stats_decay, use_locking=True) |
|
|
update_op = tf.compat.v1.assign_add( |
|
|
update_op, (1. - self._stats_decay) * stats_new, use_locking=True) |
|
|
updateOps.append(update_op) |
|
|
|
|
|
with tf.control_dependencies(updateOps): |
|
|
stats_step_op = tf.compat.v1.assign_add(self.stats_step, 1) |
|
|
|
|
|
if KFAC_DEBUG: |
|
|
stats_step_op = (tf.compat.v1.Print(stats_step_op, |
|
|
[tf.convert_to_tensor(value='step:'), |
|
|
self.global_step, |
|
|
tf.convert_to_tensor(value='fac step:'), |
|
|
self.factor_step, |
|
|
tf.convert_to_tensor(value='sgd step:'), |
|
|
self.sgd_step, |
|
|
tf.convert_to_tensor(value='Accum:'), |
|
|
tf.convert_to_tensor(value=accumulate), |
|
|
tf.convert_to_tensor(value='Accum coeff:'), |
|
|
tf.convert_to_tensor(value=accumulateCoeff), |
|
|
tf.convert_to_tensor(value='stat step:'), |
|
|
self.stats_step, updateOps[0], updateOps[1]])) |
|
|
return [stats_step_op, ] |
|
|
|
|
|
def getStatsEigen(self, stats=None): |
|
|
if len(self.stats_eigen) == 0: |
|
|
stats_eigen = {} |
|
|
if stats is None: |
|
|
stats = self.stats |
|
|
|
|
|
tmpEigenCache = {} |
|
|
with tf.device('/cpu:0'): |
|
|
for var in stats: |
|
|
for key in ['fprop_concat_stats', 'bprop_concat_stats']: |
|
|
for stats_var in stats[var][key]: |
|
|
if stats_var not in tmpEigenCache: |
|
|
stats_dim = stats_var.get_shape()[1].value |
|
|
e = tf.Variable(tf.ones( |
|
|
[stats_dim]), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/e', trainable=False) |
|
|
Q = tf.Variable(tf.linalg.tensor_diag(tf.ones( |
|
|
[stats_dim])), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/Q', trainable=False) |
|
|
stats_eigen[stats_var] = {'e': e, 'Q': Q} |
|
|
tmpEigenCache[ |
|
|
stats_var] = stats_eigen[stats_var] |
|
|
else: |
|
|
stats_eigen[stats_var] = tmpEigenCache[ |
|
|
stats_var] |
|
|
self.stats_eigen = stats_eigen |
|
|
return self.stats_eigen |
|
|
|
|
|
def computeStatsEigen(self): |
|
|
""" compute the eigen decomp using copied var stats to avoid concurrent read/write from other queue """ |
|
|
|
|
|
|
|
|
with tf.device('/cpu:0'): |
|
|
def removeNone(tensor_list): |
|
|
local_list = [] |
|
|
for item in tensor_list: |
|
|
if item is not None: |
|
|
local_list.append(item) |
|
|
return local_list |
|
|
|
|
|
def copyStats(var_list): |
|
|
print("copying stats to buffer tensors before eigen decomp") |
|
|
redundant_stats = {} |
|
|
copied_list = [] |
|
|
for item in var_list: |
|
|
if item is not None: |
|
|
if item not in redundant_stats: |
|
|
if self._use_float64: |
|
|
redundant_stats[item] = tf.cast( |
|
|
tf.identity(item), tf.float64) |
|
|
else: |
|
|
redundant_stats[item] = tf.identity(item) |
|
|
copied_list.append(redundant_stats[item]) |
|
|
else: |
|
|
copied_list.append(None) |
|
|
return copied_list |
|
|
|
|
|
|
|
|
|
|
|
stats_eigen = self.stats_eigen |
|
|
computedEigen = {} |
|
|
eigen_reverse_lookup = {} |
|
|
updateOps = [] |
|
|
|
|
|
|
|
|
|
|
|
with tf.control_dependencies([]): |
|
|
for stats_var in stats_eigen: |
|
|
if stats_var not in computedEigen: |
|
|
eigens = tf.linalg.eigh(stats_var) |
|
|
e = eigens[0] |
|
|
Q = eigens[1] |
|
|
if self._use_float64: |
|
|
e = tf.cast(e, tf.float32) |
|
|
Q = tf.cast(Q, tf.float32) |
|
|
updateOps.append(e) |
|
|
updateOps.append(Q) |
|
|
computedEigen[stats_var] = {'e': e, 'Q': Q} |
|
|
eigen_reverse_lookup[e] = stats_eigen[stats_var]['e'] |
|
|
eigen_reverse_lookup[Q] = stats_eigen[stats_var]['Q'] |
|
|
|
|
|
self.eigen_reverse_lookup = eigen_reverse_lookup |
|
|
self.eigen_update_list = updateOps |
|
|
|
|
|
if KFAC_DEBUG: |
|
|
self.eigen_update_list = [item for item in updateOps] |
|
|
with tf.control_dependencies(updateOps): |
|
|
updateOps.append(tf.compat.v1.Print(tf.constant( |
|
|
0.), [tf.convert_to_tensor(value='computed factor eigen')])) |
|
|
|
|
|
return updateOps |
|
|
|
|
|
def applyStatsEigen(self, eigen_list): |
|
|
updateOps = [] |
|
|
print(('updating %d eigenvalue/vectors' % len(eigen_list))) |
|
|
for i, (tensor, mark) in enumerate(zip(eigen_list, self.eigen_update_list)): |
|
|
stats_eigen_var = self.eigen_reverse_lookup[mark] |
|
|
updateOps.append( |
|
|
tf.compat.v1.assign(stats_eigen_var, tensor, use_locking=True)) |
|
|
|
|
|
with tf.control_dependencies(updateOps): |
|
|
factor_step_op = tf.compat.v1.assign_add(self.factor_step, 1) |
|
|
updateOps.append(factor_step_op) |
|
|
if KFAC_DEBUG: |
|
|
updateOps.append(tf.compat.v1.Print(tf.constant( |
|
|
0.), [tf.convert_to_tensor(value='updated kfac factors')])) |
|
|
return updateOps |
|
|
|
|
|
def getKfacPrecondUpdates(self, gradlist, varlist): |
|
|
updatelist = [] |
|
|
vg = 0. |
|
|
|
|
|
assert len(self.stats) > 0 |
|
|
assert len(self.stats_eigen) > 0 |
|
|
assert len(self.factors) > 0 |
|
|
counter = 0 |
|
|
|
|
|
grad_dict = {var: grad for grad, var in zip(gradlist, varlist)} |
|
|
|
|
|
for grad, var in zip(gradlist, varlist): |
|
|
GRAD_RESHAPE = False |
|
|
GRAD_TRANSPOSE = False |
|
|
|
|
|
fpropFactoredFishers = self.stats[var]['fprop_concat_stats'] |
|
|
bpropFactoredFishers = self.stats[var]['bprop_concat_stats'] |
|
|
|
|
|
if (len(fpropFactoredFishers) + len(bpropFactoredFishers)) > 0: |
|
|
counter += 1 |
|
|
GRAD_SHAPE = grad.get_shape() |
|
|
if len(grad.get_shape()) > 2: |
|
|
|
|
|
KW = int(grad.get_shape()[0]) |
|
|
KH = int(grad.get_shape()[1]) |
|
|
C = int(grad.get_shape()[2]) |
|
|
D = int(grad.get_shape()[3]) |
|
|
|
|
|
if len(fpropFactoredFishers) > 1 and self._channel_fac: |
|
|
|
|
|
grad = tf.reshape(grad, [KW * KH, C, D]) |
|
|
else: |
|
|
|
|
|
grad = tf.reshape(grad, [-1, D]) |
|
|
GRAD_RESHAPE = True |
|
|
elif len(grad.get_shape()) == 1: |
|
|
|
|
|
D = int(grad.get_shape()[0]) |
|
|
|
|
|
grad = tf.expand_dims(grad, 0) |
|
|
GRAD_RESHAPE = True |
|
|
else: |
|
|
|
|
|
C = int(grad.get_shape()[0]) |
|
|
D = int(grad.get_shape()[1]) |
|
|
|
|
|
if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias: |
|
|
|
|
|
|
|
|
|
|
|
var_assnBias = self.stats[var]['assnBias'] |
|
|
grad = tf.concat( |
|
|
[grad, tf.expand_dims(grad_dict[var_assnBias], 0)], 0) |
|
|
|
|
|
|
|
|
|
|
|
eigVals = [] |
|
|
|
|
|
for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']): |
|
|
Q = self.stats_eigen[stats]['Q'] |
|
|
e = detectMinVal(self.stats_eigen[stats][ |
|
|
'e'], var, name='act', debug=KFAC_DEBUG) |
|
|
|
|
|
Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='act') |
|
|
eigVals.append(e) |
|
|
grad = gmatmul(Q, grad, transpose_a=True, reduce_dim=idx) |
|
|
|
|
|
for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']): |
|
|
Q = self.stats_eigen[stats]['Q'] |
|
|
e = detectMinVal(self.stats_eigen[stats][ |
|
|
'e'], var, name='grad', debug=KFAC_DEBUG) |
|
|
|
|
|
Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='grad') |
|
|
eigVals.append(e) |
|
|
grad = gmatmul(grad, Q, transpose_b=False, reduce_dim=idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weightDecayCoeff = 0. |
|
|
if var in self._weight_decay_dict: |
|
|
weightDecayCoeff = self._weight_decay_dict[var] |
|
|
if KFAC_DEBUG: |
|
|
print(('weight decay coeff for %s is %f' % (var.name, weightDecayCoeff))) |
|
|
|
|
|
if self._factored_damping: |
|
|
if KFAC_DEBUG: |
|
|
print(('use factored damping for %s' % (var.name))) |
|
|
coeffs = 1. |
|
|
num_factors = len(eigVals) |
|
|
|
|
|
|
|
|
if len(eigVals) == 1: |
|
|
damping = self._epsilon + weightDecayCoeff |
|
|
else: |
|
|
damping = tf.pow( |
|
|
self._epsilon + weightDecayCoeff, 1. / num_factors) |
|
|
eigVals_tnorm_avg = [tf.reduce_mean( |
|
|
input_tensor=tf.abs(e)) for e in eigVals] |
|
|
for e, e_tnorm in zip(eigVals, eigVals_tnorm_avg): |
|
|
eig_tnorm_negList = [ |
|
|
item for item in eigVals_tnorm_avg if item != e_tnorm] |
|
|
if len(eigVals) == 1: |
|
|
adjustment = 1. |
|
|
elif len(eigVals) == 2: |
|
|
adjustment = tf.sqrt( |
|
|
e_tnorm / eig_tnorm_negList[0]) |
|
|
else: |
|
|
eig_tnorm_negList_prod = reduce( |
|
|
lambda x, y: x * y, eig_tnorm_negList) |
|
|
adjustment = tf.pow( |
|
|
tf.pow(e_tnorm, num_factors - 1.) / eig_tnorm_negList_prod, 1. / num_factors) |
|
|
coeffs *= (e + adjustment * damping) |
|
|
else: |
|
|
coeffs = 1. |
|
|
damping = (self._epsilon + weightDecayCoeff) |
|
|
for e in eigVals: |
|
|
coeffs *= e |
|
|
coeffs += damping |
|
|
|
|
|
|
|
|
|
|
|
grad /= coeffs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']): |
|
|
Q = self.stats_eigen[stats]['Q'] |
|
|
grad = gmatmul(Q, grad, transpose_a=False, reduce_dim=idx) |
|
|
|
|
|
for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']): |
|
|
Q = self.stats_eigen[stats]['Q'] |
|
|
grad = gmatmul(grad, Q, transpose_b=True, reduce_dim=idx) |
|
|
|
|
|
|
|
|
|
|
|
if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias: |
|
|
|
|
|
|
|
|
|
|
|
var_assnBias = self.stats[var]['assnBias'] |
|
|
C_plus_one = int(grad.get_shape()[0]) |
|
|
grad_assnBias = tf.reshape(tf.slice(grad, |
|
|
begin=[ |
|
|
C_plus_one - 1, 0], |
|
|
size=[1, -1]), var_assnBias.get_shape()) |
|
|
grad_assnWeights = tf.slice(grad, |
|
|
begin=[0, 0], |
|
|
size=[C_plus_one - 1, -1]) |
|
|
grad_dict[var_assnBias] = grad_assnBias |
|
|
grad = grad_assnWeights |
|
|
|
|
|
|
|
|
if GRAD_RESHAPE: |
|
|
grad = tf.reshape(grad, GRAD_SHAPE) |
|
|
|
|
|
grad_dict[var] = grad |
|
|
|
|
|
print(('projecting %d gradient matrices' % counter)) |
|
|
|
|
|
for g, var in zip(gradlist, varlist): |
|
|
grad = grad_dict[var] |
|
|
|
|
|
if KFAC_DEBUG: |
|
|
print(('apply clipping to %s' % (var.name))) |
|
|
tf.compat.v1.Print(grad, [tf.sqrt(tf.reduce_sum(input_tensor=tf.pow(grad, 2)))], "Euclidean norm of new grad") |
|
|
local_vg = tf.reduce_sum(input_tensor=grad * g * (self._lr * self._lr)) |
|
|
vg += local_vg |
|
|
|
|
|
|
|
|
if KFAC_DEBUG: |
|
|
print('apply vFv clipping') |
|
|
|
|
|
scaling = tf.minimum(1., tf.sqrt(self._clip_kl / vg)) |
|
|
if KFAC_DEBUG: |
|
|
scaling = tf.compat.v1.Print(scaling, [tf.convert_to_tensor( |
|
|
value='clip: '), scaling, tf.convert_to_tensor(value=' vFv: '), vg]) |
|
|
with tf.control_dependencies([tf.compat.v1.assign(self.vFv, vg)]): |
|
|
updatelist = [grad_dict[var] for var in varlist] |
|
|
for i, item in enumerate(updatelist): |
|
|
updatelist[i] = scaling * item |
|
|
|
|
|
return updatelist |
|
|
|
|
|
def compute_gradients(self, loss, var_list=None): |
|
|
varlist = var_list |
|
|
if varlist is None: |
|
|
varlist = tf.compat.v1.trainable_variables() |
|
|
g = tf.gradients(ys=loss, xs=varlist) |
|
|
|
|
|
return [(a, b) for a, b in zip(g, varlist)] |
|
|
|
|
|
def apply_gradients_kfac(self, grads): |
|
|
g, varlist = list(zip(*grads)) |
|
|
|
|
|
if len(self.stats_eigen) == 0: |
|
|
self.getStatsEigen() |
|
|
|
|
|
qr = None |
|
|
|
|
|
if self._async: |
|
|
print('Use async eigen decomp') |
|
|
|
|
|
factorOps_dummy = self.computeStatsEigen() |
|
|
|
|
|
|
|
|
queue = tf.queue.FIFOQueue(1, [item.dtype for item in factorOps_dummy], shapes=[ |
|
|
item.get_shape() for item in factorOps_dummy]) |
|
|
enqueue_op = tf.cond(pred=tf.logical_and(tf.equal(tf.math.floormod(self.stats_step, self._kfac_update), tf.convert_to_tensor( |
|
|
value=0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), true_fn=lambda: queue.enqueue(self.computeStatsEigen()), false_fn=tf.no_op) |
|
|
|
|
|
def dequeue_op(): |
|
|
return queue.dequeue() |
|
|
|
|
|
qr = tf.compat.v1.train.QueueRunner(queue, [enqueue_op]) |
|
|
|
|
|
updateOps = [] |
|
|
global_step_op = tf.compat.v1.assign_add(self.global_step, 1) |
|
|
updateOps.append(global_step_op) |
|
|
|
|
|
with tf.control_dependencies([global_step_op]): |
|
|
|
|
|
|
|
|
assert self._update_stats_op != None |
|
|
updateOps.append(self._update_stats_op) |
|
|
dependency_list = [] |
|
|
if not self._async: |
|
|
dependency_list.append(self._update_stats_op) |
|
|
|
|
|
with tf.control_dependencies(dependency_list): |
|
|
def no_op_wrapper(): |
|
|
return tf.group(*[tf.compat.v1.assign_add(self.cold_step, 1)]) |
|
|
|
|
|
if not self._async: |
|
|
|
|
|
updateFactorOps = tf.cond(pred=tf.logical_and(tf.equal(tf.math.floormod(self.stats_step, self._kfac_update), |
|
|
tf.convert_to_tensor(value=0)), |
|
|
tf.greater_equal(self.stats_step, self._stats_accum_iter)), true_fn=lambda: tf.group(*self.applyStatsEigen(self.computeStatsEigen())), false_fn=no_op_wrapper) |
|
|
else: |
|
|
|
|
|
updateFactorOps = tf.cond(pred=tf.greater_equal(self.stats_step, self._stats_accum_iter), |
|
|
true_fn=lambda: tf.cond(pred=tf.equal(queue.size(), tf.convert_to_tensor(value=0)), |
|
|
true_fn=tf.no_op, |
|
|
|
|
|
false_fn=lambda: tf.group( |
|
|
*self.applyStatsEigen(dequeue_op())), |
|
|
), |
|
|
false_fn=no_op_wrapper) |
|
|
|
|
|
updateOps.append(updateFactorOps) |
|
|
|
|
|
with tf.control_dependencies([updateFactorOps]): |
|
|
def gradOp(): |
|
|
return list(g) |
|
|
|
|
|
def getKfacGradOp(): |
|
|
return self.getKfacPrecondUpdates(g, varlist) |
|
|
u = tf.cond(pred=tf.greater(self.factor_step, |
|
|
tf.convert_to_tensor(value=0)), true_fn=getKfacGradOp, false_fn=gradOp) |
|
|
|
|
|
optim = tf.compat.v1.train.MomentumOptimizer( |
|
|
self._lr * (1. - self._momentum), self._momentum) |
|
|
|
|
|
|
|
|
def optimOp(): |
|
|
def updateOptimOp(): |
|
|
if self._full_stats_init: |
|
|
return tf.cond(pred=tf.greater(self.factor_step, tf.convert_to_tensor(value=0)), true_fn=lambda: optim.apply_gradients(list(zip(u, varlist))), false_fn=tf.no_op) |
|
|
else: |
|
|
return optim.apply_gradients(list(zip(u, varlist))) |
|
|
if self._full_stats_init: |
|
|
return tf.cond(pred=tf.greater_equal(self.stats_step, self._stats_accum_iter), true_fn=updateOptimOp, false_fn=tf.no_op) |
|
|
else: |
|
|
return tf.cond(pred=tf.greater_equal(self.sgd_step, self._cold_iter), true_fn=updateOptimOp, false_fn=tf.no_op) |
|
|
updateOps.append(optimOp()) |
|
|
|
|
|
return tf.group(*updateOps), qr |
|
|
|
|
|
def apply_gradients(self, grads): |
|
|
coldOptim = tf.compat.v1.train.MomentumOptimizer( |
|
|
self._cold_lr, self._momentum) |
|
|
|
|
|
def coldSGDstart(): |
|
|
sgd_grads, sgd_var = zip(*grads) |
|
|
|
|
|
if self.max_grad_norm != None: |
|
|
sgd_grads, sgd_grad_norm = tf.clip_by_global_norm(sgd_grads,self.max_grad_norm) |
|
|
|
|
|
sgd_grads = list(zip(sgd_grads,sgd_var)) |
|
|
|
|
|
sgd_step_op = tf.compat.v1.assign_add(self.sgd_step, 1) |
|
|
coldOptim_op = coldOptim.apply_gradients(sgd_grads) |
|
|
if KFAC_DEBUG: |
|
|
with tf.control_dependencies([sgd_step_op, coldOptim_op]): |
|
|
sgd_step_op = tf.compat.v1.Print( |
|
|
sgd_step_op, [self.sgd_step, tf.convert_to_tensor(value='doing cold sgd step')]) |
|
|
return tf.group(*[sgd_step_op, coldOptim_op]) |
|
|
|
|
|
kfacOptim_op, qr = self.apply_gradients_kfac(grads) |
|
|
|
|
|
def warmKFACstart(): |
|
|
return kfacOptim_op |
|
|
|
|
|
return tf.cond(pred=tf.greater(self.sgd_step, self._cold_iter), true_fn=warmKFACstart, false_fn=coldSGDstart), qr |
|
|
|
|
|
def minimize(self, loss, loss_sampled, var_list=None): |
|
|
grads = self.compute_gradients(loss, var_list=var_list) |
|
|
update_stats_op = self.compute_and_apply_stats( |
|
|
loss_sampled, var_list=var_list) |
|
|
return self.apply_gradients(grads) |
|
|
|