Add final_log_prob return value to AdamOptimization.optimize method#229
Conversation
|
""" WalkthroughThe changes add a Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant AdamOptimization
participant ObjectiveFunction
User->>AdamOptimization: call optimize(rng_key, objective, initial_position, data)
AdamOptimization->>ObjectiveFunction: compute gradients (params, data)
AdamOptimization->>AdamOptimization: project params within bounds
AdamOptimization-->>User: return (rng_key, optimized_positions, final_log_prob)
User->>AdamOptimization: call __call__()
AdamOptimization->>User: unpack (rng_key, optimized_positions, _)
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
src/flowMC/strategy/optimization.py (1)
79-88: 🛠️ Refactor suggestionUpdate docstring to reflect new return value.
The docstring for the
optimizemethod should be updated to describe the new return value (final_log_prob).def optimize( self, rng_key: PRNGKeyArray, objective: Callable, initial_position: Float[Array, " n_chain n_dim"], data: dict, ): """Optimization kernel. This can be used independently of the __call__ method. Args: rng_key: PRNGKeyArray Random key for the optimization. objective: Callable Objective function to optimize. initial_position: Float[Array, " n_chain n_dim"] Initial positions for the optimization. + data: dict + Data dictionary to be passed to objective function. + + Returns: + rng_key: PRNGKeyArray + Updated random key. + optimized_positions: Float[Array, " n_chain n_dim"] + Optimized positions after optimization. + final_log_prob: Float[Array, " n_chains"] + Final log probabilities of the optimized positions. """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/flowMC/strategy/optimization.py(5 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/flowMC/strategy/optimization.py (2)
test/unit/test_strategies.py (1)
loss_fn(64-65)test/unit/test_bundle.py (1)
logpdf(6-7)
🪛 GitHub Actions: pre-commit
src/flowMC/strategy/optimization.py
[error] 92-94: Prettier formatting check failed. File was reformatted by black. Run 'black src/flowMC/strategy/optimization.py' to fix code style issues.
🔇 Additional comments (6)
src/flowMC/strategy/optimization.py (6)
29-29: Type annotation formatting updated.The type annotation for bounds has been properly updated to include a leading space inside the dimension specifier, enhancing consistency with other annotations in the codebase.
40-40: Type annotation formatting updated.The type annotation for the bounds parameter in the constructor has been properly updated to match the class member annotation.
61-61: Type annotation formatting updated.The return type annotation has been consistently updated with the space inside dimension specifier.
66-66: Updated return value handling to match signature.The code now correctly unpacks the third return value (final log probabilities) from
optimizeusing an underscore to indicate it's being ignored in this context.
131-131: Added final log probability to optimize method return value.The method now returns a tuple that includes the final log probabilities of the optimized positions, which is the main objective of this PR. This additional return value will be useful for external callers that need to evaluate the quality of the optimization.
92-94:Details
❓ Verification inconclusive
Fix formatting according to pipeline failure.
There's a pre-commit check failure related to formatting. Please run
black src/flowMC/strategy/optimization.pyto fix the code style issues.
🏁 Script executed:
#!/bin/bash # Show the exact formatting issues black --diff src/flowMC/strategy/optimization.pyLength of output: 146
Formatting error in CI – please format with Black
It looks like the pipeline failed on code style forsrc/flowMC/strategy/optimization.py. Ensure Black is installed locally and run:black src/flowMC/strategy/optimization.pyor, if you’re using pre-commit:
pre-commit run --all-filesThen commit the updated file and re-run the CI checks.
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 92-94: Prettier formatting check failed. File was reformatted by black. Run 'black src/flowMC/strategy/optimization.py' to fix code style issues.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
test/unit/test_strategies.py (1)
67-69: Consider adding test assertions for the new return valueWhile correctly handling the return values, this test doesn't verify the validity of the newly added
final_log_probreturn value. Consider adding assertions to validate that the log probabilities are calculated correctly.- rng_key, optimized_position, _ = self.strategy.optimize( + rng_key, optimized_position, final_log_prob = self.strategy.optimize( key, loss_fn, initial_position, {"data": jnp.arange(self.n_dim)} ) + + # Assert final_log_prob shape matches expectation + assert final_log_prob.shape == (self.n_chains,) + # Assert log probabilities are finite (not NaN or Inf) + assert jnp.all(jnp.isfinite(final_log_prob))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
test/unit/test_strategies.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
test/unit/test_strategies.py (1)
src/flowMC/strategy/optimization.py (1)
optimize(72-133)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: build (3.12)
- GitHub Check: pre-commit (3.12)
- GitHub Check: build (3.11)
🔇 Additional comments (1)
test/unit/test_strategies.py (1)
67-67:Details
✅ Verification successful
Properly handles the new return signature
The change correctly unpacks the three values now returned by
optimize:rng_key,optimized_position, and the newfinal_log_prob(ignored with_). This matches the update to theoptimizemethod insrc/flowMC/strategy/optimization.py.
🏁 Script executed:
#!/bin/bash # Check if the optimize method is now consistently used with three return values throughout the codebase grep -n "optimize(" --include="*.py" -r . | grep -v "def optimize"Length of output: 411
Optimize return signature is handled consistently
All occurrences of
optimize(now unpack three return values (rng_key,optimized_position(s), and the ignoredfinal_log_prob), both in
src/flowMC/strategy/optimization.py:66test/unit/test_strategies.py:67No further changes needed.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
src/flowMC/strategy/optimization.py (2)
60-64: Inconsistent dimension labels in jaxtyping annotationsYou use
n_chainelsewhere but returnFloat[Array, " n_chains n_dim"]here (extra “s”).
Although jaxtyping ignores the string content at runtime, keeping labels consistent greatly improves readability and static checking.Consider renaming to
n_chain(singular) everywhere or adoptingn_chainsconsistently.
23-24: Defaultboundsshape may surprise usersThe default
jnp.array([[-jnp.inf, jnp.inf]])has shape(1, 2)which broadcasts, but it silently applies the same bound to every dimension.
If a user passesn_dim > 1without realising this broadcast they may get unintended behaviour.Suggestions:
- Document this broadcasting explicitly in the docstring.
- Validate the provided
boundsagainstinitial_position.shape[-1]and raise if the shapes are incompatible.Also applies to: 31-32, 42-42
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/flowMC/strategy/optimization.py(5 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/flowMC/strategy/optimization.py (3)
test/unit/test_bundle.py (1)
logpdf(6-7)test/unit/test_strategies.py (1)
loss_fn(64-65)src/flowMC/resource/nf_model/base.py (1)
loss_fn(99-100)
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
src/flowMC/strategy/optimization.py (3)
74-75: Fixed loss_fn signatureThe loss function now correctly accepts the data parameter, addressing a previous issue noted in the past review comments.
120-120: Fixed parameter naming to prevent shadowingChanging the scan body parameter from
datato_stepremoves the potential for shadowing the outerdatavariable, addressing a bug noted in previous reviews.
124-126: Correct gradient function usage with data parameterThe gradient function now correctly passes the data parameter, fixing the issue identified in previous reviews.
🧹 Nitpick comments (2)
src/flowMC/strategy/optimization.py (2)
77-79: Consider exposing final log probabilitiesThe
__call__method discards the final log probabilities returned byoptimize. Consider whether this information should be exposed to callers of__call__for consistency with the updatedoptimizemethod.- rng_key, optimized_positions, _ = self.optimize( + rng_key, optimized_positions, final_log_prob = self.optimize( rng_key, loss_fn, initial_position, data ) - return rng_key, resources, optimized_positions + return rng_key, resources, optimized_positions, final_log_probThis would require updating the return type annotation as well.
148-148: Consider using a logging system instead of printUsing a print statement for indicating the optimization method could be disruptive in some contexts (e.g., notebooks, UIs). Consider using a proper logging system that can be configured by the user.
- print("Using Adam optimization") + import logging + logging.info("Using Adam optimization")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/flowMC/strategy/optimization.py(6 hunks)
🧰 Additional context used
🪛 GitHub Actions: pre-commit
src/flowMC/strategy/optimization.py
[error] 52-54: Black formatting check failed. The file was reformatted by the black hook. Run 'black src/flowMC/strategy/optimization.py' to apply formatting locally.
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: build (3.12)
- GitHub Check: build (3.11)
🔇 Additional comments (10)
src/flowMC/strategy/optimization.py (10)
23-27: Well-documented bounds parameterThe docstring clearly explains the purpose and expected format of the bounds parameter, including how different shapes are handled for broadcasting or per-dimension constraints. This is helpful for users who need to enforce box constraints during optimization.
34-34: Good default for backward compatibilitySetting the default bounds to
[-jnp.inf, jnp.inf]maintains backward compatibility with code that doesn't specify bounds, ensuring unbounded optimization remains the default behavior.
45-45: Proper parameter addition in constructor signatureThe bounds parameter is correctly added to the constructor with the same default value as the class attribute.
53-58: Early validation prevents subtle errorsGood practice to validate the bounds shape early in the constructor. This prevents cryptic errors that might occur later during optimization.
90-96: Robust validation at runtimeThe additional validation in
optimizeis important as it ensures that the bounds are compatible with the actual dimension of the initial positions. This serves as a safeguard against potential shape mismatches.
107-117: Clear and comprehensive docstring updateThe docstring for
optimizehas been appropriately updated to document the newdataparameter and thefinal_log_probreturn value. This ensures developers understand the updated function signature.
129-131: Good implementation of bounds projectionUsing
optax.projections.projection_boxis an appropriate way to enforce the box constraints during optimization. This ensures parameters stay within bounds after each update step.
155-161: Good validation of optimization resultsComputing and checking the final log probabilities for NaN or infinity values is a helpful diagnostic feature that can alert users to potential issues with their optimization.
162-162: Updated return signature to include final log probabilitiesThe method now correctly returns the final log probabilities, which provides valuable information to users about the quality of the optimization results.
1-163: Formatting issue flagged by CIThe Black formatter has detected formatting issues. Please run the suggested command to fix them:
black src/flowMC/strategy/optimization.py🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 52-54: Black formatting check failed. The file was reformatted by the black hook. Run 'black src/flowMC/strategy/optimization.py' to apply formatting locally.
| Float[Array, " n_chain n_dim"], | ||
| ]: | ||
| def loss_fn(params: Float[Array, " n_dim"]) -> Float: | ||
| def loss_fn(params: Float[Array, " n_dim"], data: dict) -> Float: |
There was a problem hiding this comment.
@kazewong This may affect performance, but removing it will cause an issue when calling below.
|
@kazewong Ready for review. The tests fails with something related to the GH token, probably just because I am running it on my own fork. |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
Summary by CodeRabbit