diff --git a/core/object/class_db.cpp b/core/object/class_db.cpp index 1a3fccd6b76..5a32288d97b 100644 --- a/core/object/class_db.cpp +++ b/core/object/class_db.cpp @@ -35,9 +35,6 @@ #include "core/object/script_language.h" #include "core/version.h" -#define OBJTYPE_RLOCK RWLockRead _rw_lockr_(lock); -#define OBJTYPE_WLOCK RWLockWrite _rw_lockw_(lock); - #ifdef DEBUG_METHODS_ENABLED MethodDefinition D_METHODP(const char *p_name, const char *const **p_args, uint32_t p_argcount) { @@ -238,13 +235,13 @@ bool ClassDB::_is_parent_class(const StringName &p_class, const StringName &p_in } bool ClassDB::is_parent_class(const StringName &p_class, const StringName &p_inherits) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); return _is_parent_class(p_class, p_inherits); } void ClassDB::get_class_list(List *p_classes) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); for (const KeyValue &E : classes) { p_classes->push_back(E.key); @@ -255,7 +252,7 @@ void ClassDB::get_class_list(List *p_classes) { #ifdef TOOLS_ENABLED void ClassDB::get_extensions_class_list(List *p_classes) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); for (const KeyValue &E : classes) { if (E.value.api != API_EXTENSION && E.value.api != API_EDITOR_EXTENSION) { @@ -268,7 +265,7 @@ void ClassDB::get_extensions_class_list(List *p_classes) { } void ClassDB::get_extension_class_list(const Ref &p_extension, List *p_classes) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); for (const KeyValue &E : classes) { if (E.value.api != API_EXTENSION && E.value.api != API_EDITOR_EXTENSION) { @@ -285,7 +282,7 @@ void ClassDB::get_extension_class_list(const Ref &p_extension, List #endif void ClassDB::get_inheriters_from_class(const StringName &p_class, List *p_classes) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); for (const KeyValue &E : classes) { if (E.key != p_class && _is_parent_class(E.key, p_class)) { @@ -295,7 +292,7 @@ void ClassDB::get_inheriters_from_class(const StringName &p_class, List *p_classes) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); for (const KeyValue &E : classes) { if (E.value.inherits == p_class) { @@ -305,7 +302,7 @@ void ClassDB::get_direct_inheriters_from_class(const StringName &p_class, List &r_result) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *start = classes.getptr(p_class); if (!start) { @@ -356,13 +353,13 @@ StringName ClassDB::_get_parent_class(const StringName &p_class) { } StringName ClassDB::get_parent_class(const StringName &p_class) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); return _get_parent_class(p_class); } ClassDB::APIType ClassDB::get_api_type(const StringName &p_class) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *ti = classes.getptr(p_class); @@ -372,7 +369,7 @@ ClassDB::APIType ClassDB::get_api_type(const StringName &p_class) { uint32_t ClassDB::get_api_hash(APIType p_api) { #ifdef DEBUG_METHODS_ENABLED - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); if (api_hashes_cache.has(p_api)) { return api_hashes_cache[p_api]; @@ -520,12 +517,12 @@ uint32_t ClassDB::get_api_hash(APIType p_api) { } bool ClassDB::class_exists(const StringName &p_class) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); return classes.has(p_class); } void ClassDB::add_compatibility_class(const StringName &p_class, const StringName &p_fallback) { - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); compat_classes[p_class] = p_fallback; } @@ -539,7 +536,7 @@ StringName ClassDB::get_compatibility_class(const StringName &p_class) { Object *ClassDB::_instantiate_internal(const StringName &p_class, bool p_require_real_class, bool p_notify_postinitialize) { ClassInfo *ti; { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ti = classes.getptr(p_class); if (!_can_instantiate(ti)) { if (compat_classes.has(p_class)) { @@ -645,7 +642,7 @@ ObjectGDExtension *ClassDB::get_placeholder_extension(const StringName &p_class) ClassInfo *ti; { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ti = classes.getptr(p_class); if (!_can_instantiate(ti)) { if (compat_classes.has(p_class)) { @@ -730,7 +727,7 @@ void ClassDB::set_object_extension_instance(Object *p_object, const StringName & ERR_FAIL_NULL(p_object); ClassInfo *ti; { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ti = classes.getptr(p_class); if (!_can_instantiate(ti)) { if (compat_classes.has(p_class)) { @@ -755,7 +752,7 @@ void ClassDB::set_object_extension_instance(Object *p_object, const StringName & bool ClassDB::can_instantiate(const StringName &p_class) { String script_path; { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *ti = classes.getptr(p_class); if (!ti) { @@ -781,7 +778,7 @@ use_script: bool ClassDB::is_abstract(const StringName &p_class) { String script_path; { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *ti = classes.getptr(p_class); if (!ti) { @@ -813,7 +810,7 @@ use_script: bool ClassDB::is_virtual(const StringName &p_class) { String script_path; { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *ti = classes.getptr(p_class); if (!ti) { @@ -837,7 +834,7 @@ use_script: } void ClassDB::_add_class2(const StringName &p_class, const StringName &p_inherits) { - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); const StringName &name = p_class; @@ -880,7 +877,7 @@ static MethodInfo info_from_bind(MethodBind *p_method) { } void ClassDB::get_method_list(const StringName &p_class, List *p_methods, bool p_no_inheritance, bool p_exclude_from_properties) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -926,7 +923,7 @@ void ClassDB::get_method_list(const StringName &p_class, List *p_met } void ClassDB::get_method_list_with_compatibility(const StringName &p_class, List> *p_methods, bool p_no_inheritance, bool p_exclude_from_properties) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -986,7 +983,7 @@ void ClassDB::get_method_list_with_compatibility(const StringName &p_class, List } bool ClassDB::get_method_info(const StringName &p_class, const StringName &p_method, MethodInfo *r_info, bool p_no_inheritance, bool p_exclude_from_properties) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1036,7 +1033,7 @@ bool ClassDB::get_method_info(const StringName &p_class, const StringName &p_met } MethodBind *ClassDB::get_method(const StringName &p_class, const StringName &p_name) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1051,7 +1048,7 @@ MethodBind *ClassDB::get_method(const StringName &p_class, const StringName &p_n } Vector ClassDB::get_method_compatibility_hashes(const StringName &p_class, const StringName &p_name) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1070,7 +1067,7 @@ Vector ClassDB::get_method_compatibility_hashes(const StringName &p_cl } MethodBind *ClassDB::get_method_with_compatibility(const StringName &p_class, const StringName &p_name, uint64_t p_hash, bool *r_method_exists, bool *r_is_deprecated) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1105,7 +1102,7 @@ MethodBind *ClassDB::get_method_with_compatibility(const StringName &p_class, co } void ClassDB::bind_integer_constant(const StringName &p_class, const StringName &p_enum, const StringName &p_name, int64_t p_constant, bool p_is_bitfield) { - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); ClassInfo *type = classes.getptr(p_class); @@ -1142,7 +1139,7 @@ void ClassDB::bind_integer_constant(const StringName &p_class, const StringName } void ClassDB::get_integer_constant_list(const StringName &p_class, List *p_constants, bool p_no_inheritance) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1167,7 +1164,7 @@ void ClassDB::get_integer_constant_list(const StringName &p_class, List } int64_t ClassDB::get_integer_constant(const StringName &p_class, const StringName &p_name, bool *p_success) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1191,7 +1188,7 @@ int64_t ClassDB::get_integer_constant(const StringName &p_class, const StringNam } bool ClassDB::has_integer_constant(const StringName &p_class, const StringName &p_name, bool p_no_inheritance) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1210,7 +1207,7 @@ bool ClassDB::has_integer_constant(const StringName &p_class, const StringName & } StringName ClassDB::get_integer_constant_enum(const StringName &p_class, const StringName &p_name, bool p_no_inheritance) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1234,7 +1231,7 @@ StringName ClassDB::get_integer_constant_enum(const StringName &p_class, const S } void ClassDB::get_enum_list(const StringName &p_class, List *p_enums, bool p_no_inheritance) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1252,7 +1249,7 @@ void ClassDB::get_enum_list(const StringName &p_class, List *p_enums } void ClassDB::get_enum_constants(const StringName &p_class, const StringName &p_enum, List *p_constants, bool p_no_inheritance) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1275,7 +1272,7 @@ void ClassDB::get_enum_constants(const StringName &p_class, const StringName &p_ void ClassDB::set_method_error_return_values(const StringName &p_class, const StringName &p_method, const Vector &p_values) { #ifdef DEBUG_METHODS_ENABLED - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); ClassInfo *type = classes.getptr(p_class); ERR_FAIL_NULL(type); @@ -1286,7 +1283,7 @@ void ClassDB::set_method_error_return_values(const StringName &p_class, const St Vector ClassDB::get_method_error_return_values(const StringName &p_class, const StringName &p_method) { #ifdef DEBUG_METHODS_ENABLED - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); ERR_FAIL_NULL_V(type, Vector()); @@ -1301,7 +1298,7 @@ Vector ClassDB::get_method_error_return_values(const StringName &p_class, } bool ClassDB::has_enum(const StringName &p_class, const StringName &p_name, bool p_no_inheritance) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1320,7 +1317,7 @@ bool ClassDB::has_enum(const StringName &p_class, const StringName &p_name, bool } bool ClassDB::is_enum_bitfield(const StringName &p_class, const StringName &p_name, bool p_no_inheritance) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1339,7 +1336,7 @@ bool ClassDB::is_enum_bitfield(const StringName &p_class, const StringName &p_na } void ClassDB::add_signal(const StringName &p_class, const MethodInfo &p_signal) { - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); ClassInfo *type = classes.getptr(p_class); ERR_FAIL_NULL(type); @@ -1358,7 +1355,7 @@ void ClassDB::add_signal(const StringName &p_class, const MethodInfo &p_signal) } void ClassDB::get_signal_list(const StringName &p_class, List *p_signals, bool p_no_inheritance) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); ERR_FAIL_NULL(type); @@ -1379,7 +1376,7 @@ void ClassDB::get_signal_list(const StringName &p_class, List *p_sig } bool ClassDB::has_signal(const StringName &p_class, const StringName &p_signal, bool p_no_inheritance) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); ClassInfo *check = type; while (check) { @@ -1396,7 +1393,7 @@ bool ClassDB::has_signal(const StringName &p_class, const StringName &p_signal, } bool ClassDB::get_signal(const StringName &p_class, const StringName &p_signal, MethodInfo *r_signal) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); ClassInfo *check = type; while (check) { @@ -1413,7 +1410,7 @@ bool ClassDB::get_signal(const StringName &p_class, const StringName &p_signal, } void ClassDB::add_property_group(const StringName &p_class, const String &p_name, const String &p_prefix, int p_indent_depth) { - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); ClassInfo *type = classes.getptr(p_class); ERR_FAIL_NULL(type); @@ -1426,7 +1423,7 @@ void ClassDB::add_property_group(const StringName &p_class, const String &p_name } void ClassDB::add_property_subgroup(const StringName &p_class, const String &p_name, const String &p_prefix, int p_indent_depth) { - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); ClassInfo *type = classes.getptr(p_class); ERR_FAIL_NULL(type); @@ -1443,7 +1440,7 @@ void ClassDB::add_property_array_count(const StringName &p_class, const String & } void ClassDB::add_property_array(const StringName &p_class, const StringName &p_path, const String &p_array_element_prefix) { - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); ClassInfo *type = classes.getptr(p_class); ERR_FAIL_NULL(type); @@ -1452,9 +1449,9 @@ void ClassDB::add_property_array(const StringName &p_class, const StringName &p_ // NOTE: For implementation simplicity reasons, this method doesn't allow setters to have optional arguments at the end. void ClassDB::add_property(const StringName &p_class, const PropertyInfo &p_pinfo, const StringName &p_setter, const StringName &p_getter, int p_index) { - lock.read_lock(); + Locker::Lock lock(Locker::STATE_WRITE); + ClassInfo *type = classes.getptr(p_class); - lock.read_unlock(); ERR_FAIL_NULL(type); @@ -1486,8 +1483,6 @@ void ClassDB::add_property(const StringName &p_class, const PropertyInfo &p_pinf ERR_FAIL_COND_MSG(type->property_setget.has(p_pinfo.name), vformat("Object '%s' already has property '%s'.", p_class, p_pinfo.name)); #endif - OBJTYPE_WLOCK - type->property_list.push_back(p_pinfo); type->property_map[p_pinfo.name] = p_pinfo; #ifdef DEBUG_METHODS_ENABLED @@ -1518,7 +1513,7 @@ void ClassDB::set_property_default_value(const StringName &p_class, const String void ClassDB::add_linked_property(const StringName &p_class, const String &p_property, const String &p_linked_property) { #ifdef TOOLS_ENABLED - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); ClassInfo *type = classes.getptr(p_class); ERR_FAIL_NULL(type); @@ -1534,7 +1529,7 @@ void ClassDB::add_linked_property(const StringName &p_class, const String &p_pro } void ClassDB::get_property_list(const StringName &p_class, List *p_list, bool p_no_inheritance, const Object *p_validator) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); ClassInfo *check = type; @@ -1577,7 +1572,7 @@ void ClassDB::get_linked_properties_info(const StringName &p_class, const String } bool ClassDB::get_property_info(const StringName &p_class, const StringName &p_property, PropertyInfo *r_info, bool p_no_inheritance, const Object *p_validator) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *check = classes.getptr(p_class); while (check) { @@ -1800,7 +1795,7 @@ bool ClassDB::has_property(const StringName &p_class, const StringName &p_proper } void ClassDB::set_method_flags(const StringName &p_class, const StringName &p_method, int p_flags) { - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); ClassInfo *type = classes.getptr(p_class); ClassInfo *check = type; ERR_FAIL_NULL(check); @@ -1825,7 +1820,7 @@ bool ClassDB::has_method(const StringName &p_class, const StringName &p_method, } int ClassDB::get_method_argument_count(const StringName &p_class, const StringName &p_method, bool *r_is_valid, bool p_no_inheritance) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -1864,7 +1859,7 @@ void ClassDB::_bind_compatibility(ClassInfo *type, MethodBind *p_method) { } void ClassDB::_bind_method_custom(const StringName &p_class, MethodBind *p_method, bool p_compatibility) { - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); StringName method_name = p_method->get_name(); @@ -1933,7 +1928,7 @@ MethodBind *ClassDB::bind_methodfi(uint32_t p_flags, MethodBind *p_bind, bool p_ StringName mdname = StaticCString::create(method_name); #endif - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); ERR_FAIL_NULL_V(p_bind, nullptr); p_bind->set_name(mdname); @@ -1996,7 +1991,7 @@ MethodBind *ClassDB::bind_methodfi(uint32_t p_flags, MethodBind *p_bind, bool p_ void ClassDB::add_virtual_method(const StringName &p_class, const MethodInfo &p_method, bool p_virtual, const Vector &p_arg_names, bool p_object_core) { ERR_FAIL_COND_MSG(!classes.has(p_class), vformat("Request for nonexistent class '%s'.", p_class)); - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); #ifdef DEBUG_METHODS_ENABLED MethodInfo mi = p_method; @@ -2030,7 +2025,7 @@ void ClassDB::add_virtual_method(const StringName &p_class, const MethodInfo &p_ void ClassDB::add_virtual_compatibility_method(const StringName &p_class, const MethodInfo &p_method, bool p_virtual, const Vector &p_arg_names, bool p_object_core) { ERR_FAIL_COND_MSG(!classes.has(p_class), vformat("Request for nonexistent class '%s'.", p_class)); - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); HashMap> &virtual_methods_compat = classes[p_class].virtual_methods_compat; @@ -2065,7 +2060,7 @@ void ClassDB::get_virtual_methods(const StringName &p_class, List *p } Vector ClassDB::get_virtual_method_compatibility_hashes(const StringName &p_class, const StringName &p_name) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *type = classes.getptr(p_class); @@ -2106,14 +2101,14 @@ void ClassDB::add_extension_class_virtual_method(const StringName &p_class, cons } void ClassDB::set_class_enabled(const StringName &p_class, bool p_enable) { - OBJTYPE_WLOCK; + Locker::Lock lock(Locker::STATE_WRITE); ERR_FAIL_COND_MSG(!classes.has(p_class), vformat("Request for nonexistent class '%s'.", p_class)); classes[p_class].disabled = !p_enable; } bool ClassDB::is_class_enabled(const StringName &p_class) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *ti = classes.getptr(p_class); if (!ti || !ti->creation_func) { @@ -2127,7 +2122,7 @@ bool ClassDB::is_class_enabled(const StringName &p_class) { } bool ClassDB::is_class_exposed(const StringName &p_class) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *ti = classes.getptr(p_class); ERR_FAIL_NULL_V_MSG(ti, false, vformat("Cannot get class '%s'.", String(p_class))); @@ -2135,7 +2130,7 @@ bool ClassDB::is_class_exposed(const StringName &p_class) { } bool ClassDB::is_class_reloadable(const StringName &p_class) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *ti = classes.getptr(p_class); ERR_FAIL_NULL_V_MSG(ti, false, vformat("Cannot get class '%s'.", String(p_class))); @@ -2143,7 +2138,7 @@ bool ClassDB::is_class_reloadable(const StringName &p_class) { } bool ClassDB::is_class_runtime(const StringName &p_class) { - OBJTYPE_RLOCK; + Locker::Lock lock(Locker::STATE_READ); ClassInfo *ti = classes.getptr(p_class); ERR_FAIL_NULL_V_MSG(ti, false, vformat("Cannot get class '%s'.", String(p_class))); @@ -2341,8 +2336,6 @@ uint64_t ClassDB::get_native_struct_size(const StringName &p_name) { return native_structs[p_name].struct_size; } -RWLock ClassDB::lock; - void ClassDB::cleanup_defaults() { default_values.clear(); default_values_cached.clear(); @@ -2378,3 +2371,32 @@ bool ClassDB::is_default_array_arg(const Array &p_array) { } // + +ClassDB::Locker::Lock::Lock(Locker::State p_state) { + DEV_ASSERT(p_state != STATE_UNLOCKED); + if (p_state == STATE_READ) { + if (Locker::thread_state == STATE_UNLOCKED) { + state = STATE_READ; + Locker::thread_state = STATE_READ; + Locker::lock.read_lock(); + } + } else if (p_state == STATE_WRITE) { + if (Locker::thread_state == STATE_UNLOCKED) { + state = STATE_WRITE; + Locker::thread_state = STATE_WRITE; + Locker::lock.write_lock(); + } else if (Locker::thread_state == STATE_READ) { + CRASH_NOW_MSG("Lock can't be upgraded from read to write."); + } + } +} + +ClassDB::Locker::Lock::~Lock() { + if (state == STATE_READ) { + Locker::lock.read_unlock(); + Locker::thread_state = STATE_UNLOCKED; + } else if (state == STATE_WRITE) { + Locker::lock.write_unlock(); + Locker::thread_state = STATE_UNLOCKED; + } +} diff --git a/core/object/class_db.h b/core/object/class_db.h index 27750b99072..e97ed9aeeb1 100644 --- a/core/object/class_db.h +++ b/core/object/class_db.h @@ -152,7 +152,31 @@ public: return ret; } - static RWLock lock; + // We need a recursive r/w lock because there are various code paths + // that may in turn invoke other entry points with require locking. + class Locker { + public: + enum State { + STATE_UNLOCKED, + STATE_READ, + STATE_WRITE, + }; + + private: + inline static RWLock lock; + inline thread_local static State thread_state = STATE_UNLOCKED; + + public: + class Lock { + State state = STATE_UNLOCKED; + + public: + explicit Lock(State p_state); + ~Lock(); + }; + }; + inline static Locker locker; + static HashMap classes; static HashMap resource_base_extensions; static HashMap compat_classes; @@ -206,7 +230,7 @@ public: template static void register_class(bool p_virtual = false) { - GLOBAL_LOCK_FUNCTION; + Locker::Lock lock(Locker::STATE_WRITE); static_assert(std::is_same_v, "Class not declared properly, please use GDCLASS."); T::initialize_class(); ClassInfo *t = classes.getptr(T::get_class_static()); @@ -221,7 +245,7 @@ public: template static void register_abstract_class() { - GLOBAL_LOCK_FUNCTION; + Locker::Lock lock(Locker::STATE_WRITE); static_assert(std::is_same_v, "Class not declared properly, please use GDCLASS."); T::initialize_class(); ClassInfo *t = classes.getptr(T::get_class_static()); @@ -234,7 +258,7 @@ public: template static void register_internal_class() { - GLOBAL_LOCK_FUNCTION; + Locker::Lock lock(Locker::STATE_WRITE); static_assert(std::is_same_v, "Class not declared properly, please use GDCLASS."); T::initialize_class(); ClassInfo *t = classes.getptr(T::get_class_static()); @@ -249,7 +273,7 @@ public: template static void register_runtime_class() { - GLOBAL_LOCK_FUNCTION; + Locker::Lock lock(Locker::STATE_WRITE); static_assert(std::is_same_v, "Class not declared properly, please use GDCLASS."); T::initialize_class(); ClassInfo *t = classes.getptr(T::get_class_static()); @@ -274,7 +298,7 @@ public: template static void register_custom_instance_class() { - GLOBAL_LOCK_FUNCTION; + Locker::Lock lock(Locker::STATE_WRITE); static_assert(std::is_same_v, "Class not declared properly, please use GDCLASS."); T::initialize_class(); ClassInfo *t = classes.getptr(T::get_class_static()); @@ -390,7 +414,7 @@ public: template static MethodBind *bind_vararg_method(uint32_t p_flags, const StringName &p_name, M p_method, const MethodInfo &p_info = MethodInfo(), const Vector &p_default_args = Vector(), bool p_return_nil_is_variant = true) { - GLOBAL_LOCK_FUNCTION; + Locker::Lock lock(Locker::STATE_WRITE); MethodBind *bind = create_vararg_method_bind(p_method, p_info, p_return_nil_is_variant); ERR_FAIL_NULL_V(bind, nullptr); @@ -403,7 +427,7 @@ public: template static MethodBind *bind_compatibility_vararg_method(uint32_t p_flags, const StringName &p_name, M p_method, const MethodInfo &p_info = MethodInfo(), const Vector &p_default_args = Vector(), bool p_return_nil_is_variant = true) { - GLOBAL_LOCK_FUNCTION; + Locker::Lock lock(Locker::STATE_WRITE); MethodBind *bind = create_vararg_method_bind(p_method, p_info, p_return_nil_is_variant); ERR_FAIL_NULL_V(bind, nullptr);