Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subscriptions now handles unsubs even if the context raises an Exception. #162

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 246 additions & 0 deletions epymorph/test/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from epymorph import util
from epymorph.data_shape import DataShapeMatcher, Shapes, SimDimensions
from epymorph.util import Event, subscriptions
from epymorph.util import match as m


Expand Down Expand Up @@ -103,3 +104,248 @@ def test_check_ndarray_03(self):
util.check_ndarray(arr, shape=DataShapeMatcher(Shapes.TxN, dim2, True))
with self.assertRaises(util.NumpyTypeError):
util.check_ndarray(arr, dtype=m.dtype(np.str_))


class TestEvent(unittest.TestCase):
def setUp(self):
"""Set up a new Event instance for each test."""
self.event = Event[int]()

def test_subscribe_adds_subscriber(self):
"""Test: subscribing adds a subscriber."""

def handler(event: int):
pass

self.event.subscribe(handler)
self.assertEqual(len(self.event._subscribers), 1)
self.assertIn(handler, self.event._subscribers)

def test_unsubscribe_removes_subscriber(self):
"""Test: unsubscribing removes the correct subscriber."""

def handler(event: int):
pass

self.assertEqual(len(self.event._subscribers), 0)

unsubscribe = self.event.subscribe(handler)
self.assertEqual(len(self.event._subscribers), 1)

unsubscribe()
self.assertEqual(len(self.event._subscribers), 0)
self.assertNotIn(handler, self.event._subscribers)

def test_publish_calls_subscriber(self):
"""Test: publish calls the subscribed handler."""
self.subscriber_called = False

def handler(event: int):
self.subscriber_called = True
self.assertEqual(event, 42)

self.event.subscribe(handler)
self.event.publish(42)
self.assertTrue(self.subscriber_called)

def test_publish_multiple_subscribers(self):
"""Test: publish calls all subscribers."""
self.subscriber1_called = False
self.subscriber2_called = False

def handler1(event: int):
self.subscriber1_called = True

def handler2(event: int):
self.subscriber2_called = True

self.event.subscribe(handler1)
self.event.subscribe(handler2)
self.event.publish(42)

self.assertTrue(self.subscriber1_called)
self.assertTrue(self.subscriber2_called)

def test_unsubscribed_handler_not_called(self):
"""Test that unsubscribed handler is not called when event is published."""
self.subscriber_called = False

def handler(event: int):
self.subscriber_called = True

unsubscribe = self.event.subscribe(handler)
unsubscribe()

self.event.publish(42)
self.assertFalse(self.subscriber_called)

def test_has_subscribers_initially_false(self):
"""Test: has_subscribers is False initially."""
self.assertFalse(self.event.has_subscribers)

def test_has_subscribers_after_subscribe(self):
"""Test: has_subscribers becomes True after subscribing."""

def handler(event: int):
pass

self.event.subscribe(handler)
self.assertTrue(self.event.has_subscribers)

def test_has_subscribers_after_unsubscribe(self):
"""Test: has_subscribers becomes False after unsubscribing all."""

def handler(event: int):
pass

unsubscribe = self.event.subscribe(handler)
unsubscribe()
self.assertFalse(self.event.has_subscribers)

def test_subscribe_multiple_times_same_handler(self):
"""Test: a handler can subscribe multiple times and all instances get called."""
call_count = 0

def handler(event: int):
nonlocal call_count
call_count += 1

self.event.subscribe(handler)
self.event.subscribe(handler)
self.event.publish(42)

self.assertEqual(call_count, 2)

def test_unsubscribe_multiple_times_same_handler(self):
"""Test: multiple subs of the same handler can be individually unsub'd."""
call_count = 0

def handler(event: int):
nonlocal call_count
call_count += 1

unsubscribe1 = self.event.subscribe(handler)
unsubscribe2 = self.event.subscribe(handler)

# Unsubscribe the first one
unsubscribe1()
self.event.publish(42)

self.assertEqual(call_count, 1)

# Unsubscribe the second one
unsubscribe2()
self.event.publish(42)

self.assertEqual(call_count, 1) # Should not be incremented again

def test_publish_with_no_subscribers(self):
"""Test: publishing with no subscribers does nothing."""
try:
self.event.publish(42)
except Exception as e: # noqa: BLE001
self.fail(f"publish raised an exception: {e}")


class TestSubscriptions(unittest.TestCase):
def setUp(self):
"""Set up a new Event instance for each test."""
self.event = Event[int]()

def test_no_subs(self):
"""Test: no subscribing happened."""
try:
with subscriptions() as _sub:
pass
except Exception as e: # noqa: BLE001
self.fail(f"subscriptions raised an exception: {e}")

def test_one_sub(self):
"""Test: one subscriber."""
acc = 0

def handler(event: int):
nonlocal acc
acc += event

# Events values published during the context will accumulate into `acc`,
# but not outside of the context.

self.event.publish(3)

with subscriptions() as sub:
sub.subscribe(self.event, handler)
self.assertTrue(self.event.has_subscribers)
self.event.publish(7)
self.event.publish(11)

self.event.publish(13)

self.assertEqual(acc, 18) # 7 + 11
self.assertFalse(self.event.has_subscribers)

def test_multiple_sub(self):
"""Test: multiple subscribers."""
acc = 0

def handler1(event: int):
nonlocal acc
acc += event

def handler2(event: int):
nonlocal acc
acc += event

self.event.publish(3)

with subscriptions() as sub:
sub.subscribe(self.event, handler1)
sub.subscribe(self.event, handler2)
self.assertTrue(self.event.has_subscribers)
self.event.publish(7)
self.event.publish(11)

self.event.publish(13)

self.assertEqual(acc, 36) # 2 * (7 + 11)
self.assertFalse(self.event.has_subscribers)

def test_before_sub(self):
"""Test: subscribers from before the context are untouched."""

acc1, acc2 = 0, 0

def handler_before(event: int):
nonlocal acc1
acc1 += event

def handler_context(event: int):
nonlocal acc2
acc2 += event

self.event.subscribe(handler_before)
self.event.publish(3)

with subscriptions() as sub:
sub.subscribe(self.event, handler_context)
self.event.publish(7)
self.assertEqual(len(self.event._subscribers), 2)

self.event.publish(13)

self.assertEqual(acc1, 23) # 3 + 7 + 13
self.assertEqual(acc2, 7) # 7
self.assertEqual(len(self.event._subscribers), 1)

def test_exception_in_context(self):
"""Test: subscribers are unsub'd even if an exception was thrown."""

def handler(event: int):
pass

with self.assertRaises(Exception):
with subscriptions() as sub:
sub.subscribe(self.event, handler)
raise Exception("ruh roh")

self.assertFalse(self.event.has_subscribers)
6 changes: 4 additions & 2 deletions epymorph/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,10 @@ def subscriptions() -> Generator[Subscriber, None, None]:
Subscriber will be automatically unsubscribed when the context closes.
"""
sub = Subscriber()
yield sub
sub.unsubscribe()
try:
yield sub
finally:
sub.unsubscribe()


# singletons
Expand Down