-
Notifications
You must be signed in to change notification settings - Fork 41
Add per variable loss to Stepper and Log using TrainAggs #981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Arcomano1234
wants to merge
23
commits into
main
Choose a base branch
from
feature/per-channel-loss-train-agg
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
62077b9
claude first attempt to add per variable loss to TrainAgg
Arcomano1234 1d1aded
make per channel loss configurable and clean up code
Arcomano1234 7079529
test era5 config
Arcomano1234 ef1b9f6
set per channel loss in TrainAgg to False
Arcomano1234 87af429
update tests
Arcomano1234 6223e0e
incorporate comments
Arcomano1234 35f2f1d
fix doc
Arcomano1234 84211b1
remove coupling of train agg and generic trainer
Arcomano1234 7def3ca
Merge branch 'main' into feature/per-channel-loss-train-agg
Arcomano1234 a67c993
revert changes to ERA5 configs
Arcomano1234 377fdc0
Merge branch 'feature/per-channel-loss-train-agg' of github.com:ai2cm…
Arcomano1234 61811a0
claude attempt for breaking things up to PRs
Arcomano1234 b310017
claude attempt for breaking things up to PRs
Arcomano1234 022dc70
claude losses return 1d vector over channel dim
Arcomano1234 c608a9b
claude refactor to no reduce losses
Arcomano1234 3ad2b01
Merge branch 'main' into feature/losses-return-1d-loss-vector
Arcomano1234 1f75eba
make loss reduction an argument
Arcomano1234 373d6ec
move reduction arg to StepLoss forward
Arcomano1234 f64dc68
address naming comment
Arcomano1234 d236f93
Merge branch 'main' into feature/per-channel-loss-train-agg
Arcomano1234 a78a4a8
Merge branch 'feature/losses-return-1d-loss-vector' into feature/per-…
Arcomano1234 72c0e19
Add newly created regression files
Arcomano1234 f82f599
clean up loss and tests to remove unused code after the loss-return-1…
Arcomano1234 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file modified
BIN
+1.59 KB
(100%)
fme/ace/stepper/testdata/stepper_train_on_batch_regression-False-crps.pt
Binary file not shown.
Binary file modified
BIN
-13.7 KB
(82%)
fme/ace/stepper/testdata/stepper_train_on_batch_regression-False.pt
Binary file not shown.
Binary file modified
BIN
+1.59 KB
(100%)
fme/ace/stepper/testdata/stepper_train_on_batch_regression-True-crps.pt
Binary file not shown.
Binary file modified
BIN
-13.7 KB
(82%)
fme/ace/stepper/testdata/stepper_train_on_batch_regression-True.pt
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not ideal to have this coupling with the naming in the stepper metrics, but this already exists for the the other loss terms so I think it's okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was not a fan of this at all either but Claude and I couldn't think of a good way. I guess the one thing I can do is make this an aggregator it self and decouple anything from the stepper. This would also help reduce the need to record it when we aren't using it during training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem becomes then getting the loss function to the aggregator which has it's own complications. I defer to you or Jeremy on whether its worth decoupling this from the stepper and just pass a loss_fn to an "PerChannelLossAggregator".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, like I said it's a pre-existing issue so I don't think we should worry about decoupling in this PR. But open to other thoughts from @mcgibbon on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest using a new attribute on the TrainOutput instead of a string label.