@@ -492,41 +492,57 @@ def add_reader(self, fd, callback, *args):
492
492
self ._ensure_fd_no_transport (fd )
493
493
return self ._add_reader (fd , callback , * args )
494
494
495
- def _add_reader (self , fd , callback , * args ):
495
+ # Local helper to factor out common logic between _add_reader/_add_writer
496
+ def _add_io_handler (self , set_handle , wait_ready , fd , callback , args ):
496
497
self ._check_closed ()
497
498
handle = ScopedHandle (callback , args , self )
498
- reader = self ._set_read_handle (fd , handle )
499
- if reader is not None :
500
- reader .cancel ()
501
- if self ._token is None :
502
- return
503
- self ._nursery .start_soon (self ._reader_loop , fd , handle )
499
+ old_handle = set_handle (fd , handle )
504
500
505
- def _set_read_handle (self , fd , handle ):
506
- try :
507
- key = self ._selector .get_key (fd )
508
- except KeyError :
509
- self ._selector .register (fd , EVENT_READ , (handle , None ))
501
+ if old_handle is not None :
502
+ old_handle .cancel ()
503
+ if self ._token is None :
510
504
return None
511
- else :
512
- mask , (reader , writer ) = key .events , key .data
513
- self ._selector .modify (fd , mask | EVENT_READ , (handle , writer ))
514
- return reader
505
+ self ._nursery .start_soon (self ._io_task , fd , handle , wait_ready )
506
+ return handle
515
507
516
- async def _reader_loop (self , fd , handle ):
508
+ async def _io_task (self , fd , handle , wait_ready ):
517
509
with handle ._scope :
518
510
try :
519
511
while True :
520
512
if handle ._cancelled :
521
513
break
522
- await _wait_readable (fd )
514
+ try :
515
+ await wait_ready (fd )
516
+ except OSError :
517
+ # maybe someone did
518
+ # h = add_reader(sock); h.cancel(); sock.close()
519
+ # without yielding to the event loop
520
+ if handle ._cancelled :
521
+ break
522
+ raise
523
523
if handle ._cancelled :
524
524
break
525
525
handle ._run ()
526
526
await self .synchronize ()
527
527
except Exception as exc :
528
528
handle ._raise (exc )
529
529
530
+ def _add_reader (self , fd , callback , * args ):
531
+ return self ._add_io_handler (
532
+ self ._set_read_handle , _wait_readable , fd , callback , args
533
+ )
534
+
535
+ def _set_read_handle (self , fd , handle ):
536
+ try :
537
+ key = self ._selector .get_key (fd )
538
+ except KeyError :
539
+ self ._selector .register (fd , EVENT_READ , (handle , None ))
540
+ return None
541
+ else :
542
+ mask , (reader , writer ) = key .events , key .data
543
+ self ._selector .modify (fd , mask | EVENT_READ , (handle , writer ))
544
+ return reader
545
+
530
546
# writing to a file descriptor
531
547
532
548
def add_writer (self , fd , callback , * args ):
@@ -546,15 +562,10 @@ def add_writer(self, fd, callback, *args):
546
562
547
563
# remove_writer: unchanged from asyncio
548
564
549
- def _add_writer (self , fd , callback , * args ):
550
- self ._check_closed ()
551
- handle = ScopedHandle (callback , args , self )
552
- writer = self ._set_write_handle (fd , handle )
553
- if writer is not None :
554
- writer .cancel ()
555
- if self ._token is None :
556
- return
557
- self ._nursery .start_soon (self ._writer_loop , fd , handle )
565
+ def _add_writer (self , fd , callback , * args , _defer_start = False ):
566
+ return self ._add_io_handler (
567
+ self ._set_write_handle , _wait_writable , fd , callback , args
568
+ )
558
569
559
570
def _set_write_handle (self , fd , handle ):
560
571
try :
@@ -566,20 +577,6 @@ def _set_write_handle(self, fd, handle):
566
577
self ._selector .modify (fd , mask | EVENT_WRITE , (reader , handle ))
567
578
return writer
568
579
569
- async def _writer_loop (self , fd , handle ):
570
- with handle ._scope :
571
- try :
572
- while True :
573
- if handle ._cancelled :
574
- break
575
- await _wait_writable (fd )
576
- if handle ._cancelled :
577
- break
578
- handle ._run ()
579
- await self .synchronize ()
580
- except Exception as exc :
581
- handle ._raise (exc )
582
-
583
580
def autoclose (self , fd ):
584
581
"""
585
582
Mark a file descriptor so that it's auto-closed along with this loop.
@@ -752,6 +749,7 @@ async def _main_loop_exit(self):
752
749
# clean core fields
753
750
self ._nursery = None
754
751
self ._task = None
752
+ self ._token = None
755
753
756
754
def is_running (self ):
757
755
if self ._stopped is None :
@@ -778,6 +776,11 @@ async def wait_stopped(self):
778
776
"""
779
777
await self ._stopped .wait ()
780
778
779
+ def _trio_io_cancel (self , cancel_scope ):
780
+ """Called when a ScopedHandle representing an I/O reader or writer
781
+ has its cancel() method called."""
782
+ cancel_scope .cancel ()
783
+
781
784
def stop (self ):
782
785
"""Halt the main loop.
783
786
0 commit comments