I’m trying to add an early stopping mechanism to my training with the TensorFlow Object Detection API. I found a piece of code online that I used as a reference. Below is my version of the EarlyStoppingHook class that I adapted from the online example:
class EarlyStoppingHook(session_run_hook.SessionRunHook):
"""Hook that stops training based on a certain condition."""
def __init__(self, metric='validation_loss', delta=0, patience=0,
mode='auto'):
self.metric = metric
self.patience = patience
self.delta = delta
self.wait_time = 0
self.max_wait_time = 0
self.current_step = 0
if mode not in ['auto', 'min', 'max']:
logging.warning('Unknown mode %s, using auto mode.', mode)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'accuracy' in self.metric:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.delta *= 1
else:
self.delta *= -1
self.best_value = np.Inf if self.monitor_op == np.less else -np.Inf
def begin(self):
graph = tf.get_default_graph()
self.metric = graph.as_graph_element(self.metric)
if isinstance(self.metric, tf.Operation):
self.metric = self.metric.outputs[0]
def before_run(self, run_context):
return session_run_hook.SessionRunArgs(self.metric)
def after_run(self, run_context, run_values):
self.current_step += 1
current_value = run_values.results
if self.current_step % 200 == 0:
print(f"Current value: {current_value}, best: {self.best_value}, wait: {self.wait_time}, max wait: {self.max_wait_time}")
if self.monitor_op(current_value - self.delta, self.best_value):
self.best_value = current_value
if self.max_wait_time < self.wait_time:
self.max_wait_time = self.wait_time
self.wait_time = 0
else:
self.wait_time += 1
if self.wait_time >= self.patience:
run_context.request_stop()
I’m using the class in this manner:
early_stopping_hook = EarlyStoppingHook(
metric='total_loss',
patience=2000)
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=train_steps, hooks=[early_stopping_hook])
What confuses me is the term ‘total_loss’. Is that referring to validation loss or training loss? Also, where are the various losses like ‘total_loss’, ‘loss_1’, and ‘loss_2’ defined within the TensorFlow Object Detection API? I’m struggling to find clear information on what these loss names mean.