Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion include/pybind11/detail/holder_caster_foreign_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,31 @@ struct holder_caster_foreign_helpers {
PyObject *o;
};

// Downcast shared_ptr from the enable_shared_from_this base to the target type.
// SFINAE probe: use static_pointer_cast when the static downcast is valid (common case),
// fall back to dynamic_pointer_cast when it isn't (virtual inheritance — issue #5989).
// We can't use dynamic_pointer_cast unconditionally because it requires polymorphic types;
// we can't use is_polymorphic to choose because that's orthogonal to virtual inheritance.
// (The implementation uses the "tag dispatch via overload priority" trick.)
template <typename type, typename esft_base>
static auto esft_downcast(const std::shared_ptr<esft_base> &existing, int /*preferred*/)
-> decltype(static_cast<type *>(std::declval<esft_base *>()), std::shared_ptr<type>()) {
return std::static_pointer_cast<type>(existing);
}

template <typename type, typename esft_base>
static std::shared_ptr<type> esft_downcast(const std::shared_ptr<esft_base> &existing,
... /*fallback*/) {
return std::dynamic_pointer_cast<type>(existing);
}

template <typename type>
static auto set_via_shared_from_this(type *value, std::shared_ptr<type> *holder_out)
-> decltype(value->shared_from_this(), bool()) {
// object derives from enable_shared_from_this;
// try to reuse an existing shared_ptr if one is known
if (auto existing = try_get_shared_from_this(value)) {
*holder_out = std::static_pointer_cast<type>(existing);
*holder_out = esft_downcast<type>(existing, 0);
return true;
}
return false;
Expand Down
37 changes: 37 additions & 0 deletions tests/test_smart_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,28 @@ struct SharedFromThisVBase : std::enable_shared_from_this<SharedFromThisVBase> {
};
struct SharedFromThisVirt : virtual SharedFromThisVBase {};

// Issue #5989: static_pointer_cast where dynamic_pointer_cast is needed
// (virtual inheritance with shared_ptr holder)
struct SftVirtBase : std::enable_shared_from_this<SftVirtBase> {
SftVirtBase() = default;
virtual ~SftVirtBase() = default;
static std::shared_ptr<SftVirtBase> create() { return std::make_shared<SftVirtBase>(); }
virtual std::string name() { return "SftVirtBase"; }
};
struct SftVirtDerived : SftVirtBase {
using SftVirtBase::SftVirtBase;
static std::shared_ptr<SftVirtDerived> create() { return std::make_shared<SftVirtDerived>(); }
std::string name() override { return "SftVirtDerived"; }
};
struct SftVirtDerived2 : virtual SftVirtDerived {
using SftVirtDerived::SftVirtDerived;
static std::shared_ptr<SftVirtDerived2> create() {
return std::make_shared<SftVirtDerived2>();
}
std::string name() override { return "SftVirtDerived2"; }
std::string call_name(const std::shared_ptr<SftVirtDerived2> &d2) { return d2->name(); }
};

// test_move_only_holder
struct C {
C() { print_created(this); }
Expand Down Expand Up @@ -522,6 +544,21 @@ TEST_SUBMODULE(smart_ptr, m) {
py::class_<SharedFromThisVirt, std::shared_ptr<SharedFromThisVirt>>(m, "SharedFromThisVirt")
.def_static("get", []() { return sft.get(); });

// Issue #5989: static_pointer_cast where dynamic_pointer_cast is needed
py::class_<SftVirtBase, std::shared_ptr<SftVirtBase>>(m, "SftVirtBase")
.def(py::init<>(&SftVirtBase::create))
.def("name", &SftVirtBase::name);
py::class_<SftVirtDerived, SftVirtBase, std::shared_ptr<SftVirtDerived>>(m, "SftVirtDerived")
.def(py::init<>(&SftVirtDerived::create));
py::class_<SftVirtDerived2, SftVirtDerived, std::shared_ptr<SftVirtDerived2>>(
m, "SftVirtDerived2")
.def(py::init<>(&SftVirtDerived2::create))
// TODO: Remove this once inherited methods work through virtual bases.
// Without it, d2.name() segfaults because pybind11 uses an incorrect
// pointer offset when dispatching through the virtual inheritance chain.
.def("name", &SftVirtDerived2::name)
.def("call_name", &SftVirtDerived2::call_name, py::arg("d2"));

// test_move_only_holder
py::class_<C, custom_unique_ptr<C>>(m, "TypeWithMoveOnlyHolder")
.def_static("make", []() { return custom_unique_ptr<C>(new C); })
Expand Down
13 changes: 13 additions & 0 deletions tests/test_smart_ptr.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,19 @@ def test_shared_ptr_from_this_and_references():
assert y is z


def test_shared_from_this_virt_shared_ptr_arg():
"""Issue #5989: static_pointer_cast fails with virtual inheritance."""
b = m.SftVirtBase()
assert b.name() == "SftVirtBase"

d = m.SftVirtDerived()
assert d.name() == "SftVirtDerived"

d2 = m.SftVirtDerived2()
assert d2.name() == "SftVirtDerived2"
assert d2.call_name(d2) == "SftVirtDerived2"


@pytest.mark.skipif("env.GRAALPY", reason="Cannot reliably trigger GC")
def test_move_only_holder():
a = m.TypeWithMoveOnlyHolder.make()
Expand Down
Loading