Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: integrate one step with const integrators #22663

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions systems/analysis/explicit_euler_integrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,23 @@ class ExplicitEulerIntegrator final : public IntegratorBase<T> {
int get_error_estimate_order() const override { return 0; }

private:
bool DoStep(const T& h) override;
bool DoStepConst(const T& h, Context<T>* context) const override;
};

/**
* Integrates the system forward in time by h, starting at the current time t₀.
* This value of h is determined by IntegratorBase::Step().
*/
template <class T>
bool ExplicitEulerIntegrator<T>::DoStep(const T& h) {
Context<T>& context = *this->get_mutable_context();

bool ExplicitEulerIntegrator<T>::DoStepConst(const T& h,
Context<T>* context) const {
// CAUTION: This is performance-sensitive inner loop code that uses dangerous
// long-lived references into state and cache to avoid unnecessary copying and
// cache invalidation. Be careful not to insert calls to methods that could
// invalidate any of these references before they are used.

// Evaluate derivative xcdot₀ ← xcdot(t₀, x(t₀), u(t₀)).
const ContinuousState<T>& xc_deriv = this->EvalTimeDerivatives(context);
const ContinuousState<T>& xc_deriv = this->EvalTimeDerivatives(*context);
const VectorBase<T>& xcdot0 = xc_deriv.get_vector();

// Cache: xcdot0 references the live derivative cache value, currently
Expand All @@ -78,8 +77,8 @@ bool ExplicitEulerIntegrator<T>::DoStep(const T& h) {

// Update continuous state and time. This call marks t- and xc-dependent
// cache entries out of date, including xcdot0.
VectorBase<T>& xc = context.SetTimeAndGetMutableContinuousStateVector(
context.get_time() + h); // t ← t₀ + h
VectorBase<T>& xc = context->SetTimeAndGetMutableContinuousStateVector(
context->get_time() + h); // t ← t₀ + h

// Cache: xcdot0 still references the derivative cache value, which is
// unchanged, although it is marked out of date.
Expand Down
24 changes: 16 additions & 8 deletions systems/analysis/integrator_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,8 @@ class IntegratorBase {
integrator convergence)
- Takes only a single step forward.
*/
[[nodiscard]] bool IntegrateWithSingleFixedStepToTime(const T& t_target) {
[[nodiscard]] bool IntegrateWithSingleFixedStepToTime(
const T& t_target) const {
using std::max;
using std::abs;

Expand All @@ -1063,10 +1064,10 @@ class IntegratorBase {
throw std::logic_error("IntegrateWithSingleFixedStepToTime() requires "
"fixed stepping.");

if (!Step(h))
if (!DoStepConst(h, context_))
return false;

UpdateStepStatistics(h);
// UpdateStepStatistics(h);

if constexpr (scalar_predicate<T>::is_bool) {
// Correct any round-off error that has occurred. Formula below requires
Expand Down Expand Up @@ -1318,7 +1319,8 @@ class IntegratorBase {
Subclasses should call this function rather than calling
system.EvalTimeDerivatives() directly.
*/
const ContinuousState<T>& EvalTimeDerivatives(const Context<T>& context) {
const ContinuousState<T>& EvalTimeDerivatives(
const Context<T>& context) const {
return EvalTimeDerivatives(get_system(), context); // See below.
}

Expand All @@ -1330,8 +1332,8 @@ class IntegratorBase {
function evaluations.
*/
template <typename U>
const ContinuousState<U>& EvalTimeDerivatives(const System<U>& system,
const Context<U>& context) {
const ContinuousState<U>& EvalTimeDerivatives(
const System<U>& system, const Context<U>& context) const {
const CacheEntry& entry = system.get_time_derivatives_cache_entry();
const CacheEntryValue& value = entry.get_cache_entry_value(context);
const int64_t serial_number_before = value.serial_number();
Expand Down Expand Up @@ -1460,7 +1462,7 @@ class IntegratorBase {
example, by switching to an algorithm not subject to convergence
failures (e.g., explicit Euler) for very small step sizes.
*/
virtual bool DoStep(const T& h) = 0;
virtual bool DoStep(const T& h) { return DoStepConst(h, context_); }

// TODO(russt): Allow subclasses to override the interpolation scheme used, as
// the 'optimal' dense output scheme is only known by the specific integration
Expand Down Expand Up @@ -1516,6 +1518,12 @@ class IntegratorBase {
return true;
}

virtual bool DoStepConst(const T& h, Context<T>* context) const {
unused(h, context);
throw std::runtime_error(
"This integrator does not (yet) implement the const DoStep() variant.");
}

/**
* Gets an error estimate of the state variables recorded by the last call
* to StepOnceFixedSize(). If the integrator does not support error
Expand Down Expand Up @@ -1628,7 +1636,7 @@ class IntegratorBase {
T smallest_adapted_step_size_taken_{nan()};
T largest_step_size_taken_{nan()};
int64_t num_steps_taken_{0};
int64_t num_ode_evals_{0};
mutable int64_t num_ode_evals_{0};
int64_t num_shrinkages_from_error_control_{0};
int64_t num_shrinkages_from_substep_failures_{0};
int64_t num_substep_failures_{0};
Expand Down
7 changes: 3 additions & 4 deletions systems/analysis/runge_kutta2_integrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class RungeKutta2Integrator final : public IntegratorBase<T> {
int get_error_estimate_order() const override { return 0; }

private:
bool DoStep(const T& h) override;
bool DoStepConst(const T& h, Context<T>* context) const override;

// A pre-allocated temporary for use by integration.
std::unique_ptr<ContinuousState<T>> derivs0_;
Expand All @@ -67,9 +67,8 @@ class RungeKutta2Integrator final : public IntegratorBase<T> {
* </pre>
*/
template <class T>
bool RungeKutta2Integrator<T>::DoStep(const T& h) {
Context<T>* const context = IntegratorBase<T>::get_mutable_context();

bool RungeKutta2Integrator<T>::DoStepConst(const T& h,
Context<T>* context) const {
// CAUTION: This is performance-sensitive inner loop code that uses dangerous
// long-lived references into state and cache to avoid unnecessary copying and
// cache invalidation. Be careful not to insert calls to methods that could
Expand Down
89 changes: 89 additions & 0 deletions systems/analysis/runge_kutta3_integrator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,95 @@ bool RungeKutta3Integrator<T>::DoStep(const T& h) {
return true;
}

template <class T>
bool RungeKutta3Integrator<T>::DoStepConst(const T& h, Context<T>* context) const {
using std::abs;
const T t0 = context->get_time();
const T t1 = t0 + h;

// CAUTION: This is performance-sensitive inner loop code that uses dangerous
// long-lived references into state and cache to avoid unnecessary copying and
// cache invalidation. Be careful not to insert calls to methods that could
// invalidate any of these references before they are used.

// TODO(sherm1) Consider moving this notation description to IntegratorBase
// when it is more widely adopted.
// Notation: we're using numeric subscripts for times t₀ and t₁, and
// lower-case letter superscripts like t⁽ᵃ⁾ and t⁽ᵇ⁾ to indicate values
// for intermediate stages of which there are two here, a and b.
// State x₀ = {xc₀, xd₀, xa₀}. We modify only t and xc here, but
// derivative calculations depend on everything in the context, including t,
// x and inputs u (which may depend on t and x).
// Define x⁽ᵃ⁾ ≜ {xc⁽ᵃ⁾, xd₀, xa₀} and u⁽ᵃ⁾ ≜ u(t⁽ᵃ⁾, x⁽ᵃ⁾).

// Evaluate derivative xcdot₀ ← xcdot(t₀, x(t₀), u(t₀)). Copy the result
// into a temporary since we'll be calculating more derivatives below.
derivs0_->get_mutable_vector().SetFrom(
this->EvalTimeDerivatives(*context).get_vector());
const VectorBase<T>& xcdot0 = derivs0_->get_vector();

// Cache: xcdot0 references a *copy* of the derivative result so is immune
// to subsequent evaluations.

// Compute the first intermediate state and derivative
// (at t⁽ᵃ⁾=t₀+h/2, x⁽ᵃ⁾, u⁽ᵃ⁾).

// This call marks t- and xc-dependent cache entries out of date, including
// the derivative cache entry. Note that xc is a live reference into the
// context -- subsequent changes through that reference are unobservable so
// will require manual out-of-date notifications.
VectorBase<T>& xc = context->SetTimeAndGetMutableContinuousStateVector(
t0 + h / 2); // t⁽ᵃ⁾ ← t₀ + h/2
xc.CopyToPreSizedVector(&save_xc0_); // Save xc₀ while we can.
xc.PlusEqScaled(h / 2, xcdot0); // xc⁽ᵃ⁾ ← xc₀ + h/2 xcdot₀

derivs1_->get_mutable_vector().SetFrom(
this->EvalTimeDerivatives(*context).get_vector());
const VectorBase<T>& xcdot_a = derivs1_->get_vector(); // xcdot⁽ᵃ⁾

// Cache: xcdot_a references a *copy* of the derivative result so is immune
// to subsequent evaluations.

// Compute the second intermediate state and derivative
// (at t⁽ᵇ⁾=t₁, x⁽ᵇ⁾, u⁽ᵇ⁾).

// This call marks t- and xc-dependent cache entries out of date, including
// the derivative cache entry. (We already have the xc reference but must
// issue the out-of-date notification here since we're about to change it.)
context->SetTimeAndNoteContinuousStateChange(t1);

// xcⱼ ← xc₀ - h xcdot₀ + 2 h xcdot⁽ᵃ⁾
xc.SetFromVector(save_xc0_); // Restore xc ← xc₀.
xc.PlusEqScaled({{-h, xcdot0}, {2 * h, xcdot_a}});

const VectorBase<T>& xcdot_b = // xcdot⁽ᵇ⁾
this->EvalTimeDerivatives(*context).get_vector();

// Cache: xcdot_b references the live derivative cache value, currently
// up to date but about to be marked out of date. We do not want to make
// an unnecessary copy of this data.

// Cache: we're about to write through the xc reference again, so need to
// mark xc-dependent cache entries out of date, including xcdot_b; time
// doesn't change here.
context->NoteContinuousStateChange();

// Calculate the final O(h³) state at t₁.
// xc₁ ← xc₀ + h/6 xcdot₀ + 2/3 h xcdot⁽ᵃ⁾ + h/6 xcdot⁽ᵇ⁾
xc.SetFromVector(save_xc0_); // Restore xc ← xc₀.
const T h6 = h / 6.0;

// Cache: xcdot_b still references the derivative cache value, which is
// unchanged, although it is marked out of date. xcdot0 and xcdot_a are
// unaffected.
xc.PlusEqScaled({{h6, xcdot0},
{4 * h6, xcdot_a},
{h6, xcdot_b}});

// RK3 always succeeds in taking its desired step.
return true;
}

} // namespace systems
} // namespace drake

Expand Down
5 changes: 3 additions & 2 deletions systems/analysis/runge_kutta3_integrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,18 @@ class RungeKutta3Integrator final : public IntegratorBase<T> {
private:
void DoInitialize() override;
bool DoStep(const T& h) override;
bool DoStepConst(const T& h, Context<T>* context) const override;

// Vector used in error estimate calculations.
VectorX<T> err_est_vec_;

// Vector used to save initial value of xc.
VectorX<T> save_xc0_;
mutable VectorX<T> save_xc0_;

// These are pre-allocated temporaries for use by integration. They store
// the derivatives computed at various points within the integration
// interval.
std::unique_ptr<ContinuousState<T>> derivs0_, derivs1_;
mutable std::unique_ptr<ContinuousState<T>> derivs0_, derivs1_;
};

} // namespace systems
Expand Down
17 changes: 8 additions & 9 deletions systems/analysis/semi_explicit_euler_integrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,28 +104,27 @@ class SemiExplicitEulerIntegrator final : public IntegratorBase<T> {
bool supports_error_estimation() const override { return false; }

private:
bool DoStep(const T& h) override;
bool DoStepConst(const T& h, Context<T>* context) const override;

// This is a pre-allocated temporary for use by integration
BasicVector<T> qdot_;
mutable BasicVector<T> qdot_;
};

/**
* Integrates the system forward in time by h. This value is determined
* by IntegratorBase::StepOnce().
*/
template <class T>
bool SemiExplicitEulerIntegrator<T>::DoStep(const T& h) {
bool SemiExplicitEulerIntegrator<T>::DoStepConst(const T& h,
Context<T>* context) const {
const System<T>& system = this->get_system();
Context<T>& context = *this->get_mutable_context();

// CAUTION: This is performance-sensitive inner loop code that uses dangerous
// long-lived references into state and cache to avoid unnecessary copying and
// cache invalidation. Be careful not to insert calls to methods that could
// invalidate any of these references before they are used.

// Evaluate derivative xcdot(t₀) ← xcdot(t₀, x(t₀), u(t₀)).
const ContinuousState<T>& xc_deriv = this->EvalTimeDerivatives(context);
const ContinuousState<T>& xc_deriv = this->EvalTimeDerivatives(*context);
// Retrieve the accelerations and auxiliary variable derivatives.
const VectorBase<T>& vdot = xc_deriv.get_generalized_velocity();
const VectorBase<T>& zdot = xc_deriv.get_misc_continuous_state();
Expand All @@ -137,7 +136,7 @@ bool SemiExplicitEulerIntegrator<T>::DoStep(const T& h) {
// This invalidates computations that are dependent on v or z.
// Marks v- and z-dependent cache entries out of date, including vdot and
// zdot; time doesn't change here.
std::pair<VectorBase<T>*, VectorBase<T>*> vz = context.GetMutableVZVectors();
std::pair<VectorBase<T>*, VectorBase<T>*> vz = context->GetMutableVZVectors();
VectorBase<T>& v = *vz.first;
VectorBase<T>& z = *vz.second;

Expand All @@ -151,13 +150,13 @@ bool SemiExplicitEulerIntegrator<T>::DoStep(const T& h) {
// Convert the updated generalized velocity to the time derivative of
// generalized coordinates. Note that this mapping is q-dependent and
// hasn't been invalidated if it was pre-computed.
system.MapVelocityToQDot(context, v, &qdot_);
system.MapVelocityToQDot(*context, v, &qdot_);

// Now set time and q to their final values. This marks time- and
// q-dependent cache entries out of date. That includes the derivative
// cache entry though we don't need it again here.
VectorBase<T>& q =
context.SetTimeAndGetMutableQVector(context.get_time() + h);
context->SetTimeAndGetMutableQVector(context->get_time() + h);
q.PlusEqScaled(h, qdot_);

// This integrator always succeeds at taking the step.
Expand Down