-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathode_utils.c
More file actions
90 lines (69 loc) · 2.15 KB
/
ode_utils.c
File metadata and controls
90 lines (69 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include "ode_utils.h"
#include <math.h>
#include <stdlib.h>
double eps = 1e-9;
double eval_func(const FunctionTable *table, const double x) {
const double step = (table->r_end - table->r_start) / table->size;
const double pos = (x - table->r_start) / step;
const int idx = (int)pos;
const double t = pos - idx;
return table->data[idx] * (1 - t) + table->data[idx + 1] * t;
}
double deriv(const UnaryFunc f, const double x) {
return (f(x + eps) - f(x - eps)) / (2 * eps);
}
LossPair var_solve(
const UnaryFunc lhs,
const UnaryFunc rhs,
const double start,
const double end,
const int amt) {
LossPair best = {0, INFINITY};
const double step = (end - start) / amt;
for (int i = 0; i <= amt; i++) {
const double x = start + i * step;
const double c_loss = fabs(rhs(x) - lhs(x));
if (c_loss < best.loss) {
best.val = x;
best.loss = c_loss;
}
}
return best;
}
FunctionTable* ode_solve(
const BiVarFunc f,
const Coord anchor,
const double r_start,
const double r_end,
const int r_amt) {
FunctionTable* table = malloc(sizeof(FunctionTable) + r_amt * sizeof(double));
table->size = r_amt;
table->r_start = r_start;
table->r_end = r_end;
const double h = (r_end - r_start) / (double) r_amt;
const int pos = lround((anchor.x - r_start) / h);
double x = anchor.x;
double y = anchor.y;
table->data[pos] = anchor.y;
for (int i = pos + 1; i < r_amt; i++) {
const double k1 = f(x, y);
const double k2 = f(x + h/2, y + h/2 * k1);
const double k3 = f(x + h/2, y + h/2 * k2);
const double k4 = f(x + h, y + h * k3);
y += h/6 * (k1 + 2*k2 + 2*k3 + k4);
table->data[i] = y;
x += h;
}
x = anchor.x;
y = anchor.y;
for (int i = pos - 1; i >= 0; i--) {
const double k1 = f(x, y);
const double k2 = f(x - h/2, y - h/2 * k1);
const double k3 = f(x - h/2, y - h/2 * k2);
const double k4 = f(x - h, y - h * k3);
y -= h/6 * (k1 + 2*k2 + 2*k3 + k4);
table->data[i] = y;
x -= h;
}
return table;
}