Source code for ml_logger.logger.tensorboard

"""Logger class that writes to tensorboard."""

from typing import Dict

from tensorboardX import SummaryWriter

from ml_logger.logger.base import Logger as BaseLogger
from ml_logger.types import ConfigType, LogType, MetricType, NumType
from ml_logger.utils import flatten_dict, make_dir


[docs]class Logger(BaseLogger): """Logger class that writes to tensorboardX.""" def __init__(self, config: ConfigType): """Initialise the tensorboardX Logger. Args: config (ConfigType): config to initialise the tensorboardX logger. The config can have any parameters that tensorboardX.SummaryWriter() method accepts (https://tensorboardx.readthedocs.io/en/latest/tensorboard.html#tensorboardX.SummaryWriter). Note that the config is passed as keyword arguments to the tensorboardX.SummaryWriter() method. This provides a lot of flexibility to the users to configure tensorboard. This also means that config should not have any parameters that tensorboardX.SummaryWriter() would not accept. """ super().__init__(config=config) key = "logdir" if key in config and config[key] is not None: make_dir(config[key]) self.summary_writer = SummaryWriter(**config) self.keys_to_skip = ["logbook_id", "logbook_type", "logbook_timestamp"]
[docs] def write(self, log: LogType) -> None: """Write the log to tensorboard. Args: log (LogType): Log to write """ logbook_type = log["logbook_type"] if logbook_type == "metric": log = self._prepare_metric_log_to_write(log=log) self.write_metric(metric=log) else: if logbook_type == "config": self.write_config(config=log)
# Only metric logs and message logs are supported right now
[docs] def write_metric(self, metric: MetricType) -> None: """Write metric to tensorboard. Args: metric (MetricType): Metric to write """ global_step = None if "global_step" in metric: global_step = metric.pop("global_step") walltime = None if "walltime" in metric: walltime = metric.pop("walltime") main_tag = "" if "tag" in metric: main_tag = str(metric.pop("tag")) + "/" elif "main_tag" in metric: main_tag = str(metric.pop("main_tag")) + "/" if self.key_prefix: prefix = {metric.pop(self.key_prefix)} metric = {f"{prefix}_{key}": value for key, value in metric.items()} for key, value in metric.items(): self.summary_writer.add_scalar( tag=f"{main_tag}{key}", scalar_value=value, global_step=global_step, walltime=walltime, )
[docs] def write_config(self, config: ConfigType) -> None: """Write the config to tensorboard. Args: config (ConfigType): Config to write """ name = None if "name" in config: name = config.pop("name") metric_dict: Dict[str, NumType] = {} if "metric_dict" in config: metric_dict = config.pop("metric_dict") metric_dict = self._prepare_metric_log_to_write(log=metric_dict) global_step = None if "global_step" in config: global_step = config.pop("global_step") config = self._prepare_log_to_write(log=config) for key in config: if config[key] is None: config[key] = "None" self.summary_writer.add_hparams( hparam_dict=flatten_dict(config), metric_dict=metric_dict, name=name, global_step=global_step, )