This repository was archived by the owner on Apr 19, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathda_methods.py
More file actions
136 lines (122 loc) · 4.1 KB
/
da_methods.py
File metadata and controls
136 lines (122 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Tuple, Callable, List
import numpy as np
import scipy
import jax.numpy as jnp
from dynamical_system import DynamicalSystem
Array = Union[np.ndarray, jnp.ndarray]
def da_loss_fn(
x0: Array,
y: Array,
dyn_sys: DynamicalSystem,
correlation_transform: Callable[[Array, str], Array],
physics_transform: Callable[[Array], Array],
observation_transform: Callable[[Array], Array],
) -> float:
"""
Data assimilation objective function.
Args:
x0: decorrelated initial state for system evolution.
y: observations to assimilate.
dyn_sys: DynamicalSystem.
correlation_transform: assigns spatial correlations to the initial state.
physics_transform: transforms physical trajectory after integration.
observation_transform: transforms observation data.
Returns:
Mean squared data assimilation loss, averaged over all grid dimensions and
all variables associated with each grid point.
"""
x0_shape = dyn_sys.state_dim
num_time_steps = y.shape[0]
x0 = x0.reshape(x0_shape)
x0_transformed = correlation_transform(x0, 'cor')
x = dyn_sys.integrate(x0_transformed, num_time_steps)
x_transformed = physics_transform(x)
y_transformed = observation_transform(y)
return jnp.mean(jnp.square(x_transformed - y_transformed))
def optimize_lbfgs_scipy(
f_value_and_grad: Callable[[Array], Tuple[float, Array]],
x: Array,
max_iter: int,
f_eval: Callable[[Array], float] = None,
) -> Tuple[
Array,
scipy.optimize.OptimizeResult,
List[float],
List[float],
]:
"""
Minimizes a function using scipy's L-BFGS.
Args:
f_value_and_grad: returns objective function value and its gradient.
x: initial value for the optimization.
max_iter: maximum iterations for optimizer.
f_eval: addional function to evaluate along the optimization path. If it
is 'None', the objective function will be evaluated.
Returns:
Tuple containing
(
argmin of the optimization,
optimization result object,
objective function values throught the optimization process,
evaluation function values throught the optimization process,
)
"""
fval_logger = [None] # logs last evaluation of objective function
eval_logger = [None] # logs last evaluation of evaluation function
fvals = []
eval_vals = []
original_shape = x.shape
def f_np_value_and_grad(x):
"""
Wraps the provided objective function and logs intermediate function
evaluations.
"""
x_jnp = jnp.asarray(x).reshape(original_shape)
val_jnp, grad_jnp = f_value_and_grad(x_jnp)
fval_logger[0] = val_jnp
if f_eval is None:
eval_logger[0] = val_jnp
else:
eval_logger[0] = f_eval(x_jnp)
return (
np.copy(val_jnp).astype(np.float64),
np.copy(grad_jnp).astype(np.float64).flatten(),
)
def callback(x):
"""
Gets called after every optimization step.
Needs to be called explicitly for logging as f_np_value_and_grad
is called more than once per iteration.
"""
fvals.append(fval_logger[0])
eval_vals.append(eval_logger[0])
x_np = np.copy(x).astype(np.float64).flatten()
options = {'maxiter': max_iter, 'gtol': 1e-12}
res = scipy.optimize.minimize(
f_np_value_and_grad,
x_np,
jac=True,
callback=callback,
method='L-BFGS-B',
options=options,
)
return (
jnp.asarray(res.x).reshape(original_shape),
res,
fvals,
eval_vals,
)