Skip to content

Commit 2b48bac

Browse files
committed
Add ability to suspend threads until a signal is called
For GDScript users that run code on threads, it often happens they want to `await` until something happens, the continue the thread execution. Unfortunately, signals will most likely execute on the main thread, hence this means that the code being run on the thread will continue running on the main thread by default. It has been discussed whether await should be smart about this and simply suspend the thread if running on one. The problem with this is, that users may also be willing to emit the signal that will resume from the thread itself and, at the time of awaiting, there is **no way for the interpreter to know on which thread the function will be resumed**. Additionally, suspending the thread is a very different operation than awaiting. Awaiting saves the local function stack and returns immediately, while suspending stops the whole thread until another resumes it. Mixing and matching both seems, ultimately, undesired as they are not the same. To solve this, a new utility function is added by this pull request, which suspends a thread until a signal is emitted (no matter in which other thread). It works like this. ```GDScript # Suspends a thread until a button is pressed. Thread.suspend( button.pressed ) ``` If code must do this and can run on both the main thread or a dedicated thread, it can do as follows: ```GDScript # Suspends a thread until a button is pressed. if (Thread.is_main_thread()): await button.pressed else: Thread.suspend( button.pressed ) ``` For this, the `Thread.is_main_thread()` function (already existing in the internal Thread object) has been exposed to the engine API. **Q**: Are you sure this can't be done with await transparently? **A**: No, please read again. There is no way for await to know beforehand how the function will be resumed. Only the user knows. **Q**: Does it not make more sense to have an `await_thread` or something where the user will pass the argument? **A**: I think its better to have a dedicated utility function for this. Not only the intention is more explicit, but additionally, as this makes it non exclusive to GDScript. It can be used from other languages that use the engine.
1 parent cae3d72 commit 2b48bac

8 files changed

+118
-0
lines changed

core/core_bind.cpp

+87
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,10 @@ bool Thread::is_alive() const {
13911391
return running.is_set();
13921392
}
13931393

1394+
bool Thread::is_main_thread() {
1395+
return ::Thread::is_main_thread();
1396+
}
1397+
13941398
Variant Thread::wait_to_finish() {
13951399
ERR_FAIL_COND_V_MSG(!is_started(), Variant(), "Thread must have been started to wait for its completion.");
13961400
thread.wait_to_finish();
@@ -1405,6 +1409,87 @@ void Thread::set_thread_safety_checks_enabled(bool p_enabled) {
14051409
set_current_thread_safe_for_nodes(!p_enabled);
14061410
}
14071411

1412+
class CallableCustomSuspend : public CallableCustom {
1413+
Semaphore *semaphore = nullptr;
1414+
Variant *return_value = nullptr;
1415+
1416+
// Never really going to execute since disconnection is automatic.
1417+
static bool _equal_func(const CallableCustom *p_a, const CallableCustom *p_b) {
1418+
const CallableCustomSuspend *A = static_cast<const CallableCustomSuspend *>(p_a);
1419+
const CallableCustomSuspend *B = static_cast<const CallableCustomSuspend *>(p_b);
1420+
1421+
return A->semaphore == B->semaphore;
1422+
}
1423+
1424+
// Never really going to execute since disconnection is automatic.
1425+
static bool _less_func(const CallableCustom *p_a, const CallableCustom *p_b) {
1426+
const CallableCustomSuspend *A = static_cast<const CallableCustomSuspend *>(p_a);
1427+
const CallableCustomSuspend *B = static_cast<const CallableCustomSuspend *>(p_b);
1428+
1429+
return A->semaphore < B->semaphore;
1430+
}
1431+
1432+
public:
1433+
//for every type that inherits, these must always be the same for this type
1434+
virtual uint32_t hash() const override {
1435+
return size_t(semaphore);
1436+
}
1437+
1438+
virtual String get_as_text() const override {
1439+
return "SemaphoreCallable";
1440+
}
1441+
1442+
virtual CompareEqualFunc get_compare_equal_func() const override {
1443+
return _equal_func;
1444+
}
1445+
1446+
virtual CompareLessFunc get_compare_less_func() const override {
1447+
return _less_func;
1448+
}
1449+
1450+
virtual ObjectID get_object() const override {
1451+
return ObjectID();
1452+
}
1453+
1454+
virtual void call(const Variant **p_arguments, int p_argcount, Variant &r_return_value, Callable::CallError &r_call_error) const override {
1455+
semaphore->post();
1456+
if (p_argcount == 1) { // If passed one argument, will be returned.
1457+
if (return_value) {
1458+
*return_value = *p_arguments[0];
1459+
}
1460+
} else if (p_argcount > 1) {
1461+
r_call_error.error = Callable::CallError::CALL_ERROR_TOO_MANY_ARGUMENTS;
1462+
r_call_error.argument = p_argcount;
1463+
r_call_error.expected = 1;
1464+
return;
1465+
}
1466+
1467+
r_call_error.error = Callable::CallError::CALL_OK;
1468+
}
1469+
1470+
CallableCustomSuspend(Semaphore *p_semaphore, Variant *r_return_value) {
1471+
semaphore = p_semaphore;
1472+
return_value = r_return_value;
1473+
}
1474+
};
1475+
1476+
Variant Thread::suspend(Signal p_resume_signal, bool p_connect_on_main_thread) {
1477+
Semaphore semaphore;
1478+
Variant return_value;
1479+
1480+
Callable callable(memnew(CallableCustomSuspend(&semaphore, &return_value)));
1481+
if (p_connect_on_main_thread) {
1482+
ObjectID obj_id = p_resume_signal.get_object_id();
1483+
MessageQueue::get_singleton()->push_call(obj_id, SNAME("connect"), p_resume_signal.get_name(), callable, CONNECT_ONE_SHOT);
1484+
} else {
1485+
p_resume_signal.connect(callable, Object::CONNECT_ONE_SHOT);
1486+
}
1487+
1488+
semaphore.wait();
1489+
1490+
return return_value;
1491+
}
1492+
14081493
void Thread::_bind_methods() {
14091494
ClassDB::bind_method(D_METHOD("start", "callable", "priority"), &Thread::start, DEFVAL(PRIORITY_NORMAL));
14101495
ClassDB::bind_method(D_METHOD("get_id"), &Thread::get_id);
@@ -1413,6 +1498,8 @@ void Thread::_bind_methods() {
14131498
ClassDB::bind_method(D_METHOD("wait_to_finish"), &Thread::wait_to_finish);
14141499

14151500
ClassDB::bind_static_method("Thread", D_METHOD("set_thread_safety_checks_enabled", "enabled"), &Thread::set_thread_safety_checks_enabled);
1501+
ClassDB::bind_static_method("Thread", D_METHOD("is_main_thread"), &Thread::is_main_thread);
1502+
ClassDB::bind_static_method("Thread", D_METHOD("suspend", "resume_signal", "connect_on_main_thread"), &Thread::suspend, DEFVAL(true));
14161503

14171504
BIND_ENUM_CONSTANT(PRIORITY_LOW);
14181505
BIND_ENUM_CONSTANT(PRIORITY_NORMAL);

core/core_bind.h

+2
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ class Thread : public RefCounted {
453453
bool is_started() const;
454454
bool is_alive() const;
455455
Variant wait_to_finish();
456+
static bool is_main_thread();
457+
static Variant suspend(Signal p_resume_signal, bool p_connect_on_main_thread = true);
456458

457459
static void set_thread_safety_checks_enabled(bool p_enabled);
458460
};

core/variant/variant_utility.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "core/io/marshalls.h"
3434
#include "core/object/ref_counted.h"
3535
#include "core/os/os.h"
36+
#include "core/os/semaphore.h"
3637
#include "core/templates/oa_hash_map.h"
3738
#include "core/templates/rid.h"
3839
#include "core/templates/rid_owner.h"

doc/classes/@GlobalScope.xml

+7
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,13 @@
13901390
[/codeblock]
13911391
</description>
13921392
</method>
1393+
<method name="thread_suspend">
1394+
<return type="Variant" />
1395+
<param index="0" name="on_signal" type="Signal" />
1396+
<description>
1397+
Suspend the caller thread until the signal passed as argument is emitted.
1398+
</description>
1399+
</method>
13931400
<method name="type_convert">
13941401
<return type="Variant" />
13951402
<param index="0" name="variant" type="Variant" />

doc/classes/Thread.xml

+6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
To check if a [Thread] is joinable, use [method is_started].
3131
</description>
3232
</method>
33+
<method name="is_main_thread" qualifiers="static">
34+
<return type="bool" />
35+
<description>
36+
Returns [code]true[/code] if the function is called from the main thread of the engine. This is the thread that runs game code by default.
37+
</description>
38+
</method>
3339
<method name="is_started" qualifiers="const">
3440
<return type="bool" />
3541
<description>

modules/gdscript/gdscript_function.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
#include "gdscript_function.h"
3232

33+
#include "core/config/project_settings.h"
3334
#include "gdscript.h"
3435

3536
Variant GDScriptFunction::get_constant(int p_idx) const {
@@ -187,7 +188,14 @@ bool GDScriptFunctionState::is_valid(bool p_extended_check) const {
187188
return true;
188189
}
189190

191+
bool GDScriptFunctionState::warn_on_cross_thread_await = true;
192+
190193
Variant GDScriptFunctionState::resume(const Variant &p_arg) {
194+
#ifdef DEBUG_ENABLED
195+
if (warn_on_cross_thread_await && await_id != Thread::get_caller_id()) {
196+
WARN_PRINT("Function " + function->get_name() + " was resumed on a different thread than it was awaited. If you really intend to do this, Use Thread.suspend() for cross thread suspension or disable the 'warn_on_cross_thread_await' warning on project settings.");
197+
}
198+
#endif
191199
ERR_FAIL_NULL_V(function, Variant());
192200
{
193201
MutexLock lock(GDScriptLanguage::singleton->mutex);
@@ -275,6 +283,8 @@ void GDScriptFunctionState::_bind_methods() {
275283
ClassDB::bind_vararg_method(METHOD_FLAGS_DEFAULT, "_signal_callback", &GDScriptFunctionState::_signal_callback, MethodInfo("_signal_callback"));
276284

277285
ADD_SIGNAL(MethodInfo("completed", PropertyInfo(Variant::NIL, "result", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NIL_IS_VARIANT)));
286+
287+
warn_on_cross_thread_await = GLOBAL_DEF("threads/warnings/warn_on_cross_thread_await", true);
278288
}
279289

280290
GDScriptFunctionState::GDScriptFunctionState() :

modules/gdscript/gdscript_function.h

+3
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,9 @@ class GDScriptFunctionState : public RefCounted {
608608
SelfList<GDScriptFunctionState> scripts_list;
609609
SelfList<GDScriptFunctionState> instances_list;
610610

611+
static bool warn_on_cross_thread_await;
612+
Thread::ID await_id = Thread::UNASSIGNED_ID;
613+
611614
protected:
612615
static void _bind_methods();
613616

modules/gdscript/gdscript_vm.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -2542,6 +2542,8 @@ Variant GDScriptFunction::call(GDScriptInstance *p_instance, const Variant **p_a
25422542
gdfs->state.ip = ip + 2;
25432543
gdfs->state.line = line;
25442544
gdfs->state.script = _script;
2545+
gdfs->await_id = Thread::get_caller_id();
2546+
25452547
{
25462548
MutexLock lock(GDScriptLanguage::get_singleton()->mutex);
25472549
_script->pending_func_states.add(&gdfs->scripts_list);

0 commit comments

Comments
 (0)