diff --git a/flare/learners/lmpotf.py b/flare/learners/lmpotf.py index a6f7817a9..f390e6213 100644 --- a/flare/learners/lmpotf.py +++ b/flare/learners/lmpotf.py @@ -76,7 +76,7 @@ def __init__( rcut: float, type2number: Union[(int, List[int])], dftcalc: object, - energy_correction: List[float] = 0.0, + energy_correction: Union[(float, List[float])] = 0.0, force_training=True, energy_training=True, stress_training=True,