@@ -116,38 +116,38 @@ def to_arrow_type(
116116 """
117117 import pyarrow as pa
118118
119- if type (dt ) == BooleanType :
119+ if isinstance (dt , BooleanType ) :
120120 arrow_type = pa .bool_ ()
121- elif type (dt ) == ByteType :
121+ elif isinstance (dt , ByteType ) :
122122 arrow_type = pa .int8 ()
123- elif type (dt ) == ShortType :
123+ elif isinstance (dt , ShortType ) :
124124 arrow_type = pa .int16 ()
125- elif type (dt ) == IntegerType :
125+ elif isinstance (dt , IntegerType ) :
126126 arrow_type = pa .int32 ()
127- elif type (dt ) == LongType :
127+ elif isinstance (dt , LongType ) :
128128 arrow_type = pa .int64 ()
129- elif type (dt ) == FloatType :
129+ elif isinstance (dt , FloatType ) :
130130 arrow_type = pa .float32 ()
131- elif type (dt ) == DoubleType :
131+ elif isinstance (dt , DoubleType ) :
132132 arrow_type = pa .float64 ()
133- elif type (dt ) == DecimalType :
133+ elif isinstance (dt , DecimalType ) :
134134 arrow_type = pa .decimal128 (dt .precision , dt .scale )
135- elif type (dt ) == StringType :
135+ elif isinstance (dt , StringType ) :
136136 arrow_type = pa .large_string () if prefers_large_types else pa .string ()
137- elif type (dt ) == BinaryType :
137+ elif isinstance (dt , BinaryType ) :
138138 arrow_type = pa .large_binary () if prefers_large_types else pa .binary ()
139- elif type (dt ) == DateType :
139+ elif isinstance (dt , DateType ) :
140140 arrow_type = pa .date32 ()
141- elif type (dt ) == TimestampType :
141+ elif isinstance (dt , TimestampType ) :
142142 assert timezone is not None
143143 arrow_type = pa .timestamp ("us" , tz = timezone )
144- elif type (dt ) == TimestampNTZType :
144+ elif isinstance (dt , TimestampNTZType ) :
145145 arrow_type = pa .timestamp ("us" , tz = None )
146- elif type (dt ) == DayTimeIntervalType :
146+ elif isinstance (dt , DayTimeIntervalType ) :
147147 arrow_type = pa .duration ("us" )
148- elif type (dt ) == TimeType :
148+ elif isinstance (dt , TimeType ) :
149149 arrow_type = pa .time64 ("ns" )
150- elif type (dt ) == ArrayType :
150+ elif isinstance (dt , ArrayType ) :
151151 field = pa .field (
152152 "element" ,
153153 to_arrow_type (
@@ -159,7 +159,7 @@ def to_arrow_type(
159159 nullable = dt .containsNull ,
160160 )
161161 arrow_type = pa .list_ (field )
162- elif type (dt ) == MapType :
162+ elif isinstance (dt , MapType ) :
163163 key_field = pa .field (
164164 "key" ,
165165 to_arrow_type (
@@ -181,7 +181,7 @@ def to_arrow_type(
181181 nullable = dt .valueContainsNull ,
182182 )
183183 arrow_type = pa .map_ (key_field , value_field )
184- elif type (dt ) == StructType :
184+ elif isinstance (dt , StructType ) :
185185 field_names = dt .names
186186 if error_on_duplicated_field_names_in_struct and len (set (field_names )) != len (field_names ):
187187 raise UnsupportedOperationException (
@@ -203,7 +203,7 @@ def to_arrow_type(
203203 for field in dt
204204 ]
205205 arrow_type = pa .struct (fields )
206- elif type (dt ) == NullType :
206+ elif isinstance (dt , NullType ) :
207207 arrow_type = pa .null ()
208208 elif isinstance (dt , UserDefinedType ):
209209 arrow_type = to_arrow_type (
@@ -212,15 +212,15 @@ def to_arrow_type(
212212 timezone = timezone ,
213213 prefers_large_types = prefers_large_types ,
214214 )
215- elif type (dt ) == VariantType :
215+ elif isinstance (dt , VariantType ) :
216216 fields = [
217217 pa .field ("value" , pa .binary (), nullable = False ),
218218 # The metadata field is tagged so we can identify that the arrow struct actually
219219 # represents a variant.
220220 pa .field ("metadata" , pa .binary (), nullable = False , metadata = {b"variant" : b"true" }),
221221 ]
222222 arrow_type = pa .struct (fields )
223- elif type (dt ) == GeometryType :
223+ elif isinstance (dt , GeometryType ) :
224224 fields = [
225225 pa .field ("srid" , pa .int32 (), nullable = False ),
226226 pa .field (
@@ -231,7 +231,7 @@ def to_arrow_type(
231231 ),
232232 ]
233233 arrow_type = pa .struct (fields )
234- elif type (dt ) == GeographyType :
234+ elif isinstance (dt , GeographyType ) :
235235 fields = [
236236 pa .field ("srid" , pa .int32 (), nullable = False ),
237237 pa .field (
@@ -545,7 +545,7 @@ def _check_arrow_array_timestamps_localize(
545545 if types .is_timestamp (a .type ) and truncate and a .type .unit == "ns" :
546546 a = pc .floor_temporal (a , unit = "microsecond" )
547547
548- if types .is_timestamp (a .type ) and a .type .tz is None and type (dt ) == TimestampType :
548+ if types .is_timestamp (a .type ) and a .type .tz is None and isinstance (dt , TimestampType ) :
549549 assert timezone is not None
550550
551551 # Only localize timestamps that will become Spark TimestampType columns.
@@ -867,31 +867,31 @@ def _to_corrected_pandas_type(dt: DataType) -> Optional[Any]:
867867 import numpy as np
868868 import pandas as pd
869869
870- if type (dt ) == ByteType :
870+ if isinstance (dt , ByteType ) :
871871 return np .int8
872- elif type (dt ) == ShortType :
872+ elif isinstance (dt , ShortType ) :
873873 return np .int16
874- elif type (dt ) == IntegerType :
874+ elif isinstance (dt , IntegerType ) :
875875 return np .int32
876- elif type (dt ) == LongType :
876+ elif isinstance (dt , LongType ) :
877877 return np .int64
878- elif type (dt ) == FloatType :
878+ elif isinstance (dt , FloatType ) :
879879 return np .float32
880- elif type (dt ) == DoubleType :
880+ elif isinstance (dt , DoubleType ) :
881881 return np .float64
882- elif type (dt ) == BooleanType :
882+ elif isinstance (dt , BooleanType ) :
883883 return bool
884- elif type (dt ) == TimestampType :
884+ elif isinstance (dt , TimestampType ) :
885885 if LooseVersion (pd .__version__ ) < "3.0.0" :
886886 return np .dtype ("datetime64[ns]" )
887887 else :
888888 return np .dtype ("datetime64[us]" )
889- elif type (dt ) == TimestampNTZType :
889+ elif isinstance (dt , TimestampNTZType ) :
890890 if LooseVersion (pd .__version__ ) < "3.0.0" :
891891 return np .dtype ("datetime64[ns]" )
892892 else :
893893 return np .dtype ("datetime64[us]" )
894- elif type (dt ) == DayTimeIntervalType :
894+ elif isinstance (dt , DayTimeIntervalType ) :
895895 if LooseVersion (pd .__version__ ) < "3.0.0" :
896896 return np .dtype ("timedelta64[ns]" )
897897 else :
0 commit comments