Skip to content

Conversation

@Atharva9621
Copy link

@Atharva9621 Atharva9621 commented Oct 20, 2025

Add support for custom loss and metrics in model_sweep

Fixes #544

  • Custom loss, metrics, and optimizers can now be passed to model_sweep in the same way as tabular_model.fit() through custom_fit_params.
  • custom_fit_params expects a dictionary specifying the custom loss, metrics, or optimizer.
  • Minimal code changes; fully backward compatible.
  • Updated corresponding tests.

Example usage

class CustomLoss(nn.Module):
      def __init__(self):
          super(CustomLoss, self).__init__()
  
      def forward(self, inputs, targets):
          loss = torch.mean((inputs - targets) ** 4)
          return 100*loss.mean()

def custom_metric(y_hat, y):
    return (y_hat - y).mean()

sweep_df, best_model = model_sweep(
    task="regression",
    train=train,
    test=val,
    data_config=data_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    model_list="lite",
    custom_fit_params = {
        "loss": CustomLoss(),
        "metrics": [custom_metric],
        "metrics_prob_inputs": [True],
        "optimizer": torch.optim.Adagrad,
    }
)

📚 Documentation preview 📚: https://pytorch-tabular--587.org.readthedocs.build/en/587/

@dosubot dosubot bot added size:M This PR changes 30-99 lines, ignoring generated files. enhancement New feature or request labels Oct 20, 2025
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for custom loss functions, metrics, and optimizers to the model_sweep function, making it consistent with the TabularModel.fit() API.

Key Changes

  • Added custom_fit_params parameter to model_sweep function that accepts custom loss, metrics, and optimizer specifications
  • Updated validation logic to ensure rank_metric is "loss" when custom metrics are provided
  • Enhanced test coverage with a new test case (test_model_compare_custom) demonstrating custom fit parameters

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
src/pytorch_tabular/tabular_model_sweep.py Added custom_fit_params parameter to model_sweep and _validate_args, with validation logic and documentation; unpacks params when calling prepare_model
tests/test_common.py Updated _run_model_compare to accept and forward custom_fit_params; added new test case with custom loss, metrics, and optimizer

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Atharva9621 and others added 3 commits December 21, 2025 19:03
fix typo

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
fix test assertion

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
fix test assertion

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request size:M This PR changes 30-99 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Help: custom loss for model_sweep

1 participant