From b08e14bc5dff153323fe7895c9741b0394d2c99e Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 12 Jul 2025 09:39:36 -0400 Subject: [PATCH 01/16] Clean up formatting and remove unused import Add placeholder unit test for likelihood module --- src/jimgw/core/single_event/detector.py | 7 ++----- src/jimgw/core/single_event/utils.py | 1 - test/unit/test_likelhood.py | 1 + 3 files changed, 3 insertions(+), 6 deletions(-) create mode 100644 test/unit/test_likelhood.py diff --git a/src/jimgw/core/single_event/detector.py b/src/jimgw/core/single_event/detector.py index 92d11958c..3eb9d93cb 100644 --- a/src/jimgw/core/single_event/detector.py +++ b/src/jimgw/core/single_event/detector.py @@ -631,16 +631,13 @@ def inject_signal( self.set_frequency_bounds() masked_signal = projected_strain[self.frequency_mask] - df = self.sliced_frequencies[1] - self.sliced_frequencies[0] + df = self.sliced_frequencies[1] - self.sliced_frequencies[0] _optimal_snr_sq = inner_product( masked_signal, masked_signal, self.sliced_psd, df ) optimal_snr = _optimal_snr_sq**0.5 match_filtered_snr = complex_inner_product( - masked_signal, - self.sliced_fd_data, - self.sliced_psd, - df + masked_signal, self.sliced_fd_data, self.sliced_psd, df ) match_filtered_snr /= optimal_snr diff --git a/src/jimgw/core/single_event/utils.py b/src/jimgw/core/single_event/utils.py index 2b7be90de..6975afbc7 100644 --- a/src/jimgw/core/single_event/utils.py +++ b/src/jimgw/core/single_event/utils.py @@ -1,6 +1,5 @@ import jax.numpy as jnp from jaxtyping import Array, Float, Complex -from typing import Optional from jimgw.core.constants import MTSUN from jimgw.core.utils import safe_arctan2, carte_to_spherical_angles diff --git a/test/unit/test_likelhood.py b/test/unit/test_likelhood.py new file mode 100644 index 000000000..5871ed8ee --- /dev/null +++ b/test/unit/test_likelhood.py @@ -0,0 +1 @@ +import pytest From 0793099dfd9688a361faeb4d1e2b5aa585cd8b48 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 12 Jul 2025 12:35:30 -0400 Subject: [PATCH 02/16] Refactor code for readability and consistency - Reformat long function calls and asset specs for clarity - Use consistent string quoting and indentation - Move assert error messages to multi-line format for readability - Add properties to SingleEventLikelihood for duration and detector names - Refactor likelihood classes for modularity and clarity --- jim_dagster/InjectionRecovery/assets.py | 176 ++++++++++++++++++---- jim_dagster/RealDataCatalog/assets.py | 142 ++++++++++++++--- src/jimgw/core/prior.py | 12 +- src/jimgw/core/single_event/data.py | 24 +-- src/jimgw/core/single_event/detector.py | 6 +- src/jimgw/core/single_event/likelihood.py | 158 ++++++++++++++++++- src/jimgw/run/cli/execute_single_run.py | 6 +- src/jimgw/run/run_manager.py | 6 +- 8 files changed, 451 insertions(+), 79 deletions(-) diff --git a/jim_dagster/InjectionRecovery/assets.py b/jim_dagster/InjectionRecovery/assets.py index becfe7107..60e8d825c 100644 --- a/jim_dagster/InjectionRecovery/assets.py +++ b/jim_dagster/InjectionRecovery/assets.py @@ -19,6 +19,7 @@ # Sample a fiducial population + @dg.asset( group_name="prerun", key_prefix="InjectionRecovery", @@ -32,10 +33,12 @@ def sample_population(): path_prefix="./data/", ) + # TODO: Add diagnostics regarding the sampled population. # Create asset group for run and configuration + @dg.asset( group_name="prerun", description="Configuration file for the run.", @@ -76,10 +79,17 @@ def config_file(): run.local_data_prefix = f"./data/runs/{idx}/strains/" run.serialize(f"./data/runs/{idx}/config.yaml") + @dg.multi_asset( specs=[ - dg.AssetSpec(key=["InjectionRecovery", "strain"], deps=[["InjectionRecovery", "config_file"]]), - dg.AssetSpec(key=["InjectionRecovery", "psd"], deps=[["InjectionRecovery", "config_file"]]), + dg.AssetSpec( + key=["InjectionRecovery", "strain"], + deps=[["InjectionRecovery", "config_file"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "psd"], + deps=[["InjectionRecovery", "config_file"]], + ), ], group_name="prerun", ) @@ -122,19 +132,53 @@ def raw_data(): detector.data.to_file(f"./data/runs/{idx}/strains/{ifo}_data") detector.psd.to_file(f"./data/runs/{idx}/strains/{ifo}_psd") + @dg.multi_asset( specs=[ - dg.AssetSpec(key=["InjectionRecovery", "training_chains"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "training_log_prob"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "training_local_acceptance"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "training_global_acceptance"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "training_loss"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "production_chains"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "production_log_prob"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "production_local_acceptance"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "production_global_acceptance"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "auxiliary_nf_samples"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "auxiliary_prior_samples"], deps=[["InjectionRecovery", "raw_data"]]), + dg.AssetSpec( + key=["InjectionRecovery", "training_chains"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "training_log_prob"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "training_local_acceptance"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "training_global_acceptance"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "training_loss"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "production_chains"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "production_log_prob"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "production_local_acceptance"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "production_global_acceptance"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "auxiliary_nf_samples"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "auxiliary_prior_samples"], + deps=[["InjectionRecovery", "raw_data"]], + ), ], group_name="run", ) @@ -145,64 +189,140 @@ def run(): """ pass + # Create asset group for diagnostics -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_loss"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_loss"]], + key_prefix="InjectionRecovery", +) def loss_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_chains"]], + key_prefix="InjectionRecovery", +) def training_chains_corner_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_chains"]], + key_prefix="InjectionRecovery", +) def training_chains_trace_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_chains"]], + key_prefix="InjectionRecovery", +) def training_chains_rhat_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_log_prob"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_log_prob"]], + key_prefix="InjectionRecovery", +) def training_log_prob_distribution(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_log_prob"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_log_prob"]], + key_prefix="InjectionRecovery", +) def training_log_prob_evolution(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_local_acceptance"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_local_acceptance"]], + key_prefix="InjectionRecovery", +) def training_local_acceptance_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_global_acceptance"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_global_acceptance"]], + key_prefix="InjectionRecovery", +) def training_global_acceptance_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_chains"]], + key_prefix="InjectionRecovery", +) def production_chains_corner_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_chains"]], + key_prefix="InjectionRecovery", +) def production_chains_trace_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_chains"]], + key_prefix="InjectionRecovery", +) def production_chains_rhat_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_log_prob"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_log_prob"]], + key_prefix="InjectionRecovery", +) def production_log_prob_distribution(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_log_prob"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_log_prob"]], + key_prefix="InjectionRecovery", +) def production_log_prob_evolution(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_local_acceptance"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_local_acceptance"]], + key_prefix="InjectionRecovery", +) def production_local_acceptance_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_global_acceptance"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_global_acceptance"]], + key_prefix="InjectionRecovery", +) def production_global_acceptance_plot(): pass diff --git a/jim_dagster/RealDataCatalog/assets.py b/jim_dagster/RealDataCatalog/assets.py index 7e2bc40a4..d35c1029c 100644 --- a/jim_dagster/RealDataCatalog/assets.py +++ b/jim_dagster/RealDataCatalog/assets.py @@ -12,20 +12,21 @@ event_partitions_def = DynamicPartitionsDefinition(name="event_name") + @dg.asset( key_prefix="RealDataCatalog", group_name="prerun", description="Fetch all confident events and their gps time", ) def event_list(context: AssetExecutionContext): - catalogs = ['GWTC-1-confident', 'GWTC-2.1-confident', 'GWTC-3-confident'] + catalogs = ["GWTC-1-confident", "GWTC-2.1-confident", "GWTC-3-confident"] result = [] event_names = [] for catalog in catalogs: - event_list = gwosc.api.fetch_catalog_json(catalog)['events'] + event_list = gwosc.api.fetch_catalog_json(catalog)["events"] for event in event_list.values(): - name = event['commonName'] - gps_time = event['GPS'] + name = event["commonName"] + gps_time = event["GPS"] result.append((name, gps_time)) event_names.append(name) os.makedirs("data", exist_ok=True) @@ -39,7 +40,7 @@ def event_list(context: AssetExecutionContext): # We should be able to partition this asset and run it in parallel for each event. @dg.multi_asset( specs=[ - dg.AssetSpec(["RealDataCatalog","strain"], deps=[event_list]), + dg.AssetSpec(["RealDataCatalog", "strain"], deps=[event_list]), dg.AssetSpec(["RealDataCatalog", "psd"], deps=[event_list]), ], group_name="prerun", @@ -61,7 +62,9 @@ def raw_data(context: AssetExecutionContext): data = Data.from_gwosc(ifo, start, end) data.to_file(os.path.join(event_dir, f"{ifo}_data")) # TODO: Perhaps we should make sure the PSD estimation window are the same accross all IFOs? - psd_data = Data.from_gwosc(ifo, start - 4098, end -2) # This needs to be changed at some point + psd_data = Data.from_gwosc( + ifo, start - 4098, end - 2 + ) # This needs to be changed at some point if np.isnan(psd_data.td).any(): psd_data = Data.from_gwosc(ifo, start + 2, end + 4098) if np.isnan(psd_data.td).any(): @@ -91,6 +94,7 @@ def raw_data_plot(context: AssetExecutionContext): Plot the raw strain data for each IFO for the event. """ import matplotlib.pyplot as plt + event_name = context.partition_key event_dir = os.path.join("data", event_name, "raw") plots_dir = os.path.join("data", event_name, "plots") @@ -101,7 +105,7 @@ def raw_data_plot(context: AssetExecutionContext): data_file = os.path.join(event_dir, f"{ifo}_data.npz") if os.path.exists(data_file): data = np.load(data_file) - t = data["epoch"] + np.arange(data["td"].shape[0]) * data['dt'] + t = data["epoch"] + np.arange(data["td"].shape[0]) * data["dt"] td = data["td"] if t is not None and td is not None: plt.figure() @@ -115,6 +119,7 @@ def raw_data_plot(context: AssetExecutionContext): plot_paths.append(plot_path) return plot_paths + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "psd"]], @@ -126,6 +131,7 @@ def psd_plot(context: AssetExecutionContext): Plot the PSD for each IFO for the event. """ import matplotlib.pyplot as plt + event_name = context.partition_key event_dir = os.path.join("data", event_name, "raw") plots_dir = os.path.join("data", event_name, "plots") @@ -172,7 +178,9 @@ def config_file(context: AssetExecutionContext): if os.path.exists(data_file) and os.path.exists(psd_file): available_ifos.append(ifo) if available_ifos == []: - raise RuntimeError(f"No IFOs with both data and PSD found for event {event_name}") + raise RuntimeError( + f"No IFOs with both data and PSD found for event {event_name}" + ) run = IMRPhenomPv2StandardCBCRunDefinition( n_chains=500, n_local_steps=100, @@ -219,15 +227,34 @@ def config_file(context: AssetExecutionContext): run.local_data_prefix = os.path.join(run_dir, "raw/") run.serialize(os.path.join(run_dir, "config.yaml")) + @dg.multi_asset( specs=[ - dg.AssetSpec(key=["RealDataCatalog", "training_loss"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "production_chains"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "production_log_prob"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "production_local_acceptance"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "production_global_acceptance"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "auxiliary_nf_samples"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "auxiliary_prior_samples"], deps=[raw_data, config_file]), + dg.AssetSpec( + key=["RealDataCatalog", "training_loss"], deps=[raw_data, config_file] + ), + dg.AssetSpec( + key=["RealDataCatalog", "production_chains"], deps=[raw_data, config_file] + ), + dg.AssetSpec( + key=["RealDataCatalog", "production_log_prob"], deps=[raw_data, config_file] + ), + dg.AssetSpec( + key=["RealDataCatalog", "production_local_acceptance"], + deps=[raw_data, config_file], + ), + dg.AssetSpec( + key=["RealDataCatalog", "production_global_acceptance"], + deps=[raw_data, config_file], + ), + dg.AssetSpec( + key=["RealDataCatalog", "auxiliary_nf_samples"], + deps=[raw_data, config_file], + ), + dg.AssetSpec( + key=["RealDataCatalog", "auxiliary_prior_samples"], + deps=[raw_data, config_file], + ), ], group_name="run", partitions_def=event_partitions_def, @@ -251,6 +278,7 @@ def loss_plot(context: AssetExecutionContext): Generate and save a loss plot from the training_loss asset. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -272,6 +300,7 @@ def loss_plot(context: AssetExecutionContext): plt.close() return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_chains"]], @@ -284,6 +313,7 @@ def production_chains_corner_plot(context: AssetExecutionContext): """ import matplotlib.pyplot as plt import corner + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -294,7 +324,22 @@ def production_chains_corner_plot(context: AssetExecutionContext): results = np.load(results_path, allow_pickle=True) chains = results["chains"].item() # keys = np.sort(list(chains.keys())) - keys = ['M_c', 'q', 's1_mag', 's1_theta', 's1_phi', 's2_mag', 's2_theta', 's2_phi', 'iota', 'd_L', 'phase_c', 'psi', 'ra', 'dec'] + keys = [ + "M_c", + "q", + "s1_mag", + "s1_theta", + "s1_phi", + "s2_mag", + "s2_theta", + "s2_phi", + "iota", + "d_L", + "phase_c", + "psi", + "ra", + "dec", + ] samples = np.array([chains[key] for key in keys]).T fig = corner.corner(samples[::10], labels=keys) plot_path = os.path.join(plots_dir, "production_chains_corner.png") @@ -302,6 +347,7 @@ def production_chains_corner_plot(context: AssetExecutionContext): plt.close(fig) return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "auxiliary_nf_samples"]], @@ -314,6 +360,7 @@ def nf_samples_corner_plot(context: AssetExecutionContext): """ import matplotlib.pyplot as plt import corner + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -324,7 +371,22 @@ def nf_samples_corner_plot(context: AssetExecutionContext): results = np.load(results_path, allow_pickle=True) nf_samples = results["nf_samples"].item() # keys = np.sort(list(nf_samples.keys())) - keys = ['M_c', 'q', 's1_mag', 's1_theta', 's1_phi', 's2_mag', 's2_theta', 's2_phi', 'iota', 'd_L', 'phase_c', 'psi', 'ra', 'dec'] + keys = [ + "M_c", + "q", + "s1_mag", + "s1_theta", + "s1_phi", + "s2_mag", + "s2_theta", + "s2_phi", + "iota", + "d_L", + "phase_c", + "psi", + "ra", + "dec", + ] nf_samples = np.array([nf_samples[key] for key in keys]).T fig = corner.corner(nf_samples, labels=keys) # Thinning for better visualization plot_path = os.path.join(plots_dir, "nf_samples_corner.png") @@ -332,6 +394,7 @@ def nf_samples_corner_plot(context: AssetExecutionContext): plt.close(fig) return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "auxiliary_prior_samples"]], @@ -344,6 +407,7 @@ def prior_samples_corner_plot(context: AssetExecutionContext): """ import matplotlib.pyplot as plt import corner + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -354,7 +418,22 @@ def prior_samples_corner_plot(context: AssetExecutionContext): results = np.load(results_path, allow_pickle=True) prior_samples = results["prior_samples"].item() # keys = np.sort(list(prior_samples.keys())) - keys = ['M_c', 'q', 's1_mag', 's1_theta', 's1_phi', 's2_mag', 's2_theta', 's2_phi', 'iota', 'd_L', 'phase_c', 'psi', 'ra', 'dec'] + keys = [ + "M_c", + "q", + "s1_mag", + "s1_theta", + "s1_phi", + "s2_mag", + "s2_theta", + "s2_phi", + "iota", + "d_L", + "phase_c", + "psi", + "ra", + "dec", + ] prior_samples = np.array([prior_samples[key] for key in keys]).T fig = corner.corner(prior_samples, labels=keys) # Thinning for better visualization plot_path = os.path.join(plots_dir, "prior_samples_corner.png") @@ -362,6 +441,7 @@ def prior_samples_corner_plot(context: AssetExecutionContext): plt.close(fig) return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_chains"]], @@ -373,6 +453,7 @@ def production_chains_trace_plot(context: AssetExecutionContext): Generate and save a trace plot for the production chains. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -382,7 +463,22 @@ def production_chains_trace_plot(context: AssetExecutionContext): raise FileNotFoundError(f"Results file not found: {results_path}") results = np.load(results_path, allow_pickle=True) chains = results["chains"].item() - keys = ['M_c', 'q', 's1_mag', 's1_theta', 's1_phi', 's2_mag', 's2_theta', 's2_phi', 'iota', 'd_L', 'phase_c', 'psi', 'ra', 'dec'] + keys = [ + "M_c", + "q", + "s1_mag", + "s1_theta", + "s1_phi", + "s2_mag", + "s2_theta", + "s2_phi", + "iota", + "d_L", + "phase_c", + "psi", + "ra", + "dec", + ] n_params = len(keys) samples = [chains[key] for key in keys] @@ -400,6 +496,7 @@ def production_chains_trace_plot(context: AssetExecutionContext): plt.close() return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_log_prob"]], @@ -411,6 +508,7 @@ def production_log_prob_distribution(context: AssetExecutionContext): Generate and save a histogram of the production log probability. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -432,6 +530,7 @@ def production_log_prob_distribution(context: AssetExecutionContext): plt.close() return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_log_prob"]], @@ -443,6 +542,7 @@ def production_log_prob_evolution(context: AssetExecutionContext): Generate and save a plot of the evolution of the production log probability. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -464,6 +564,7 @@ def production_log_prob_evolution(context: AssetExecutionContext): plt.close() return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_local_acceptance"]], @@ -475,6 +576,7 @@ def production_local_acceptance_plot(context: AssetExecutionContext): Generate and save a plot of the local acceptance rate. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -496,6 +598,7 @@ def production_local_acceptance_plot(context: AssetExecutionContext): plt.close() return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_global_acceptance"]], @@ -507,6 +610,7 @@ def production_global_acceptance_plot(context: AssetExecutionContext): Generate and save a plot of the global acceptance rate. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") diff --git a/src/jimgw/core/prior.py b/src/jimgw/core/prior.py index 3f44eedbe..6feb1bcb9 100644 --- a/src/jimgw/core/prior.py +++ b/src/jimgw/core/prior.py @@ -145,9 +145,9 @@ def __repr__(self): def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) - assert ( - self.n_dims == 1 - ), "StandardNormalDistribution needs to be 1D distributions" + assert self.n_dims == 1, ( + "StandardNormalDistribution needs to be 1D distributions" + ) def sample( self, rng_key: PRNGKeyArray, n_samples: int @@ -193,9 +193,9 @@ def __init__( base_prior: list[Prior], transforms: list[BijectiveTransform], ): - assert ( - len(base_prior) == 1 - ), "SequentialTransformPrior only takes one base prior" + assert len(base_prior) == 1, ( + "SequentialTransformPrior only takes one base prior" + ) self.base_prior = base_prior self.transforms = transforms self.parameter_names = base_prior[0].parameter_names diff --git a/src/jimgw/core/single_event/data.py b/src/jimgw/core/single_event/data.py index 998fa35f3..a93f402d8 100644 --- a/src/jimgw/core/single_event/data.py +++ b/src/jimgw/core/single_event/data.py @@ -295,9 +295,9 @@ def from_fd( Returns: Data: Data object with the Fourier and time domain data. """ - assert len(fd) == len( - frequencies - ), "Frequency and data arrays must have the same length" + assert len(fd) == len(frequencies), ( + "Frequency and data arrays must have the same length" + ) # form full frequency array delta_f = frequencies[1] - frequencies[0] fnyq = frequencies[-1] @@ -315,9 +315,9 @@ def from_fd( delta_t = 1 / (2 * fnyq) data_td_full = jnp.fft.irfft(data_fd_full) / delta_t # check frequencies - assert jnp.allclose( - f, jnp.fft.rfftfreq(len(data_td_full), delta_t) - ), "Generated frequencies do not match the input frequencies" + assert jnp.allclose(f, jnp.fft.rfftfreq(len(data_td_full), delta_t)), ( + "Generated frequencies do not match the input frequencies" + ) # create a Data object data = cls(data_td_full, delta_t, epoch=epoch, name=name) data.fd = data_fd_full @@ -326,9 +326,9 @@ def from_fd( # represents the input FD data. d_new, f_new = data.frequency_slice(frequencies[0], frequencies[-1]) assert all(jnp.equal(d_new, fd)), "Data do not match after slicing" - assert all( - jnp.equal(f_new, frequencies) - ), "Frequencies do not match after slicing" + assert all(jnp.equal(f_new, frequencies)), ( + "Frequencies do not match after slicing" + ) return data @classmethod @@ -448,9 +448,9 @@ def __init__( # NOTE: Are we sure the values and frequencies start from 0? self.values = values self.frequencies = frequencies - assert self.n_freq == len( - self.frequencies - ), "Values and frequencies must have the same length" + assert self.n_freq == len(self.frequencies), ( + "Values and frequencies must have the same length" + ) self.name = name or "" def __repr__(self) -> str: diff --git a/src/jimgw/core/single_event/detector.py b/src/jimgw/core/single_event/detector.py index 3eb9d93cb..de3a78412 100644 --- a/src/jimgw/core/single_event/detector.py +++ b/src/jimgw/core/single_event/detector.py @@ -142,9 +142,9 @@ def set_frequency_bounds( data, freqs_1 = self.data.frequency_slice(*self.frequency_bounds) psd, freqs_2 = self.psd.frequency_slice(*self.frequency_bounds) - assert all( - freqs_1 == freqs_2 - ), f"The {self.name} data and PSD must have same frequencies" + assert all(freqs_1 == freqs_2), ( + f"The {self.name} data and PSD must have same frequencies" + ) self._sliced_frequencies = freqs_1 self._sliced_fd_data = data diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index d97466175..548c86422 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -23,6 +23,15 @@ class SingleEventLikelihood(LikelihoodBase): detectors: Sequence[Detector] waveform: Waveform + @property + def duration(self) -> Float: + return self.detectors[0].data.duration + + @property + def detector_names(self): + """The interferometers for the likelihood.""" + return [detector.name for detector in self.detectors] + def __init__(self, detectors: Sequence[Detector], waveform: Waveform) -> None: self.detectors = detectors self.waveform = waveform @@ -36,6 +45,144 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: return 0.0 +class BaseTransientLikelihoodFD(SingleEventLikelihood): + def __init__( + self, + detectors: Sequence[Detector], + waveform: Waveform, + f_min: Float = 0, + f_max: Float = float("inf"), + trigger_time: Float = 0, + ) -> None: + super().__init__(detectors, waveform) + # Set the frequency bounds for the detectors + _frequencies = [] + for detector in detectors: + detector.set_frequency_bounds(f_min, f_max) + _frequencies.append(detector.sliced_frequencies) + _frequencies = jnp.array(_frequencies) + assert jnp.all(jnp.array(_frequencies)[:-1] == jnp.array(_frequencies)[1:]), ( + "The frequency arrays are not all the same." + ) + + self.frequencies = _frequencies[0] + self.trigger_time = trigger_time + self.gmst = compute_gmst(self.trigger_time) + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + params["trigger_time"] = self.trigger_time + params["gmst"] = self.gmst + waveform_sky = self.waveform(self.frequencies, params) + log_likelihood = 0.0 + df = ( + self.detectors[0].sliced_frequencies[1] + - self.detectors[0].sliced_frequencies[0] + ) + for ifo in self.detectors: + freqs, ifo_data, psd = ( + ifo.sliced_frequencies, + ifo.sliced_fd_data, + ifo.sliced_psd, + ) + h_dec = ifo.fd_response(freqs, waveform_sky, params) + match_filter_SNR = inner_product(h_dec, ifo_data, psd, df) + optimal_SNR = inner_product(h_dec, h_dec, psd, df) + log_likelihood += match_filter_SNR - optimal_SNR / 2 + return log_likelihood + + +class TimeMarginalizedLikelihoodFD(BaseTransientLikelihoodFD): + tc_range: tuple[Float, Float] = ( + -0.12, + 0.12, + ) # Default range for time marginalization + tc_array: Float[ + Array, " duration*f_sample/2" + ] # Array of time values for marginalization + pad_low: Float[Array, " n_pad_low"] # Padding for low frequencies + pad_high: Float[Array, " n_pad_high"] # Padding for high frequencies + + def __init__( + self, + detectors: Sequence[Detector], + waveform: Waveform, + f_min: Float = 0, + f_max: Float = float("inf"), + trigger_time: Float = 0, + ) -> None: + super().__init__(detectors, waveform, f_min, f_max, trigger_time) + fs = self.detectors[0].data.sampling_frequency + duration = self.detectors[0].data.duration + # Refactored: use instance attributes instead of self.kwargs + self.tc_array = jnp.fft.fftfreq(int(duration * fs / 2), 1.0 / duration) + self.pad_low = jnp.zeros(int(self.frequencies[0] * duration)) + if jnp.isclose(self.frequencies[-1], fs / 2.0 - 1.0 / duration): + self.pad_high = jnp.array([]) + else: + self.pad_high = jnp.zeros( + int((fs / 2.0 - 1.0 / duration - self.frequencies[-1]) * duration) + ) + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + log_likelihood = 0.0 + complex_h_inner_d = jnp.zeros_like(self.detectors[0].sliced_frequencies) + df = ( + self.detectors[0].sliced_frequencies[1] + - self.detectors[0].sliced_frequencies[0] + ) + waveform_sky = self.waveform(self.frequencies, params) + for ifo in self.detectors: + freqs, ifo_data, psd = ( + ifo.sliced_frequencies, + ifo.sliced_fd_data, + ifo.sliced_psd, + ) + h_dec = ifo.fd_response(freqs, waveform_sky, params) + # using instead of + complex_h_inner_d += 4 * h_dec * jnp.conj(ifo_data) / psd * df + optimal_SNR = inner_product(h_dec, h_dec, psd, df) + log_likelihood += -optimal_SNR / 2 + + # padding the complex_h_inner_d + # this array is the hd*/S for f in [0, fs / 2 - df] + complex_h_inner_d_positive_f = jnp.concatenate( + (self.pad_low, complex_h_inner_d, self.pad_high) + ) + + # make use of the fft + # which then return the exp(-i2pift_c) + # w.r.t. the tc_array + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") + + # set the values to -inf when it is outside the tc range + # so that they will disappear after the logsumexp + fft_h_inner_d = jnp.where( + (self.tc_array > self.tc_range[0]) & (self.tc_array < self.tc_range[1]), + fft_h_inner_d.real, + jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, + ) + + # using the logsumexp to marginalize over the tc prior range + log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(self.tc_array)) + return log_likelihood + + +class PhaseMarginalizedLikelihoodFD(BaseTransientLikelihoodFD): + pass + + +# class HeterodynedTransientLikelihoodFD(SingleEventLikelihood): +# pass + + +class HeterodynedTimeMarginalizedLikelihoodFD(SingleEventLikelihood): + pass + + +class HeterodynedPhaseMarginalizedLikelihoodFD(SingleEventLikelihood): + pass + + class TransientLikelihoodFD(SingleEventLikelihood): def __init__( self, @@ -56,9 +203,9 @@ def __init__( detector.set_frequency_bounds(f_min, f_max) _frequencies.append(detector.sliced_frequencies) _frequencies = jnp.array(_frequencies) - assert jnp.all( - jnp.array(_frequencies)[:-1] == jnp.array(_frequencies)[1:] - ), "The frequency arrays are not all the same." + assert jnp.all(jnp.array(_frequencies)[:-1] == jnp.array(_frequencies)[1:]), ( + "The frequency arrays are not all the same." + ) self.detectors = detectors self.frequencies = _frequencies[0] @@ -103,7 +250,6 @@ def __init__( * duration ) ) - print() else: self.param_func = lambda x: x self.likelihood_function = original_likelihood @@ -121,7 +267,9 @@ def __init__( ), "Cannot have t_c fixed while having the marginalization of t_c turned on" assert not ( "phase_c" in fixing_parameters and "phase" in self.marginalization - ), "Cannot have phase_c fixed while having the marginalization of phase_c turned on" + ), ( + "Cannot have phase_c fixed while having the marginalization of phase_c turned on" + ) # if the same key exists in both dictionary, # the later one will overwrite the former one self.fixing_func = lambda x: {**x, **fixing_parameters} diff --git a/src/jimgw/run/cli/execute_single_run.py b/src/jimgw/run/cli/execute_single_run.py index e9d8c3f79..ebe06a9fa 100644 --- a/src/jimgw/run/cli/execute_single_run.py +++ b/src/jimgw/run/cli/execute_single_run.py @@ -17,9 +17,9 @@ ) args = parser.parse_args() - assert args.run_definition.endswith( - ".yaml" - ), "Run definition file must be a YAML file." + assert args.run_definition.endswith(".yaml"), ( + "Run definition file must be a YAML file." + ) definitions_name = yaml.safe_load(open(args.run_definition))["definition_name"] diff --git a/src/jimgw/run/run_manager.py b/src/jimgw/run/run_manager.py index a569c7bdb..03fd063ce 100644 --- a/src/jimgw/run/run_manager.py +++ b/src/jimgw/run/run_manager.py @@ -24,9 +24,9 @@ def __init__(self, run: RunDefinition | str): else: logging.ERROR("Run object or path not given.") - assert isinstance( - run, RunDefinition - ), "Run object or path not given. Please provide a Run object or a path to a serialized Run object." + assert isinstance(run, RunDefinition), ( + "Run object or path not given. Please provide a Run object or a path to a serialized Run object." + ) # Initialize the jim objects needed for the run run.initialize_jim_objects() From 95d07adec29e0663f86073721d711f3281c1047c Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 12 Jul 2025 12:45:55 -0400 Subject: [PATCH 03/16] Remove unused likelihood functions and refactor time marginalization --- src/jimgw/core/single_event/likelihood.py | 101 ++-------------------- 1 file changed, 6 insertions(+), 95 deletions(-) diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index 548c86422..15a86d504 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -64,7 +64,6 @@ def __init__( assert jnp.all(jnp.array(_frequencies)[:-1] == jnp.array(_frequencies)[1:]), ( "The frequency arrays are not all the same." ) - self.frequencies = _frequencies[0] self.trigger_time = trigger_time self.gmst = compute_gmst(self.trigger_time) @@ -92,15 +91,10 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: class TimeMarginalizedLikelihoodFD(BaseTransientLikelihoodFD): - tc_range: tuple[Float, Float] = ( - -0.12, - 0.12, - ) # Default range for time marginalization - tc_array: Float[ - Array, " duration*f_sample/2" - ] # Array of time values for marginalization - pad_low: Float[Array, " n_pad_low"] # Padding for low frequencies - pad_high: Float[Array, " n_pad_high"] # Padding for high frequencies + tc_range: tuple[Float, Float] + tc_array: Float[Array, " duration*f_sample/2"] + pad_low: Float[Array, " n_pad_low"] + pad_high: Float[Array, " n_pad_high"] def __init__( self, @@ -109,11 +103,12 @@ def __init__( f_min: Float = 0, f_max: Float = float("inf"), trigger_time: Float = 0, + tc_range: tuple[Float, Float] = (-0.12, 0.12), ) -> None: super().__init__(detectors, waveform, f_min, f_max, trigger_time) + self.tc_range = tc_range fs = self.detectors[0].data.sampling_frequency duration = self.detectors[0].data.duration - # Refactored: use instance attributes instead of self.kwargs self.tc_array = jnp.fft.fftfreq(int(duration * fs / 2), 1.0 / duration) self.pad_low = jnp.zeros(int(self.frequencies[0] * duration)) if jnp.isclose(self.frequencies[-1], fs / 2.0 - 1.0 / duration): @@ -706,24 +701,6 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: } -def original_likelihood( - params: dict[str, Float], - h_sky: dict[str, Complex[Array, " n_dim"]], - detectors: list[Detector], - **kwargs, -) -> Float: - log_likelihood = 0.0 - df = detectors[0].sliced_frequencies[1] - detectors[0].sliced_frequencies[0] - for ifo in detectors: - freqs, data, psd = ifo.sliced_frequencies, ifo.sliced_fd_data, ifo.sliced_psd - h_dec = ifo.fd_response(freqs, h_sky, params) - match_filter_SNR = inner_product(h_dec, data, psd, df) - optimal_SNR = inner_product(h_dec, h_dec, psd, df) - log_likelihood += match_filter_SNR - optimal_SNR / 2 - - return log_likelihood - - def phase_marginalized_likelihood( params: dict[str, Float], h_sky: dict[str, Complex[Array, " n_dim"]], @@ -744,72 +721,6 @@ def phase_marginalized_likelihood( return log_likelihood -def _get_tc_array(duration: Float, sampling_rate: Float): - return jnp.fft.fftfreq(int(duration * sampling_rate / 2), 1 / duration) - - -def _get_frequencies_pads(detector: Detector, fs: Float) -> tuple[Float, Float]: - f_low, f_high = detector.frequency_bounds - duration = detector.data.duration - delta_f = 1 / duration - - pad_low = jnp.zeros(int(f_low * duration)) - - f_Nyquist_diff = fs / 2.0 - delta_f - f_high - if jnp.isclose(f_Nyquist_diff, 0): - pad_high = jnp.array([]) - else: - pad_high = jnp.zeros(int(f_Nyquist_diff * duration)) - return pad_low, pad_high - - -def time_marginalized_likelihood( - params: dict[str, Float], - h_sky: dict[str, Complex[Array, " n_dim"]], - detectors: list[Detector], - **kwargs, -) -> Float: - log_likelihood = 0.0 - complex_h_inner_d = jnp.zeros_like(detectors[0].sliced_frequencies) - df = detectors[0].sliced_frequencies[1] - detectors[0].sliced_frequencies[0] - for ifo in detectors: - freqs, data, psd = ifo.sliced_frequencies, ifo.sliced_fd_data, ifo.sliced_psd - h_dec = ifo.fd_response(freqs, h_sky, params) - # using instead of - complex_h_inner_d += 4 * h_dec * jnp.conj(data) / psd * df - optimal_SNR = inner_product(h_dec, h_dec, psd, df) - log_likelihood += -optimal_SNR / 2 - - # fetch the tc range tc_array, lower padding and higher padding - tc_range = [-0.12, 0.12] # TODO: This is hard coded right now, need to update. - tc_array = kwargs["tc_array"] - pad_low = kwargs["pad_low"] - pad_high = kwargs["pad_high"] - - # padding the complex_h_inner_d - # this array is the hd*/S for f in [0, fs / 2 - df] - complex_h_inner_d_positive_f = jnp.concatenate( - (pad_low, complex_h_inner_d, pad_high) - ) - - # make use of the fft - # which then return the exp(-i2pift_c) - # w.r.t. the tc_array - fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") - - # set the values to -inf when it is outside the tc range - # so that they will disappear after the logsumexp - fft_h_inner_d = jnp.where( - (tc_array > tc_range[0]) & (tc_array < tc_range[1]), - fft_h_inner_d.real, - jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, - ) - - # using the logsumexp to marginalize over the tc prior range - log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(tc_array)) - return log_likelihood - - def phase_time_marginalized_likelihood( params: dict[str, Float], h_sky: dict[str, Complex[Array, " n_dim"]], From e5153dea2036ef5f68604ed1646df30b29280d35 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 12 Jul 2025 12:48:03 -0400 Subject: [PATCH 04/16] Add docstrings to TimeMarginalizedLikelihoodFD class and methods --- src/jimgw/core/single_event/likelihood.py | 67 ++++++++++++++++++++--- 1 file changed, 59 insertions(+), 8 deletions(-) diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index 15a86d504..6241ec97a 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -91,6 +91,32 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: class TimeMarginalizedLikelihoodFD(BaseTransientLikelihoodFD): + """Frequency-domain likelihood class with analytic marginalization over coalescence time. + + This class implements a likelihood function for gravitational wave transient events, + marginalized over the coalescence time parameter (`t_c`). The marginalization is performed + using a fast Fourier transform (FFT) over the frequency domain inner product between the + model and the data. The likelihood is computed for a set of detectors and a waveform model. + + Attributes: + tc_range (tuple[Float, Float]): The range of coalescence times to marginalize over. + tc_array (Float[Array, "duration*f_sample/2"]): Array of time shifts corresponding to FFT bins. + pad_low (Float[Array, "n_pad_low"]): Zero-padding array for frequencies below the minimum frequency. + pad_high (Float[Array, "n_pad_high"]): Zero-padding array for frequencies above the maximum frequency. + + Args: + detectors (Sequence[Detector]): List of detector objects containing data and metadata. + waveform (Waveform): Waveform model to evaluate. + f_min (Float, optional): Minimum frequency for likelihood evaluation. Defaults to 0. + f_max (Float, optional): Maximum frequency for likelihood evaluation. Defaults to infinity. + trigger_time (Float, optional): GPS time of the event trigger. Defaults to 0. + tc_range (tuple[Float, Float], optional): Range of coalescence times to marginalize over. Defaults to (-0.12, 0.12). + + Example: + >>> likelihood = TimeMarginalizedLikelihoodFD(detectors, waveform, f_min=20, f_max=1024, trigger_time=1234567890) + >>> logL = likelihood.evaluate(params, data) + """ + tc_range: tuple[Float, Float] tc_array: Float[Array, " duration*f_sample/2"] pad_low: Float[Array, " n_pad_low"] @@ -105,6 +131,19 @@ def __init__( trigger_time: Float = 0, tc_range: tuple[Float, Float] = (-0.12, 0.12), ) -> None: + """Initializes the TimeMarginalizedLikelihoodFD class. + + Sets up the frequency bounds, coalescence time range, FFT time array, and zero-padding + arrays for the likelihood calculation. + + Args: + detectors (Sequence[Detector]): List of detector objects. + waveform (Waveform): Waveform model. + f_min (Float, optional): Minimum frequency. Defaults to 0. + f_max (Float, optional): Maximum frequency. Defaults to infinity. + trigger_time (Float, optional): Event trigger time. Defaults to 0. + tc_range (tuple[Float, Float], optional): Marginalization range for coalescence time. Defaults to (-0.12, 0.12). + """ super().__init__(detectors, waveform, f_min, f_max, trigger_time) self.tc_range = tc_range fs = self.detectors[0].data.sampling_frequency @@ -119,6 +158,22 @@ def __init__( ) def evaluate(self, params: dict[str, Float], data: dict) -> Float: + """Evaluate the time-marginalized likelihood for a given set of parameters. + + Computes the log-likelihood marginalized over coalescence time by: + - Calculating the frequency-domain inner product between the model and data for each detector. + - Padding the inner product array to cover the full frequency range. + - Applying FFT to obtain the likelihood as a function of coalescence time. + - Restricting the FFT output to the specified `tc_range`. + - Marginalizing using logsumexp over the allowed coalescence times. + + Args: + params (dict[str, Float]): Dictionary of model parameters. + data (dict): Dictionary containing data (not used in this implementation). + + Returns: + Float: The marginalized log-likelihood value. + """ log_likelihood = 0.0 complex_h_inner_d = jnp.zeros_like(self.detectors[0].sliced_frequencies) df = ( @@ -138,26 +193,22 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: optimal_SNR = inner_product(h_dec, h_dec, psd, df) log_likelihood += -optimal_SNR / 2 - # padding the complex_h_inner_d - # this array is the hd*/S for f in [0, fs / 2 - df] + # Padding the complex_h_inner_d to cover the full frequency range complex_h_inner_d_positive_f = jnp.concatenate( (self.pad_low, complex_h_inner_d, self.pad_high) ) - # make use of the fft - # which then return the exp(-i2pift_c) - # w.r.t. the tc_array + # FFT to obtain exp(-i2πf t_c) as a function of t_c fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") - # set the values to -inf when it is outside the tc range - # so that they will disappear after the logsumexp + # Restrict FFT output to the allowed tc_range, set others to -inf fft_h_inner_d = jnp.where( (self.tc_array > self.tc_range[0]) & (self.tc_array < self.tc_range[1]), fft_h_inner_d.real, jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, ) - # using the logsumexp to marginalize over the tc prior range + # Marginalize over t_c using logsumexp log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(self.tc_array)) return log_likelihood From 92b79646842baf11aa5ee1572452031c1345221d Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 12 Jul 2025 12:49:02 -0400 Subject: [PATCH 05/16] Update likelihood.py --- src/jimgw/core/single_event/likelihood.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index 6241ec97a..b02979bbb 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -174,6 +174,9 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: Returns: Float: The marginalized log-likelihood value. """ + params["trigger_time"] = self.trigger_time + params["gmst"] = self.gmst + params['t_c'] = 0.0 # Fixing t_c to 0 for time marginalization log_likelihood = 0.0 complex_h_inner_d = jnp.zeros_like(self.detectors[0].sliced_frequencies) df = ( From 918fe9ee45582243e90b377ec137b921a12f6561 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 12 Jul 2025 12:49:43 -0400 Subject: [PATCH 06/16] Add docstrings to BaseTransientLikelihoodFD class and methods --- src/jimgw/core/single_event/likelihood.py | 45 +++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index b02979bbb..7bb3c7047 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -46,6 +46,28 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: class BaseTransientLikelihoodFD(SingleEventLikelihood): + """Base class for frequency-domain transient gravitational wave likelihood. + + This class provides the basic likelihood evaluation for gravitational wave transient events + in the frequency domain, using matched filtering across multiple detectors. + + Attributes: + frequencies (Float[Array]): The frequency array used for likelihood evaluation. + trigger_time (Float): The GPS time of the event trigger. + gmst (Float): Greenwich Mean Sidereal Time computed from the trigger time. + + Args: + detectors (Sequence[Detector]): List of detector objects containing data and metadata. + waveform (Waveform): Waveform model to evaluate. + f_min (Float, optional): Minimum frequency for likelihood evaluation. Defaults to 0. + f_max (Float, optional): Maximum frequency for likelihood evaluation. Defaults to infinity. + trigger_time (Float, optional): GPS time of the event trigger. Defaults to 0. + + Example: + >>> likelihood = BaseTransientLikelihoodFD(detectors, waveform, f_min=20, f_max=1024, trigger_time=1234567890) + >>> logL = likelihood.evaluate(params, data) + """ + def __init__( self, detectors: Sequence[Detector], @@ -54,6 +76,17 @@ def __init__( f_max: Float = float("inf"), trigger_time: Float = 0, ) -> None: + """Initializes the BaseTransientLikelihoodFD class. + + Sets up the frequency bounds for the detectors and computes the Greenwich Mean Sidereal Time. + + Args: + detectors (Sequence[Detector]): List of detector objects. + waveform (Waveform): Waveform model. + f_min (Float, optional): Minimum frequency. Defaults to 0. + f_max (Float, optional): Maximum frequency. Defaults to infinity. + trigger_time (Float, optional): Event trigger time. Defaults to 0. + """ super().__init__(detectors, waveform) # Set the frequency bounds for the detectors _frequencies = [] @@ -69,6 +102,18 @@ def __init__( self.gmst = compute_gmst(self.trigger_time) def evaluate(self, params: dict[str, Float], data: dict) -> Float: + """Evaluate the log-likelihood for a given set of parameters. + + Computes the log-likelihood by matched filtering the model waveform against the data + for each detector, using the frequency-domain inner product. + + Args: + params (dict[str, Float]): Dictionary of model parameters. + data (dict): Dictionary containing data (not used in this implementation). + + Returns: + Float: The log-likelihood value. + """ params["trigger_time"] = self.trigger_time params["gmst"] = self.gmst waveform_sky = self.waveform(self.frequencies, params) From ab782cacccf3d8a52247d89bcbcda23dc63b5455 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 12 Jul 2025 13:14:13 -0400 Subject: [PATCH 07/16] Add unit tests for BaseTransientLikelihoodFD class --- test/unit/test_likelhood.py | 75 +++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/test/unit/test_likelhood.py b/test/unit/test_likelhood.py index 5871ed8ee..042514358 100644 --- a/test/unit/test_likelhood.py +++ b/test/unit/test_likelhood.py @@ -1 +1,76 @@ import pytest +import numpy as np +from jimgw.core.single_event.likelihood import BaseTransientLikelihoodFD +from jimgw.core.single_event.detector import get_H1, get_L1 +from jimgw.core.single_event.waveform import RippleIMRPhenomD +from jimgw.core.single_event.data import Data + +class TestBaseTransientLikelihoodFD: + """ + Organized tests for BaseTransientLikelihoodFD using real detector and waveform implementations. + """ + + @pytest.fixture + def GW150912_likelihood(self) -> BaseTransientLikelihoodFD: + """ + Fixture to set up a realistic BaseTransientLikelihoodFD instance using GWOSC data and power spectral density. + """ + gps = 1126259462.4 + start = gps - 2 + end = gps + 2 + psd_start = gps - 2048 + psd_end = gps + 2048 + fmin = 20.0 + fmax = 1024.0 + + # Initialize detectors and set data/PSD + ifos = [get_H1(), get_L1()] + for ifo in ifos: + data = Data.from_gwosc(ifo.name, start, end) + ifo.set_data(data) + psd_data = Data.from_gwosc(ifo.name, psd_start, psd_end) + psd_fftlength = data.duration * data.sampling_frequency + ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength)) + + waveform = RippleIMRPhenomD(f_ref=20.0) + likelihood = BaseTransientLikelihoodFD( + detectors=ifos, + waveform=waveform, + f_min=fmin, + f_max=fmax, + trigger_time=gps, + ) + return likelihood + + def test_likelihood_initialization(self, GW150912_likelihood: BaseTransientLikelihoodFD): + """ + Test initialization and attributes of BaseTransientLikelihoodFD with realistic setup. + """ + likelihood = GW150912_likelihood + assert isinstance(likelihood, BaseTransientLikelihoodFD) + assert np.allclose(likelihood.frequencies, [20.0, (20.0 + 1024.0) / 2, 1024.0]) + assert likelihood.trigger_time == 1126259462.4 + assert hasattr(likelihood, "gmst") + + def test_likelihood_evaluation(self, GW150912_likelihood: BaseTransientLikelihoodFD): + """ + Test the evaluation of the likelihood with realistic parameters. + """ + likelihood = GW150912_likelihood + # Example parameters for testing + params = { + "M_c": 30.0, + "eta": 0.249, + "s1_z": 0.0, + "s2_z": 0.0, + "d_L": 400.0, + "phase_c": 0.0, + "t_c": 0.0, + "iota": 0.0, + "ra": 1.375, + "dec": -1.2108, + "gmst": likelihood.gmst, + "psi": 0.0, + } + log_likelihood = likelihood.evaluate(params, {}) + assert np.isfinite(log_likelihood), "Log likelihood should be finite" \ No newline at end of file From e9b03c89d9b51187873dbf16a28447f2c278e52b Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 12 Jul 2025 13:52:31 -0400 Subject: [PATCH 08/16] Initiali refactoring of the likelihood class --- src/jimgw/core/single_event/likelihood.py | 289 ++++++++-------------- test/unit/test_likelhood.py | 11 +- 2 files changed, 107 insertions(+), 193 deletions(-) diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index 7bb3c7047..56e942371 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -22,6 +22,7 @@ class SingleEventLikelihood(LikelihoodBase): detectors: Sequence[Detector] waveform: Waveform + fixed_parameters: dict[str, Float] = {} @property def duration(self) -> Float: @@ -32,9 +33,15 @@ def detector_names(self): """The interferometers for the likelihood.""" return [detector.name for detector in self.detectors] - def __init__(self, detectors: Sequence[Detector], waveform: Waveform) -> None: + def __init__( + self, + detectors: Sequence[Detector], + waveform: Waveform, + fixed_parameters: Optional[dict[str, Float]] = None, + ) -> None: self.detectors = detectors self.waveform = waveform + self.fixed_parameters = fixed_parameters if fixed_parameters is not None else {} class ZeroLikelihood(LikelihoodBase): @@ -72,6 +79,7 @@ def __init__( self, detectors: Sequence[Detector], waveform: Waveform, + fixed_parameters: Optional[dict[str, Float]] = None, f_min: Float = 0, f_max: Float = float("inf"), trigger_time: Float = 0, @@ -87,7 +95,7 @@ def __init__( f_max (Float, optional): Maximum frequency. Defaults to infinity. trigger_time (Float, optional): Event trigger time. Defaults to 0. """ - super().__init__(detectors, waveform) + super().__init__(detectors, waveform, fixed_parameters) # Set the frequency bounds for the detectors _frequencies = [] for detector in detectors: @@ -114,6 +122,7 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: Returns: Float: The log-likelihood value. """ + params.update(self.fixed_parameters) params["trigger_time"] = self.trigger_time params["gmst"] = self.gmst waveform_sky = self.waveform(self.frequencies, params) @@ -171,6 +180,7 @@ def __init__( self, detectors: Sequence[Detector], waveform: Waveform, + fixed_parameters: Optional[dict[str, Float]] = None, f_min: Float = 0, f_max: Float = float("inf"), trigger_time: Float = 0, @@ -189,7 +199,12 @@ def __init__( trigger_time (Float, optional): Event trigger time. Defaults to 0. tc_range (tuple[Float, Float], optional): Marginalization range for coalescence time. Defaults to (-0.12, 0.12). """ - super().__init__(detectors, waveform, f_min, f_max, trigger_time) + super().__init__( + detectors, waveform, fixed_parameters, f_min, f_max, trigger_time + ) + assert "t_c" not in self.fixed_parameters, ( + "Cannot have t_c fixed while marginalizing over t_c" + ) self.tc_range = tc_range fs = self.detectors[0].data.sampling_frequency duration = self.detectors[0].data.duration @@ -219,9 +234,10 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: Returns: Float: The marginalized log-likelihood value. """ + params.update(self.fixed_parameters) params["trigger_time"] = self.trigger_time params["gmst"] = self.gmst - params['t_c'] = 0.0 # Fixing t_c to 0 for time marginalization + params["t_c"] = 0.0 # Fixing t_c to 0 for time marginalization log_likelihood = 0.0 complex_h_inner_d = jnp.zeros_like(self.detectors[0].sliced_frequencies) df = ( @@ -262,139 +278,97 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: class PhaseMarginalizedLikelihoodFD(BaseTransientLikelihoodFD): - pass + """This has not been tested by a human yet.""" + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + log_likelihood = 0.0 + complex_d_inner_h = 0.0 + 0.0j + params.update(self.fixed_parameters) + params["phase_c"] = 0.0 # Fixing phase_c to 0 for phase marginalization + params["trigger_time"] = self.trigger_time + params["gmst"] = self.gmst + waveform_sky = self.waveform(self.frequencies, params) + df = ( + self.detectors[0].sliced_frequencies[1] + - self.detectors[0].sliced_frequencies[0] + ) + for ifo in self.detectors: + freqs, ifo_data, psd = ( + ifo.sliced_frequencies, + ifo.sliced_fd_data, + ifo.sliced_psd, + ) + h_dec = ifo.fd_response(freqs, waveform_sky, params) + complex_d_inner_h += complex_inner_product(h_dec, ifo_data, psd, df) + optimal_SNR = inner_product(h_dec, h_dec, psd, df) + log_likelihood += -optimal_SNR / 2 -# class HeterodynedTransientLikelihoodFD(SingleEventLikelihood): -# pass + log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) + return log_likelihood -class HeterodynedTimeMarginalizedLikelihoodFD(SingleEventLikelihood): - pass +class PhaseTimeMarginalizedLikelihoodFD(TimeMarginalizedLikelihoodFD): + """This has not been tested by a human yet.""" + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + # Refactored: use self.detectors, self.frequencies, self.tc_array, self.pad_low, self.pad_high, self.tc_range + log_likelihood = 0.0 + complex_h_inner_d = 0.0 + 0.0j + params.update(self.fixed_parameters) -class HeterodynedPhaseMarginalizedLikelihoodFD(SingleEventLikelihood): - pass + df = ( + self.detectors[0].sliced_frequencies[1] + - self.detectors[0].sliced_frequencies[0] + ) + params["trigger_time"] = self.trigger_time + params["gmst"] = self.gmst + params["t_c"] = 0.0 # Fix t_c for marginalization + params["phase_c"] = 0.0 + waveform_sky = self.waveform(self.frequencies, params) + for ifo in self.detectors: + freqs, ifo_data, psd = ( + ifo.sliced_frequencies, + ifo.sliced_fd_data, + ifo.sliced_psd, + ) + h_dec = ifo.fd_response(freqs, waveform_sky, params) + complex_h_inner_d += complex_inner_product(h_dec, ifo_data, psd, df) + optimal_SNR = inner_product(h_dec, h_dec, psd, df) + log_likelihood += -optimal_SNR / 2 + # Pad the complex_h_inner_d to cover the full frequency range + complex_h_inner_d_positive_f = jnp.concatenate( + (self.pad_low, complex_h_inner_d, self.pad_high) + ) -class TransientLikelihoodFD(SingleEventLikelihood): - def __init__( - self, - detectors: Sequence[Detector], - waveform: Waveform, - f_min: Float = 0, - f_max: Float = float("inf"), - trigger_time: Float = 0, - **kwargs, - ) -> None: - # NOTE: having 'kwargs' here makes it very difficult to diagnose - # errors and keep track of what's going on, would be better to list - # explicitly what the arguments are accepted + # FFT to obtain exp(-i2πf t_c) as a function of t_c + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") - # Set the frequency bounds for the detectors - _frequencies = [] - for detector in detectors: - detector.set_frequency_bounds(f_min, f_max) - _frequencies.append(detector.sliced_frequencies) - _frequencies = jnp.array(_frequencies) - assert jnp.all(jnp.array(_frequencies)[:-1] == jnp.array(_frequencies)[1:]), ( - "The frequency arrays are not all the same." + # Restrict FFT output to the allowed tc_range, set others to -inf + log_i0_abs_fft = jnp.where( + (self.tc_array > self.tc_range[0]) & (self.tc_array < self.tc_range[1]), + log_i0(jnp.absolute(fft_h_inner_d)), + jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, ) - self.detectors = detectors - self.frequencies = _frequencies[0] - self.duration = self.detectors[0].data.duration - self.waveform = waveform - self.trigger_time = trigger_time - self.gmst = compute_gmst(self.trigger_time) - self.kwargs = kwargs - if "marginalization" in self.kwargs: - marginalization = self.kwargs["marginalization"] - assert marginalization in [ - "phase", - "phase-time", - "time", - ], "Only support time, phase and phase+time marginalzation" - self.marginalization = marginalization - if self.marginalization == "phase-time": - self.param_func = lambda x: {**x, "phase_c": 0.0, "t_c": 0.0} - self.likelihood_function = phase_time_marginalized_likelihood - logging.info("Marginalizing over phase and time") - elif self.marginalization == "time": - self.param_func = lambda x: {**x, "t_c": 0.0} - self.likelihood_function = time_marginalized_likelihood - logging.info("Marginalizing over time") - elif self.marginalization == "phase": - self.param_func = lambda x: {**x, "phase_c": 0.0} - self.likelihood_function = phase_marginalized_likelihood - logging.info("Marginalizing over phase") - if "time" in self.marginalization: - fs = self.detectors[0].data.sampling_frequency - duration = self.detectors[0].data.duration - self.kwargs["tc_array"] = jnp.fft.fftfreq( - int(duration * fs / 2), 1.0 / duration - ) - self.kwargs["pad_low"] = jnp.zeros(int(self.frequencies[0] * duration)) - if jnp.isclose(self.frequencies[-1], fs / 2.0 - 1.0 / duration): - self.kwargs["pad_high"] = jnp.array([]) - else: - self.kwargs["pad_high"] = jnp.zeros( - int( - (fs / 2.0 - 1.0 / duration - self.frequencies[-1]) - * duration - ) - ) - else: - self.param_func = lambda x: x - self.likelihood_function = original_likelihood - self.marginalization = "" + # Marginalize over t_c using logsumexp + log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(self.tc_array)) + return log_likelihood - # the fixing_parameters is expected to be a dictionary - # with key as parameter name and value is the fixed value - # e.g. {'M_c': 1.1975, 't_c': 0} - if "fixing_parameters" in self.kwargs: - fixing_parameters = self.kwargs["fixing_parameters"] - print(f"Parameters are fixed {fixing_parameters}") - # check for conflict with the marginalization - assert not ( - "t_c" in fixing_parameters and "time" in self.marginalization - ), "Cannot have t_c fixed while having the marginalization of t_c turned on" - assert not ( - "phase_c" in fixing_parameters and "phase" in self.marginalization - ), ( - "Cannot have phase_c fixed while having the marginalization of phase_c turned on" - ) - # if the same key exists in both dictionary, - # the later one will overwrite the former one - self.fixing_func = lambda x: {**x, **fixing_parameters} - else: - self.fixing_func = lambda x: x - @property - def detector_names(self): - """The interferometers for the likelihood.""" - return [detector.name for detector in self.detectors] +# class HeterodynedTransientLikelihoodFD(SingleEventLikelihood): +# pass - def evaluate(self, params: dict[str, Float], data: dict) -> Float: - # TODO: Test whether we need to pass data in or with class changes is fine. - """Evaluate the likelihood for a given set of parameters.""" - params["trigger_time"] = self.trigger_time - params["gmst"] = self.gmst - # adjust the params due to different marginalzation scheme - params = self.param_func(params) - # adjust the params due to fixing parameters - params = self.fixing_func(params) - # evaluate the waveform as usual - waveform_sky = self.waveform(self.frequencies, params) - return self.likelihood_function( - params, - waveform_sky, - self.detectors, # type: ignore - **self.kwargs, - ) + +class HeterodynedTimeMarginalizedLikelihoodFD(SingleEventLikelihood): + pass + + +class HeterodynedPhaseMarginalizedLikelihoodFD(SingleEventLikelihood): + pass -class HeterodynedTransientLikelihoodFD(TransientLikelihoodFD): +class HeterodynedTransientLikelihoodFD(BaseTransientLikelihoodFD): n_bins: int # Number of bins to use for the likelihood ref_params: dict # Reference parameters for the likelihood freq_grid_low: Array # Heterodyned frequency grid @@ -795,79 +769,14 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: likelihood_presets = { - "TransientLikelihoodFD": TransientLikelihoodFD, + "BaseTransientLikelihoodFD": BaseTransientLikelihoodFD, + "TimeMarginalizedLikelihoodFD": TimeMarginalizedLikelihoodFD, + "PhaseMarginalizedLikelihoodFD": PhaseMarginalizedLikelihoodFD, + "PhaseTimeMarginalizedLikelihoodFD": PhaseTimeMarginalizedLikelihoodFD, "HeterodynedTransientLikelihoodFD": HeterodynedTransientLikelihoodFD, } -def phase_marginalized_likelihood( - params: dict[str, Float], - h_sky: dict[str, Complex[Array, " n_dim"]], - detectors: list[Detector], - **kwargs, -) -> Float: - log_likelihood = 0.0 - complex_d_inner_h = 0.0 + 0.0j - df = detectors[0].sliced_frequencies[1] - detectors[0].sliced_frequencies[0] - for ifo in detectors: - freqs, data, psd = ifo.sliced_frequencies, ifo.sliced_fd_data, ifo.sliced_psd - h_dec = ifo.fd_response(freqs, h_sky, params) - complex_d_inner_h += complex_inner_product(h_dec, data, psd, df) - optimal_SNR = inner_product(h_dec, h_dec, psd, df) - log_likelihood += -optimal_SNR / 2 - - log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) - return log_likelihood - - -def phase_time_marginalized_likelihood( - params: dict[str, Float], - h_sky: dict[str, Complex[Array, " n_dim"]], - detectors: list[Detector], - **kwargs, -) -> Float: - log_likelihood = 0.0 - complex_h_inner_d = 0.0 + 0.0j - df = detectors[0].sliced_frequencies[1] - detectors[0].sliced_frequencies[0] - for ifo in detectors: - freqs, data, psd = ifo.sliced_frequencies, ifo.sliced_fd_data, ifo.sliced_psd - h_dec = ifo.fd_response(freqs, h_sky, params) - # using instead of - complex_h_inner_d += complex_inner_product(data, h_dec, psd, df) - optimal_SNR = inner_product(h_dec, h_dec, psd, df) - log_likelihood += -optimal_SNR / 2 - duration = detectors[0].data.duration - - # fetch the tc range tc_array, lower padding and higher padding - tc_range = kwargs["tc_range"] - fs = kwargs["sampling_rate"] - tc_array = _get_tc_array(duration, fs) - pad_low, pad_high = _get_frequencies_pads(detectors[0], fs=fs) - - # padding the complex_h_inner_d - # this array is the hd*/S for f in [0, fs / 2 - df] - complex_h_inner_d_positive_f = jnp.concatenate( - (pad_low, complex_h_inner_d, pad_high) - ) - - # make use of the fft - # which then return the exp(-i2pift_c) - # w.r.t. the tc_array - fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") - - # set the values to -inf when it is outside the tc range - # so that they will disappear after the logsumexp - log_i0_abs_fft = jnp.where( - (tc_array > tc_range[0]) & (tc_array < tc_range[1]), - log_i0(jnp.absolute(fft_h_inner_d)), - jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, - ) - - # using the logsumexp to marginalize over the tc prior range - log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(tc_array)) - return log_likelihood - - def original_relative_binning_likelihood( params, A0_array, diff --git a/test/unit/test_likelhood.py b/test/unit/test_likelhood.py index 042514358..749e49e57 100644 --- a/test/unit/test_likelhood.py +++ b/test/unit/test_likelhood.py @@ -5,6 +5,7 @@ from jimgw.core.single_event.waveform import RippleIMRPhenomD from jimgw.core.single_event.data import Data + class TestBaseTransientLikelihoodFD: """ Organized tests for BaseTransientLikelihoodFD using real detector and waveform implementations. @@ -42,7 +43,9 @@ def GW150912_likelihood(self) -> BaseTransientLikelihoodFD: ) return likelihood - def test_likelihood_initialization(self, GW150912_likelihood: BaseTransientLikelihoodFD): + def test_likelihood_initialization( + self, GW150912_likelihood: BaseTransientLikelihoodFD + ): """ Test initialization and attributes of BaseTransientLikelihoodFD with realistic setup. """ @@ -52,7 +55,9 @@ def test_likelihood_initialization(self, GW150912_likelihood: BaseTransientLikel assert likelihood.trigger_time == 1126259462.4 assert hasattr(likelihood, "gmst") - def test_likelihood_evaluation(self, GW150912_likelihood: BaseTransientLikelihoodFD): + def test_likelihood_evaluation( + self, GW150912_likelihood: BaseTransientLikelihoodFD + ): """ Test the evaluation of the likelihood with realistic parameters. """ @@ -73,4 +78,4 @@ def test_likelihood_evaluation(self, GW150912_likelihood: BaseTransientLikelihoo "psi": 0.0, } log_likelihood = likelihood.evaluate(params, {}) - assert np.isfinite(log_likelihood), "Log likelihood should be finite" \ No newline at end of file + assert np.isfinite(log_likelihood), "Log likelihood should be finite" From 5eb4ffd5825dc0b0233905628bfb883a0fb615cb Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 12 Jul 2025 13:55:43 -0400 Subject: [PATCH 09/16] Update workbench.py --- example/workbench.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/example/workbench.py b/example/workbench.py index fb225fc57..fe294ce88 100644 --- a/example/workbench.py +++ b/example/workbench.py @@ -15,7 +15,7 @@ UniformSpherePrior, ) from jimgw.core.single_event.detector import get_H1, get_L1, get_V1 -from jimgw.core.single_event.likelihood import TransientLikelihoodFD +from jimgw.core.single_event.likelihood import BaseTransientLikelihoodFD from jimgw.core.single_event.data import Data from jimgw.core.single_event.waveform import RippleIMRPhenomPv2 from jimgw.core.transforms import BoundToUnbound @@ -131,7 +131,9 @@ gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax ), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform( + tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0] + ), SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), BoundToUnbound( name_mapping=(["M_c"], ["M_c_unbounded"]), @@ -207,13 +209,12 @@ ] -likelihood = TransientLikelihoodFD( +likelihood = BaseTransientLikelihoodFD( ifos, waveform=waveform, trigger_time=gps, f_min=fmin, f_max=fmax, - # marginalization="time", ) jim = Jim( From 7f2bfe995cecb27f713babbddcee12b2a6654b4c Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 14 Jul 2025 08:33:31 -0400 Subject: [PATCH 10/16] Refactor likelihood evaluation methods to use _likelihood Add abstract _likelihood method to SingleEventLikelihood and refactor subclasses to implement core likelihood logic in _likelihood. Move parameter updates and fixed values to evaluate methods for clarity. --- src/jimgw/core/single_event/likelihood.py | 72 ++++++++++++++++------- 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index 56e942371..1c7f36aed 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -17,6 +17,7 @@ ) import logging from typing import Sequence +from abc import abstractmethod class SingleEventLikelihood(LikelihoodBase): @@ -43,12 +44,25 @@ def __init__( self.waveform = waveform self.fixed_parameters = fixed_parameters if fixed_parameters is not None else {} + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + """Evaluate the likelihood for a given set of parameters. + + This is a template method that calls the core likelihood evaluation method + """ + params.update(self.fixed_parameters) + return self._likelihood(params, data) + + @abstractmethod + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + """Core likelihood evaluation method to be implemented by subclasses.""" + raise NotImplementedError("Subclasses must implement this method.") + class ZeroLikelihood(LikelihoodBase): def __init__(self): pass - def evaluate(self, params: dict[str, Float], data: dict) -> Float: + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: return 0.0 @@ -125,6 +139,11 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: params.update(self.fixed_parameters) params["trigger_time"] = self.trigger_time params["gmst"] = self.gmst + log_likelihood = self._likelihood(params, data) + return log_likelihood + + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + """Core likelihood evaluation method for frequency-domain transient events.""" waveform_sky = self.waveform(self.frequencies, params) log_likelihood = 0.0 df = ( @@ -218,26 +237,28 @@ def __init__( ) def evaluate(self, params: dict[str, Float], data: dict) -> Float: - """Evaluate the time-marginalized likelihood for a given set of parameters. + params.update(self.fixed_parameters) + params["trigger_time"] = self.trigger_time + params["gmst"] = self.gmst + params["t_c"] = 0.0 # Fixing t_c to 0 for time marginalization + log_likelihood = self._likelihood(params, data) + return log_likelihood + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + """Evaluate the time-marginalized likelihood for a given set of parameters. Computes the log-likelihood marginalized over coalescence time by: - Calculating the frequency-domain inner product between the model and data for each detector. - Padding the inner product array to cover the full frequency range. - Applying FFT to obtain the likelihood as a function of coalescence time. - Restricting the FFT output to the specified `tc_range`. - Marginalizing using logsumexp over the allowed coalescence times. - Args: params (dict[str, Float]): Dictionary of model parameters. data (dict): Dictionary containing data (not used in this implementation). - Returns: Float: The marginalized log-likelihood value. """ - params.update(self.fixed_parameters) - params["trigger_time"] = self.trigger_time - params["gmst"] = self.gmst - params["t_c"] = 0.0 # Fixing t_c to 0 for time marginalization + log_likelihood = 0.0 complex_h_inner_d = jnp.zeros_like(self.detectors[0].sliced_frequencies) df = ( @@ -256,22 +277,22 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: complex_h_inner_d += 4 * h_dec * jnp.conj(ifo_data) / psd * df optimal_SNR = inner_product(h_dec, h_dec, psd, df) log_likelihood += -optimal_SNR / 2 - + # Padding the complex_h_inner_d to cover the full frequency range complex_h_inner_d_positive_f = jnp.concatenate( (self.pad_low, complex_h_inner_d, self.pad_high) ) - + # FFT to obtain exp(-i2πf t_c) as a function of t_c fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") - + # Restrict FFT output to the allowed tc_range, set others to -inf fft_h_inner_d = jnp.where( (self.tc_array > self.tc_range[0]) & (self.tc_array < self.tc_range[1]), fft_h_inner_d.real, jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, ) - + # Marginalize over t_c using logsumexp log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(self.tc_array)) return log_likelihood @@ -281,12 +302,17 @@ class PhaseMarginalizedLikelihoodFD(BaseTransientLikelihoodFD): """This has not been tested by a human yet.""" def evaluate(self, params: dict[str, Float], data: dict) -> Float: - log_likelihood = 0.0 - complex_d_inner_h = 0.0 + 0.0j params.update(self.fixed_parameters) params["phase_c"] = 0.0 # Fixing phase_c to 0 for phase marginalization params["trigger_time"] = self.trigger_time params["gmst"] = self.gmst + log_likelihood = self._likelihood(params, data) + return log_likelihood + + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + log_likelihood = 0.0 + complex_d_inner_h = 0.0 + 0.0j + waveform_sky = self.waveform(self.frequencies, params) df = ( self.detectors[0].sliced_frequencies[1] @@ -309,21 +335,24 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: class PhaseTimeMarginalizedLikelihoodFD(TimeMarginalizedLikelihoodFD): """This has not been tested by a human yet.""" - + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + params.update(self.fixed_parameters) + params["trigger_time"] = self.trigger_time + params["gmst"] = self.gmst + params["t_c"] = 0.0 # Fix t_c for marginalization + params["phase_c"] = 0.0 + return self._likelihood(params, data) + + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: # Refactored: use self.detectors, self.frequencies, self.tc_array, self.pad_low, self.pad_high, self.tc_range log_likelihood = 0.0 complex_h_inner_d = 0.0 + 0.0j - params.update(self.fixed_parameters) df = ( self.detectors[0].sliced_frequencies[1] - self.detectors[0].sliced_frequencies[0] ) - params["trigger_time"] = self.trigger_time - params["gmst"] = self.gmst - params["t_c"] = 0.0 # Fix t_c for marginalization - params["phase_c"] = 0.0 waveform_sky = self.waveform(self.frequencies, params) for ifo in self.detectors: freqs, ifo_data, psd = ( @@ -407,8 +436,7 @@ def __init__( prior: Optional[Prior] = None, sample_transforms: list[BijectiveTransform] = [], likelihood_transforms: list[NtoMTransform] = [], - **kwargs, - ) -> None: + super().__init__(detectors, waveform, f_min, f_max, trigger_time) logging.info("Initializing heterodyned likelihood..") From 85a3c15758ef5e60376d85832ca5699673d2b5b3 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 14 Jul 2025 15:17:48 -0400 Subject: [PATCH 11/16] update heterodyne likelihood --- src/jimgw/core/single_event/likelihood.py | 244 ++++++---------------- 1 file changed, 61 insertions(+), 183 deletions(-) diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index 1c7f36aed..e0f15c1ae 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -384,19 +384,6 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(self.tc_array)) return log_likelihood - -# class HeterodynedTransientLikelihoodFD(SingleEventLikelihood): -# pass - - -class HeterodynedTimeMarginalizedLikelihoodFD(SingleEventLikelihood): - pass - - -class HeterodynedPhaseMarginalizedLikelihoodFD(SingleEventLikelihood): - pass - - class HeterodynedTransientLikelihoodFD(BaseTransientLikelihoodFD): n_bins: int # Number of bins to use for the likelihood ref_params: dict # Reference parameters for the likelihood @@ -427,8 +414,8 @@ def __init__( waveform: Waveform, f_min: Float = 0, f_max: Float = float("inf"), - n_bins: int = 100, trigger_time: float = 0, + n_bins: int = 100, popsize: int = 100, n_steps: int = 2000, ref_params: dict = {}, @@ -436,6 +423,7 @@ def __init__( prior: Optional[Prior] = None, sample_transforms: list[BijectiveTransform] = [], likelihood_transforms: list[NtoMTransform] = [], + ): super().__init__(detectors, waveform, f_min, f_max, trigger_time) @@ -445,44 +433,6 @@ def __init__( if reference_waveform is None: reference_waveform = waveform - self.kwargs = kwargs - if "marginalization" in self.kwargs: - marginalization = self.kwargs["marginalization"] - assert marginalization in [ - "phase", - ], "Heterodyned likelihood only support phase marginalzation" - self.marginalization = marginalization - if self.marginalization == "phase": - self.param_func = lambda x: {**x, "phase_c": 0.0} - self.likelihood_function = phase_marginalized_likelihood - self.rb_likelihood_function = ( - phase_marginalized_relative_binning_likelihood - ) - logging.info("Marginalizing over phase") - else: - self.param_func = lambda x: x - self.likelihood_function = original_likelihood - self.rb_likelihood_function = original_relative_binning_likelihood - self.marginalization = "" - - # the fixing_parameters is expected to be a dictionary - # with key as parameter name and value is the fixed value - # e.g. {'M_c': 1.1975, 't_c': 0} - if "fixing_parameters" in self.kwargs: - fixing_parameters = self.kwargs["fixing_parameters"] - logging.info(f"Parameters are fixed {fixing_parameters}") - # check for conflict with the marginalization - assert not ( - "t_c" in fixing_parameters and "time" in self.marginalization - ), "Cannot have t_c fixed while marginalizing over t_c" - assert not ( - "phase_c" in fixing_parameters and "phase" in self.marginalization - ), "Cannot have phase_c fixed while marginalizing over phase_c" - # if the same key exists in both dictionary, - # the later one will overwrite the former one - self.fixing_func = lambda x: {**x, **fixing_parameters} - else: - self.fixing_func = lambda x: x # Get the original frequency grid frequency_original = self.frequencies @@ -522,10 +472,6 @@ def __init__( self.ref_params["trigger_time"] = self.trigger_time self.ref_params["gmst"] = self.gmst - # adjust the params due to different marginalzation scheme - self.ref_params = self.param_func(self.ref_params) - # adjust the params due to fixing parameters - self.ref_params = self.fixing_func(self.ref_params) self.waveform_low_ref = {} self.waveform_center_ref = {} @@ -589,57 +535,39 @@ def __init__( self.B1_array[detector.name] = B1[mask_heterodyne_center] def evaluate(self, params: dict[str, Float], data: dict) -> Float: - frequencies_low = self.freq_grid_low - frequencies_center = self.freq_grid_center params["trigger_time"] = self.trigger_time params["gmst"] = self.gmst - # adjust the params due to different marginalzation scheme - params = self.param_func(params) - # adjust the params due to fixing parameters - params = self.fixing_func(params) + params.update(self.fixed_parameters) # evaluate the waveforms as usual + return self._likelihood(params, data) + + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + frequencies_low = self.freq_grid_low + frequencies_center = self.freq_grid_center + log_likelihood = 0.0 waveform_sky_low = self.waveform(frequencies_low, params) waveform_sky_center = self.waveform(frequencies_center, params) - log_likelihood = self.rb_likelihood_function( - params, - self.A0_array, - self.A1_array, - self.B0_array, - self.B1_array, - waveform_sky_low, - waveform_sky_center, - self.waveform_low_ref, - self.waveform_center_ref, - self.detectors, - frequencies_low, - frequencies_center, - **self.kwargs, - ) + for detector in self.detectors: + waveform_low = detector.fd_response(frequencies_low, waveform_sky_low, params) + waveform_center = detector.fd_response( + frequencies_low, waveform_sky_center, params + ) + + r0 = waveform_center / self.waveform_center_ref[detector.name] + r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( + frequencies_low - frequencies_center + ) + match_filter_SNR = jnp.sum( + self.A0_array[detector.name] * r0.conj() + self.A1_array[detector.name] * r1.conj() + ) + optimal_SNR = jnp.sum( + self.B0_array[detector.name] * jnp.abs(r0) ** 2 + + 2 * self.B1_array[detector.name] * (r0 * r1.conj()).real + ) + log_likelihood += (match_filter_SNR - optimal_SNR / 2).real + return log_likelihood - def evaluate_original( - self, params: dict[str, Float], data: dict - ) -> ( - Float - ): # TODO: Test whether we need to pass data in or with class changes is fine. - """ - Evaluate the likelihood for a given set of parameters. - """ - params["trigger_time"] = self.trigger_time - params["gmst"] = self.gmst - # adjust the params due to different marginalzation scheme - params = self.param_func(params) - # adjust the params due to fixing parameters - params = self.fixing_func(params) - # evaluate the waveform as usual - waveform_sky = self.waveform(self.frequencies, params) - return self.likelihood_function( - params, - waveform_sky, - self.detectors, # type: ignore - **self.kwargs, - ) - @staticmethod def max_phase_diff( f: Float[Array, " n_freq"], @@ -748,7 +676,7 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: named_params = transform.backward(named_params) for transform in likelihood_transforms: named_params = transform.forward(named_params) - return -self.evaluate_original(named_params, data) + return -super(HeterodynedTransientLikelihoodFD, self).evaluate(named_params, data) print("Starting the optimizer") @@ -794,7 +722,37 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: for transform in likelihood_transforms: named_params = transform.forward(named_params) return named_params + +class HeterodynedPhaseMarginalizedLikelihoodFD(HeterodynedTransientLikelihoodFD): + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + frequencies_low = self.freq_grid_low + frequencies_center = self.freq_grid_center + waveform_sky_low = self.waveform(frequencies_low, params) + waveform_sky_center = self.waveform(frequencies_center, params) + log_likelihood = 0.0 + complex_d_inner_h = 0.0 + + for detector in self.detectors: + waveform_low = detector.fd_response(frequencies_low, waveform_sky_low, params) + waveform_center = detector.fd_response( + frequencies_center, waveform_sky_center, params + ) + r0 = waveform_center / self.waveform_center_ref[detector.name] + r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( + frequencies_low - frequencies_center + ) + complex_d_inner_h += jnp.sum( + self.A0_array[detector.name] * r0.conj() + self.A1_array[detector.name] * r1.conj() + ) + optimal_SNR = jnp.sum( + self.B0_array[detector.name] * jnp.abs(r0) ** 2 + + 2 * self.B1_array[detector.name] * (r0 * r1.conj()).real + ) + log_likelihood += -optimal_SNR.real / 2 + log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) + + return log_likelihood likelihood_presets = { "BaseTransientLikelihoodFD": BaseTransientLikelihoodFD, @@ -802,85 +760,5 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: "PhaseMarginalizedLikelihoodFD": PhaseMarginalizedLikelihoodFD, "PhaseTimeMarginalizedLikelihoodFD": PhaseTimeMarginalizedLikelihoodFD, "HeterodynedTransientLikelihoodFD": HeterodynedTransientLikelihoodFD, -} - - -def original_relative_binning_likelihood( - params, - A0_array, - A1_array, - B0_array, - B1_array, - waveform_sky_low, - waveform_sky_center, - waveform_low_ref, - waveform_center_ref, - detectors, - frequencies_low, - frequencies_center, - **kwargs, -): - log_likelihood = 0.0 - - for detector in detectors: - waveform_low = detector.fd_response(frequencies_low, waveform_sky_low, params) - waveform_center = detector.fd_response( - frequencies_low, waveform_sky_center, params - ) - - r0 = waveform_center / waveform_center_ref[detector.name] - r1 = (waveform_low / waveform_low_ref[detector.name] - r0) / ( - frequencies_low - frequencies_center - ) - match_filter_SNR = jnp.sum( - A0_array[detector.name] * r0.conj() + A1_array[detector.name] * r1.conj() - ) - optimal_SNR = jnp.sum( - B0_array[detector.name] * jnp.abs(r0) ** 2 - + 2 * B1_array[detector.name] * (r0 * r1.conj()).real - ) - log_likelihood += (match_filter_SNR - optimal_SNR / 2).real - - return log_likelihood - - -def phase_marginalized_relative_binning_likelihood( - params, - A0_array, - A1_array, - B0_array, - B1_array, - waveform_sky_low, - waveform_sky_center, - waveform_low_ref, - waveform_center_ref, - detectors, - frequencies_low, - frequencies_center, - **kwargs, -): - log_likelihood = 0.0 - complex_d_inner_h = 0.0 - - for detector in detectors: - waveform_low = detector.fd_response(frequencies_low, waveform_sky_low, params) - waveform_center = detector.fd_response( - frequencies_center, waveform_sky_center, params - ) - - r0 = waveform_center / waveform_center_ref[detector.name] - r1 = (waveform_low / waveform_low_ref[detector.name] - r0) / ( - frequencies_low - frequencies_center - ) - complex_d_inner_h += jnp.sum( - A0_array[detector.name] * r0.conj() + A1_array[detector.name] * r1.conj() - ) - optimal_SNR = jnp.sum( - B0_array[detector.name] * jnp.abs(r0) ** 2 - + 2 * B1_array[detector.name] * (r0 * r1.conj()).real - ) - log_likelihood += -optimal_SNR.real / 2 - - log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) - - return log_likelihood + "PhaseMarginalizedHeterodynedLikelihoodFD": HeterodynedPhaseMarginalizedLikelihoodFD, +} \ No newline at end of file From e38b8e7ec0c71797121e5f3e20b3965b26abb569 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 14 Jul 2025 15:18:31 -0400 Subject: [PATCH 12/16] ruff formatting --- src/jimgw/core/prior.py | 12 ++--- src/jimgw/core/single_event/data.py | 24 ++++----- src/jimgw/core/single_event/detector.py | 6 +-- src/jimgw/core/single_event/likelihood.py | 64 +++++++++++++---------- src/jimgw/run/cli/execute_single_run.py | 6 +-- src/jimgw/run/run_manager.py | 6 +-- 6 files changed, 64 insertions(+), 54 deletions(-) diff --git a/src/jimgw/core/prior.py b/src/jimgw/core/prior.py index 6feb1bcb9..3f44eedbe 100644 --- a/src/jimgw/core/prior.py +++ b/src/jimgw/core/prior.py @@ -145,9 +145,9 @@ def __repr__(self): def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) - assert self.n_dims == 1, ( - "StandardNormalDistribution needs to be 1D distributions" - ) + assert ( + self.n_dims == 1 + ), "StandardNormalDistribution needs to be 1D distributions" def sample( self, rng_key: PRNGKeyArray, n_samples: int @@ -193,9 +193,9 @@ def __init__( base_prior: list[Prior], transforms: list[BijectiveTransform], ): - assert len(base_prior) == 1, ( - "SequentialTransformPrior only takes one base prior" - ) + assert ( + len(base_prior) == 1 + ), "SequentialTransformPrior only takes one base prior" self.base_prior = base_prior self.transforms = transforms self.parameter_names = base_prior[0].parameter_names diff --git a/src/jimgw/core/single_event/data.py b/src/jimgw/core/single_event/data.py index a93f402d8..998fa35f3 100644 --- a/src/jimgw/core/single_event/data.py +++ b/src/jimgw/core/single_event/data.py @@ -295,9 +295,9 @@ def from_fd( Returns: Data: Data object with the Fourier and time domain data. """ - assert len(fd) == len(frequencies), ( - "Frequency and data arrays must have the same length" - ) + assert len(fd) == len( + frequencies + ), "Frequency and data arrays must have the same length" # form full frequency array delta_f = frequencies[1] - frequencies[0] fnyq = frequencies[-1] @@ -315,9 +315,9 @@ def from_fd( delta_t = 1 / (2 * fnyq) data_td_full = jnp.fft.irfft(data_fd_full) / delta_t # check frequencies - assert jnp.allclose(f, jnp.fft.rfftfreq(len(data_td_full), delta_t)), ( - "Generated frequencies do not match the input frequencies" - ) + assert jnp.allclose( + f, jnp.fft.rfftfreq(len(data_td_full), delta_t) + ), "Generated frequencies do not match the input frequencies" # create a Data object data = cls(data_td_full, delta_t, epoch=epoch, name=name) data.fd = data_fd_full @@ -326,9 +326,9 @@ def from_fd( # represents the input FD data. d_new, f_new = data.frequency_slice(frequencies[0], frequencies[-1]) assert all(jnp.equal(d_new, fd)), "Data do not match after slicing" - assert all(jnp.equal(f_new, frequencies)), ( - "Frequencies do not match after slicing" - ) + assert all( + jnp.equal(f_new, frequencies) + ), "Frequencies do not match after slicing" return data @classmethod @@ -448,9 +448,9 @@ def __init__( # NOTE: Are we sure the values and frequencies start from 0? self.values = values self.frequencies = frequencies - assert self.n_freq == len(self.frequencies), ( - "Values and frequencies must have the same length" - ) + assert self.n_freq == len( + self.frequencies + ), "Values and frequencies must have the same length" self.name = name or "" def __repr__(self) -> str: diff --git a/src/jimgw/core/single_event/detector.py b/src/jimgw/core/single_event/detector.py index de3a78412..3eb9d93cb 100644 --- a/src/jimgw/core/single_event/detector.py +++ b/src/jimgw/core/single_event/detector.py @@ -142,9 +142,9 @@ def set_frequency_bounds( data, freqs_1 = self.data.frequency_slice(*self.frequency_bounds) psd, freqs_2 = self.psd.frequency_slice(*self.frequency_bounds) - assert all(freqs_1 == freqs_2), ( - f"The {self.name} data and PSD must have same frequencies" - ) + assert all( + freqs_1 == freqs_2 + ), f"The {self.name} data and PSD must have same frequencies" self._sliced_frequencies = freqs_1 self._sliced_fd_data = data diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index e0f15c1ae..f9e0d2f05 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -2,7 +2,7 @@ import jax.numpy as jnp from flowMC.strategy.optimization import AdamOptimization from jax.scipy.special import logsumexp -from jaxtyping import Array, Float, Complex +from jaxtyping import Array, Float from typing import Optional from scipy.interpolate import interp1d from jimgw.core.utils import log_i0 @@ -46,12 +46,12 @@ def __init__( def evaluate(self, params: dict[str, Float], data: dict) -> Float: """Evaluate the likelihood for a given set of parameters. - + This is a template method that calls the core likelihood evaluation method """ params.update(self.fixed_parameters) return self._likelihood(params, data) - + @abstractmethod def _likelihood(self, params: dict[str, Float], data: dict) -> Float: """Core likelihood evaluation method to be implemented by subclasses.""" @@ -116,9 +116,9 @@ def __init__( detector.set_frequency_bounds(f_min, f_max) _frequencies.append(detector.sliced_frequencies) _frequencies = jnp.array(_frequencies) - assert jnp.all(jnp.array(_frequencies)[:-1] == jnp.array(_frequencies)[1:]), ( - "The frequency arrays are not all the same." - ) + assert jnp.all( + jnp.array(_frequencies)[:-1] == jnp.array(_frequencies)[1:] + ), "The frequency arrays are not all the same." self.frequencies = _frequencies[0] self.trigger_time = trigger_time self.gmst = compute_gmst(self.trigger_time) @@ -141,7 +141,7 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: params["gmst"] = self.gmst log_likelihood = self._likelihood(params, data) return log_likelihood - + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: """Core likelihood evaluation method for frequency-domain transient events.""" waveform_sky = self.waveform(self.frequencies, params) @@ -221,9 +221,9 @@ def __init__( super().__init__( detectors, waveform, fixed_parameters, f_min, f_max, trigger_time ) - assert "t_c" not in self.fixed_parameters, ( - "Cannot have t_c fixed while marginalizing over t_c" - ) + assert ( + "t_c" not in self.fixed_parameters + ), "Cannot have t_c fixed while marginalizing over t_c" self.tc_range = tc_range fs = self.detectors[0].data.sampling_frequency duration = self.detectors[0].data.duration @@ -277,22 +277,22 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: complex_h_inner_d += 4 * h_dec * jnp.conj(ifo_data) / psd * df optimal_SNR = inner_product(h_dec, h_dec, psd, df) log_likelihood += -optimal_SNR / 2 - + # Padding the complex_h_inner_d to cover the full frequency range complex_h_inner_d_positive_f = jnp.concatenate( (self.pad_low, complex_h_inner_d, self.pad_high) ) - + # FFT to obtain exp(-i2πf t_c) as a function of t_c fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") - + # Restrict FFT output to the allowed tc_range, set others to -inf fft_h_inner_d = jnp.where( (self.tc_array > self.tc_range[0]) & (self.tc_array < self.tc_range[1]), fft_h_inner_d.real, jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, ) - + # Marginalize over t_c using logsumexp log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(self.tc_array)) return log_likelihood @@ -335,7 +335,7 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: class PhaseTimeMarginalizedLikelihoodFD(TimeMarginalizedLikelihoodFD): """This has not been tested by a human yet.""" - + def evaluate(self, params: dict[str, Float], data: dict) -> Float: params.update(self.fixed_parameters) params["trigger_time"] = self.trigger_time @@ -384,6 +384,7 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(self.tc_array)) return log_likelihood + class HeterodynedTransientLikelihoodFD(BaseTransientLikelihoodFD): n_bins: int # Number of bins to use for the likelihood ref_params: dict # Reference parameters for the likelihood @@ -433,7 +434,6 @@ def __init__( if reference_waveform is None: reference_waveform = waveform - # Get the original frequency grid frequency_original = self.frequencies # Get the grid of the relative binning scheme (contains the final endpoint) @@ -548,24 +548,27 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: waveform_sky_low = self.waveform(frequencies_low, params) waveform_sky_center = self.waveform(frequencies_center, params) for detector in self.detectors: - waveform_low = detector.fd_response(frequencies_low, waveform_sky_low, params) + waveform_low = detector.fd_response( + frequencies_low, waveform_sky_low, params + ) waveform_center = detector.fd_response( frequencies_low, waveform_sky_center, params ) - + r0 = waveform_center / self.waveform_center_ref[detector.name] r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( frequencies_low - frequencies_center ) match_filter_SNR = jnp.sum( - self.A0_array[detector.name] * r0.conj() + self.A1_array[detector.name] * r1.conj() + self.A0_array[detector.name] * r0.conj() + + self.A1_array[detector.name] * r1.conj() ) optimal_SNR = jnp.sum( self.B0_array[detector.name] * jnp.abs(r0) ** 2 + 2 * self.B1_array[detector.name] * (r0 * r1.conj()).real ) log_likelihood += (match_filter_SNR - optimal_SNR / 2).real - + return log_likelihood @staticmethod @@ -676,7 +679,9 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: named_params = transform.backward(named_params) for transform in likelihood_transforms: named_params = transform.forward(named_params) - return -super(HeterodynedTransientLikelihoodFD, self).evaluate(named_params, data) + return -super(HeterodynedTransientLikelihoodFD, self).evaluate( + named_params, data + ) print("Starting the optimizer") @@ -722,7 +727,8 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: for transform in likelihood_transforms: named_params = transform.forward(named_params) return named_params - + + class HeterodynedPhaseMarginalizedLikelihoodFD(HeterodynedTransientLikelihoodFD): def _likelihood(self, params: dict[str, Float], data: dict) -> Float: frequencies_low = self.freq_grid_low @@ -731,9 +737,11 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: waveform_sky_center = self.waveform(frequencies_center, params) log_likelihood = 0.0 complex_d_inner_h = 0.0 - + for detector in self.detectors: - waveform_low = detector.fd_response(frequencies_low, waveform_sky_low, params) + waveform_low = detector.fd_response( + frequencies_low, waveform_sky_low, params + ) waveform_center = detector.fd_response( frequencies_center, waveform_sky_center, params ) @@ -742,7 +750,8 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: frequencies_low - frequencies_center ) complex_d_inner_h += jnp.sum( - self.A0_array[detector.name] * r0.conj() + self.A1_array[detector.name] * r1.conj() + self.A0_array[detector.name] * r0.conj() + + self.A1_array[detector.name] * r1.conj() ) optimal_SNR = jnp.sum( self.B0_array[detector.name] * jnp.abs(r0) ** 2 @@ -751,9 +760,10 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: log_likelihood += -optimal_SNR.real / 2 log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) - + return log_likelihood + likelihood_presets = { "BaseTransientLikelihoodFD": BaseTransientLikelihoodFD, "TimeMarginalizedLikelihoodFD": TimeMarginalizedLikelihoodFD, @@ -761,4 +771,4 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: "PhaseTimeMarginalizedLikelihoodFD": PhaseTimeMarginalizedLikelihoodFD, "HeterodynedTransientLikelihoodFD": HeterodynedTransientLikelihoodFD, "PhaseMarginalizedHeterodynedLikelihoodFD": HeterodynedPhaseMarginalizedLikelihoodFD, -} \ No newline at end of file +} diff --git a/src/jimgw/run/cli/execute_single_run.py b/src/jimgw/run/cli/execute_single_run.py index ebe06a9fa..e9d8c3f79 100644 --- a/src/jimgw/run/cli/execute_single_run.py +++ b/src/jimgw/run/cli/execute_single_run.py @@ -17,9 +17,9 @@ ) args = parser.parse_args() - assert args.run_definition.endswith(".yaml"), ( - "Run definition file must be a YAML file." - ) + assert args.run_definition.endswith( + ".yaml" + ), "Run definition file must be a YAML file." definitions_name = yaml.safe_load(open(args.run_definition))["definition_name"] diff --git a/src/jimgw/run/run_manager.py b/src/jimgw/run/run_manager.py index 03fd063ce..a569c7bdb 100644 --- a/src/jimgw/run/run_manager.py +++ b/src/jimgw/run/run_manager.py @@ -24,9 +24,9 @@ def __init__(self, run: RunDefinition | str): else: logging.ERROR("Run object or path not given.") - assert isinstance(run, RunDefinition), ( - "Run object or path not given. Please provide a Run object or a path to a serialized Run object." - ) + assert isinstance( + run, RunDefinition + ), "Run object or path not given. Please provide a Run object or a path to a serialized Run object." # Initialize the jim objects needed for the run run.initialize_jim_objects() From d69c0d264dbf70fbae56a9e396f953a7f0dde336 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 14 Jul 2025 15:34:08 -0400 Subject: [PATCH 13/16] update run interface --- src/jimgw/core/single_event/likelihood.py | 5 ++- .../run/library/IMRPhenomPv2_standard_cbc.py | 38 +++++++++---------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index f9e0d2f05..d8672331f 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -61,8 +61,9 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: class ZeroLikelihood(LikelihoodBase): def __init__(self): pass - - def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + """Evaluate the likelihood, which is always zero.""" return 0.0 diff --git a/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py b/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py index fdf2fef55..4f2234778 100644 --- a/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py +++ b/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py @@ -13,7 +13,7 @@ from jimgw.core.single_event.data import Data, PowerSpectrum from jimgw.core.single_event.detector import get_detector_preset -from jimgw.core.single_event.likelihood import TransientLikelihoodFD, ZeroLikelihood +from jimgw.core.single_event.likelihood import BaseTransientLikelihoodFD, ZeroLikelihood from jimgw.core.single_event.waveform import RippleIMRPhenomPv2 from jimgw.core.transforms import BoundToUnbound, BijectiveTransform, NtoMTransform from jimgw.core.single_event.transforms import ( @@ -22,6 +22,7 @@ MassRatioToSymmetricMassRatioTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform, + GeocentricArrivalTimeToDetectorArrivalTimeTransform ) from typing import Optional, Sequence, Self @@ -54,7 +55,7 @@ def __init__( max_s2: float, iota_range: tuple[float, float], dL_range: tuple[float, float], - # t_c_range: tuple[float, float], + t_c_range: tuple[float, float], phase_c_range: tuple[float, float], psi_range: tuple[float, float], ra_range: tuple[float, float], @@ -69,7 +70,7 @@ def __init__( self.max_s2 = max_s2 self.iota_range = iota_range self.dL_range = dL_range - # self.t_c_range = t_c_range + self.t_c_range = t_c_range self.phase_c_range = phase_c_range self.psi_range = psi_range self.ra_range = ra_range @@ -85,7 +86,7 @@ def initialize_jim_objects(self): def initialize_likelihood( self, local_data_prefix: Optional[str] = None - ) -> TransientLikelihoodFD: + ) -> BaseTransientLikelihoodFD: logging.info("Initializing likelihood...") gps = self.gps @@ -123,13 +124,12 @@ def initialize_likelihood( waveform = RippleIMRPhenomPv2(f_ref=self.f_ref) - likelihood = TransientLikelihoodFD( + likelihood = BaseTransientLikelihoodFD( detectors=self.ifos, waveform=waveform, trigger_time=gps, f_min=self.f_min, f_max=self.f_max, - marginalization="time", ) return likelihood @@ -152,9 +152,9 @@ def initialize_prior(self) -> CombinePrior: 2.0, parameter_names=["d_L"], ) - # t_c_prior = UniformPrior( - # self.t_c_range[0], self.t_c_range[1], parameter_names=["t_c"] - # ) + t_c_prior = UniformPrior( + self.t_c_range[0], self.t_c_range[1], parameter_names=["t_c"] + ) phase_c_prior = UniformPrior( self.phase_c_range[0], self.phase_c_range[1], parameter_names=["phase_c"] ) @@ -173,7 +173,7 @@ def initialize_prior(self) -> CombinePrior: s2_prior, iota_prior, dL_prior, - # t_c_prior, + t_c_prior, phase_c_prior, psi_prior, ra_prior, @@ -202,12 +202,12 @@ def initialize_sample_transforms(self) -> Sequence[BijectiveTransform]: GeocentricArrivalPhaseToDetectorArrivalPhaseTransform( gps_time=self.gps, ifo=self.ifos[0] ), - # GeocentricArrivalTimeToDetectorArrivalTimeTransform( - # tc_min=self.t_c_range[0], - # tc_max=self.t_c_range[1], - # gps_time=self.gps, - # ifo=self.ifos[0], - # ), + GeocentricArrivalTimeToDetectorArrivalTimeTransform( + tc_min=self.t_c_range[0], + tc_max=self.t_c_range[1], + gps_time=self.gps, + ifo=self.ifos[0], + ), SkyFrameToDetectorFrameSkyPositionTransform( gps_time=self.gps, ifos=self.ifos ), @@ -289,7 +289,7 @@ def serialize(self, path: str = "./") -> dict: "max_s2": self.max_s2, "iota_range": list(self.iota_range), "dL_range": list(self.dL_range), - # "t_c_range": list(self.t_c_range), + "t_c_range": list(self.t_c_range), "phase_c_range": list(self.phase_c_range), "psi_range": list(self.psi_range), "ra_range": list(self.ra_range), @@ -322,7 +322,7 @@ def deserialize(cls, path: str) -> Self: max_s2=run_dict["max_s2"], iota_range=tuple(run_dict["iota_range"]), dL_range=tuple(run_dict["dL_range"]), - # t_c_range=tuple(run_dict["t_c_range"]), + t_c_range=tuple(run_dict["t_c_range"]), phase_c_range=tuple(run_dict["phase_c_range"]), psi_range=tuple(run_dict["psi_range"]), ra_range=tuple(run_dict["ra_range"]), @@ -354,7 +354,7 @@ def __init__(self): max_s2=0.99, iota_range=(0.0, jnp.pi), dL_range=(1.0, 10000.0), - # t_c_range=(-0.05, 0.05), + t_c_range=(-0.05, 0.05), phase_c_range=(0.0, 2 * jnp.pi), psi_range=(0.0, jnp.pi), ra_range=(0.0, 2 * jnp.pi), From e6ae3d64eacc1cded58757b91ce55fc4d6543605 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 14 Jul 2025 15:41:28 -0400 Subject: [PATCH 14/16] Add evaluate method to HeterodynedPhaseMarginalizedLikelihoodFD Expand unit tests for likelihood classes and add fixtures --- src/jimgw/core/single_event/likelihood.py | 9 + test/unit/test_likelhood.py | 192 ++++++++++++++-------- 2 files changed, 137 insertions(+), 64 deletions(-) diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index d8672331f..1e3afe266 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -731,6 +731,15 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: class HeterodynedPhaseMarginalizedLikelihoodFD(HeterodynedTransientLikelihoodFD): + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + params.update(self.fixed_parameters) + params["phase_c"] = 0.0 + params["trigger_time"] = self.trigger_time + params["gmst"] = self.gmst + log_likelihood = self._likelihood(params, data) + return log_likelihood + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: frequencies_low = self.freq_grid_low frequencies_center = self.freq_grid_center diff --git a/test/unit/test_likelhood.py b/test/unit/test_likelhood.py index 749e49e57..0bb2566e8 100644 --- a/test/unit/test_likelhood.py +++ b/test/unit/test_likelhood.py @@ -1,81 +1,145 @@ import pytest import numpy as np -from jimgw.core.single_event.likelihood import BaseTransientLikelihoodFD +from jimgw.core.single_event.likelihood import ( + SingleEventLikelihood, + ZeroLikelihood, + BaseTransientLikelihoodFD, + TimeMarginalizedLikelihoodFD, + PhaseMarginalizedLikelihoodFD, + PhaseTimeMarginalizedLikelihoodFD, + HeterodynedTransientLikelihoodFD, + HeterodynedPhaseMarginalizedLikelihoodFD, +) from jimgw.core.single_event.detector import get_H1, get_L1 from jimgw.core.single_event.waveform import RippleIMRPhenomD from jimgw.core.single_event.data import Data +@pytest.fixture +def detectors_and_waveform(): + gps = 1126259462.4 + start = gps - 2 + end = gps + 2 + psd_start = gps - 2048 + psd_end = gps + 2048 + fmin = 20.0 + fmax = 1024.0 + ifos = [get_H1(), get_L1()] + for ifo in ifos: + data = Data.from_gwosc(ifo.name, start, end) + ifo.set_data(data) + psd_data = Data.from_gwosc(ifo.name, psd_start, psd_end) + psd_fftlength = data.duration * data.sampling_frequency + ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength)) + waveform = RippleIMRPhenomD(f_ref=20.0) + return ifos, waveform, fmin, fmax, gps + + +def example_params(gmst): + return { + "M_c": 30.0, + "eta": 0.249, + "s1_z": 0.0, + "s2_z": 0.0, + "d_L": 400.0, + "phase_c": 0.0, + "t_c": 0.0, + "iota": 0.0, + "ra": 1.375, + "dec": -1.2108, + "gmst": gmst, + "psi": 0.0, + } + + +class TestZeroLikelihood: + def test_initialization_and_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = ZeroLikelihood() + assert isinstance(likelihood, ZeroLikelihood) + params = example_params(gps) + result = likelihood.evaluate(params, {}) + assert result == 0.0 + + class TestBaseTransientLikelihoodFD: - """ - Organized tests for BaseTransientLikelihoodFD using real detector and waveform implementations. - """ - - @pytest.fixture - def GW150912_likelihood(self) -> BaseTransientLikelihoodFD: - """ - Fixture to set up a realistic BaseTransientLikelihoodFD instance using GWOSC data and power spectral density. - """ - gps = 1126259462.4 - start = gps - 2 - end = gps + 2 - psd_start = gps - 2048 - psd_end = gps + 2048 - fmin = 20.0 - fmax = 1024.0 - - # Initialize detectors and set data/PSD - ifos = [get_H1(), get_L1()] - for ifo in ifos: - data = Data.from_gwosc(ifo.name, start, end) - ifo.set_data(data) - psd_data = Data.from_gwosc(ifo.name, psd_start, psd_end) - psd_fftlength = data.duration * data.sampling_frequency - ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength)) - - waveform = RippleIMRPhenomD(f_ref=20.0) + def test_initialization(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform likelihood = BaseTransientLikelihoodFD( - detectors=ifos, - waveform=waveform, - f_min=fmin, - f_max=fmax, - trigger_time=gps, + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps ) - return likelihood - - def test_likelihood_initialization( - self, GW150912_likelihood: BaseTransientLikelihoodFD - ): - """ - Test initialization and attributes of BaseTransientLikelihoodFD with realistic setup. - """ - likelihood = GW150912_likelihood assert isinstance(likelihood, BaseTransientLikelihoodFD) assert np.allclose(likelihood.frequencies, [20.0, (20.0 + 1024.0) / 2, 1024.0]) assert likelihood.trigger_time == 1126259462.4 assert hasattr(likelihood, "gmst") - def test_likelihood_evaluation( - self, GW150912_likelihood: BaseTransientLikelihoodFD - ): - """ - Test the evaluation of the likelihood with realistic parameters. - """ - likelihood = GW150912_likelihood - # Example parameters for testing - params = { - "M_c": 30.0, - "eta": 0.249, - "s1_z": 0.0, - "s2_z": 0.0, - "d_L": 400.0, - "phase_c": 0.0, - "t_c": 0.0, - "iota": 0.0, - "ra": 1.375, - "dec": -1.2108, - "gmst": likelihood.gmst, - "psi": 0.0, - } + def test_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = BaseTransientLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps + ) + params = example_params(likelihood.gmst) log_likelihood = likelihood.evaluate(params, {}) assert np.isfinite(log_likelihood), "Log likelihood should be finite" + + +class TestTimeMarginalizedLikelihoodFD: + def test_initialization_and_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = TimeMarginalizedLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps, tc_range=(-0.15, 0.15) + ) + assert isinstance(likelihood, TimeMarginalizedLikelihoodFD) + params = example_params(likelihood.gmst) + result = likelihood.evaluate(params, {}) + assert np.isfinite(result) + + +class TestPhaseMarginalizedLikelihoodFD: + def test_initialization_and_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = PhaseMarginalizedLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps + ) + assert isinstance(likelihood, PhaseMarginalizedLikelihoodFD) + params = example_params(likelihood.gmst) + result = likelihood.evaluate(params, {}) + assert np.isfinite(result) + + +class TestPhaseTimeMarginalizedLikelihoodFD: + def test_initialization_and_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = PhaseTimeMarginalizedLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps, tc_range=(-0.15, 0.15) + ) + assert isinstance(likelihood, PhaseTimeMarginalizedLikelihoodFD) + params = example_params(likelihood.gmst) + result = likelihood.evaluate(params, {}) + assert np.isfinite(result) + + +class TestHeterodynedTransientLikelihoodFD: + def test_initialization_and_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = HeterodynedTransientLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps, ref_params=example_params(gps) + ) + assert isinstance(likelihood, HeterodynedTransientLikelihoodFD) + params = example_params(likelihood.gmst) + result = likelihood.evaluate(params, {}) + assert np.isfinite(result) + + +class TestHeterodynedPhaseMarginalizedLikelihoodFD: + def test_initialization_and_likelihood(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = HeterodynedPhaseMarginalizedLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps, ref_params=example_params(gps) + ) + assert isinstance(likelihood, HeterodynedPhaseMarginalizedLikelihoodFD) + params = example_params(likelihood.gmst) + result = likelihood.evaluate(params, {}) + assert np.isfinite(result) + +# Need to add tests for running the heterodyned likelihood with different parameters From cbd7d820a326d04d5b383889e6847512e823f424 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Tue, 15 Jul 2025 14:34:53 +0800 Subject: [PATCH 15/16] Fix HeterodynedTransientLikelihoodFD initialization to include fixed_parameters --- src/jimgw/core/single_event/likelihood.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index 1e3afe266..b46d23f90 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -414,6 +414,7 @@ def __init__( self, detectors: Sequence[Detector], waveform: Waveform, + fixed_parameters: Optional[dict[str, Float]] = None, f_min: Float = 0, f_max: Float = float("inf"), trigger_time: float = 0, @@ -427,7 +428,7 @@ def __init__( likelihood_transforms: list[NtoMTransform] = [], ): - super().__init__(detectors, waveform, f_min, f_max, trigger_time) + super().__init__(detectors, waveform, fixed_parameters, f_min, f_max, trigger_time) logging.info("Initializing heterodyned likelihood..") From ecd0a923350b831951c3cd6b8149260a7223a6ad Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Thu, 17 Jul 2025 08:57:58 -0400 Subject: [PATCH 16/16] Refactor likelihood phase difference calculation and clean up code --- src/jimgw/core/single_event/likelihood.py | 29 ++++++++++--------- .../run/library/IMRPhenomPv2_standard_cbc.py | 2 +- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index b46d23f90..2b4a7bc06 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -61,7 +61,7 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: class ZeroLikelihood(LikelihoodBase): def __init__(self): pass - + def evaluate(self, params: dict[str, Float], data: dict) -> Float: """Evaluate the likelihood, which is always zero.""" return 0.0 @@ -428,7 +428,9 @@ def __init__( likelihood_transforms: list[NtoMTransform] = [], ): - super().__init__(detectors, waveform, fixed_parameters, f_min, f_max, trigger_time) + super().__init__( + detectors, waveform, fixed_parameters, f_min, f_max, trigger_time + ) logging.info("Initializing heterodyned likelihood..") @@ -575,7 +577,7 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: @staticmethod def max_phase_diff( - f: Float[Array, " n_freq"], + freqs: Float[Array, " n_freq"], f_low: float, f_high: float, chi: float = 1.0, @@ -583,9 +585,11 @@ def max_phase_diff( """ Compute the maximum phase difference between the frequencies in the array. + See Eq.(7) in arXiv:2302.05333. + Parameters ---------- - f: Float[Array, "n_dims"] + freqs: Float[Array, "n_freq"] Array of frequencies to be binned. f_low: float Lower frequency bound. @@ -596,18 +600,15 @@ def max_phase_diff( Returns ------- - Float[Array, "n_dims"] + Float[Array, "n_freq"] Maximum phase difference between the frequencies in the array. """ gamma = jnp.arange(-5, 6) / 3.0 - f_2D = jnp.broadcast_to(f, (f.size, gamma.size)) + # Promotes freqs to 2D with shape (n_freq, 10) for later f/f_star + freq_2D = jax.lax.broadcast_in_dim(freqs, (freqs.size, gamma.size), [0]) f_star = jnp.where(gamma >= 0, f_high, f_low) - return ( - 2 - * jnp.pi - * chi - * jnp.sum((f_2D / f_star) ** gamma * jnp.sign(gamma), axis=1) - ) + summand = (freq_2D / f_star) ** gamma * jnp.sign(gamma) + return 2 * jnp.pi * chi * jnp.sum(summand, axis=1) def make_binning_scheme( self, freqs: Float[Array, " n_freq"], n_bins: int, chi: float = 1 @@ -732,7 +733,7 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: class HeterodynedPhaseMarginalizedLikelihoodFD(HeterodynedTransientLikelihoodFD): - + def evaluate(self, params: dict[str, Float], data: dict) -> Float: params.update(self.fixed_parameters) params["phase_c"] = 0.0 @@ -740,7 +741,7 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: params["gmst"] = self.gmst log_likelihood = self._likelihood(params, data) return log_likelihood - + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: frequencies_low = self.freq_grid_low frequencies_center = self.freq_grid_center diff --git a/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py b/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py index 4f2234778..81498e405 100644 --- a/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py +++ b/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py @@ -22,7 +22,7 @@ MassRatioToSymmetricMassRatioTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform, - GeocentricArrivalTimeToDetectorArrivalTimeTransform + GeocentricArrivalTimeToDetectorArrivalTimeTransform, ) from typing import Optional, Sequence, Self