forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
integrators.py
1041 lines (929 loc) · 37.3 KB
/
integrators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing the implementations of the various numerical integrators.
Higher order methods mostly taken from [1].
References:
[1] Leimkuhler, Benedict and Sebastian Reich. Simulating hamiltonian dynamics.
Vol. 14. Cambridge university press, 2004.
[2] Forest, Etienne and Ronald D. Ruth. Fourth-order symplectic integration.
Physica D: Nonlinear Phenomena 43.1 (1990): 105-117.
[3] Blanes, Sergio and Per Christian Moan. Practical symplectic partitioned
Runge–Kutta and Runge–Kutta–Nyström methods. Journal of Computational and
Applied Mathematics 142.2 (2002): 313-330.
[4] McLachlan, Robert I. On the numerical integration of ordinary differential
equations by symmetric composition methods. SIAM Journal on Scientific
Computing 16.1 (1995): 151-168.
[5] Yoshida, Haruo. Construction of higher order symplectic integrators.
Physics letters A 150.5-7 (1990): 262-268.
[6] Süli, Endre; Mayers, David (2003), An Introduction to Numerical Analysis,
Cambridge University Press, ISBN 0-521-00794-1.
[7] Hairer, Ernst; Nørsett, Syvert Paul; Wanner, Gerhard (1993), Solving
ordinary differential equations I: Nonstiff problems, Berlin, New York:
Springer-Verlag, ISBN 978-3-540-56670-0.
"""
from typing import Callable, Dict, Optional, Sequence, Tuple, TypeVar, Union
from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space
import jax
from jax import lax
from jax.experimental import ode
import jax.numpy as jnp
import numpy as np
M = TypeVar("M")
TM = TypeVar("TM")
TimeInterval = Union[jnp.ndarray, Tuple[float, float]]
# _____ _
# / ____| | |
# | | __ ___ _ __ ___ _ __ __ _| |
# | | |_ |/ _ \ '_ \ / _ \ '__/ _` | |
# | |__| | __/ | | | __/ | | (_| | |
# \_____|\___|_| |_|\___|_| \__,_|_|
# _____ _ _ _
# |_ _| | | | | (_)
# | | _ __ | |_ ___ __ _ _ __ __ _| |_ _ ___ _ __
# | | | '_ \| __/ _ \/ _` | '__/ _` | __| |/ _ \| '_ \
# _| |_| | | | || __/ (_| | | | (_| | |_| | (_) | | | |
# |_____|_| |_|\__\___|\__, |_| \__,_|\__|_|\___/|_| |_|
# __/ |
# |___/
GeneralTangentFunction = Callable[
[
Optional[Union[float, jnp.ndarray]], # t
M # y
],
TM # dy_dt
]
GeneralIntegrator = Callable[
[
GeneralTangentFunction,
Optional[Union[float, jnp.ndarray]], # t
M, # y
jnp.ndarray, # dt
],
M # y_next
]
def solve_ivp_dt(
fun: GeneralTangentFunction,
y0: M,
t0: Union[float, jnp.ndarray],
dt: Union[float, jnp.ndarray],
method: Union[str, GeneralIntegrator],
num_steps: Optional[int] = None,
steps_per_dt: int = 1,
use_scan: bool = True,
ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None
) -> Tuple[jnp.ndarray, M]:
"""Solve an initial value problem for a system of ODEs using explicit method.
This function numerically integrates a system of ordinary differential
equations given an initial value::
dy / dt = f(t, y)
y(t0) = y0
Here t is a one-dimensional independent variable (time), y(t) is an
n-dimensional vector-valued function (state), and an n-dimensional
vector-valued function f(t, y) determines the differential equations.
The goal is to find y(t) approximately satisfying the differential
equations, given an initial value y(t0)=y0.
All of the solvers supported here are explicit and non-adaptive. This makes
them easy to run with a fixed amount of computation and ensures solutions are
easily differentiable.
Args:
fun: callable
Right-hand side of the system. The calling signature is ``fun(t, y)``.
Here `t` is a scalar representing the time instance. `y` can be any
type `M`, including a flat array, that is registered as a
pytree. In addition, there is a type denoted as `TM` that represents
the tangent space to `M`. It is assumed that any element of `TM` can be
multiplied by arrays and scalars, can be added to other `TM` instances
as well as they can be right added to an element of `M`, that is
add(M, TM) exists. The function should return an element of `TM` that
defines the time derivative of `y`.
y0: an instance of `M`
Initial state at `t_span[0]`.
t0: float or array.
The initial time point of integration.
dt: array
Array containing all consecutive increments in time, at which the integral
to be evaluated. The size of this array along axis 0 defines the number of
steps that the integrator would do.
method: string or `GeneralIntegrator`
The integrator method to use. Possible values for string are:
* general_euler - see `GeneralEuler`
* rk2 - see `RungaKutta2`
* rk4 - see `RungaKutta4`
* rk38 - see `RungaKutta38`
num_steps: Optional int.
If provided the `dt` will be treated as the same per step time interval,
applied for this many steps. In other words setting this argument is
equivalent to replicating `dt` num_steps times and stacking over axis=0.
steps_per_dt: int
This determines the overall step size. Between any two values of t_eval
the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt.
use_scan: bool
Whether for the loop to use `lax.scan` or a python loop
ode_int_kwargs: dict
Extra arguments to be passed to `ode.odeint` when method="adaptive"
Returns:
t: array
Time points at which the solution is evaluated.
y : an instance of M
Values of the solution at `t`.
"""
if method == "adaptive":
ndim = y0.q.ndim if isinstance(y0, phase_space.PhaseSpace) else y0.ndim
signs = jnp.asarray(jnp.sign(dt))
signs = signs.reshape([-1] + [1] * (ndim - 1))
if isinstance(dt, float) or dt.ndim == 0:
true_t_eval = t0 + dt * np.arange(1, num_steps + 1)
else:
true_t_eval = t0 + dt[None] * np.arange(1, num_steps + 1)[:, None]
if isinstance(dt, float):
dt = np.asarray(dt)
if isinstance(dt, np.ndarray) and dt.ndim > 0:
if np.all(np.abs(dt) != np.abs(dt[0])):
raise ValueError("Not all values of `dt` where the same.")
elif isinstance(dt, jnp.ndarray) and dt.ndim > 0:
raise ValueError("The code here works only when `dy_dt` is time "
"independent and `np.abs(dt)` is the same. For this we "
"allow calling this only with numpy (not jax.numpy) "
"arrays.")
dt: jnp.ndarray = jnp.abs(jnp.asarray(dt))
dt = dt.reshape([-1])[0]
t_eval = t0 + dt * np.arange(num_steps + 1)
outputs = ode.odeint(
func=lambda y_, t_: fun(None, y_) * signs,
y0=y0,
t=jnp.abs(t_eval - t0),
**(ode_int_kwargs or dict())
)
# Note that we do not return the initial point
return true_t_eval, jax.tree_map(lambda x: x[1:], outputs)
method = get_integrator(method)
if num_steps is not None:
dt = jnp.repeat(jnp.asarray(dt)[None], repeats=num_steps, axis=0)
t_eval = t0 + jnp.cumsum(dt, axis=0)
t0 = jnp.ones_like(t_eval[..., :1]) * t0
t = jnp.concatenate([t0, t_eval[..., :-1]], axis=-1)
def loop_body(y_: M, t_dt: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[M, M]:
t_, dt_ = t_dt
dt_: jnp.ndarray = dt_ / steps_per_dt
for _ in range(steps_per_dt):
y_ = method(fun, t_, y_, dt_)
t_ = t_ + dt_
return y_, y_
if use_scan:
return t_eval, lax.scan(loop_body, init=y0, xs=(t, dt))[1]
else:
y = [y0]
for t_and_dt_i in zip(t, dt):
y.append(loop_body(y[-1], t_and_dt_i)[0])
# Note that we do not return the initial point
return t_eval, jax.tree_map(lambda *args: jnp.stack(args, axis=0),
*y[1:])
def solve_ivp_dt_two_directions(
fun: GeneralTangentFunction,
y0: M,
t0: Union[float, jnp.ndarray],
dt: Union[float, jnp.ndarray],
method: Union[str, GeneralIntegrator],
num_steps_forward: int,
num_steps_backward: int,
include_y0: bool = True,
steps_per_dt: int = 1,
use_scan: bool = True,
ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None
) -> M:
"""Equivalent to `solve_ivp_dt` but you can specify unrolling the problem for a fixed number of steps in both time directions."""
yt = []
if num_steps_backward > 0:
yt_bck = solve_ivp_dt(
fun=fun,
y0=y0,
t0=t0,
dt=- dt,
method=method,
num_steps=num_steps_backward,
steps_per_dt=steps_per_dt,
use_scan=use_scan,
ode_int_kwargs=ode_int_kwargs
)[1]
yt.append(jax.tree_map(lambda x: jnp.flip(x, axis=0), yt_bck))
if include_y0:
yt.append(jax.tree_map(lambda x: x[None], y0))
if num_steps_forward > 0:
yt_fwd = solve_ivp_dt(
fun=fun,
y0=y0,
t0=t0,
dt=dt,
method=method,
num_steps=num_steps_forward,
steps_per_dt=steps_per_dt,
use_scan=use_scan,
ode_int_kwargs=ode_int_kwargs
)[1]
yt.append(yt_fwd)
if len(yt) > 1:
return jax.tree_map(lambda *a: jnp.concatenate(a, axis=0), *yt)
else:
return yt[0]
def solve_ivp_t_eval(
fun: GeneralTangentFunction,
t_span: TimeInterval,
y0: M,
method: Union[str, GeneralIntegrator],
t_eval: Optional[jnp.ndarray] = None,
steps_per_dt: int = 1,
use_scan: bool = True,
ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None
) -> Tuple[jnp.ndarray, M]:
"""Solve an initial value problem for a system of ODEs using an explicit method.
This function numerically integrates a system of ordinary differential
equations given an initial value::
dy / dt = f(t, y)
y(t0) = y0
Here t is a one-dimensional independent variable (time), y(t) is an
n-dimensional vector-valued function (state), and an n-dimensional
vector-valued function f(t, y) determines the differential equations.
The goal is to find y(t) approximately satisfying the differential
equations, given an initial value y(t0)=y0.
All of the solvers supported here are explicit and non-adaptive. This in
terms makes them easy to run with fixed amount of computation and
the solutions to be easily differentiable.
Args:
fun: callable
Right-hand side of the system. The calling signature is ``fun(t, y)``.
Here `t` is a scalar representing the time instance. `y` can be any
type `M`, including a flat array, that is registered as a
pytree. In addition, there is a type denoted as `TM` that represents
the tangent space to `M`. It is assumed that any element of `TM` can be
multiplied by arrays and scalars, can be added to other `TM` instances
as well as they can be right added to an element of `M`, that is
add(M, TM) exists. The function should return an element of `TM` that
defines the time derivative of `y`.
t_span: 2-tuple of floats
Interval of integration (t0, tf). The solver starts with t=t0 and
integrates until it reaches t=tf.
y0: an instance of `M`
Initial state at `t_span[0]`.
method: string or `GeneralIntegrator`
The integrator method to use. Possible values for string are:
* general_euler - see `GeneralEuler`
* rk2 - see `RungaKutta2`
* rk4 - see `RungaKutta4`
* rk38 - see `RungaKutta38`
t_eval: array or None.
Times at which to store the computed solution. Must be sorted and lie
within `t_span`. If None then t_eval = [t_span[-1]]
steps_per_dt: int
This determines the overall step size. Between any two values of t_eval
the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt.
use_scan: bool
Whether for the loop to use `lax.scan` or a python loop
ode_int_kwargs: dict
Extra arguments to be passed to `ode.odeint` when method="adaptive"
Returns:
t: array
Time points at which the solution is evaluated.
y : an instance of M
Values of the solution at `t`.
"""
# Check for t_eval
if t_eval is None:
t_eval = np.asarray([t_span[-1]])
if isinstance(t_span[0], float) and isinstance(t_span[1], float):
t_span = np.asarray(t_span)
elif isinstance(t_span[0], float) and isinstance(t_span[1], jnp.ndarray):
t_span = (np.full_like(t_span[1], t_span[0]), t_span[1])
t_span = np.stack(t_span, axis=0)
elif isinstance(t_span[1], float) and isinstance(t_span[0], jnp.ndarray):
t_span = (t_span[0], jnp.full_like(t_span[0], t_span[1]))
t_span = np.stack(t_span, axis=0)
else:
t_span = np.stack(t_span, axis=0)
def check_span(span, ts):
# Verify t_span and t_eval
if span[0] < span[1]:
# Forward in time
if not np.all(np.logical_and(span[0] <= ts, ts <= span[1])):
raise ValueError("Values in `t_eval` are not within `t_span`.")
if not np.all(ts[:-1] < ts[1:]):
raise ValueError("Values in `t_eval` are not properly sorted.")
else:
# Backward in time
if not np.all(np.logical_and(span[0] >= ts, ts >= span[1])):
raise ValueError("Values in `t_eval` are not within `t_span`.")
if not np.all(ts[:-1] > ts[1:]):
raise ValueError("Values in `t_eval` are not properly sorted.")
if t_span.ndim == 1:
check_span(t_span, t_eval)
elif t_span.ndim == 2:
if t_eval.ndim != 2:
raise ValueError("t_eval should have rank 2.")
for i in range(t_span.shape[1]):
check_span(t_span[:, i], t_eval[:, i])
t = np.concatenate([t_span[:1], t_eval[:-1]], axis=0)
return solve_ivp_dt(
fun=fun,
y0=y0,
t0=t_span[0],
dt=t_eval - t,
method=method,
steps_per_dt=steps_per_dt,
use_scan=use_scan,
ode_int_kwargs=ode_int_kwargs
)
class RungaKutta(GeneralIntegrator):
"""A general Runga-Kutta integrator defined using a Butcher tableau."""
def __init__(
self,
a_tableau: Sequence[Sequence[float]],
b_tableau: Sequence[float],
c_tableau: Sequence[float],
order: int):
if len(b_tableau) != len(c_tableau) + 1:
raise ValueError("The length of b_tableau should be exactly one more than"
" the length of c_tableau.")
if len(b_tableau) != len(a_tableau) + 1:
raise ValueError("The length of b_tableau should be exactly one more than"
" the length of a_tableau.")
self.a_tableau = a_tableau
self.b_tableau = b_tableau
self.c_tableau = c_tableau
self.order = order
def __call__(
self,
tangent_func: GeneralTangentFunction,
t: jnp.ndarray,
y: M,
dt: jnp.ndarray
) -> M: # pytype: disable=invalid-annotation
k = [tangent_func(t, y)]
zero = jax.tree_map(jnp.zeros_like, k[0])
# We always broadcast opposite to numpy (e.g. leading dims (batch) count)
if dt.ndim > 0:
dt = dt.reshape(dt.shape + (1,) * (y.ndim - dt.ndim))
if t.ndim > 0:
t = t.reshape(t.shape + (1,) * (y.ndim - t.ndim))
for c_n, a_n_row in zip(self.c_tableau, self.a_tableau):
t_n = t + dt * c_n
products = [a_i * k_i for a_i, k_i in zip(a_n_row, k) if a_i != 0.0]
delta_n = sum(products, zero)
y_n = y + dt * delta_n
k.append(tangent_func(t_n, y_n))
products = [b_i * k_i for b_i, k_i in zip(self.b_tableau, k) if b_i != 0.0]
delta = sum(products, zero)
return y + dt * delta
class GeneralEuler(RungaKutta):
"""The standard Euler method (for general ODE problems)."""
def __init__(self):
super().__init__(
a_tableau=[],
b_tableau=[1.0],
c_tableau=[],
order=1
)
class RungaKutta2(RungaKutta):
"""The second order Runga-Kutta method corresponding to the mid-point rule."""
def __init__(self):
super().__init__(
a_tableau=[[1.0 / 2.0]],
b_tableau=[0.0, 1.0],
c_tableau=[1.0 / 2.0],
order=2
)
class RungaKutta4(RungaKutta):
"""The fourth order Runga-Kutta method from [6]."""
def __init__(self):
super().__init__(
a_tableau=[[1.0 / 2.0],
[0.0, 1.0 / 2.0],
[0.0, 0.0, 1.0]],
b_tableau=[1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0],
c_tableau=[1.0 / 2.0, 1.0 / 2.0, 1.0],
order=4
)
class RungaKutta38(RungaKutta):
"""The fourth order 3/8 rule Runga-Kutta method from [7]."""
def __init__(self):
super().__init__(
a_tableau=[[1.0 / 3.0],
[-1.0 / 3.0, 1.0],
[1.0, -1.0, 1.0]],
b_tableau=[1.0 / 8.0, 3.0 / 8.0, 3.0 / 8.0, 1.0 / 8.0],
c_tableau=[1.0 / 3.0, 2.0 / 3.0, 1.0],
order=4
)
# _____ _ _ _
# / ____| | | | | (_)
# | (___ _ _ _ __ ___ _ __ | | ___ ___| |_ _ ___
# \___ \| | | | '_ ` _ \| '_ \| |/ _ \/ __| __| |/ __|
# ____) | |_| | | | | | | |_) | | __/ (__| |_| | (__
# |_____/ \__, |_| |_| |_| .__/|_|\___|\___|\__|_|\___|
# __/ | | |
# |___/ |_|
# _____ _ _ _
# |_ _| | | | | (_)
# | | _ __ | |_ ___ __ _ _ __ __ _| |_ _ ___ _ __
# | | | '_ \| __/ _ \/ _` | '__/ _` | __| |/ _ \| '_ \
# _| |_| | | | || __/ (_| | | | (_| | |_| | (_) | | | |
# |_____|_| |_|\__\___|\__, |_| \__,_|\__|_|\___/|_| |_|
# __/ |
# |___/
SymplecticIntegrator = Callable[
[
phase_space.SymplecticTangentFunction,
jnp.ndarray, # t
phase_space.PhaseSpace, # (q, p)
jnp.ndarray, # dt
],
phase_space.PhaseSpace # (q_next, p_next)
]
def solve_hamiltonian_ivp_dt(
hamiltonian: phase_space.HamiltonianFunction,
y0: phase_space.PhaseSpace,
t0: Union[float, jnp.ndarray],
dt: Union[float, jnp.ndarray],
method: Union[str, SymplecticIntegrator],
num_steps: Optional[int] = None,
steps_per_dt: int = 1,
use_scan: bool = True,
ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None
) -> Tuple[jnp.ndarray, phase_space.PhaseSpace]:
"""Solve an initial value problem for a Hamiltonian system.
This function numerically integrates a Hamiltonian system given an
initial value::
dq / dt = dH / dp
dp / dt = - dH / dq
q(t0), p(t0) = y0.q, y0.p
Here t is a one-dimensional independent variable (time), y(t) is an
n-dimensional vector-valued function (state), and an n-dimensional
vector-valued function H(t, q, p) determines the value of the Hamiltonian.
The goal is to find q(t) and p(t) approximately satisfying the differential
equations, given an initial values q(t0), p(t0) = y0.q, y0.p
All of the solvers supported here are explicit and non-adaptive. This in
terms makes them easy to run with fixed amount of computation and
the solutions to be easily differentiable.
Args:
hamiltonian: callable
The Hamiltonian function. The calling signature is ``h(t, s)``, where
`s` is an instance of `PhaseSpace`.
y0: an instance of `M`
Initial state at t=t0.
t0: float or array.
The initial time point of integration.
dt: array
Array containing all consecutive increments in time, at which the integral
to be evaluated. The size of this array along axis 0 defines the number of
steps that the integrator would do.
method: string or `GeneralIntegrator`
The integrator method to use. Possible values for string are:
* symp_euler - see `SymplecticEuler`
* symp_euler_q - a `SymplecticEuler` with position_first=True
* symp_euler_p - a `SymplecticEuler` with position_first=False
* leap_frog - see `LeapFrog`
* leap_frog_q - a `LeapFrog` with position_first=True
* leap_frog_p - a `LeapFrog` with position_first=False
* stormer_verlet - same as leap_frog
* stormer_verlet_q - same as leap_frog_q
* stormer_verlet_p - same as leap_frog_p
* ruth4 - see `Ruth4`,
* sym4 - see `Symmetric4`
* sym6 - see `Symmetric6`
* so4 - see `SymmetricSo4`
* so4_q - a `SymmetricSo4` with position_first=True
* so4_p - a `SymmetricSo4` with position_first=False
* so6 - see `SymmetricSo6`
* so6_q - a `SymmetricSo6` with position_first=True
* so6_p - a `SymmetricSo6` with position_first=False
* so8 - see `SymmetricSo8`
* so8_q - a `SymmetricSo8` with position_first=True
* so8_p - a `SymmetricSo8` with position_first=False
num_steps: Optional int.
If provided the `dt` will be treated as the same per step time interval,
applied for this many steps. In other words setting this argument is
equivalent to replicating `dt` num_steps times and stacking over axis=0.
steps_per_dt: int
This determines the overall step size. Between any two values of t_eval
the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt.
use_scan: bool
Whether for the loop to use `lax.scan` or a python loop
ode_int_kwargs: dict
Extra arguments to be passed to `ode.odeint` when method="adaptive"
Returns:
t: array
Time points at which the solution is evaluated.
y : an instance of M
Values of the solution at `t`.
"""
if not isinstance(y0, phase_space.PhaseSpace):
raise ValueError("The initial state must be an instance of `PhaseSpace`.")
dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
return solve_ivp_dt(
fun=dy_dt,
y0=y0,
t0=t0,
dt=dt,
method=method,
num_steps=num_steps,
steps_per_dt=steps_per_dt,
use_scan=use_scan,
ode_int_kwargs=ode_int_kwargs
)
def solve_hamiltonian_ivp_t_eval(
hamiltonian: phase_space.HamiltonianFunction,
t_span: TimeInterval,
y0: phase_space.PhaseSpace,
method: Union[str, SymplecticIntegrator],
t_eval: Optional[jnp.ndarray] = None,
steps_per_dt: int = 1,
use_scan: bool = True,
ode_int_kwargs: Optional[Dict[str, Union[float, int]]] = None
) -> Tuple[jnp.ndarray, phase_space.PhaseSpace]:
"""Solve an initial value problem for a Hamiltonian system.
This function numerically integrates a Hamiltonian system given an
initial value::
dq / dt = dH / dp
dp / dt = - dH / dq
q(t0), p(t0) = y0.q, y0.p
Here t is a one-dimensional independent variable (time), y(t) is an
n-dimensional vector-valued function (state), and an n-dimensional
vector-valued function H(t, q, p) determines the value of the Hamiltonian.
The goal is to find q(t) and p(t) approximately satisfying the differential
equations, given an initial values q(t0), p(t0) = y0.q, y0.p
All of the solvers supported here are explicit and non-adaptive. This in
terms makes them easy to run with fixed amount of computation and
the solutions to be easily differentiable.
Args:
hamiltonian: callable
The Hamiltonian function. The calling signature is ``h(t, s)``, where
`s` is an instance of `PhaseSpace`.
t_span: 2-tuple of floats
Interval of integration (t0, tf). The solver starts with t=t0 and
integrates until it reaches t=tf.
y0: an instance of `M`
Initial state at `t_span[0]`.
method: string or `GeneralIntegrator`
The integrator method to use. Possible values for string are:
* symp_euler - see `SymplecticEuler`
* symp_euler_q - a `SymplecticEuler` with position_first=True
* symp_euler_p - a `SymplecticEuler` with position_first=False
* leap_frog - see `LeapFrog`
* leap_frog_q - a `LeapFrog` with position_first=True
* leap_frog_p - a `LeapFrog` with position_first=False
* stormer_verlet - same as leap_frog
* stormer_verlet_q - same as leap_frog_q
* stormer_verlet_p - same as leap_frog_p
* ruth4 - see `Ruth4`,
* sym4 - see `Symmetric4`
* sym6 - see `Symmetric6`
* so4 - see `SymmetricSo4`
* so4_q - a `SymmetricSo4` with position_first=True
* so4_p - a `SymmetricSo4` with position_first=False
* so6 - see `SymmetricSo6`
* so6_q - a `SymmetricSo6` with position_first=True
* so6_p - a `SymmetricSo6` with position_first=False
* so8 - see `SymmetricSo8`
* so8_q - a `SymmetricSo8` with position_first=True
* so8_p - a `SymmetricSo8` with position_first=False
t_eval: array or None.
Times at which to store the computed solution. Must be sorted and lie
within `t_span`. If None then t_eval = [t_span[-1]]
steps_per_dt: int
This determines the overall step size. Between any two values of t_eval
the step size is `dt = (t_eval[i+1] - t_eval[i]) / steps_per_dt.
use_scan: bool
Whether for the loop to use `lax.scan` or a python loop
ode_int_kwargs: dict
Extra argumrnts to be passed to `ode.odeint` when method="adaptive"
Returns:
t: array
Time points at which the solution is evaluated.
y : an instance of M
Values of the solution at `t`.
"""
if not isinstance(y0, phase_space.PhaseSpace):
raise ValueError("The initial state must be an instance of `PhaseSpace`.")
dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
if method == "adaptive":
dy_dt = phase_space.transform_symplectic_tangent_function_using_array(dy_dt)
return solve_ivp_t_eval( # pytype: disable=bad-return-type # jax-ndarray
fun=dy_dt,
t_span=t_span,
y0=y0,
method=method,
t_eval=t_eval,
steps_per_dt=steps_per_dt,
use_scan=use_scan,
ode_int_kwargs=ode_int_kwargs
)
class CompositionSymplectic(SymplecticIntegrator):
"""A generalized symplectic integrator based on compositions.
Simulates Hamiltonian dynamics using a composition of symplectic steps:
q_{0} = q_init, p_{0} = p_init
for i in [1, n]:
p_{i+1} = p_{i} - c_{i} * dH/dq(q_{i}) * dt
q_{i+1} = q_{i} + d_{i} * dH/dp(p_{i+1}) * dt
q_next = q_{n}, p_next = p_{n}
This integrator always starts with updating the momentum.
The order argument is used mainly for testing to estimate the error when
integrating various systems.
"""
def __init__(
self,
momentum_coefficients: Sequence[float],
position_coefficients: Sequence[float],
order: int):
if len(position_coefficients) != len(momentum_coefficients):
raise ValueError("The number of momentum_coefficients and "
"position_coefficients must be the same.")
if not np.allclose(sum(position_coefficients), 1.0):
raise ValueError("The sum of the position_coefficients "
"must be equal to 1.")
if not np.allclose(sum(momentum_coefficients), 1.0):
raise ValueError("The sum of the momentum_coefficients "
"must be equal to 1.")
self.momentum_coefficients = momentum_coefficients
self.position_coefficients = position_coefficients
self.order = order
def __call__(
self,
tangent_func: phase_space.SymplecticTangentFunction,
t: jnp.ndarray,
y: phase_space.PhaseSpace,
dt: jnp.ndarray
) -> phase_space.PhaseSpace:
q, p = y.q, y.p
# This is intentional to prevent a bug where one uses y later
del y
# We always broadcast opposite to numpy (e.g. leading dims (batch) count)
if dt.ndim > 0:
dt = dt.reshape(dt.shape + (1,) * (q.ndim - dt.ndim))
if t.ndim > 0:
t = t.reshape(t.shape + (1,) * (q.ndim - t.ndim))
t_q = t
t_p = t
for c, d in zip(self.momentum_coefficients, self.position_coefficients):
# Update momentum
if c != 0.0:
dp_dt = tangent_func(t_p, phase_space.PhaseSpace(q, p)).p
p = p + c * dt * dp_dt
t_p = t_p + c * dt
# Update position
if d != 0.0:
dq_dt = tangent_func(t_q, phase_space.PhaseSpace(q, p)).q
q = q + d * dt * dq_dt
t_q = t_q + d * dt
return phase_space.PhaseSpace(position=q, momentum=p)
class SymplecticEuler(CompositionSymplectic):
"""The symplectic Euler method (for Hamiltonian systems).
If position_first = True:
q_{t+1} = q_{t} + dH/dp(p_{t}) * dt
p_{t+1} = p_{t} - dH/dq(q_{t+1}) * dt
else:
p_{t+1} = p_{t} - dH/dq(q_{t}) * dt
q_{t+1} = q_{t} + dH/dp(p_{t+1}) * dt
"""
def __init__(self, position_first=True):
if position_first:
super().__init__(
momentum_coefficients=[0.0, 1.0],
position_coefficients=[1.0, 0.0],
order=1
)
else:
super().__init__(
momentum_coefficients=[1.0],
position_coefficients=[1.0],
order=1
)
class SymmetricCompositionSymplectic(CompositionSymplectic):
"""A generalized composition integrator that is symmetric.
The integrators produced are always of the form:
[update_q, update_p, ..., update_p, update_q]
or
[update_p, update_q, ..., update_q, update_p]
based on the position_first argument. The method will expect which ever is
updated first to have one more coefficient.
"""
def __init__(
self,
momentum_coefficients: Sequence[float],
position_coefficients: Sequence[float],
position_first: bool,
order: int):
position_coefficients = list(position_coefficients)
momentum_coefficients = list(momentum_coefficients)
if position_first:
if len(position_coefficients) != len(momentum_coefficients) + 1:
raise ValueError("The number of position_coefficients must be one more "
"than momentum_coefficients when position_first=True.")
momentum_coefficients = [0.0] + momentum_coefficients
else:
if len(position_coefficients) + 1 != len(momentum_coefficients):
raise ValueError("The number of momentum_coefficients must be one more "
"than position_coefficients when position_first=True.")
position_coefficients = position_coefficients + [0.0]
super().__init__(
position_coefficients=position_coefficients,
momentum_coefficients=momentum_coefficients,
order=order
)
def symmetrize_coefficients(
coefficients: Sequence[float],
odd_number: bool
) -> Sequence[float]:
"""Symmetrizes the coefficients for an integrator."""
coefficients = list(coefficients)
if odd_number:
final = 1.0 - 2.0 * sum(coefficients)
return coefficients + [final] + coefficients[::-1]
else:
final = 0.5 - sum(coefficients)
return coefficients + [final, final] + coefficients[::-1]
class LeapFrog(SymmetricCompositionSymplectic):
"""The standard Leap-Frog method (also known as Stormer-Verlet).
If position_first = True:
q_half = q_{t} + dH/dp(p_{t}) * dt / 2
p_{t+1} = p_{t} - dH/dq(q_half) * dt
q_{t+1} = q_half + dH/dp(p_{t+1}) * dt / 2
else:
p_half = p_{t} - dH/dq(q_{t}) * dt / 2
q_{t+1} = q_{t} + dH/dp(p_half) * dt
p_{t+1} = p_half - dH/dq(q_{t+1}) * dt / 2
"""
def __init__(self, position_first=False):
if position_first:
super().__init__(
position_coefficients=[0.5, 0.5],
momentum_coefficients=[1.0],
position_first=True,
order=2
)
else:
super().__init__(
position_coefficients=[1.0],
momentum_coefficients=[0.5, 0.5],
position_first=False,
order=2
)
class Ruth4(SymmetricCompositionSymplectic):
"""The Fourth order method from [2]."""
def __init__(self):
cbrt_2 = float(np.cbrt(2.0))
c = [1.0 / (2.0 - cbrt_2)]
# 3: [c1, 1.0 - 2*c1, c1]
c = symmetrize_coefficients(c, odd_number=True)
d = [1.0 / (4.0 - 2.0 * cbrt_2)]
# 4: [d1, 0.5 - d1, 0.5 - d1, d1]
d = symmetrize_coefficients(d, odd_number=False)
super().__init__(
position_coefficients=d,
momentum_coefficients=c,
position_first=True,
order=4
)
class Symmetric4(SymmetricCompositionSymplectic):
"""The fourth order method from Table 6.1 in [1] (originally from [3])."""
def __init__(self):
c = [0.0792036964311957, 0.353172906049774, -0.0420650803577195]
# 7 : [c1, c2, c3, 1.0 - c1 - c2 - c3, c3, c2, c1]
c = symmetrize_coefficients(c, odd_number=True)
d = [0.209515106613362, -0.143851773179818]
# 6: [d1, d2, 0.5 - d1, 0.5 - d1, d2, d1]
d = symmetrize_coefficients(d, odd_number=False)
super().__init__(
position_coefficients=d,
momentum_coefficients=c,
position_first=False,
order=4
)
class Symmetric6(SymmetricCompositionSymplectic):
"""The sixth order method from Table 6.1 in [1] (originally from [3])."""
def __init__(self):
c = [0.0502627644003922, 0.413514300428344, 0.0450798897943977,
-0.188054853819569, 0.541960678450780]
# 11 : [c1, c2, c3, c4, c5, 1.0 - sum(ci), c5, c4, c3, c2, c1]
c = symmetrize_coefficients(c, odd_number=True)
d = [0.148816447901042, -0.132385865767784, 0.067307604692185,
0.432666402578175]
# 10: [d1, d2, d3, d4, 0.5 - sum(di), 0.5 - sum(di), d4, d3, d2, d1]
d = symmetrize_coefficients(d, odd_number=False)
super().__init__(
position_coefficients=d,
momentum_coefficients=c,
position_first=False,
order=4
)
def coefficients_based_on_composing_second_order(
weights: Sequence[float]
) -> Tuple[Sequence[float], Sequence[float]]:
"""Constructs the coefficients for methods based on second-order schemes."""
coefficients_0 = []
coefficients_1 = []
coefficients_0.append(weights[0] / 2.0)
for i in range(len(weights) - 1):
coefficients_1.append(weights[i])
coefficients_0.append((weights[i] + weights[i + 1]) / 2.0)
coefficients_1.append(weights[-1])
coefficients_0.append(weights[-1] / 2.0)
return coefficients_0, coefficients_1
class SymmetricSo4(SymmetricCompositionSymplectic):
"""The fourth order method from Table 6.2 in [1] (originally from [4])."""
def __init__(self, position_first: bool = False):
w = [0.28, 0.62546642846767004501]
# 5
w = symmetrize_coefficients(w, odd_number=True)
c0, c1 = coefficients_based_on_composing_second_order(w)
c_q, c_p = (c0, c1) if position_first else (c1, c0)
super().__init__(
position_coefficients=c_q,
momentum_coefficients=c_p,
position_first=position_first,
order=4
)
class SymmetricSo6(SymmetricCompositionSymplectic):
"""The sixth order method from Table 6.2 in [1] (originally from [5])."""
def __init__(self, position_first: bool = False):
w = [0.78451361047755726382, 0.23557321335935813368,
-1.17767998417887100695]
# 7
w = symmetrize_coefficients(w, odd_number=True)
c0, c1 = coefficients_based_on_composing_second_order(w)
c_q, c_p = (c0, c1) if position_first else (c1, c0)
super().__init__(
position_coefficients=c_q,
momentum_coefficients=c_p,
position_first=position_first,
order=6
)
class SymmetricSo8(SymmetricCompositionSymplectic):
"""The eighth order method from Table 6.2 in [1] (originally from [4])."""
def __init__(self, position_first: bool = False):
w = [0.74167036435061295345, -0.40910082580003159400,
0.19075471029623837995, -0.57386247111608226666,
0.29906418130365592384, 0.33462491824529818378,
0.31529309239676659663]
# 15
w = symmetrize_coefficients(w, odd_number=True)
c0, c1 = coefficients_based_on_composing_second_order(w)
c_q, c_p = (c0, c1) if position_first else (c1, c0)
super().__init__(
position_coefficients=c_q,
momentum_coefficients=c_p,
position_first=position_first,
order=8
)
general_integrators = dict(
general_euler=GeneralEuler(),
rk2=RungaKutta2(),
rk4=RungaKutta4(),
rk38=RungaKutta38()