Skip to content

Conversation

@VincentAuriau
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @VincentAuriau, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the choice_learn library by introducing the capability to use sample_weight in both model evaluation and validation processes. The core purpose is to provide more granular control over how individual data points influence metric calculations, which is particularly useful for complex scenarios like 'tripwise metrics'. These changes improve the flexibility and accuracy of model training and assessment, allowing users to define the importance of each sample and manage validation frequency more effectively.

Highlights

  • Weighted Metrics Evaluation: The evaluate method in BaseBasketModel now supports a sample_weight argument, allowing for weighted calculation of metrics during model evaluation. This ensures that individual samples can contribute differently to the overall metric.
  • Flexible Validation in Training: The fit method in ChoiceModel has been enhanced to include a validation_freq parameter, controlling how often validation runs during training. Additionally, the val_dataset argument can now accept a tuple containing a ChoiceDataset and corresponding sample weights, enabling weighted validation.
  • Data Batching with Sample Weights: The iter_batch_evaluate function in BasketDataset has been updated to yield weights_batch instead of identifier_batch. These weights are calculated to normalize the contribution of each item within a trip, ensuring fair representation in evaluation.
  • New Unit Test for Weighted Validation: A new unit test, test_weighted_val_dataset, has been added to test_simplemnl.py to specifically verify the correct functionality of weighted validation datasets during model fitting.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for sample_weight in trip-wise metrics and for validation datasets, which is a valuable enhancement. The implementation is generally well-done, with corresponding updates to the evaluate and fit methods, and a new test case to validate the changes. However, I've identified a critical bug in choice_learn/basket_models/data/basket_dataset.py that could lead to a ZeroDivisionError when calculating weights. Please address this issue.

trip_identifier.extend(
[np.max(trip_identifier, initial=-1) + 1] * len(additional_trip_data[0])
)
weights.extend([1 / additional_trip_data[0]] * len(additional_trip_data[0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential division-by-zero error here. additional_trip_data[0] is a NumPy array of item IDs. The expression 1 / additional_trip_data[0] performs element-wise division. If any item ID in additional_trip_data[0] is 0 (which is common, especially for the checkout item), this will raise a ZeroDivisionError.

Based on the logic in the if trip_batch_size == -1: block (line 799), it seems the intention is to use the number of items to calculate the weight. The line should likely be weights.extend([1 / len(additional_trip_data[0])] * len(additional_trip_data[0])).

Suggested change
weights.extend([1 / additional_trip_data[0]] * len(additional_trip_data[0]))
weights.extend([1 / len(additional_trip_data[0])] * len(additional_trip_data[0]))

@github-actions
Copy link
Contributor

github-actions bot commented Dec 15, 2025

Coverage

Coverage Report for Python 3.9
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2372789%111–112, 123, 141, 185, 255, 377, 485, 585–587, 676, 762, 772, 822–830, 890–893, 941–942
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1923084%74–77, 295–297, 407, 540–576, 636, 658–661, 700–705, 791–802, 852
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   bakery.py38392%47, 51, 61
   synthetic_dataset.py81693%62, 194–199, 247
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3252792%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2362360%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py784147%70, 116–119, 136–157, 167, 180–188, 200–224, 234
TOTAL563184185% 

Tests Skipped Failures Errors Time
222 0 💤 0 ❌ 0 🔥 6m 41s ⏱️

@github-actions
Copy link
Contributor

github-actions bot commented Dec 15, 2025

Coverage

Coverage Report for Python 3.10
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2372789%111–112, 123, 141, 185, 255, 377, 485, 585–587, 676, 762, 772, 822–830, 890–893, 941–942
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1923084%74–77, 295–297, 407, 540–576, 636, 658–661, 700–705, 791–802, 852
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   bakery.py38392%47, 51, 61
   synthetic_dataset.py81693%62, 194–199, 247
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3252792%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py1241885%186, 341, 360, 364–380
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py784147%70, 116–119, 136–157, 167, 180–188, 200–224, 234
TOTAL563385985% 

Tests Skipped Failures Errors Time
222 0 💤 1 ❌ 0 🔥 6m 58s ⏱️

@github-actions
Copy link
Contributor

github-actions bot commented Dec 15, 2025

Coverage

Coverage Report for Python 3.11
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2372789%111–112, 123, 141, 185, 255, 377, 485, 585–587, 676, 762, 772, 822–830, 890–893, 941–942
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1923084%74–77, 295–297, 407, 540–576, 636, 658–661, 700–705, 791–802, 852
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   bakery.py38392%47, 51, 61
   synthetic_dataset.py81693%62, 194–199, 247
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3252891%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 908
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py784147%70, 116–119, 136–157, 167, 180–188, 200–224, 234
TOTAL563384485% 

Tests Skipped Failures Errors Time
222 0 💤 0 ❌ 0 🔥 6m 5s ⏱️

@github-actions
Copy link
Contributor

github-actions bot commented Dec 15, 2025

Coverage

Coverage Report for Python 3.12
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2372789%111–112, 123, 141, 185, 255, 377, 485, 585–587, 676, 762, 772, 822–830, 890–893, 941–942
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1923084%74–77, 295–297, 407, 540–576, 636, 658–661, 700–705, 791–802, 852
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   bakery.py38392%47, 53, 61
   synthetic_dataset.py81693%62, 194–199, 247
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3252792%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py784147%70, 116–119, 136–157, 167, 180–188, 200–224, 234
TOTAL563384385% 

Tests Skipped Failures Errors Time
222 0 💤 0 ❌ 0 🔥 7m 35s ⏱️

@VincentAuriau VincentAuriau linked an issue Dec 18, 2025 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Compute basket-wise metrics using sample_weight

2 participants