@@ -48,6 +48,8 @@ def __init__(self, format_reader: RecordBatchReader, index_mapping: List[int], p
4848 self .first_row_id = first_row_id
4949 self .max_sequence_number = max_sequence_number
5050 self .system_fields = system_fields
51+ self .requested_field_names = [field .name for field in fields ] if fields else None
52+ self .fields = fields
5153
5254 def read_arrow_batch (self , start_idx = None , end_idx = None ) -> Optional [RecordBatch ]:
5355 if isinstance (self .format_reader , FormatBlobReader ):
@@ -57,11 +59,20 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch
5759 if record_batch is None :
5860 return None
5961
62+ num_rows = record_batch .num_rows
6063 if self .partition_info is None and self .index_mapping is None :
6164 if self .row_tracking_enabled and self .system_fields :
6265 record_batch = self ._assign_row_tracking (record_batch )
6366 return record_batch
6467
68+ if (self .partition_info is None and self .index_mapping is not None
69+ and not self .requested_field_names ):
70+ ncol = record_batch .num_columns
71+ if len (self .index_mapping ) == ncol and self .index_mapping == list (range (ncol )):
72+ if self .row_tracking_enabled and self .system_fields :
73+ record_batch = self ._assign_row_tracking (record_batch )
74+ return record_batch
75+
6576 inter_arrays = []
6677 inter_names = []
6778 num_rows = record_batch .num_rows
@@ -79,28 +90,101 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch
7990 inter_arrays .append (record_batch .column (real_index ))
8091 inter_names .append (record_batch .schema .field (real_index ).name )
8192 else :
82- inter_arrays = record_batch .columns
83- inter_names = record_batch .schema .names
84-
85- if self .index_mapping is not None :
93+ inter_arrays = list (record_batch .columns )
94+ inter_names = list (record_batch .schema .names )
95+
96+ if self .requested_field_names is not None :
97+ if (len (inter_names ) <= len (self .requested_field_names )
98+ and inter_names == self .requested_field_names [:len (inter_names )]):
99+ ordered_arrays = list (inter_arrays )
100+ ordered_names = list (inter_names )
101+ for name in self .requested_field_names [len (inter_names ):]:
102+ field = self .schema_map .get (name )
103+ ordered_arrays .append (
104+ pa .nulls (num_rows , type = field .type ) if field is not None else pa .nulls (num_rows )
105+ )
106+ ordered_names .append (name )
107+ inter_arrays = ordered_arrays
108+ inter_names = ordered_names
109+ else :
110+ ordered_arrays = []
111+ ordered_names = []
112+ for name in self .requested_field_names :
113+ if name in inter_names :
114+ ordered_arrays .append (inter_arrays [inter_names .index (name )])
115+ ordered_names .append (name )
116+ else :
117+ field = self .schema_map .get (name )
118+ ordered_arrays .append (
119+ pa .nulls (num_rows , type = field .type ) if field is not None else pa .nulls (num_rows )
120+ )
121+ ordered_names .append (name )
122+ inter_arrays = ordered_arrays
123+ inter_names = ordered_names
124+
125+ if self .index_mapping is not None and not (
126+ self .requested_field_names is not None and inter_names == self .requested_field_names ):
86127 mapped_arrays = []
87128 mapped_names = []
129+ partition_names = set ()
130+ if self .partition_info :
131+ for i in range (len (self .partition_info .partition_fields )):
132+ partition_names .add (self .partition_info .partition_fields [i ].name )
133+
134+ non_partition_indices = [idx for idx , name in enumerate (inter_names ) if name not in partition_names ]
88135 for i , real_index in enumerate (self .index_mapping ):
89- if 0 <= real_index < len (inter_arrays ):
90- mapped_arrays .append (inter_arrays [real_index ])
91- mapped_names .append (inter_names [real_index ])
136+ if 0 <= real_index < len (non_partition_indices ):
137+ actual_index = non_partition_indices [real_index ]
138+ mapped_arrays .append (inter_arrays [actual_index ])
139+ mapped_names .append (inter_names [actual_index ])
92140 else :
93141 null_array = pa .nulls (num_rows )
94142 mapped_arrays .append (null_array )
95143 mapped_names .append (f"null_col_{ i } " )
96144
145+ if self .partition_info :
146+ partition_names = set ()
147+ partition_arrays_map = {}
148+ for i in range (len (inter_names )):
149+ field_name = inter_names [i ]
150+ if field_name in partition_names or (self .partition_info and any (
151+ self .partition_info .partition_fields [j ].name == field_name
152+ for j in range (len (self .partition_info .partition_fields ))
153+ )):
154+ partition_names .add (field_name )
155+ partition_arrays_map [field_name ] = inter_arrays [i ]
156+
157+ if self .requested_field_names :
158+ final_arrays = []
159+ final_names = []
160+ mapped_name_to_array = {name : arr for name , arr in zip (mapped_names , mapped_arrays )}
161+
162+ for name in self .requested_field_names :
163+ if name in mapped_name_to_array :
164+ final_arrays .append (mapped_name_to_array [name ])
165+ final_names .append (name )
166+ elif name in partition_arrays_map :
167+ final_arrays .append (partition_arrays_map [name ])
168+ final_names .append (name )
169+
170+ inter_arrays = final_arrays
171+ inter_names = final_names
172+ else :
173+ mapped_name_set = set (mapped_names )
174+ for name , arr in partition_arrays_map .items ():
175+ if name not in mapped_name_set :
176+ mapped_arrays .append (arr )
177+ mapped_names .append (name )
178+ inter_arrays = mapped_arrays
179+ inter_names = mapped_names
180+ else :
181+ inter_arrays = mapped_arrays
182+ inter_names = mapped_names
183+
97184 if self .system_primary_key :
98185 for i in range (len (self .system_primary_key )):
99- if not mapped_names [i ].startswith ("_KEY_" ):
100- mapped_names [i ] = f"_KEY_{ mapped_names [i ]} "
101-
102- inter_arrays = mapped_arrays
103- inter_names = mapped_names
186+ if i < len (inter_names ) and not inter_names [i ].startswith ("_KEY_" ):
187+ inter_names [i ] = f"_KEY_{ inter_names [i ]} "
104188
105189 # to contains 'not null' property
106190 final_fields = []
@@ -109,6 +193,9 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch
109193 target_field = self .schema_map .get (name )
110194 if not target_field :
111195 target_field = pa .field (name , array .type )
196+ else :
197+ if name in (SpecialFields .ROW_ID .name , SpecialFields .SEQUENCE_NUMBER .name ):
198+ target_field = pa .field (name , target_field .type , nullable = False )
112199 final_fields .append (target_field )
113200 final_schema = pa .schema (final_fields )
114201 record_batch = pa .RecordBatch .from_arrays (inter_arrays , schema = final_schema )
@@ -122,20 +209,26 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch
122209 def _assign_row_tracking (self , record_batch : RecordBatch ) -> RecordBatch :
123210 """Assign row tracking meta fields (_ROW_ID and _SEQUENCE_NUMBER)."""
124211 arrays = list (record_batch .columns )
212+ num_cols = len (arrays )
125213
126- # Handle _ROW_ID field
127214 if SpecialFields .ROW_ID .name in self .system_fields .keys ():
128215 idx = self .system_fields [SpecialFields .ROW_ID .name ]
129- # Create a new array that fills with computed row IDs
130- arrays [idx ] = pa .array (range (self .first_row_id , self .first_row_id + record_batch .num_rows ), type = pa .int64 ())
216+ if idx < num_cols :
217+ arrays [idx ] = pa .array (range (self .first_row_id , self .first_row_id + record_batch .num_rows ), type = pa .int64 ())
131218
132- # Handle _SEQUENCE_NUMBER field
133219 if SpecialFields .SEQUENCE_NUMBER .name in self .system_fields .keys ():
134220 idx = self .system_fields [SpecialFields .SEQUENCE_NUMBER .name ]
135- # Create a new array that fills with max_sequence_number
136- arrays [idx ] = pa .repeat (self .max_sequence_number , record_batch .num_rows )
137-
138- return pa .RecordBatch .from_arrays (arrays , names = record_batch .schema .names )
221+ if idx < num_cols :
222+ arrays [idx ] = pa .repeat (self .max_sequence_number , record_batch .num_rows )
223+
224+ names = record_batch .schema .names
225+ fields = []
226+ for i , name in enumerate (names ):
227+ input_field = record_batch .schema .field (name )
228+ fields .append (pa .field (name , arrays [i ].type , nullable = input_field .nullable ))
229+ if fields :
230+ return pa .RecordBatch .from_arrays (arrays , schema = pa .schema (fields ))
231+ return pa .RecordBatch .from_arrays (arrays , names = names )
139232
140233 def close (self ) -> None :
141234 self .format_reader .close ()
0 commit comments