@@ -127,7 +127,8 @@ void PQFlashIndex<T, LabelT>::setup_thread_data(uint64_t nthreads, uint64_t visi
127127 {
128128#pragma omp critical
129129 {
130- SSDThreadData<T> *data = new SSDThreadData<T>(this ->_aligned_dim , visited_reserve, max_degree, max_filters_per_query);
130+ SSDThreadData<T> *data =
131+ new SSDThreadData<T>(this ->_aligned_dim , visited_reserve, max_degree, max_filters_per_query);
131132 this ->reader ->register_thread ();
132133 data->ctx = this ->reader ->get_ctx ();
133134 this ->_thread_data .push (data);
@@ -571,7 +572,8 @@ void PQFlashIndex<T, LabelT>::generate_random_labels(std::vector<LabelT> &labels
571572}
572573
573574template <typename T, typename LabelT>
574- void PQFlashIndex<T, LabelT>::load_label_map (std::basic_istream<char > &map_reader, std::unordered_map<std::string, LabelT>& string_to_int_map)
575+ void PQFlashIndex<T, LabelT>::load_label_map (std::basic_istream<char > &map_reader,
576+ std::unordered_map<std::string, LabelT> &string_to_int_map)
575577{
576578 std::string line, token;
577579 LabelT token_as_num;
@@ -687,7 +689,6 @@ bool PQFlashIndex<T, LabelT>::point_has_any_label(uint32_t point_id, const std::
687689 return ret_val;
688690}
689691
690-
691692template <typename T, typename LabelT>
692693void PQFlashIndex<T, LabelT>::parse_label_file (std::basic_istream<char > &infile, size_t &num_points_labels)
693694{
@@ -778,7 +779,8 @@ template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::set_univers
778779}
779780
780781template <typename T, typename LabelT>
781- void PQFlashIndex<T, LabelT>::load_label_medoid_map (const std::string& labels_to_medoids_filepath, std::istream& medoid_stream)
782+ void PQFlashIndex<T, LabelT>::load_label_medoid_map (const std::string &labels_to_medoids_filepath,
783+ std::istream &medoid_stream)
782784{
783785 std::string line, token;
784786
@@ -840,11 +842,10 @@ void PQFlashIndex<T, LabelT>::load_dummy_map(const std::string &dummy_map_filepa
840842 }
841843 catch (std::system_error &e)
842844 {
843- throw FileException (dummy_map_filepath, e, __FUNCSIG__, __FILE__, __LINE__);
845+ throw FileException (dummy_map_filepath, e, __FUNCSIG__, __FILE__, __LINE__);
844846 }
845847}
846848
847-
848849template <typename T, typename LabelT>
849850#ifdef EXEC_ENV_OLS
850851bool PQFlashIndex<T, LabelT>::use_filter_support (MemoryMappedFiles &files)
@@ -980,7 +981,6 @@ template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::load_labels
980981 ss << " Note: Filter support is enabled but " << dummy_map_file << " file cannot be opened" << std::endl;
981982 diskann::cerr << ss.str ();
982983 }
983-
984984 }
985985 else
986986 {
@@ -1107,11 +1107,8 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
11071107 // bytes are needed to store the header and read in that many using our
11081108 // 'standard' aligned file reader approach.
11091109 reader->open (_disk_index_file);
1110- this ->setup_thread_data (
1111- num_threads,
1112- defaults::VISITED_RESERVE,
1113- defaults::MAX_GRAPH_DEGREE,
1114- (use_filter_support (files)? defaults::MAX_FILTERS_PER_QUERY : 0 ));
1110+ this ->setup_thread_data (num_threads, defaults::VISITED_RESERVE, defaults::MAX_GRAPH_DEGREE,
1111+ (use_filter_support (files) ? defaults::MAX_FILTERS_PER_QUERY : 0 ));
11151112 this ->_max_nthreads = num_threads;
11161113
11171114 char *bytes = getHeaderBytes ();
@@ -1145,6 +1142,77 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
11451142 READ_U64 (index_metadata, _nnodes_per_sector);
11461143 _max_degree = ((_max_node_len - _disk_bytes_per_point) / sizeof (uint32_t )) - 1 ;
11471144
1145+ // Early validation: read first node from disk and validate neighbor data.
1146+ // If data type is wrong, _disk_bytes_per_point will be incorrect, causing
1147+ // us to read garbage for neighbor count and neighbor IDs.
1148+ // Disk layout: [sector 0: metadata] [sector 1+: node data]
1149+ // Node layout: [vector data: _disk_bytes_per_point bytes] [neighbor count: 4 bytes] [neighbor IDs: 4 bytes each]
1150+ if (!_use_disk_index_pq && disk_nnodes > 0 )
1151+ {
1152+ #ifdef EXEC_ENV_OLS
1153+ // In OLS environment, index_metadata is a ContentBuf with only the header (sector 0).
1154+ // We must use the reader (which is already open) to read sector 1 containing the first node.
1155+ uint64_t num_sectors_for_node = _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP (_max_node_len, defaults::SECTOR_LEN);
1156+ char *first_sector_buf = nullptr ;
1157+ alloc_aligned ((void **)&first_sector_buf, num_sectors_for_node * defaults::SECTOR_LEN, defaults::SECTOR_LEN);
1158+
1159+ std::vector<AlignedRead> read_reqs;
1160+ read_reqs.emplace_back (defaults::SECTOR_LEN, num_sectors_for_node * defaults::SECTOR_LEN, first_sector_buf);
1161+
1162+ // We need a temporary IOContext for this read
1163+ IOContext tmp_ctx = reader->get_ctx ();
1164+ reader->read (read_reqs, tmp_ctx);
1165+
1166+ // First node starts at offset 0 within sector 1
1167+ char *first_node_buf = first_sector_buf;
1168+ #else
1169+ // In non-OLS environment, we can seek and read directly from the file stream
1170+ uint64_t num_sectors_for_node = _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP (_max_node_len, defaults::SECTOR_LEN);
1171+ std::vector<char > first_node_buf_vec (num_sectors_for_node * defaults::SECTOR_LEN);
1172+ index_metadata.seekg (defaults::SECTOR_LEN, std::ios::beg);
1173+ index_metadata.read (first_node_buf_vec.data (), num_sectors_for_node * defaults::SECTOR_LEN);
1174+ char *first_node_buf = first_node_buf_vec.data ();
1175+ #endif
1176+
1177+ // Get neighbor count (located after vector data)
1178+ uint32_t *nhood_ptr = reinterpret_cast <uint32_t *>(first_node_buf + _disk_bytes_per_point);
1179+ uint32_t num_neighbors = *nhood_ptr;
1180+
1181+ // Validate neighbor count is reasonable
1182+ if (num_neighbors > _max_degree)
1183+ {
1184+ #ifdef EXEC_ENV_OLS
1185+ aligned_free (first_sector_buf);
1186+ #endif
1187+ std::stringstream stream;
1188+ stream << " Data type mismatch detected: first node has neighbor count " << num_neighbors
1189+ << " which exceeds max_degree " << _max_degree << " . "
1190+ << " Please ensure --data_type matches the type used when building the index." ;
1191+ throw diskann::ANNException (stream.str (), -1 , __FUNCSIG__, __FILE__, __LINE__);
1192+ }
1193+
1194+ // Validate each neighbor ID is within valid range [0, disk_nnodes)
1195+ uint32_t *neighbors = nhood_ptr + 1 ;
1196+ for (uint32_t i = 0 ; i < num_neighbors; i++)
1197+ {
1198+ if (neighbors[i] >= disk_nnodes)
1199+ {
1200+ #ifdef EXEC_ENV_OLS
1201+ aligned_free (first_sector_buf);
1202+ #endif
1203+ std::stringstream stream;
1204+ stream << " Data type mismatch detected: first node has invalid neighbor ID " << neighbors[i]
1205+ << " (max valid ID is " << (disk_nnodes - 1 ) << " ). "
1206+ << " Please ensure --data_type matches the type used when building the index." ;
1207+ throw diskann::ANNException (stream.str (), -1 , __FUNCSIG__, __FILE__, __LINE__);
1208+ }
1209+ }
1210+
1211+ #ifdef EXEC_ENV_OLS
1212+ aligned_free (first_sector_buf);
1213+ #endif
1214+ }
1215+
11481216 if (_max_degree > defaults::MAX_GRAPH_DEGREE)
11491217 {
11501218 std::stringstream stream;
@@ -1180,11 +1248,11 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
11801248 READ_U64 (index_metadata, this ->_nvecs_per_sector );
11811249 }
11821250
1183- #ifdef EXEC_ENV_OLS
1184- load_labels (files, _disk_index_file);
1185- #else
1186- load_labels (_disk_index_file);
1187- #endif
1251+ #ifdef EXEC_ENV_OLS
1252+ load_labels (files, _disk_index_file);
1253+ #else
1254+ load_labels (_disk_index_file);
1255+ #endif
11881256
11891257 diskann::cout << " Disk-Index File Meta-data: " ;
11901258 diskann::cout << " # nodes per sector: " << _nnodes_per_sector;
@@ -1201,11 +1269,8 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
12011269 // open AlignedFileReader handle to index_file
12021270 std::string index_fname (_disk_index_file);
12031271 reader->open (index_fname);
1204- this ->setup_thread_data (
1205- num_threads,
1206- defaults::VISITED_RESERVE,
1207- defaults::MAX_GRAPH_DEGREE,
1208- (use_filter_support ()? defaults::MAX_FILTERS_PER_QUERY : 0 ));
1272+ this ->setup_thread_data (num_threads, defaults::VISITED_RESERVE, defaults::MAX_GRAPH_DEGREE,
1273+ (use_filter_support () ? defaults::MAX_FILTERS_PER_QUERY : 0 ));
12091274 this ->_max_nthreads = num_threads;
12101275
12111276#endif
@@ -1466,16 +1531,18 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
14661531 NeighborPriorityQueue &retset = query_scratch->retset ;
14671532 std::vector<Neighbor> &full_retset = query_scratch->full_retset ;
14681533 tsl::robin_set<location_t > full_retset_ids;
1469- if (use_filters) {
1534+ if (use_filters)
1535+ {
14701536 uint64_t size_to_reserve = std::max (l_search, (std::min ((uint64_t )filter_label_count, this ->_max_degree ) + 1 ));
14711537 retset.reserve (size_to_reserve);
1472- full_retset.reserve (4096 );
1538+ full_retset.reserve (4096 );
14731539 full_retset_ids.reserve (4096 );
1474- } else {
1540+ }
1541+ else
1542+ {
14751543 retset.reserve (l_search + 1 );
14761544 }
14771545
1478-
14791546 uint32_t best_medoid = 0 ;
14801547 uint32_t cur_list_size = 0 ;
14811548 float best_dist = (std::numeric_limits<float >::max)();
@@ -1495,7 +1562,9 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
14951562 retset.insert (Neighbor (best_medoid, dist_scratch[0 ]));
14961563 visited.insert (best_medoid);
14971564 cur_list_size = 1 ;
1498- } else {
1565+ }
1566+ else
1567+ {
14991568 std::vector<location_t > filter_specific_medoids;
15001569 filter_specific_medoids.reserve (filter_label_count);
15011570 location_t ctr = 0 ;
@@ -1513,12 +1582,12 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
15131582 for (ctr = 0 ; ctr < filter_specific_medoids.size (); ctr++)
15141583 {
15151584 retset.insert (Neighbor (filter_specific_medoids[ctr], dist_scratch[ctr]));
1516- // retset[ctr].id = filter_specific_medoids[ctr];
1517- // retset[ctr].distance = dist_scratch[ctr];
1518- // retset[ctr].expanded = false;
1585+ // retset[ctr].id = filter_specific_medoids[ctr];
1586+ // retset[ctr].distance = dist_scratch[ctr];
1587+ // retset[ctr].expanded = false;
15191588 visited.insert (filter_specific_medoids[ctr]);
15201589 }
1521- cur_list_size = (uint32_t ) filter_specific_medoids.size ();
1590+ cur_list_size = (uint32_t )filter_specific_medoids.size ();
15221591 }
15231592
15241593 uint32_t cmps = 0 ;
@@ -1535,10 +1604,10 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
15351604 std::vector<std::pair<uint32_t , std::pair<uint32_t , uint32_t *>>> cached_nhoods;
15361605 cached_nhoods.reserve (2 * beam_width);
15371606
1538- // if we are doing multi-filter search we don't want to restrict the number of IOs
1539- // at present. Must revisit this decision later.
1607+ // if we are doing multi-filter search we don't want to restrict the number of IOs
1608+ // at present. Must revisit this decision later.
15401609 uint32_t max_ios_for_query = use_filters || (io_limit == 0 ) ? std::numeric_limits<uint32_t >::max () : io_limit;
1541- const std::vector<LabelT>& label_ids = filter_labels; // avoid renaming.
1610+ const std::vector<LabelT> & label_ids = filter_labels; // avoid renaming.
15421611 std::vector<LabelT> lbl_vec;
15431612
15441613 retset.sort ();
@@ -1554,7 +1623,6 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
15541623 // find new beam
15551624 uint32_t num_seen = 0 ;
15561625
1557-
15581626 for (const auto &lbl : label_ids)
15591627 { // assuming that number of OR labels is
15601628 // less than max frontier size allowed
@@ -1583,7 +1651,8 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
15831651 retset[lbl_marker].expanded = true ;
15841652 if (this ->_count_visited_nodes )
15851653 {
1586- reinterpret_cast <std::atomic<uint32_t > &>(this ->_node_visit_counter [retset[lbl_marker].id ].second )
1654+ reinterpret_cast <std::atomic<uint32_t > &>(
1655+ this ->_node_visit_counter [retset[lbl_marker].id ].second )
15871656 .fetch_add (1 );
15881657 }
15891658 break ;
@@ -1686,7 +1755,6 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
16861755 full_retset.push_back (Neighbor ((unsigned )cached_nhood.first , cur_expanded_dist));
16871756 }
16881757
1689-
16901758 uint64_t nnbrs = cached_nhood.second .first ;
16911759 uint32_t *node_nbrs = cached_nhood.second .second ;
16921760
@@ -1768,7 +1836,7 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
17681836 {
17691837 full_retset.push_back (Neighbor (frontier_nhood.first , cur_expanded_dist));
17701838 }
1771-
1839+
17721840 uint32_t *node_nbrs = (node_buf + 1 );
17731841 // compute node_nbrs <-> query dist in PQ space
17741842 cpu_timer.reset ();
0 commit comments