Skip to content

Commit b6844d3

Browse files
committed
Run mypy on pytests
1 parent ae44f85 commit b6844d3

8 files changed

Lines changed: 91 additions & 56 deletions

File tree

.github/workflows/ci.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,6 @@ jobs:
767767
- uses: dtolnay/rust-toolchain@stable
768768
with:
769769
targets: ${{ matrix.platform.rust-target }}
770-
components: rust-src
771770
- uses: actions/setup-python@v6
772771
with:
773772
python-version: "3.14"
@@ -784,6 +783,18 @@ jobs:
784783
needs: [fmt]
785784
if: ${{ !contains(github.event.pull_request.labels.*.name, 'CI-build-full') && github.event_name == 'pull_request' }}
786785
runs-on: ubuntu-latest
786+
steps:
787+
- uses: actions/checkout@v6.0.2
788+
- uses: dtolnay/rust-toolchain@stable
789+
- uses: actions/setup-python@v6
790+
with:
791+
python-version: "3.14"
792+
- run: python -m pip install --upgrade pip && pip install nox[uv]
793+
- run: nox -s test-introspection
794+
795+
mypy-pytests:
796+
needs: [fmt]
797+
runs-on: ubuntu-latest
787798
steps:
788799
- uses: actions/checkout@v6.0.2
789800
- uses: dtolnay/rust-toolchain@stable
@@ -793,7 +804,8 @@ jobs:
793804
with:
794805
python-version: "3.14"
795806
- run: python -m pip install --upgrade pip && pip install nox[uv]
796-
- run: nox -s test-introspection
807+
- run: nox -s mypy
808+
working-directory: pytests
797809

798810
conclusion:
799811
needs:

pytests/noxfile.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import shutil
2+
from pathlib import Path
3+
14
import nox
25
import sys
36
from nox.command import CommandFailed
@@ -34,3 +37,29 @@ def try_install_binary(package: str, constraint: str):
3437
def bench(session: nox.Session):
3538
session.install(".[dev]")
3639
session.run("pytest", "--benchmark-enable", "--benchmark-only", *session.posargs)
40+
41+
42+
@nox.session
43+
def mypy(session: nox.Session):
44+
session.env["MATURIN_PEP517_ARGS"] = "--profile=dev"
45+
try:
46+
# We move the stubs where maturin is expecting them to be
47+
shutil.copytree("stubs", "pyo3_pytests")
48+
(Path("pyo3_pytests") / "py.typed").touch()
49+
session.install(".[dev]")
50+
51+
# TODO: remove --disable-error-code", "override" when __eq__ and __ne__ will always take object for input
52+
# TODO: remove "--disable-error-code", "misc" when #[classattr] will be properly emitted
53+
session.run_always(
54+
"python",
55+
"-m",
56+
"mypy",
57+
"tests",
58+
"--disable-error-code",
59+
"override",
60+
"--disable-error-code",
61+
"misc",
62+
)
63+
# TODO: enable stubtest when previously listed errors will be fixed session.run_always("python", "-m", "mypy.stubtest", "pyo3_pytests")
64+
finally:
65+
shutil.rmtree("pyo3_pytests")

pytests/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ classifiers = [
2121
[project.optional-dependencies]
2222
dev = [
2323
"hypothesis>=3.55",
24+
"mypy~=1.0",
2425
"pytest-asyncio>=0.21,<2",
2526
"pytest-benchmark>=3.4",
2627
"pytest>=7",

pytests/src/datetime.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#![cfg(not(Py_LIMITED_API))]
2-
31
use pyo3::prelude::*;
42
use pyo3::types::{
53
PyDate, PyDateAccess, PyDateTime, PyDelta, PyDeltaAccess, PyTime, PyTimeAccess, PyTuple,
@@ -190,15 +188,19 @@ impl TzClass {
190188
TzClass {}
191189
}
192190

193-
fn utcoffset<'py>(&self, dt: &Bound<'py, PyDateTime>) -> PyResult<Bound<'py, PyDelta>> {
194-
PyDelta::new(dt.py(), 0, 3600, 0, true)
191+
fn utcoffset<'py>(
192+
&self,
193+
_dt: Option<&Bound<'_, PyDateTime>>,
194+
py: Python<'py>,
195+
) -> PyResult<Bound<'py, PyDelta>> {
196+
PyDelta::new(py, 0, 3600, 0, true)
195197
}
196198

197-
fn tzname(&self, _dt: &Bound<'_, PyDateTime>) -> String {
199+
fn tzname(&self, _dt: Option<&Bound<'_, PyDateTime>>) -> String {
198200
String::from("+01:00")
199201
}
200202

201-
fn dst<'py>(&self, _dt: &Bound<'py, PyDateTime>) -> Option<Bound<'py, PyDelta>> {
203+
fn dst(&self, _dt: Option<&Bound<'_, PyDateTime>>) -> Option<Bound<'static, PyDelta>> {
202204
None
203205
}
204206
}

pytests/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ mod awaitable;
55
mod buf_and_str;
66
mod comparisons;
77
mod consts;
8+
#[cfg(any(not(Py_LIMITED_API), Py_3_12))]
89
mod datetime;
910
mod dict_iter;
1011
mod enums;
@@ -26,7 +27,7 @@ mod pyo3_pytests {
2627
#[pymodule_export]
2728
use buf_and_str::buf_and_str;
2829

29-
#[cfg(not(Py_LIMITED_API))]
30+
#[cfg(any(not(Py_LIMITED_API), Py_3_12))]
3031
#[pymodule_export]
3132
use datetime::datetime;
3233

pytests/stubs/datetime.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ from typing import final
44
@final
55
class TzClass(tzinfo):
66
def __new__(cls, /) -> TzClass: ...
7-
def dst(self, /, _dt: datetime) -> timedelta | None: ...
8-
def tzname(self, /, _dt: datetime) -> str: ...
9-
def utcoffset(self, /, dt: datetime) -> timedelta: ...
7+
def dst(self, /, _dt: datetime | None) -> timedelta | None: ...
8+
def tzname(self, /, _dt: datetime | None) -> str: ...
9+
def utcoffset(self, /, _dt: datetime | None) -> timedelta: ...
1010

1111
def date_from_timestamp(timestamp: int) -> date: ...
1212
def datetime_from_timestamp(ts: float, tz: tzinfo | None = None) -> datetime: ...

pytests/tests/test_comparisons.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Type, Union
1+
from typing import Any
22

33
import sys
44
import pytest
@@ -24,7 +24,7 @@ def __eq__(self, other: object) -> bool:
2424
else:
2525
return NotImplemented
2626

27-
def __ne__(self, other: Self) -> bool:
27+
def __ne__(self, other: object) -> bool:
2828
if isinstance(other, self.__class__):
2929
return self.x != other.x
3030
else:
@@ -33,13 +33,13 @@ def __ne__(self, other: Self) -> bool:
3333

3434
@pytest.mark.skipif(
3535
sys.implementation.name == "graalpy"
36-
and __graalpython__.get_graalvm_version().startswith("24.1"), # noqa: F821
36+
and __graalpython__.get_graalvm_version().startswith("24.1"), # type: ignore[name-defined] # noqa: F821
3737
reason="Bug in GraalPy 24.1",
3838
)
3939
@pytest.mark.parametrize(
4040
"ty", (Eq, EqDerived, PyEq), ids=("rust", "rust-derived", "python")
4141
)
42-
def test_eq(ty: Type[Union[Eq, EqDerived, PyEq]]):
42+
def test_eq(ty: Any):
4343
a = ty(0)
4444
b = ty(0)
4545
c = ty(1)
@@ -78,12 +78,12 @@ class PyEqDefaultNe:
7878
def __init__(self, x: int) -> None:
7979
self.x = x
8080

81-
def __eq__(self, other: Self) -> bool:
82-
return self.x == other.x
81+
def __eq__(self, other: object) -> bool:
82+
return isinstance(other, self.__class__) and self.x == other.x
8383

8484

8585
@pytest.mark.parametrize("ty", (EqDefaultNe, PyEqDefaultNe), ids=("rust", "python"))
86-
def test_eq_default_ne(ty: Type[Union[EqDefaultNe, PyEqDefaultNe]]):
86+
def test_eq_default_ne(ty: Any):
8787
a = ty(0)
8888
b = ty(0)
8989
c = ty(1)
@@ -121,10 +121,14 @@ def __lt__(self, other: Self) -> bool:
121121
def __le__(self, other: Self) -> bool:
122122
return self.x <= other.x
123123

124-
def __eq__(self, other: Self) -> bool:
124+
def __eq__(self, other: object) -> bool:
125+
if not isinstance(other, self.__class__):
126+
return NotImplemented
125127
return self.x == other.x
126128

127-
def __ne__(self, other: Self) -> bool:
129+
def __ne__(self, other: object) -> bool:
130+
if not isinstance(other, self.__class__):
131+
return NotImplemented
128132
return self.x != other.x
129133

130134
def __gt__(self, other: Self) -> bool:
@@ -139,7 +143,7 @@ def __ge__(self, other: Self) -> bool:
139143
(Ordered, OrderedDerived, OrderedRichCmp, PyOrdered),
140144
ids=("rust", "rust-derived", "rust-richcmp", "python"),
141145
)
142-
def test_ordered(ty: Type[Union[Ordered, OrderedDerived, OrderedRichCmp, PyOrdered]]):
146+
def test_ordered(ty: Any):
143147
a = ty(0)
144148
b = ty(0)
145149
c = ty(1)
@@ -174,7 +178,9 @@ def __lt__(self, other: Self) -> bool:
174178
def __le__(self, other: Self) -> bool:
175179
return self.x <= other.x
176180

177-
def __eq__(self, other: Self) -> bool:
181+
def __eq__(self, other: object) -> bool:
182+
if not isinstance(other, self.__class__):
183+
return NotImplemented
178184
return self.x == other.x
179185

180186
def __gt__(self, other: Self) -> bool:
@@ -187,7 +193,7 @@ def __ge__(self, other: Self) -> bool:
187193
@pytest.mark.parametrize(
188194
"ty", (OrderedDefaultNe, PyOrderedDefaultNe), ids=("rust", "python")
189195
)
190-
def test_ordered_default_ne(ty: Type[Union[OrderedDefaultNe, PyOrderedDefaultNe]]):
196+
def test_ordered_default_ne(ty: Any):
191197
a = ty(0)
192198
b = ty(0)
193199
c = ty(1)

pytests/tests/test_enums.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -73,28 +73,20 @@ def test_complex_enum_field_getters():
7373
)
7474
def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
7575
if isinstance(variant, enums.ComplexEnum.Int):
76-
x = variant.i
77-
assert x == 42
76+
assert variant.i == 42
7877
elif isinstance(variant, enums.ComplexEnum.Float):
79-
x = variant.f
80-
assert x == 3.14
78+
assert variant.f == 3.14
8179
elif isinstance(variant, enums.ComplexEnum.Str):
82-
x = variant.s
83-
assert x == "hello"
80+
assert variant.s == "hello"
8481
elif isinstance(variant, enums.ComplexEnum.EmptyStruct):
8582
assert True
8683
elif isinstance(variant, enums.ComplexEnum.MultiFieldStruct):
87-
x = variant.a
88-
y = variant.b
89-
z = variant.c
90-
assert x == 42
91-
assert y == 3.14
92-
assert z is True
84+
assert variant.a == 42
85+
assert variant.b == 3.14
86+
assert variant.c is True
9387
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
94-
x = variant.a
95-
y = variant.b
96-
assert x == 42
97-
assert y is None
88+
assert variant.a == 42
89+
assert variant.b is None
9890
else:
9991
assert False
10092

@@ -113,28 +105,20 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
113105
def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum):
114106
variant = enums.do_complex_stuff(variant)
115107
if isinstance(variant, enums.ComplexEnum.Int):
116-
x = variant.i
117-
assert x == 5
108+
assert variant.i == 5
118109
elif isinstance(variant, enums.ComplexEnum.Float):
119-
x = variant.f
120-
assert x == 9.8596
110+
assert variant.f == 9.8596
121111
elif isinstance(variant, enums.ComplexEnum.Str):
122-
x = variant.s
123-
assert x == "42"
112+
assert variant.s == "42"
124113
elif isinstance(variant, enums.ComplexEnum.EmptyStruct):
125114
assert True
126115
elif isinstance(variant, enums.ComplexEnum.MultiFieldStruct):
127-
x = variant.a
128-
y = variant.b
129-
z = variant.c
130-
assert x == 42
131-
assert y == 3.14
132-
assert z is True
116+
assert variant.a == 42
117+
assert variant.b == 3.14
118+
assert variant.c is True
133119
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
134-
x = variant.a
135-
y = variant.b
136-
assert x == 84
137-
assert y == "HELLO"
120+
assert variant.a == 84
121+
assert variant.b == "HELLO"
138122
else:
139123
assert False
140124

0 commit comments

Comments
 (0)