Skip to content

Commit

Permalink
optimal-linesearch: simplify using BFGS
Browse files Browse the repository at this point in the history
  • Loading branch information
justinlovinger committed Mar 18, 2024
1 parent bef8dc9 commit 0e9c60e
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 230 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion optimal-linesearch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ name = "benchmark"
harness = false

[features]
serde = ["dep:serde"]
serde = ["dep:serde", "ndarray/serde"]

[dependencies]
derive_builder = "0.13.0"
Expand Down
110 changes: 23 additions & 87 deletions optimal-linesearch/src/backtracking_line_search/high_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{iter::Sum, ops::RangeInclusive};

use derive_builder::Builder;
use derive_getters::{Dissolve, Getters};
use ndarray::{Array1, Array2, LinalgScalar, ScalarOperand};
use ndarray::{Array1, LinalgScalar, ScalarOperand};
use num_traits::{AsPrimitive, Float, Signed};
use rand::{
distributions::{uniform::SampleUniform, Uniform},
Expand All @@ -13,10 +13,7 @@ use crate::{
initial_step_size::IncrRate,
is_near_minima,
step_direction::{
bfgs::{
approx_inv_snd_derivatives, bfgs_direction, initial_approx_inv_snd_derivatives_gamma,
initial_approx_inv_snd_derivatives_identity,
},
bfgs::{BfgsIteration, BfgsIterationGamma, BfgsIterationIdentity},
steepest_descent,
},
StepSize,
Expand Down Expand Up @@ -73,9 +70,9 @@ pub enum StepDirection {
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum BfgsInitializer {
/// Initialize using [`initial_approx_inv_snd_derivatives_identity`].
/// Initialize using [`crate::step_direction::bfgs::initial_approx_inv_snd_derivatives_identity`].
Identity,
/// Initialize using [`initial_approx_inv_snd_derivatives_gamma`].
/// Initialize using [`crate::step_direction::bfgs::initial_approx_inv_snd_derivatives_gamma`].
#[default]
Gamma,
}
Expand Down Expand Up @@ -297,7 +294,14 @@ where
// We use them so we can call `self.step`.
let mut step_dir_state = match &self.problem.agnostic.direction {
StepDirection::Steepest => StepDirState::Steepest,
StepDirection::Bfgs { .. } => StepDirState::Bfgs(BfgsIteration::First),
StepDirection::Bfgs { initializer } => match initializer {
BfgsInitializer::Identity => {
StepDirState::Bfgs(BfgsIteration::Identity(BfgsIterationIdentity::default()))
}
BfgsInitializer::Gamma => {
StepDirState::Bfgs(BfgsIteration::Gamma(BfgsIterationGamma::default()))
}
},
};
let mut step_size = self.problem.agnostic.initial_step_size;
let mut point = self.initial_point.clone();
Expand Down Expand Up @@ -338,85 +342,30 @@ where
}
StepDirState::Bfgs(bfgs_state) => {
let derivatives = Array1::from_vec(derivatives);
let (direction, approx_inv_second_derivatives) = match bfgs_state {
BfgsIteration::First => {
let approx_inv_snd_derivatives =
initial_approx_inv_snd_derivatives_identity(point.len());
let direction =
bfgs_direction(approx_inv_snd_derivatives.view(), derivatives.view());
(
direction,
match &self.problem.agnostic.direction {
StepDirection::Bfgs { initializer } => match initializer {
BfgsInitializer::Identity => Some(approx_inv_snd_derivatives),
BfgsInitializer::Gamma => None,
},
_ => panic!(),
},
)
}
BfgsIteration::Second {
prev_derivatives,
prev_step,
} => {
let approx_inv_snd_derivatives = initial_approx_inv_snd_derivatives_gamma(
prev_derivatives,
prev_step,
derivatives.view(),
);
let direction =
bfgs_direction(approx_inv_snd_derivatives.view(), derivatives.view());
(direction, Some(approx_inv_snd_derivatives))
}
BfgsIteration::Other {
prev_derivatives,
prev_approx_inv_snd_derivatives,
prev_step,
} => {
let approx_inv_snd_derivatives = approx_inv_snd_derivatives(
prev_approx_inv_snd_derivatives,
prev_step,
prev_derivatives,
derivatives.view(),
);
let direction =
bfgs_direction(approx_inv_snd_derivatives.view(), derivatives.view());
(direction, Some(approx_inv_snd_derivatives))
}
};
let (direction, before_iteration) = bfgs_state.direction(derivatives);

let (step_size, new_point) = self.search(
step_size,
point.clone(),
value,
derivatives.iter().cloned(),
before_iteration.derivatives().iter().cloned(),
direction.to_vec(),
);

let prev_derivatives = derivatives;
// We could theoretically get `prev_step` directly from line-search,
// but then we would need to calculate each point of line-search
// less efficiently,
// calculating step and point separately.
let prev_step = new_point
.iter()
.cloned()
.zip(point)
.map(|(x, y)| x - y)
.collect();
let step_dir_state = match approx_inv_second_derivatives {
None => BfgsIteration::Second {
prev_derivatives,
prev_step,
},
Some(prev_approx_inv_snd_derivatives) => BfgsIteration::Other {
prev_derivatives,
prev_approx_inv_snd_derivatives,
prev_step,
},
};
let bfgs_state = before_iteration.next(
new_point
.iter()
.cloned()
.zip(point)
.map(|(x, y)| x - y)
.collect(),
);

(StepDirState::Bfgs(step_dir_state), step_size, new_point)
(StepDirState::Bfgs(bfgs_state), step_size, new_point)
}
};

Expand Down Expand Up @@ -454,16 +403,3 @@ enum StepDirState<A> {
Steepest,
Bfgs(BfgsIteration<A>),
}

enum BfgsIteration<A> {
First,
Second {
prev_derivatives: Array1<A>,
prev_step: Array1<A>,
},
Other {
prev_derivatives: Array1<A>,
prev_approx_inv_snd_derivatives: Array2<A>,
prev_step: Array1<A>,
},
}
Loading

0 comments on commit 0e9c60e

Please sign in to comment.