diff --git a/optimal-linesearch/src/backtracking_line_search/high_level.rs b/optimal-linesearch/src/backtracking_line_search/high_level.rs
index ea615de..bd75faa 100644
--- a/optimal-linesearch/src/backtracking_line_search/high_level.rs
+++ b/optimal-linesearch/src/backtracking_line_search/high_level.rs
@@ -1,7 +1,4 @@
-use std::{
- iter::Sum,
- ops::{RangeInclusive, Sub},
-};
+use std::{iter::Sum, ops::RangeInclusive};
use derive_builder::Builder;
use derive_getters::{Dissolve, Getters};
@@ -282,7 +279,12 @@ pub struct BacktrackingLineSearchWith {
pub initial_point: Vec,
}
-impl BacktrackingLineSearchWith {
+impl BacktrackingLineSearchWith
+where
+ A: Sum + Signed + Float + ScalarOperand + LinalgScalar,
+ F: Fn(&[A]) -> A,
+ FFD: Fn(&[A]) -> (A, Vec),
+{
/// Return a point that attempts to minimize the given objective function.
pub fn argmin(self) -> Vec
where
@@ -298,18 +300,17 @@ impl BacktrackingLineSearchWith {
StepDirection::Bfgs { .. } => StepDirState::Bfgs(BfgsIteration::First),
};
let mut step_size = self.problem.agnostic.initial_step_size;
- let mut point = Array1::from_vec(self.initial_point.clone());
+ let mut point = self.initial_point.clone();
match self.problem.agnostic.stopping_criteria {
BacktrackingLineSearchStoppingCriteria::Iteration(i) => {
for _ in 0..i {
- let (value, derivatives) =
- (self.problem.obj_func_and_d)(point.as_slice().unwrap());
+ let (value, derivatives) = (self.problem.obj_func_and_d)(&point);
(step_dir_state, step_size, point) =
self.step(step_dir_state, step_size, point, value, derivatives);
}
}
BacktrackingLineSearchStoppingCriteria::NearMinima => loop {
- let (value, derivatives) = (self.problem.obj_func_and_d)(point.as_slice().unwrap());
+ let (value, derivatives) = (self.problem.obj_func_and_d)(&point);
if is_near_minima(value, derivatives.iter().copied()) {
break;
}
@@ -324,24 +325,20 @@ impl BacktrackingLineSearchWith {
&self,
step_dir_state: StepDirState,
step_size: StepSize,
- point: Array1,
+ point: Vec,
value: A,
derivatives: Vec,
- ) -> (StepDirState, StepSize, Array1)
- where
- A: Sum + Float + ScalarOperand + LinalgScalar,
- F: Fn(&[A]) -> A,
- FFD: Fn(&[A]) -> (A, Vec),
- {
- let derivatives = Array1::from_vec(derivatives);
-
- let (direction, step_dir_partial_state) = match step_dir_state {
- StepDirState::Steepest => (
- steepest_descent(derivatives.iter().cloned()).collect(),
- StepDirPartialState::Steepest,
- ),
+ ) -> (StepDirState, StepSize, Vec) {
+ let (step_dir_state, step_size, new_point) = match step_dir_state {
+ StepDirState::Steepest => {
+ let direction = steepest_descent(derivatives.iter().cloned()).collect();
+ let (step_size, point) =
+ self.search(step_size, point, value, derivatives, direction);
+ (StepDirState::Steepest, step_size, point)
+ }
StepDirState::Bfgs(bfgs_state) => {
- let (direction, partial_state) = match bfgs_state {
+ let derivatives = Array1::from_vec(derivatives);
+ let (direction, approx_inv_second_derivatives) = match bfgs_state {
BfgsIteration::First => {
let approx_inv_snd_derivatives =
initial_approx_inv_snd_derivatives_identity(point.len());
@@ -351,10 +348,8 @@ impl BacktrackingLineSearchWith {
direction,
match &self.problem.agnostic.direction {
StepDirection::Bfgs { initializer } => match initializer {
- BfgsInitializer::Identity => BfgsPartialIteration::Other {
- approx_inv_snd_derivatives,
- },
- BfgsInitializer::Gamma => BfgsPartialIteration::Second,
+ BfgsInitializer::Identity => Some(approx_inv_snd_derivatives),
+ BfgsInitializer::Gamma => None,
},
_ => panic!(),
},
@@ -371,12 +366,7 @@ impl BacktrackingLineSearchWith {
);
let direction =
bfgs_direction(approx_inv_snd_derivatives.view(), derivatives.view());
- (
- direction,
- BfgsPartialIteration::Other {
- approx_inv_snd_derivatives,
- },
- )
+ (direction, Some(approx_inv_snd_derivatives))
}
BfgsIteration::Other {
prev_derivatives,
@@ -391,57 +381,42 @@ impl BacktrackingLineSearchWith {
);
let direction =
bfgs_direction(approx_inv_snd_derivatives.view(), derivatives.view());
- (
- direction,
- BfgsPartialIteration::Other {
- approx_inv_snd_derivatives,
- },
- )
+ (direction, Some(approx_inv_snd_derivatives))
}
};
- (direction.to_vec(), StepDirPartialState::Bfgs(partial_state))
- }
- };
- // The compiler should remove these clones
- // if they are not necessary
- // for the type of step-direction used.
- let (step_size, new_point) = BacktrackingSearcher::new(
- self.problem.agnostic.c_1,
- point.clone().to_vec(),
- value,
- derivatives.clone(),
- direction,
- )
- .search(
- self.problem.agnostic.backtracking_rate,
- &self.problem.obj_func,
- step_size,
- );
- let new_point = Array1::from_vec(new_point);
+ let (step_size, new_point) = self.search(
+ step_size,
+ point.clone(),
+ value,
+ derivatives.iter().cloned(),
+ direction.to_vec(),
+ );
- let step_dir_state = match step_dir_partial_state {
- StepDirPartialState::Steepest => StepDirState::Steepest,
- StepDirPartialState::Bfgs(partial_state) => {
let prev_derivatives = derivatives;
// We could theoretically get `prev_step` directly from line-search,
// but then we would need to calculate each point of line-search
// less efficiently,
// calculating step and point separately.
- let prev_step = new_point.view().sub(point);
- StepDirState::Bfgs(match partial_state {
- BfgsPartialIteration::Second => BfgsIteration::Second {
+ let prev_step = new_point
+ .iter()
+ .cloned()
+ .zip(point)
+ .map(|(x, y)| x - y)
+ .collect();
+ let step_dir_state = match approx_inv_second_derivatives {
+ None => BfgsIteration::Second {
prev_derivatives,
prev_step,
},
- BfgsPartialIteration::Other {
- approx_inv_snd_derivatives,
- } => BfgsIteration::Other {
+ Some(prev_approx_inv_snd_derivatives) => BfgsIteration::Other {
prev_derivatives,
- prev_approx_inv_snd_derivatives: approx_inv_snd_derivatives,
+ prev_approx_inv_snd_derivatives,
prev_step,
},
- })
+ };
+
+ (StepDirState::Bfgs(step_dir_state), step_size, new_point)
}
};
@@ -451,6 +426,28 @@ impl BacktrackingLineSearchWith {
(step_dir_state, step_size, new_point)
}
+
+ fn search(
+ &self,
+ step_size: StepSize,
+ point: Vec,
+ value: A,
+ derivatives: impl IntoIterator- ,
+ direction: Vec,
+ ) -> (StepSize, Vec) {
+ BacktrackingSearcher::new(
+ self.problem.agnostic.c_1,
+ point,
+ value,
+ derivatives,
+ direction,
+ )
+ .search(
+ self.problem.agnostic.backtracking_rate,
+ &self.problem.obj_func,
+ step_size,
+ )
+ }
}
enum StepDirState {
@@ -470,15 +467,3 @@ enum BfgsIteration {
prev_step: Array1,
},
}
-
-enum StepDirPartialState {
- Steepest,
- Bfgs(BfgsPartialIteration),
-}
-
-enum BfgsPartialIteration {
- Second,
- Other {
- approx_inv_snd_derivatives: Array2,
- },
-}