Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, **kwargs):
self.log_every: int = kwargs.get('log_every', 100)
self.verbose: bool = kwargs.get('verbose', False)
self.use_wandb: bool = kwargs.get('use_wandb', False)
self.use_comet: bool = kwargs.get("use_comet", False)
self.use_ui_logger: bool = kwargs.get('use_ui_logger', False)
self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
self.run_name: str = kwargs.get('run_name', None)
Expand Down
58 changes: 58 additions & 0 deletions toolkit/logging_aitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,60 @@ def finish(self):
self.run.finish()


# Comet logger class
# This class logs the data to comet_ml
class CometLogger(EmptyLogger):
def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None:
self.project = project
self.run_name = run_name
self.config = config

def start(self):
try:
import comet_ml
except ImportError:
raise ImportError(
"Failed to import comet_ml. Please install comet_ml by running `pip install comet_ml`"
)

# Configure comet_ml by running `comet login`
# This will create a .comet.config file in your home directory
# Add `workspace=your_workspace` to the .comet.config file if you want to specify a default workspace
experiment = comet_ml.Experiment(
project_name=self.project,
log_code=False, # Disables automatic code logging
)
experiment.set_name(self.run_name) # Set the experiment name

if self.run_name is not None:
experiment.set_name(self.run_name)
experiment.log_parameters(self.config) # Report hyperparameters
self.experiment = experiment
self._metrics = {}

def log(self, metric_dict: Dict[str, Any]):
self._metrics.update(metric_dict)

def log_image(
self,
image: Image,
id, # sample index
caption: str = "", # positive prompt
*args,
**kwargs,
):
# log image to comet_ml
metadata = {"caption": caption}
self.experiment.log_image(image_data=image, name=f"sample_{id}", metadata=metadata)

def commit(self, step: Optional[int] = None):
self.experiment.log_metrics(self._metrics, step=step)
self._metrics = {}

def finish(self):
self.experiment.end()


class UILogger:
def __init__(
self,
Expand Down Expand Up @@ -307,6 +361,10 @@ def create_logger(
project_name = logging_config.project_name
run_name = logging_config.run_name
return WandbLogger(project=project_name, run_name=run_name, config=all_config)
elif logging_config.use_comet:
project_name = logging_config.project_name
run_name = logging_config.run_name
return CometLogger(project=project_name, run_name=run_name, config=all_config)
elif logging_config.use_ui_logger:
if save_root is None:
raise ValueError("save_root must be provided when using UILogger")
Expand Down