Skip to content

Commit

Permalink
Add optional timeout to Executor
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Jan 21, 2024
1 parent b74620c commit 0445393
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 2 deletions.
75 changes: 73 additions & 2 deletions argmin/src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub struct Executor<O, S, I> {
observers: Observers<I>,
/// Checkpoint
checkpoint: Option<Box<dyn Checkpoint<S, I>>>,
/// Timeout
timeout: Option<std::time::Duration>,
/// Indicates whether Ctrl-C functionality should be active or not
ctrlc: bool,
/// Indicates whether to time execution or not
Expand Down Expand Up @@ -66,6 +68,7 @@ where
state,
observers: Observers::new(),
checkpoint: None,
timeout: None,
ctrlc: true,
timer: true,
}
Expand Down Expand Up @@ -250,10 +253,18 @@ where
}

if self.timer {
// Increase accumulated total_time
total_time.map(|total_time| state.time(Some(total_time.elapsed())));

// If a timeout is set, check if timeout is reached
if let (Some(timeout), Some(total_time)) = (self.timeout, total_time) {
if total_time.elapsed() > timeout {
state = state.terminate_with(TerminationReason::Timeout);
}
}
}

// Check if termination occurred inside next_iter()
// Check if termination occurred in the meantime
if state.terminated() {
break;
}
Expand Down Expand Up @@ -374,6 +385,8 @@ where

/// Enables or disables timing of individual iterations (default: enabled).
///
/// Setting this to false will silently be ignored in case a timeout is set.
///
/// # Example
///
/// ```
Expand All @@ -391,7 +404,38 @@ where
/// ```
#[must_use]
pub fn timer(mut self, timer: bool) -> Self {
self.timer = timer;
if self.timeout.is_none() {
self.timer = timer;
}
self
}

/// Sets a timeout for the run.
///
/// The optimization run is stopped once the timeout is exceeded. Note that the check is
/// performed after each iteration, therefore the actual runtime can exceed the the set
/// duration.
/// This also enables time measurements.
///
/// # Example
///
/// ```
/// # use argmin::core::{Error, Executor};
/// # use argmin::core::test_utils::{TestSolver, TestProblem};
/// #
/// # fn main() -> Result<(), Error> {
/// # let solver = TestSolver::new();
/// # let problem = TestProblem::new();
/// #
/// // Create instance of `Executor` with `problem` and `solver`
/// let executor = Executor::new(problem, solver).timeout(std::time::Duration::from_secs(30));
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
self.timer = true;
self.timeout = Some(timeout);
self
}
}
Expand Down Expand Up @@ -715,4 +759,31 @@ mod tests {
// Delete old checkpointing file
let _ = std::fs::remove_file(".checkpoints/init_test.arg");
}

#[test]
fn test_timeout() {
let solver = TestSolver::new();
let problem = TestProblem::new();
let timeout = std::time::Duration::from_secs(2);

let executor = Executor::new(problem, solver);
assert!(executor.timer);
assert!(executor.timeout.is_none());

let executor = Executor::new(problem, solver).timer(false);
assert!(!executor.timer);
assert!(executor.timeout.is_none());

let executor = Executor::new(problem, solver).timeout(timeout);
assert!(executor.timer);
assert_eq!(executor.timeout, Some(timeout));

let executor = Executor::new(problem, solver).timeout(timeout).timer(false);
assert!(executor.timer);
assert_eq!(executor.timeout, Some(timeout));

let executor = Executor::new(problem, solver).timer(false).timeout(timeout);
assert!(executor.timer);
assert_eq!(executor.timeout, Some(timeout));
}
}
8 changes: 8 additions & 0 deletions argmin/src/core/termination.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ impl TerminationStatus {
/// assert!(TerminationStatus::Terminated(TerminationReason::TargetCostReached).terminated());
/// assert!(TerminationStatus::Terminated(TerminationReason::SolverConverged).terminated());
/// assert!(TerminationStatus::Terminated(TerminationReason::KeyboardInterrupt).terminated());
/// assert!(TerminationStatus::Terminated(TerminationReason::Timeout).terminated());
/// assert!(TerminationStatus::Terminated(TerminationReason::SolverExit("Exit reason".to_string())).terminated());
/// ```
pub fn terminated(&self) -> bool {
Expand Down Expand Up @@ -59,6 +60,8 @@ pub enum TerminationReason {
KeyboardInterrupt,
/// Converged
SolverConverged,
/// Timeout reached
Timeout,
/// Solver exit with given reason
SolverExit(String),
}
Expand Down Expand Up @@ -88,6 +91,10 @@ impl TerminationReason {
/// "Solver converged"
/// );
/// assert_eq!(
/// TerminationReason::Timeout.text(),
/// "Timeout reached"
/// );
/// assert_eq!(
/// TerminationReason::SolverExit("Aborted".to_string()).text(),
/// "Aborted"
/// );
Expand All @@ -98,6 +105,7 @@ impl TerminationReason {
TerminationReason::TargetCostReached => "Target cost value reached",
TerminationReason::KeyboardInterrupt => "Keyboard interrupt",
TerminationReason::SolverConverged => "Solver converged",
TerminationReason::Timeout => "Timeout reached",
TerminationReason::SolverExit(reason) => reason.as_ref(),
}
}
Expand Down
14 changes: 14 additions & 0 deletions examples/timeout/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "example-timeout"
version = "0.1.0"
edition = "2021"
license = "MIT OR Apache-2.0"
publish = false

[dependencies]
argmin = { version = "*", path = "../../argmin" }
argmin-math = { version = "*", features = ["vec"], path = "../../argmin-math" }
argmin-observer-slog = { version = "*", path = "../../observers/slog/" }
argmin_testfunctions = "*"
rand = "0.8.5"
rand_xoshiro = "0.6.0"
98 changes: 98 additions & 0 deletions examples/timeout/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2018-2024 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// This example shows how to add a timeout to an optimization run. The optimization will be
// terminated once the 3 seconds timeout is reached.

use argmin::{
core::{observers::ObserverMode, CostFunction, Error, Executor},
solver::simulatedannealing::{Anneal, SATempFunc, SimulatedAnnealing},
};
use argmin_observer_slog::SlogLogger;
use argmin_testfunctions::rosenbrock;
use rand::{distributions::Uniform, prelude::*};
use rand_xoshiro::Xoshiro256PlusPlus;
use std::sync::{Arc, Mutex};

struct Rosenbrock {
a: f64,
b: f64,
lower_bound: Vec<f64>,
upper_bound: Vec<f64>,
rng: Arc<Mutex<Xoshiro256PlusPlus>>,
}

impl Rosenbrock {
pub fn new(a: f64, b: f64, lower_bound: Vec<f64>, upper_bound: Vec<f64>) -> Self {
Rosenbrock {
a,
b,
lower_bound,
upper_bound,
rng: Arc::new(Mutex::new(Xoshiro256PlusPlus::from_entropy())),
}
}
}

impl CostFunction for Rosenbrock {
type Param = Vec<f64>;
type Output = f64;

fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
Ok(rosenbrock(param, self.a, self.b))
}
}

impl Anneal for Rosenbrock {
type Param = Vec<f64>;
type Output = Vec<f64>;
type Float = f64;

fn anneal(&self, param: &Vec<f64>, temp: f64) -> Result<Vec<f64>, Error> {
let mut param_n = param.clone();
let mut rng = self.rng.lock().unwrap();
let distr = Uniform::from(0..param.len());
for _ in 0..(temp.floor() as u64 + 1) {
let idx = rng.sample(distr);
let val = rng.sample(Uniform::new_inclusive(-0.1, 0.1));
param_n[idx] += val;
param_n[idx] = param_n[idx].clamp(self.lower_bound[idx], self.upper_bound[idx]);
}
Ok(param_n)
}
}

fn run() -> Result<(), Error> {
let lower_bound: Vec<f64> = vec![-5.0, -5.0];
let upper_bound: Vec<f64> = vec![5.0, 5.0];
let operator = Rosenbrock::new(1.0, 100.0, lower_bound, upper_bound);
let init_param: Vec<f64> = vec![1.0, 1.2];
let temp = 15.0;
let solver = SimulatedAnnealing::new(temp)?.with_temp_func(SATempFunc::Boltzmann);

let res = Executor::new(operator, solver)
.configure(|state| state.param(init_param).max_iters(10_000_000))
.add_observer(SlogLogger::term(), ObserverMode::Always)
/////////////////////////////////////////////////////////////////////////////////////////
// //
// Add a timeout of 3 seconds //
// //
/////////////////////////////////////////////////////////////////////////////////////////
.timeout(std::time::Duration::from_secs(3))
.run()?;

// Print result
println!("{res}");
Ok(())
}

fn main() {
if let Err(ref e) = run() {
println!("{e}");
std::process::exit(1);
}
}
3 changes: 3 additions & 0 deletions media/book/src/running_solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,6 @@ let function_evaluation_counts = res.state().get_func_counts();
# }
```

Optionally, `Executor` allows one to terminate a run after a given timeout, which can be set with the `timeout` method of `Executor`.
The check whether the overall runtime exceeds the timeout is performed after every iteration, therefore the actual runtime can be longer than the set timeout.
In case of timeout, the run terminates with `TerminationReason::Timeout`.

0 comments on commit 0445393

Please sign in to comment.