-
-
Save hadilou/186ef529520c8a72116d971cd5b6920e to your computer and use it in GitHub Desktop.
Trainer with Loss on Validation for Detectron2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from detectron2.engine.hooks import HookBase | |
from detectron2.evaluation import inference_context | |
from detectron2.utils.logger import log_every_n_seconds | |
from detectron2.data import DatasetMapper, build_detection_test_loader | |
import detectron2.utils.comm as comm | |
import torch | |
import time | |
import datetime | |
class LossEvalHook(HookBase): | |
def __init__(self, eval_period, model, data_loader): | |
self._model = model | |
self._period = eval_period | |
self._data_loader = data_loader | |
def _do_loss_eval(self): | |
# Copying inference_on_dataset from evaluator.py | |
total = len(self._data_loader) | |
num_warmup = min(5, total - 1) | |
start_time = time.perf_counter() | |
total_compute_time = 0 | |
losses = [] | |
for idx, inputs in enumerate(self._data_loader): | |
if idx == num_warmup: | |
start_time = time.perf_counter() | |
total_compute_time = 0 | |
start_compute_time = time.perf_counter() | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
total_compute_time += time.perf_counter() - start_compute_time | |
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) | |
seconds_per_img = total_compute_time / iters_after_start | |
if idx >= num_warmup * 2 or seconds_per_img > 5: | |
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start | |
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1))) | |
log_every_n_seconds( | |
logging.INFO, | |
"Loss on Validation done {}/{}. {:.4f} s / img. ETA={}".format( | |
idx + 1, total, seconds_per_img, str(eta) | |
), | |
n=5, | |
) | |
loss_batch = self._get_loss(inputs) | |
losses.append(loss_batch) | |
mean_loss = np.mean(losses) | |
self.trainer.storage.put_scalar('validation_loss', mean_loss) | |
comm.synchronize() | |
return losses | |
def _get_loss(self, data): | |
# How loss is calculated on train_loop | |
metrics_dict = self._model(data) | |
metrics_dict = { | |
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v) | |
for k, v in metrics_dict.items() | |
} | |
total_losses_reduced = sum(loss for loss in metrics_dict.values()) | |
return total_losses_reduced | |
def after_step(self): | |
next_iter = self.trainer.iter + 1 | |
is_final = next_iter == self.trainer.max_iter | |
if is_final or (self._period > 0 and next_iter % self._period == 0): | |
self._do_loss_eval() | |
self.trainer.storage.put_scalars(timetest=12) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MyTrainer(DefaultTrainer): | |
@classmethod | |
def build_evaluator(cls, cfg, dataset_name, output_folder=None): | |
if output_folder is None: | |
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") | |
return COCOEvaluator(dataset_name, cfg, True, output_folder) | |
def build_hooks(self): | |
hooks = super().build_hooks() | |
hooks.insert(-1,LossEvalHook( | |
cfg.TEST.EVAL_PERIOD, | |
self.model, | |
build_detection_test_loader( | |
self.cfg, | |
self.cfg.DATASETS.TEST[0], | |
DatasetMapper(self.cfg,True) | |
) | |
)) | |
return hooks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import matplotlib.pyplot as plt | |
experiment_folder = './output/model_iter4000_lr0005_wf1_date2020_03_20__05_16_45' | |
def load_json_arr(json_path): | |
lines = [] | |
with open(json_path, 'r') as f: | |
for line in f: | |
lines.append(json.loads(line)) | |
return lines | |
experiment_metrics = load_json_arr(experiment_folder + '/metrics.json') | |
plt.plot( | |
[x['iteration'] for x in experiment_metrics], | |
[x['total_loss'] for x in experiment_metrics]) | |
plt.plot( | |
[x['iteration'] for x in experiment_metrics if 'validation_loss' in x], | |
[x['validation_loss'] for x in experiment_metrics if 'validation_loss' in x]) | |
plt.legend(['total_loss', 'validation_loss'], loc='upper left') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The metrics.json is not containing validation_loss. The x in experiment_metrics, contains these:
{
"data_time": 0.07281189849993552,
"eta_seconds": 2525.526994679776,
"fast_rcnn/cls_accuracy": 0.867919921875,
"fast_rcnn/false_negative": 1.0,
"fast_rcnn/fg_cls_accuracy": 0.0,
"iteration": 59,
"loss_box_reg": 0.3603529781103134,
"loss_cls": 1.3637239933013916,
"loss_rpn_cls": 0.3159864991903305,
"loss_rpn_loc": 1.056588351726532,
"lr": 1.4985249999999999e-05,
"roi_head/num_bg_samples": 444.875,
"roi_head/num_fg_samples": 67.125,
"rpn/num_neg_anchors": 128.0,
"rpn/num_pos_anchors": 128.0,
"time": 1.05515398,
"total_loss": 3.0616597086191177
}