diff --git a/include/pybind11/detail/holder_caster_foreign_helpers.h b/include/pybind11/detail/holder_caster_foreign_helpers.h index f636618e9f..cae571b65c 100644 --- a/include/pybind11/detail/holder_caster_foreign_helpers.h +++ b/include/pybind11/detail/holder_caster_foreign_helpers.h @@ -31,13 +31,31 @@ struct holder_caster_foreign_helpers { PyObject *o; }; + // Downcast shared_ptr from the enable_shared_from_this base to the target type. + // SFINAE probe: use static_pointer_cast when the static downcast is valid (common case), + // fall back to dynamic_pointer_cast when it isn't (virtual inheritance — issue #5989). + // We can't use dynamic_pointer_cast unconditionally because it requires polymorphic types; + // we can't use is_polymorphic to choose because that's orthogonal to virtual inheritance. + // (The implementation uses the "tag dispatch via overload priority" trick.) + template + static auto esft_downcast(const std::shared_ptr &existing, int /*preferred*/) + -> decltype(static_cast(std::declval()), std::shared_ptr()) { + return std::static_pointer_cast(existing); + } + + template + static std::shared_ptr esft_downcast(const std::shared_ptr &existing, + ... /*fallback*/) { + return std::dynamic_pointer_cast(existing); + } + template static auto set_via_shared_from_this(type *value, std::shared_ptr *holder_out) -> decltype(value->shared_from_this(), bool()) { // object derives from enable_shared_from_this; // try to reuse an existing shared_ptr if one is known if (auto existing = try_get_shared_from_this(value)) { - *holder_out = std::static_pointer_cast(existing); + *holder_out = esft_downcast(existing, 0); return true; } return false; diff --git a/tests/test_smart_ptr.cpp b/tests/test_smart_ptr.cpp index 0ac1a41bd9..f5036eac38 100644 --- a/tests/test_smart_ptr.cpp +++ b/tests/test_smart_ptr.cpp @@ -247,6 +247,28 @@ struct SharedFromThisVBase : std::enable_shared_from_this { }; struct SharedFromThisVirt : virtual SharedFromThisVBase {}; +// Issue #5989: static_pointer_cast where dynamic_pointer_cast is needed +// (virtual inheritance with shared_ptr holder) +struct SftVirtBase : std::enable_shared_from_this { + SftVirtBase() = default; + virtual ~SftVirtBase() = default; + static std::shared_ptr create() { return std::make_shared(); } + virtual std::string name() { return "SftVirtBase"; } +}; +struct SftVirtDerived : SftVirtBase { + using SftVirtBase::SftVirtBase; + static std::shared_ptr create() { return std::make_shared(); } + std::string name() override { return "SftVirtDerived"; } +}; +struct SftVirtDerived2 : virtual SftVirtDerived { + using SftVirtDerived::SftVirtDerived; + static std::shared_ptr create() { + return std::make_shared(); + } + std::string name() override { return "SftVirtDerived2"; } + std::string call_name(const std::shared_ptr &d2) { return d2->name(); } +}; + // test_move_only_holder struct C { C() { print_created(this); } @@ -522,6 +544,21 @@ TEST_SUBMODULE(smart_ptr, m) { py::class_>(m, "SharedFromThisVirt") .def_static("get", []() { return sft.get(); }); + // Issue #5989: static_pointer_cast where dynamic_pointer_cast is needed + py::class_>(m, "SftVirtBase") + .def(py::init<>(&SftVirtBase::create)) + .def("name", &SftVirtBase::name); + py::class_>(m, "SftVirtDerived") + .def(py::init<>(&SftVirtDerived::create)); + py::class_>( + m, "SftVirtDerived2") + .def(py::init<>(&SftVirtDerived2::create)) + // TODO: Remove this once inherited methods work through virtual bases. + // Without it, d2.name() segfaults because pybind11 uses an incorrect + // pointer offset when dispatching through the virtual inheritance chain. + .def("name", &SftVirtDerived2::name) + .def("call_name", &SftVirtDerived2::call_name, py::arg("d2")); + // test_move_only_holder py::class_>(m, "TypeWithMoveOnlyHolder") .def_static("make", []() { return custom_unique_ptr(new C); }) diff --git a/tests/test_smart_ptr.py b/tests/test_smart_ptr.py index 2d48aac78d..76ebd8cf20 100644 --- a/tests/test_smart_ptr.py +++ b/tests/test_smart_ptr.py @@ -251,6 +251,19 @@ def test_shared_ptr_from_this_and_references(): assert y is z +def test_shared_from_this_virt_shared_ptr_arg(): + """Issue #5989: static_pointer_cast fails with virtual inheritance.""" + b = m.SftVirtBase() + assert b.name() == "SftVirtBase" + + d = m.SftVirtDerived() + assert d.name() == "SftVirtDerived" + + d2 = m.SftVirtDerived2() + assert d2.name() == "SftVirtDerived2" + assert d2.call_name(d2) == "SftVirtDerived2" + + @pytest.mark.skipif("env.GRAALPY", reason="Cannot reliably trigger GC") def test_move_only_holder(): a = m.TypeWithMoveOnlyHolder.make()