Skip to content

Commit

Permalink
optimal-linesearch: refactor BacktrackingLineSearchWith::step
Browse files Browse the repository at this point in the history
  • Loading branch information
justinlovinger committed Mar 18, 2024
1 parent 3633870 commit bef8dc9
Showing 1 changed file with 67 additions and 82 deletions.
149 changes: 67 additions & 82 deletions optimal-linesearch/src/backtracking_line_search/high_level.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use std::{
iter::Sum,
ops::{RangeInclusive, Sub},
};
use std::{iter::Sum, ops::RangeInclusive};

use derive_builder::Builder;
use derive_getters::{Dissolve, Getters};
Expand Down Expand Up @@ -282,7 +279,12 @@ pub struct BacktrackingLineSearchWith<A, F, FFD> {
pub initial_point: Vec<A>,
}

impl<A, F, FFD> BacktrackingLineSearchWith<A, F, FFD> {
impl<A, F, FFD> BacktrackingLineSearchWith<A, F, FFD>
where
A: Sum + Signed + Float + ScalarOperand + LinalgScalar,
F: Fn(&[A]) -> A,
FFD: Fn(&[A]) -> (A, Vec<A>),
{
/// Return a point that attempts to minimize the given objective function.
pub fn argmin(self) -> Vec<A>
where
Expand All @@ -298,18 +300,17 @@ impl<A, F, FFD> BacktrackingLineSearchWith<A, F, FFD> {
StepDirection::Bfgs { .. } => StepDirState::Bfgs(BfgsIteration::First),
};
let mut step_size = self.problem.agnostic.initial_step_size;
let mut point = Array1::from_vec(self.initial_point.clone());
let mut point = self.initial_point.clone();
match self.problem.agnostic.stopping_criteria {
BacktrackingLineSearchStoppingCriteria::Iteration(i) => {
for _ in 0..i {
let (value, derivatives) =
(self.problem.obj_func_and_d)(point.as_slice().unwrap());
let (value, derivatives) = (self.problem.obj_func_and_d)(&point);
(step_dir_state, step_size, point) =
self.step(step_dir_state, step_size, point, value, derivatives);
}
}
BacktrackingLineSearchStoppingCriteria::NearMinima => loop {
let (value, derivatives) = (self.problem.obj_func_and_d)(point.as_slice().unwrap());
let (value, derivatives) = (self.problem.obj_func_and_d)(&point);
if is_near_minima(value, derivatives.iter().copied()) {
break;
}
Expand All @@ -324,24 +325,20 @@ impl<A, F, FFD> BacktrackingLineSearchWith<A, F, FFD> {
&self,
step_dir_state: StepDirState<A>,
step_size: StepSize<A>,
point: Array1<A>,
point: Vec<A>,
value: A,
derivatives: Vec<A>,
) -> (StepDirState<A>, StepSize<A>, Array1<A>)
where
A: Sum + Float + ScalarOperand + LinalgScalar,
F: Fn(&[A]) -> A,
FFD: Fn(&[A]) -> (A, Vec<A>),
{
let derivatives = Array1::from_vec(derivatives);

let (direction, step_dir_partial_state) = match step_dir_state {
StepDirState::Steepest => (
steepest_descent(derivatives.iter().cloned()).collect(),
StepDirPartialState::Steepest,
),
) -> (StepDirState<A>, StepSize<A>, Vec<A>) {
let (step_dir_state, step_size, new_point) = match step_dir_state {
StepDirState::Steepest => {
let direction = steepest_descent(derivatives.iter().cloned()).collect();
let (step_size, point) =
self.search(step_size, point, value, derivatives, direction);
(StepDirState::Steepest, step_size, point)
}
StepDirState::Bfgs(bfgs_state) => {
let (direction, partial_state) = match bfgs_state {
let derivatives = Array1::from_vec(derivatives);
let (direction, approx_inv_second_derivatives) = match bfgs_state {
BfgsIteration::First => {
let approx_inv_snd_derivatives =
initial_approx_inv_snd_derivatives_identity(point.len());
Expand All @@ -351,10 +348,8 @@ impl<A, F, FFD> BacktrackingLineSearchWith<A, F, FFD> {
direction,
match &self.problem.agnostic.direction {
StepDirection::Bfgs { initializer } => match initializer {
BfgsInitializer::Identity => BfgsPartialIteration::Other {
approx_inv_snd_derivatives,
},
BfgsInitializer::Gamma => BfgsPartialIteration::Second,
BfgsInitializer::Identity => Some(approx_inv_snd_derivatives),
BfgsInitializer::Gamma => None,
},
_ => panic!(),
},
Expand All @@ -371,12 +366,7 @@ impl<A, F, FFD> BacktrackingLineSearchWith<A, F, FFD> {
);
let direction =
bfgs_direction(approx_inv_snd_derivatives.view(), derivatives.view());
(
direction,
BfgsPartialIteration::Other {
approx_inv_snd_derivatives,
},
)
(direction, Some(approx_inv_snd_derivatives))
}
BfgsIteration::Other {
prev_derivatives,
Expand All @@ -391,57 +381,42 @@ impl<A, F, FFD> BacktrackingLineSearchWith<A, F, FFD> {
);
let direction =
bfgs_direction(approx_inv_snd_derivatives.view(), derivatives.view());
(
direction,
BfgsPartialIteration::Other {
approx_inv_snd_derivatives,
},
)
(direction, Some(approx_inv_snd_derivatives))
}
};
(direction.to_vec(), StepDirPartialState::Bfgs(partial_state))
}
};

// The compiler should remove these clones
// if they are not necessary
// for the type of step-direction used.
let (step_size, new_point) = BacktrackingSearcher::new(
self.problem.agnostic.c_1,
point.clone().to_vec(),
value,
derivatives.clone(),
direction,
)
.search(
self.problem.agnostic.backtracking_rate,
&self.problem.obj_func,
step_size,
);
let new_point = Array1::from_vec(new_point);
let (step_size, new_point) = self.search(
step_size,
point.clone(),
value,
derivatives.iter().cloned(),
direction.to_vec(),
);

let step_dir_state = match step_dir_partial_state {
StepDirPartialState::Steepest => StepDirState::Steepest,
StepDirPartialState::Bfgs(partial_state) => {
let prev_derivatives = derivatives;
// We could theoretically get `prev_step` directly from line-search,
// but then we would need to calculate each point of line-search
// less efficiently,
// calculating step and point separately.
let prev_step = new_point.view().sub(point);
StepDirState::Bfgs(match partial_state {
BfgsPartialIteration::Second => BfgsIteration::Second {
let prev_step = new_point
.iter()
.cloned()
.zip(point)
.map(|(x, y)| x - y)
.collect();
let step_dir_state = match approx_inv_second_derivatives {
None => BfgsIteration::Second {
prev_derivatives,
prev_step,
},
BfgsPartialIteration::Other {
approx_inv_snd_derivatives,
} => BfgsIteration::Other {
Some(prev_approx_inv_snd_derivatives) => BfgsIteration::Other {
prev_derivatives,
prev_approx_inv_snd_derivatives: approx_inv_snd_derivatives,
prev_approx_inv_snd_derivatives,
prev_step,
},
})
};

(StepDirState::Bfgs(step_dir_state), step_size, new_point)
}
};

Expand All @@ -451,6 +426,28 @@ impl<A, F, FFD> BacktrackingLineSearchWith<A, F, FFD> {

(step_dir_state, step_size, new_point)
}

fn search(
&self,
step_size: StepSize<A>,
point: Vec<A>,
value: A,
derivatives: impl IntoIterator<Item = A>,
direction: Vec<A>,
) -> (StepSize<A>, Vec<A>) {
BacktrackingSearcher::new(
self.problem.agnostic.c_1,
point,
value,
derivatives,
direction,
)
.search(
self.problem.agnostic.backtracking_rate,
&self.problem.obj_func,
step_size,
)
}
}

enum StepDirState<A> {
Expand All @@ -470,15 +467,3 @@ enum BfgsIteration<A> {
prev_step: Array1<A>,
},
}

enum StepDirPartialState<A> {
Steepest,
Bfgs(BfgsPartialIteration<A>),
}

enum BfgsPartialIteration<A> {
Second,
Other {
approx_inv_snd_derivatives: Array2<A>,
},
}

0 comments on commit bef8dc9

Please sign in to comment.