Skip to content

Commit 143ff27

Browse files
committed
[PYTHON] Use isinstance() instead of type() == for type checks in pandas/types.py
Replace all `type(dt) == SomeType` comparisons with `isinstance(dt, SomeType)` in `to_arrow_type`, `_check_arrow_array_timestamps_localize`, and `_to_corrected_pandas_type` for correctness with subclasses. Co-authored-by: Isaac
1 parent dd492dd commit 143ff27

1 file changed

Lines changed: 33 additions & 33 deletions

File tree

python/pyspark/sql/pandas/types.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -116,38 +116,38 @@ def to_arrow_type(
116116
"""
117117
import pyarrow as pa
118118

119-
if type(dt) == BooleanType:
119+
if isinstance(dt, BooleanType):
120120
arrow_type = pa.bool_()
121-
elif type(dt) == ByteType:
121+
elif isinstance(dt, ByteType):
122122
arrow_type = pa.int8()
123-
elif type(dt) == ShortType:
123+
elif isinstance(dt, ShortType):
124124
arrow_type = pa.int16()
125-
elif type(dt) == IntegerType:
125+
elif isinstance(dt, IntegerType):
126126
arrow_type = pa.int32()
127-
elif type(dt) == LongType:
127+
elif isinstance(dt, LongType):
128128
arrow_type = pa.int64()
129-
elif type(dt) == FloatType:
129+
elif isinstance(dt, FloatType):
130130
arrow_type = pa.float32()
131-
elif type(dt) == DoubleType:
131+
elif isinstance(dt, DoubleType):
132132
arrow_type = pa.float64()
133-
elif type(dt) == DecimalType:
133+
elif isinstance(dt, DecimalType):
134134
arrow_type = pa.decimal128(dt.precision, dt.scale)
135-
elif type(dt) == StringType:
135+
elif isinstance(dt, StringType):
136136
arrow_type = pa.large_string() if prefers_large_types else pa.string()
137-
elif type(dt) == BinaryType:
137+
elif isinstance(dt, BinaryType):
138138
arrow_type = pa.large_binary() if prefers_large_types else pa.binary()
139-
elif type(dt) == DateType:
139+
elif isinstance(dt, DateType):
140140
arrow_type = pa.date32()
141-
elif type(dt) == TimestampType:
141+
elif isinstance(dt, TimestampType):
142142
assert timezone is not None
143143
arrow_type = pa.timestamp("us", tz=timezone)
144-
elif type(dt) == TimestampNTZType:
144+
elif isinstance(dt, TimestampNTZType):
145145
arrow_type = pa.timestamp("us", tz=None)
146-
elif type(dt) == DayTimeIntervalType:
146+
elif isinstance(dt, DayTimeIntervalType):
147147
arrow_type = pa.duration("us")
148-
elif type(dt) == TimeType:
148+
elif isinstance(dt, TimeType):
149149
arrow_type = pa.time64("ns")
150-
elif type(dt) == ArrayType:
150+
elif isinstance(dt, ArrayType):
151151
field = pa.field(
152152
"element",
153153
to_arrow_type(
@@ -159,7 +159,7 @@ def to_arrow_type(
159159
nullable=dt.containsNull,
160160
)
161161
arrow_type = pa.list_(field)
162-
elif type(dt) == MapType:
162+
elif isinstance(dt, MapType):
163163
key_field = pa.field(
164164
"key",
165165
to_arrow_type(
@@ -181,7 +181,7 @@ def to_arrow_type(
181181
nullable=dt.valueContainsNull,
182182
)
183183
arrow_type = pa.map_(key_field, value_field)
184-
elif type(dt) == StructType:
184+
elif isinstance(dt, StructType):
185185
field_names = dt.names
186186
if error_on_duplicated_field_names_in_struct and len(set(field_names)) != len(field_names):
187187
raise UnsupportedOperationException(
@@ -203,7 +203,7 @@ def to_arrow_type(
203203
for field in dt
204204
]
205205
arrow_type = pa.struct(fields)
206-
elif type(dt) == NullType:
206+
elif isinstance(dt, NullType):
207207
arrow_type = pa.null()
208208
elif isinstance(dt, UserDefinedType):
209209
arrow_type = to_arrow_type(
@@ -212,15 +212,15 @@ def to_arrow_type(
212212
timezone=timezone,
213213
prefers_large_types=prefers_large_types,
214214
)
215-
elif type(dt) == VariantType:
215+
elif isinstance(dt, VariantType):
216216
fields = [
217217
pa.field("value", pa.binary(), nullable=False),
218218
# The metadata field is tagged so we can identify that the arrow struct actually
219219
# represents a variant.
220220
pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}),
221221
]
222222
arrow_type = pa.struct(fields)
223-
elif type(dt) == GeometryType:
223+
elif isinstance(dt, GeometryType):
224224
fields = [
225225
pa.field("srid", pa.int32(), nullable=False),
226226
pa.field(
@@ -231,7 +231,7 @@ def to_arrow_type(
231231
),
232232
]
233233
arrow_type = pa.struct(fields)
234-
elif type(dt) == GeographyType:
234+
elif isinstance(dt, GeographyType):
235235
fields = [
236236
pa.field("srid", pa.int32(), nullable=False),
237237
pa.field(
@@ -545,7 +545,7 @@ def _check_arrow_array_timestamps_localize(
545545
if types.is_timestamp(a.type) and truncate and a.type.unit == "ns":
546546
a = pc.floor_temporal(a, unit="microsecond")
547547

548-
if types.is_timestamp(a.type) and a.type.tz is None and type(dt) == TimestampType:
548+
if types.is_timestamp(a.type) and a.type.tz is None and isinstance(dt, TimestampType):
549549
assert timezone is not None
550550

551551
# Only localize timestamps that will become Spark TimestampType columns.
@@ -867,31 +867,31 @@ def _to_corrected_pandas_type(dt: DataType) -> Optional[Any]:
867867
import numpy as np
868868
import pandas as pd
869869

870-
if type(dt) == ByteType:
870+
if isinstance(dt, ByteType):
871871
return np.int8
872-
elif type(dt) == ShortType:
872+
elif isinstance(dt, ShortType):
873873
return np.int16
874-
elif type(dt) == IntegerType:
874+
elif isinstance(dt, IntegerType):
875875
return np.int32
876-
elif type(dt) == LongType:
876+
elif isinstance(dt, LongType):
877877
return np.int64
878-
elif type(dt) == FloatType:
878+
elif isinstance(dt, FloatType):
879879
return np.float32
880-
elif type(dt) == DoubleType:
880+
elif isinstance(dt, DoubleType):
881881
return np.float64
882-
elif type(dt) == BooleanType:
882+
elif isinstance(dt, BooleanType):
883883
return bool
884-
elif type(dt) == TimestampType:
884+
elif isinstance(dt, TimestampType):
885885
if LooseVersion(pd.__version__) < "3.0.0":
886886
return np.dtype("datetime64[ns]")
887887
else:
888888
return np.dtype("datetime64[us]")
889-
elif type(dt) == TimestampNTZType:
889+
elif isinstance(dt, TimestampNTZType):
890890
if LooseVersion(pd.__version__) < "3.0.0":
891891
return np.dtype("datetime64[ns]")
892892
else:
893893
return np.dtype("datetime64[us]")
894-
elif type(dt) == DayTimeIntervalType:
894+
elif isinstance(dt, DayTimeIntervalType):
895895
if LooseVersion(pd.__version__) < "3.0.0":
896896
return np.dtype("timedelta64[ns]")
897897
else:

0 commit comments

Comments
 (0)