-
Notifications
You must be signed in to change notification settings - Fork 12
Replace average_on_batch by average_on_trip #299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @michaelteboul, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refines the metric calculation within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the metric classes by renaming average_on_batch to average_on_trip and implementing a more efficient and correct averaging logic using tf.math.unsorted_segment_mean. This is a significant improvement over the previous implementation, which appeared to be buggy and inefficient.
My review includes a few minor suggestions to improve code clarity and consistency. Most importantly, I've noticed that there are no unit tests for the new average_on_trip functionality. I strongly recommend adding tests to verify the correctness of the new implementation for all three metric classes (NegativeLogLikeliHood, MRR, and HitRate).
| epsilon : float, optional | ||
| Lower bound for log(.), by default 1e-10 | ||
| average_on_batch: bool, optional | ||
| average_on_trip: bool, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if batch is not None and self.average_on_batch: | ||
| self.mrr.assign(self.mrr + tf.reduce_mean(mean_rank)) | ||
| self.n_evals.assign(self.n_evals + 1) | ||
| # mean_rank = tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32), axis=self.axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.n_evals.assign_add(tf.cast(tf.shape(unique_trips)[0], self.n_evals.dtype)) | ||
| else: | ||
| self.mrr.assign(self.mrr + tf.reduce_sum(mean_rank)) | ||
| self.mrr.assign(self.mrr + tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the if branch and for better readability, consider using assign_add here. It would also be good to apply the same change to the update of self.n_evals on the next line for consistency.
| self.mrr.assign(self.mrr + tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32))) | |
| self.mrr.assign_add(tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32))) |
Coverage Report for Python 3.10
|
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Coverage Report for Python 3.11
|
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Coverage Report for Python 3.12
|
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Description of the goal of the PR
Description:
Changes this PR introduces (fill it before implementation)
Checklist before requesting a review