From bef8dc998b94d9c3dc8c1ecf665ae4c1672f0360 Mon Sep 17 00:00:00 2001 From: Justin Lovinger Date: Mon, 18 Mar 2024 13:48:30 +0000 Subject: [PATCH] optimal-linesearch: refactor `BacktrackingLineSearchWith::step` --- .../backtracking_line_search/high_level.rs | 149 ++++++++---------- 1 file changed, 67 insertions(+), 82 deletions(-) diff --git a/optimal-linesearch/src/backtracking_line_search/high_level.rs b/optimal-linesearch/src/backtracking_line_search/high_level.rs index ea615de..bd75faa 100644 --- a/optimal-linesearch/src/backtracking_line_search/high_level.rs +++ b/optimal-linesearch/src/backtracking_line_search/high_level.rs @@ -1,7 +1,4 @@ -use std::{ - iter::Sum, - ops::{RangeInclusive, Sub}, -}; +use std::{iter::Sum, ops::RangeInclusive}; use derive_builder::Builder; use derive_getters::{Dissolve, Getters}; @@ -282,7 +279,12 @@ pub struct BacktrackingLineSearchWith { pub initial_point: Vec, } -impl BacktrackingLineSearchWith { +impl BacktrackingLineSearchWith +where + A: Sum + Signed + Float + ScalarOperand + LinalgScalar, + F: Fn(&[A]) -> A, + FFD: Fn(&[A]) -> (A, Vec), +{ /// Return a point that attempts to minimize the given objective function. pub fn argmin(self) -> Vec where @@ -298,18 +300,17 @@ impl BacktrackingLineSearchWith { StepDirection::Bfgs { .. } => StepDirState::Bfgs(BfgsIteration::First), }; let mut step_size = self.problem.agnostic.initial_step_size; - let mut point = Array1::from_vec(self.initial_point.clone()); + let mut point = self.initial_point.clone(); match self.problem.agnostic.stopping_criteria { BacktrackingLineSearchStoppingCriteria::Iteration(i) => { for _ in 0..i { - let (value, derivatives) = - (self.problem.obj_func_and_d)(point.as_slice().unwrap()); + let (value, derivatives) = (self.problem.obj_func_and_d)(&point); (step_dir_state, step_size, point) = self.step(step_dir_state, step_size, point, value, derivatives); } } BacktrackingLineSearchStoppingCriteria::NearMinima => loop { - let (value, derivatives) = (self.problem.obj_func_and_d)(point.as_slice().unwrap()); + let (value, derivatives) = (self.problem.obj_func_and_d)(&point); if is_near_minima(value, derivatives.iter().copied()) { break; } @@ -324,24 +325,20 @@ impl BacktrackingLineSearchWith { &self, step_dir_state: StepDirState, step_size: StepSize, - point: Array1, + point: Vec, value: A, derivatives: Vec, - ) -> (StepDirState, StepSize, Array1) - where - A: Sum + Float + ScalarOperand + LinalgScalar, - F: Fn(&[A]) -> A, - FFD: Fn(&[A]) -> (A, Vec), - { - let derivatives = Array1::from_vec(derivatives); - - let (direction, step_dir_partial_state) = match step_dir_state { - StepDirState::Steepest => ( - steepest_descent(derivatives.iter().cloned()).collect(), - StepDirPartialState::Steepest, - ), + ) -> (StepDirState, StepSize, Vec) { + let (step_dir_state, step_size, new_point) = match step_dir_state { + StepDirState::Steepest => { + let direction = steepest_descent(derivatives.iter().cloned()).collect(); + let (step_size, point) = + self.search(step_size, point, value, derivatives, direction); + (StepDirState::Steepest, step_size, point) + } StepDirState::Bfgs(bfgs_state) => { - let (direction, partial_state) = match bfgs_state { + let derivatives = Array1::from_vec(derivatives); + let (direction, approx_inv_second_derivatives) = match bfgs_state { BfgsIteration::First => { let approx_inv_snd_derivatives = initial_approx_inv_snd_derivatives_identity(point.len()); @@ -351,10 +348,8 @@ impl BacktrackingLineSearchWith { direction, match &self.problem.agnostic.direction { StepDirection::Bfgs { initializer } => match initializer { - BfgsInitializer::Identity => BfgsPartialIteration::Other { - approx_inv_snd_derivatives, - }, - BfgsInitializer::Gamma => BfgsPartialIteration::Second, + BfgsInitializer::Identity => Some(approx_inv_snd_derivatives), + BfgsInitializer::Gamma => None, }, _ => panic!(), }, @@ -371,12 +366,7 @@ impl BacktrackingLineSearchWith { ); let direction = bfgs_direction(approx_inv_snd_derivatives.view(), derivatives.view()); - ( - direction, - BfgsPartialIteration::Other { - approx_inv_snd_derivatives, - }, - ) + (direction, Some(approx_inv_snd_derivatives)) } BfgsIteration::Other { prev_derivatives, @@ -391,57 +381,42 @@ impl BacktrackingLineSearchWith { ); let direction = bfgs_direction(approx_inv_snd_derivatives.view(), derivatives.view()); - ( - direction, - BfgsPartialIteration::Other { - approx_inv_snd_derivatives, - }, - ) + (direction, Some(approx_inv_snd_derivatives)) } }; - (direction.to_vec(), StepDirPartialState::Bfgs(partial_state)) - } - }; - // The compiler should remove these clones - // if they are not necessary - // for the type of step-direction used. - let (step_size, new_point) = BacktrackingSearcher::new( - self.problem.agnostic.c_1, - point.clone().to_vec(), - value, - derivatives.clone(), - direction, - ) - .search( - self.problem.agnostic.backtracking_rate, - &self.problem.obj_func, - step_size, - ); - let new_point = Array1::from_vec(new_point); + let (step_size, new_point) = self.search( + step_size, + point.clone(), + value, + derivatives.iter().cloned(), + direction.to_vec(), + ); - let step_dir_state = match step_dir_partial_state { - StepDirPartialState::Steepest => StepDirState::Steepest, - StepDirPartialState::Bfgs(partial_state) => { let prev_derivatives = derivatives; // We could theoretically get `prev_step` directly from line-search, // but then we would need to calculate each point of line-search // less efficiently, // calculating step and point separately. - let prev_step = new_point.view().sub(point); - StepDirState::Bfgs(match partial_state { - BfgsPartialIteration::Second => BfgsIteration::Second { + let prev_step = new_point + .iter() + .cloned() + .zip(point) + .map(|(x, y)| x - y) + .collect(); + let step_dir_state = match approx_inv_second_derivatives { + None => BfgsIteration::Second { prev_derivatives, prev_step, }, - BfgsPartialIteration::Other { - approx_inv_snd_derivatives, - } => BfgsIteration::Other { + Some(prev_approx_inv_snd_derivatives) => BfgsIteration::Other { prev_derivatives, - prev_approx_inv_snd_derivatives: approx_inv_snd_derivatives, + prev_approx_inv_snd_derivatives, prev_step, }, - }) + }; + + (StepDirState::Bfgs(step_dir_state), step_size, new_point) } }; @@ -451,6 +426,28 @@ impl BacktrackingLineSearchWith { (step_dir_state, step_size, new_point) } + + fn search( + &self, + step_size: StepSize, + point: Vec, + value: A, + derivatives: impl IntoIterator, + direction: Vec, + ) -> (StepSize, Vec) { + BacktrackingSearcher::new( + self.problem.agnostic.c_1, + point, + value, + derivatives, + direction, + ) + .search( + self.problem.agnostic.backtracking_rate, + &self.problem.obj_func, + step_size, + ) + } } enum StepDirState { @@ -470,15 +467,3 @@ enum BfgsIteration { prev_step: Array1, }, } - -enum StepDirPartialState { - Steepest, - Bfgs(BfgsPartialIteration), -} - -enum BfgsPartialIteration { - Second, - Other { - approx_inv_snd_derivatives: Array2, - }, -}