Skip to content

Instantly share code, notes, and snippets.

@hadilou
Forked from ortegatron/LossEvalHook.py
Created December 8, 2020 08:40
Show Gist options
  • Save hadilou/186ef529520c8a72116d971cd5b6920e to your computer and use it in GitHub Desktop.
Save hadilou/186ef529520c8a72116d971cd5b6920e to your computer and use it in GitHub Desktop.
Trainer with Loss on Validation for Detectron2
from detectron2.engine.hooks import HookBase
from detectron2.evaluation import inference_context
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import DatasetMapper, build_detection_test_loader
import detectron2.utils.comm as comm
import torch
import time
import datetime
class LossEvalHook(HookBase):
def __init__(self, eval_period, model, data_loader):
self._model = model
self._period = eval_period
self._data_loader = data_loader
def _do_loss_eval(self):
# Copying inference_on_dataset from evaluator.py
total = len(self._data_loader)
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_compute_time = 0
losses = []
for idx, inputs in enumerate(self._data_loader):
if idx == num_warmup:
start_time = time.perf_counter()
total_compute_time = 0
start_compute_time = time.perf_counter()
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 5:
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Loss on Validation done {}/{}. {:.4f} s / img. ETA={}".format(
idx + 1, total, seconds_per_img, str(eta)
),
n=5,
)
loss_batch = self._get_loss(inputs)
losses.append(loss_batch)
mean_loss = np.mean(losses)
self.trainer.storage.put_scalar('validation_loss', mean_loss)
comm.synchronize()
return losses
def _get_loss(self, data):
# How loss is calculated on train_loop
metrics_dict = self._model(data)
metrics_dict = {
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
for k, v in metrics_dict.items()
}
total_losses_reduced = sum(loss for loss in metrics_dict.values())
return total_losses_reduced
def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
self._do_loss_eval()
self.trainer.storage.put_scalars(timetest=12)
class MyTrainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
return COCOEvaluator(dataset_name, cfg, True, output_folder)
def build_hooks(self):
hooks = super().build_hooks()
hooks.insert(-1,LossEvalHook(
cfg.TEST.EVAL_PERIOD,
self.model,
build_detection_test_loader(
self.cfg,
self.cfg.DATASETS.TEST[0],
DatasetMapper(self.cfg,True)
)
))
return hooks
import json
import matplotlib.pyplot as plt
experiment_folder = './output/model_iter4000_lr0005_wf1_date2020_03_20__05_16_45'
def load_json_arr(json_path):
lines = []
with open(json_path, 'r') as f:
for line in f:
lines.append(json.loads(line))
return lines
experiment_metrics = load_json_arr(experiment_folder + '/metrics.json')
plt.plot(
[x['iteration'] for x in experiment_metrics],
[x['total_loss'] for x in experiment_metrics])
plt.plot(
[x['iteration'] for x in experiment_metrics if 'validation_loss' in x],
[x['validation_loss'] for x in experiment_metrics if 'validation_loss' in x])
plt.legend(['total_loss', 'validation_loss'], loc='upper left')
plt.show()
@DANISHFAYAZNAJAR
Copy link

The metrics.json is not containing validation_loss. The x in experiment_metrics, contains these:
{
"data_time": 0.07281189849993552,
"eta_seconds": 2525.526994679776,
"fast_rcnn/cls_accuracy": 0.867919921875,
"fast_rcnn/false_negative": 1.0,
"fast_rcnn/fg_cls_accuracy": 0.0,
"iteration": 59,
"loss_box_reg": 0.3603529781103134,
"loss_cls": 1.3637239933013916,
"loss_rpn_cls": 0.3159864991903305,
"loss_rpn_loc": 1.056588351726532,
"lr": 1.4985249999999999e-05,
"roi_head/num_bg_samples": 444.875,
"roi_head/num_fg_samples": 67.125,
"rpn/num_neg_anchors": 128.0,
"rpn/num_pos_anchors": 128.0,
"time": 1.05515398,
"total_loss": 3.0616597086191177
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment