Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proposal: distributor request fn #321

Merged
merged 9 commits into from
Apr 13, 2021
84 changes: 53 additions & 31 deletions src/bastion/examples/distributor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,19 @@ struct ConferenceSchedule {
misc: String,
}

/// cargo r --features=tokio-runtime distributor
/// cargo r --features=tokio-runtime --example distributor
#[cfg(feature = "tokio-runtime")]
#[tokio::main]
async fn main() -> AnyResult<()> {
run()
}

#[cfg(not(feature = "tokio-runtime"))]
fn main() -> AnyResult<()> {
run()
}

fn run() -> AnyResult<()> {
let subscriber = tracing_subscriber::fmt()
.with_max_level(Level::INFO)
.finish();
Expand Down Expand Up @@ -119,47 +129,48 @@ async fn main() -> AnyResult<()> {
Bastion::start();

// Wait a bit until everyone is ready
// std::thread::sleep(std::time::Duration::from_secs(1));
sleep(std::time::Duration::from_secs(5));

let staff = Distributor::named("staff");
let enthusiasts = Distributor::named("enthusiasts");
let attendees = Distributor::named("attendees");

// Enthusiast -> Ask one of the staff members "when is the conference going to happen ?"
let answer = staff.ask_one("when is the next conference going to happen?")?;
MessageHandler::new(
answer
let reply: Result<String, SendError> = run!(async {
staff
.request("when is the next conference going to happen?")
.await
.expect("coulnd't find out when the next conference is going to happen :("),
)
.on_tell(|reply: String, _sender_addr| {
tracing::info!("received a reply to my message:\n{}", reply);
.expect("couldn't receive reply")
});

tracing::error!("{:?}", reply); // Ok("Next month!")

// "hey conference <awesomeconference> is going to happen. will you be there?"
// Broadcast / Question -> if people reply with YES => fill the 3rd group
let answers = enthusiasts
.ask_everyone("hey, the conference is going to happen, will you be there?")
.expect("couldn't ask everyone");

for answer in answers.into_iter() {
MessageHandler::new(answer.await.expect("couldn't receive reply"))
.on_tell(|rsvp: RSVP, _| {
if rsvp.attends {
tracing::info!("{:?} will be there! :)", rsvp.child_ref.id());
attendees
.subscribe(rsvp.child_ref)
.expect("couldn't subscribe attendee");
} else {
tracing::error!("{:?} won't make it :(", rsvp.child_ref.id());
}
})
.on_fallback(|unknown, _sender_addr| {
tracing::error!(
"distributor_test: uh oh, I received a message I didn't understand\n {:?}",
unknown
);
});
run!(async move {
MessageHandler::new(answer.await.expect("couldn't receive reply"))
.on_tell(|rsvp: RSVP, _| {
if rsvp.attends {
tracing::info!("{:?} will be there! :)", rsvp.child_ref.id());
attendees
.subscribe(rsvp.child_ref)
.expect("couldn't subscribe attendee");
} else {
tracing::error!("{:?} won't make it :(", rsvp.child_ref.id());
}
})
.on_fallback(|unknown, _sender_addr| {
tracing::error!(
"distributor_test: uh oh, I received a message I didn't understand\n {:?}",
unknown
);
});
});
}

// Ok now that attendees have subscribed, let's send information around!
Expand All @@ -176,14 +187,15 @@ async fn main() -> AnyResult<()> {
tracing::error!("total number of attendees: {}", total_sent.len());

tracing::info!("the conference is running!");
tokio::time::sleep(std::time::Duration::from_secs(10)).await;

// Let's wait until the conference is over 8D
sleep(std::time::Duration::from_secs(5));

// An attendee sends a thank you note to one staff member (and not bother everyone)
staff
.tell_one("the conference was amazing thank you so much!")
.context("couldn't thank the staff members :(")?;

tokio::time::sleep(std::time::Duration::from_secs(1)).await;
// And we're done!
Bastion::stop();

Expand All @@ -198,9 +210,7 @@ async fn organize_the_event(ctx: BastionContext) -> Result<(), ()> {
MessageHandler::new(ctx.recv().await?)
.on_question(|message: &str, sender| {
tracing::info!("received a question: \n{}", message);
sender
.reply("uh i think it will be next month!".to_string())
.unwrap();
sender.reply("Next month!".to_string()).unwrap();
})
.on_tell(|message: &str, _| {
tracing::info!("received a message: \n{}", message);
Expand Down Expand Up @@ -243,3 +253,15 @@ async fn be_interested_in_the_conference(ctx: BastionContext) -> Result<(), ()>
});
}
}

#[cfg(feature = "tokio-runtime")]
fn sleep(duration: std::time::Duration) {
run!(async {
tokio::time::sleep(duration).await;
});
}

#[cfg(not(feature = "tokio-runtime"))]
fn sleep(duration: std::time::Duration) {
std::thread::sleep(duration);
}
100 changes: 88 additions & 12 deletions src/bastion/src/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) enum CallbackType {
AfterStop,
BeforeRestart,
BeforeStart,
AfterStart,
}

#[derive(Default, Clone)]
Expand All @@ -22,12 +23,12 @@ pub(crate) enum CallbackType {
/// # #[cfg(feature = "tokio-runtime")]
/// # #[tokio::main]
/// # async fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # #[cfg(not(feature = "tokio-runtime"))]
/// # fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # fn run() {
Expand Down Expand Up @@ -60,6 +61,7 @@ pub(crate) enum CallbackType {
/// [`Children`]: crate::children::Children
pub struct Callbacks {
before_start: Option<Arc<dyn Fn() + Send + Sync>>,
after_start: Option<Arc<dyn Fn() + Send + Sync>>,
before_restart: Option<Arc<dyn Fn() + Send + Sync>>,
after_restart: Option<Arc<dyn Fn() + Send + Sync>>,
after_stop: Option<Arc<dyn Fn() + Send + Sync>>,
Expand All @@ -77,12 +79,12 @@ impl Callbacks {
/// # #[cfg(feature = "tokio-runtime")]
/// # #[tokio::main]
/// # async fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # #[cfg(not(feature = "tokio-runtime"))]
/// # fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # fn run() {
Expand Down Expand Up @@ -136,12 +138,12 @@ impl Callbacks {
/// # #[cfg(feature = "tokio-runtime")]
/// # #[tokio::main]
/// # async fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # #[cfg(not(feature = "tokio-runtime"))]
/// # fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # fn run() {
Expand Down Expand Up @@ -191,6 +193,74 @@ impl Callbacks {
self
}

/// Sets the method that will get called right after the [`Supervisor`]
/// or [`Children`] is launched.
/// This method will be called after the child has subscribed to its distributors and dispatchers.
///
/// Once the callback has run, the child has caught up it's message backlog,
/// and is waiting for new messages to process.
///
/// # Example
///
/// ```rust
/// # use bastion::prelude::*;
/// #
/// # #[cfg(feature = "tokio-runtime")]
/// # #[tokio::main]
/// # async fn main() {
/// # run();
/// # }
/// #
/// # #[cfg(not(feature = "tokio-runtime"))]
/// # fn main() {
/// # run();
/// # }
/// #
/// # fn run() {
/// # Bastion::init();
/// #
/// # Bastion::supervisor(|supervisor| {
/// supervisor.children(|children| {
/// let callbacks = Callbacks::new()
/// .with_after_start(|| println!("Children group ready to process messages."));
///
/// children
/// .with_exec(|ctx| {
/// // -- Children group started.
/// // with_after_start called
/// async move {
/// // ...
///
/// // This will stop the children group...
/// Ok(())
/// // Note that because the children group stopped by itself,
/// // if its supervisor restarts it, its `before_start` callback
/// // will get called and not `after_restart`.
/// }
/// // -- Children group stopped.
/// })
/// .with_callbacks(callbacks)
/// })
/// # }).unwrap();
/// #
/// # Bastion::start();
/// # Bastion::stop();
/// # Bastion::block_until_stopped();
/// # }
/// ```
///
/// [`Supervisor`]: crate::supervisor::Supervisor
/// [`Children`]: crate::children::Children
/// [`with_after_restart`]: Self::with_after_restart
pub fn with_after_start<C>(mut self, after_start: C) -> Self
where
C: Fn() + Send + Sync + 'static,
{
let after_start = Arc::new(after_start);
self.after_start = Some(after_start);
self
}

/// Sets the method that will get called before the [`Supervisor`]
/// or [`Children`] is reset if:
/// - the supervisor of the supervised element using this callback
Expand All @@ -208,12 +278,12 @@ impl Callbacks {
/// # #[cfg(feature = "tokio-runtime")]
/// # #[tokio::main]
/// # async fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # #[cfg(not(feature = "tokio-runtime"))]
/// # fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # fn run() {
Expand Down Expand Up @@ -282,12 +352,12 @@ impl Callbacks {
/// # #[cfg(feature = "tokio-runtime")]
/// # #[tokio::main]
/// # async fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # #[cfg(not(feature = "tokio-runtime"))]
/// # fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # fn run() {
Expand Down Expand Up @@ -360,12 +430,12 @@ impl Callbacks {
/// # #[cfg(feature = "tokio-runtime")]
/// # #[tokio::main]
/// # async fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # #[cfg(not(feature = "tokio-runtime"))]
/// # fn main() {
/// # run();
/// # run();
/// # }
/// #
/// # fn run() {
Expand Down Expand Up @@ -493,6 +563,12 @@ impl Callbacks {
}
}

pub(crate) fn after_start(&self) {
if let Some(after_start) = &self.after_start {
after_start()
}
}

pub(crate) fn before_restart(&self) {
if let Some(before_restart) = &self.before_restart {
before_restart()
Expand Down
5 changes: 4 additions & 1 deletion src/bastion/src/child.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ impl Child {
CallbackType::BeforeRestart => self.callbacks.before_restart(),
CallbackType::AfterRestart => self.callbacks.after_restart(),
CallbackType::AfterStop => self.callbacks.after_stop(),
CallbackType::AfterStart => self.callbacks.after_start(),
}
}

Expand Down Expand Up @@ -312,6 +313,8 @@ impl Child {
return;
};

self.callbacks.after_start();

loop {
#[cfg(feature = "scaling")]
self.update_stats().await;
Expand Down Expand Up @@ -416,7 +419,7 @@ impl Child {
distributors
.iter()
.map(|&distributor| {
global_dispatcher.register_recipient(distributor, child_ref.clone())
global_dispatcher.register_recipient(&distributor, child_ref.clone())
})
.collect::<AnyResult<Vec<_>>>()?;
}
Expand Down
Loading