Skip to content

Commit

Permalink
added test_time_vector function
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter230655 committed Mar 2, 2025
1 parent e534ed4 commit 21668a7
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 6 deletions.
19 changes: 13 additions & 6 deletions opty/direct_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,25 +690,32 @@ def time_vector(self, solution=None, start_time=0.0):
"""Returns the time instances of the problem as an numpy ndarray.
Parameters
----------
==========
solution : (n*N + q*N + r + s)-ndarray, optional
The solution to to problem. Needed if the interval is variable.
The solution to to problem. Needed if the time interval is variable.
start_time : float, optional
The initial time of the problem. Default is 0.0.
Returns
-------
A numpy num_collocation_nodes-array of time instances.
=======
time_vector : ndarray, shape(num_collocation_nodes,)
The array of time instances.
"""
t0 = start_time
if self.collocator._variable_duration:
if solution is None:
msg = 'Solution vector must be provided for variable duration.'
raise ValueError(msg)
elif solution[-1] <= 0:
msg = 'Time interval must be strictly greater than zero.'
raise ValueError(msg)
elif t0 >= solution[-1] * self.collocator.num_collocation_nodes:
msg = 'Start time must be less than the final time.'
raise ValueError(msg)
else:
return np.arange(t0, t0 + self.collocator.num_collocation_nodes*
solution[-1], solution[-1])
return np.arange(t0, t0 + self.collocator.num_collocation_nodes
*solution[-1], solution[-1])

else:
return np.arange(t0, t0 + self.collocator.num_collocation_nodes*
Expand Down
95 changes: 95 additions & 0 deletions opty/tests/test_direct_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,3 +1923,98 @@ def test_attributes_read_only():

with raises(AttributeError):
setattr(test, XX, 5)

def test_time_vector():
"""Test to ensure that time_vector retunrs the correct values."""

x, ux = mech.dynamicsymbols('x ux')
t = mech.dynamicsymbols._t

# just random eoms, no physical meaning.
eom = sym.Matrix([
-x.diff(t) + ux,
-ux.diff(t) + 2.0,
])

state_symbols = (x, ux)
num_nodes = 25

# A: constant time interval
t0, tf = np.random.uniform(0.0, 2.0), 10.0
interval_value = (tf - t0) / (num_nodes - 1)

def obj(free):
Fx = free[0*num_nodes:2*num_nodes]
return interval_value*np.sum(Fx**2)

def obj_grad(free):
grad = np.zeros_like(free)
grad[0:2*num_nodes] = 2.0*free[0:2*num_nodes]*interval_value
return grad

prob = Problem(
obj,
obj_grad,
eom,
state_symbols,
num_nodes,
interval_value,
time_symbol=t,
backend='numpy'
)
expected_time_vector = np.arange(t0, t0 + num_nodes*interval_value,
interval_value)
time_vector = prob.time_vector(start_time=t0)
assert np.allclose(time_vector, expected_time_vector)

solution = np.random.randn(prob.num_free)
time_vector = prob.time_vector(solution, start_time=t0)
assert np.allclose(time_vector, expected_time_vector)

# B: variable time interval
h =sym.symbols('h')
interval_value = h

def obj(free):
Fx = free[0*num_nodes:2*num_nodes]
return solution[-1]*np.sum(Fx**2)

def obj_grad(free):
grad = np.zeros_like(free)
grad[0:2*num_nodes] = 2.0*free[0:2*num_nodes]*solution[-1]
return grad

prob = Problem(
obj,
obj_grad,
eom,
state_symbols,
num_nodes,
interval_value,
time_symbol=t,
backend='numpy'
)

# solution must be given
with raises(ValueError):
time_vector = prob.time_vector(start_time=t0)
time_vector = prob.time_vector()

solution = np.random.randn(prob.num_free)
solution[-1] = np.random.uniform(2.5/(num_nodes-1), 10.0/(num_nodes-1))
time_vector = prob.time_vector(solution, start_time=t0)
expected_time_vector = np.arange(t0, t0 + num_nodes*solution[-1],
solution[-1])
assert np.allclose(time_vector, expected_time_vector)

# final time > initial time
solution[-1] = 1.e-75
expected_time_vector = np.arange(t0, t0 + num_nodes*solution[-1],
solution[-1])
with raises(ValueError):
time_vector = prob.time_vector(solution, start_time=t0)

# interval_value must be greater than zero
solution[-1] = 0.0
with raises(ValueError):
time_vector = prob.time_vector(solution, start_time=t0)

0 comments on commit 21668a7

Please sign in to comment.