Skip to content

Commit 8cd3d57

Browse files
authored
load validation (#702)
* load valiation * update to shi's idea * update menghao's change
1 parent 76a92c2 commit 8cd3d57

1 file changed

Lines changed: 106 additions & 38 deletions

File tree

src/pq_flash_index.cpp

Lines changed: 106 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ void PQFlashIndex<T, LabelT>::setup_thread_data(uint64_t nthreads, uint64_t visi
127127
{
128128
#pragma omp critical
129129
{
130-
SSDThreadData<T> *data = new SSDThreadData<T>(this->_aligned_dim, visited_reserve, max_degree, max_filters_per_query);
130+
SSDThreadData<T> *data =
131+
new SSDThreadData<T>(this->_aligned_dim, visited_reserve, max_degree, max_filters_per_query);
131132
this->reader->register_thread();
132133
data->ctx = this->reader->get_ctx();
133134
this->_thread_data.push(data);
@@ -571,7 +572,8 @@ void PQFlashIndex<T, LabelT>::generate_random_labels(std::vector<LabelT> &labels
571572
}
572573

573574
template <typename T, typename LabelT>
574-
void PQFlashIndex<T, LabelT>::load_label_map(std::basic_istream<char> &map_reader, std::unordered_map<std::string, LabelT>& string_to_int_map)
575+
void PQFlashIndex<T, LabelT>::load_label_map(std::basic_istream<char> &map_reader,
576+
std::unordered_map<std::string, LabelT> &string_to_int_map)
575577
{
576578
std::string line, token;
577579
LabelT token_as_num;
@@ -687,7 +689,6 @@ bool PQFlashIndex<T, LabelT>::point_has_any_label(uint32_t point_id, const std::
687689
return ret_val;
688690
}
689691

690-
691692
template <typename T, typename LabelT>
692693
void PQFlashIndex<T, LabelT>::parse_label_file(std::basic_istream<char> &infile, size_t &num_points_labels)
693694
{
@@ -778,7 +779,8 @@ template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::set_univers
778779
}
779780

780781
template <typename T, typename LabelT>
781-
void PQFlashIndex<T, LabelT>::load_label_medoid_map(const std::string& labels_to_medoids_filepath, std::istream& medoid_stream)
782+
void PQFlashIndex<T, LabelT>::load_label_medoid_map(const std::string &labels_to_medoids_filepath,
783+
std::istream &medoid_stream)
782784
{
783785
std::string line, token;
784786

@@ -840,11 +842,10 @@ void PQFlashIndex<T, LabelT>::load_dummy_map(const std::string &dummy_map_filepa
840842
}
841843
catch (std::system_error &e)
842844
{
843-
throw FileException (dummy_map_filepath, e, __FUNCSIG__, __FILE__, __LINE__);
845+
throw FileException(dummy_map_filepath, e, __FUNCSIG__, __FILE__, __LINE__);
844846
}
845847
}
846848

847-
848849
template <typename T, typename LabelT>
849850
#ifdef EXEC_ENV_OLS
850851
bool PQFlashIndex<T, LabelT>::use_filter_support(MemoryMappedFiles &files)
@@ -980,7 +981,6 @@ template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::load_labels
980981
ss << "Note: Filter support is enabled but " << dummy_map_file << " file cannot be opened" << std::endl;
981982
diskann::cerr << ss.str();
982983
}
983-
984984
}
985985
else
986986
{
@@ -1107,11 +1107,8 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
11071107
// bytes are needed to store the header and read in that many using our
11081108
// 'standard' aligned file reader approach.
11091109
reader->open(_disk_index_file);
1110-
this->setup_thread_data(
1111-
num_threads,
1112-
defaults::VISITED_RESERVE,
1113-
defaults::MAX_GRAPH_DEGREE,
1114-
(use_filter_support(files)? defaults::MAX_FILTERS_PER_QUERY : 0));
1110+
this->setup_thread_data(num_threads, defaults::VISITED_RESERVE, defaults::MAX_GRAPH_DEGREE,
1111+
(use_filter_support(files) ? defaults::MAX_FILTERS_PER_QUERY : 0));
11151112
this->_max_nthreads = num_threads;
11161113

11171114
char *bytes = getHeaderBytes();
@@ -1145,6 +1142,77 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
11451142
READ_U64(index_metadata, _nnodes_per_sector);
11461143
_max_degree = ((_max_node_len - _disk_bytes_per_point) / sizeof(uint32_t)) - 1;
11471144

1145+
// Early validation: read first node from disk and validate neighbor data.
1146+
// If data type is wrong, _disk_bytes_per_point will be incorrect, causing
1147+
// us to read garbage for neighbor count and neighbor IDs.
1148+
// Disk layout: [sector 0: metadata] [sector 1+: node data]
1149+
// Node layout: [vector data: _disk_bytes_per_point bytes] [neighbor count: 4 bytes] [neighbor IDs: 4 bytes each]
1150+
if (!_use_disk_index_pq && disk_nnodes > 0)
1151+
{
1152+
#ifdef EXEC_ENV_OLS
1153+
// In OLS environment, index_metadata is a ContentBuf with only the header (sector 0).
1154+
// We must use the reader (which is already open) to read sector 1 containing the first node.
1155+
uint64_t num_sectors_for_node = _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN);
1156+
char *first_sector_buf = nullptr;
1157+
alloc_aligned((void **)&first_sector_buf, num_sectors_for_node * defaults::SECTOR_LEN, defaults::SECTOR_LEN);
1158+
1159+
std::vector<AlignedRead> read_reqs;
1160+
read_reqs.emplace_back(defaults::SECTOR_LEN, num_sectors_for_node * defaults::SECTOR_LEN, first_sector_buf);
1161+
1162+
// We need a temporary IOContext for this read
1163+
IOContext tmp_ctx = reader->get_ctx();
1164+
reader->read(read_reqs, tmp_ctx);
1165+
1166+
// First node starts at offset 0 within sector 1
1167+
char *first_node_buf = first_sector_buf;
1168+
#else
1169+
// In non-OLS environment, we can seek and read directly from the file stream
1170+
uint64_t num_sectors_for_node = _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN);
1171+
std::vector<char> first_node_buf_vec(num_sectors_for_node * defaults::SECTOR_LEN);
1172+
index_metadata.seekg(defaults::SECTOR_LEN, std::ios::beg);
1173+
index_metadata.read(first_node_buf_vec.data(), num_sectors_for_node * defaults::SECTOR_LEN);
1174+
char *first_node_buf = first_node_buf_vec.data();
1175+
#endif
1176+
1177+
// Get neighbor count (located after vector data)
1178+
uint32_t *nhood_ptr = reinterpret_cast<uint32_t *>(first_node_buf + _disk_bytes_per_point);
1179+
uint32_t num_neighbors = *nhood_ptr;
1180+
1181+
// Validate neighbor count is reasonable
1182+
if (num_neighbors > _max_degree)
1183+
{
1184+
#ifdef EXEC_ENV_OLS
1185+
aligned_free(first_sector_buf);
1186+
#endif
1187+
std::stringstream stream;
1188+
stream << "Data type mismatch detected: first node has neighbor count " << num_neighbors
1189+
<< " which exceeds max_degree " << _max_degree << ". "
1190+
<< "Please ensure --data_type matches the type used when building the index.";
1191+
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
1192+
}
1193+
1194+
// Validate each neighbor ID is within valid range [0, disk_nnodes)
1195+
uint32_t *neighbors = nhood_ptr + 1;
1196+
for (uint32_t i = 0; i < num_neighbors; i++)
1197+
{
1198+
if (neighbors[i] >= disk_nnodes)
1199+
{
1200+
#ifdef EXEC_ENV_OLS
1201+
aligned_free(first_sector_buf);
1202+
#endif
1203+
std::stringstream stream;
1204+
stream << "Data type mismatch detected: first node has invalid neighbor ID " << neighbors[i]
1205+
<< " (max valid ID is " << (disk_nnodes - 1) << "). "
1206+
<< "Please ensure --data_type matches the type used when building the index.";
1207+
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
1208+
}
1209+
}
1210+
1211+
#ifdef EXEC_ENV_OLS
1212+
aligned_free(first_sector_buf);
1213+
#endif
1214+
}
1215+
11481216
if (_max_degree > defaults::MAX_GRAPH_DEGREE)
11491217
{
11501218
std::stringstream stream;
@@ -1180,11 +1248,11 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
11801248
READ_U64(index_metadata, this->_nvecs_per_sector);
11811249
}
11821250

1183-
#ifdef EXEC_ENV_OLS
1184-
load_labels(files, _disk_index_file);
1185-
#else
1186-
load_labels(_disk_index_file);
1187-
#endif
1251+
#ifdef EXEC_ENV_OLS
1252+
load_labels(files, _disk_index_file);
1253+
#else
1254+
load_labels(_disk_index_file);
1255+
#endif
11881256

11891257
diskann::cout << "Disk-Index File Meta-data: ";
11901258
diskann::cout << "# nodes per sector: " << _nnodes_per_sector;
@@ -1201,11 +1269,8 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
12011269
// open AlignedFileReader handle to index_file
12021270
std::string index_fname(_disk_index_file);
12031271
reader->open(index_fname);
1204-
this->setup_thread_data(
1205-
num_threads,
1206-
defaults::VISITED_RESERVE,
1207-
defaults::MAX_GRAPH_DEGREE,
1208-
(use_filter_support()? defaults::MAX_FILTERS_PER_QUERY : 0));
1272+
this->setup_thread_data(num_threads, defaults::VISITED_RESERVE, defaults::MAX_GRAPH_DEGREE,
1273+
(use_filter_support() ? defaults::MAX_FILTERS_PER_QUERY : 0));
12091274
this->_max_nthreads = num_threads;
12101275

12111276
#endif
@@ -1466,16 +1531,18 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
14661531
NeighborPriorityQueue &retset = query_scratch->retset;
14671532
std::vector<Neighbor> &full_retset = query_scratch->full_retset;
14681533
tsl::robin_set<location_t> full_retset_ids;
1469-
if (use_filters) {
1534+
if (use_filters)
1535+
{
14701536
uint64_t size_to_reserve = std::max(l_search, (std::min((uint64_t)filter_label_count, this->_max_degree) + 1));
14711537
retset.reserve(size_to_reserve);
1472-
full_retset.reserve(4096);
1538+
full_retset.reserve(4096);
14731539
full_retset_ids.reserve(4096);
1474-
} else {
1540+
}
1541+
else
1542+
{
14751543
retset.reserve(l_search + 1);
14761544
}
14771545

1478-
14791546
uint32_t best_medoid = 0;
14801547
uint32_t cur_list_size = 0;
14811548
float best_dist = (std::numeric_limits<float>::max)();
@@ -1495,7 +1562,9 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
14951562
retset.insert(Neighbor(best_medoid, dist_scratch[0]));
14961563
visited.insert(best_medoid);
14971564
cur_list_size = 1;
1498-
} else {
1565+
}
1566+
else
1567+
{
14991568
std::vector<location_t> filter_specific_medoids;
15001569
filter_specific_medoids.reserve(filter_label_count);
15011570
location_t ctr = 0;
@@ -1513,12 +1582,12 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
15131582
for (ctr = 0; ctr < filter_specific_medoids.size(); ctr++)
15141583
{
15151584
retset.insert(Neighbor(filter_specific_medoids[ctr], dist_scratch[ctr]));
1516-
//retset[ctr].id = filter_specific_medoids[ctr];
1517-
//retset[ctr].distance = dist_scratch[ctr];
1518-
//retset[ctr].expanded = false;
1585+
// retset[ctr].id = filter_specific_medoids[ctr];
1586+
// retset[ctr].distance = dist_scratch[ctr];
1587+
// retset[ctr].expanded = false;
15191588
visited.insert(filter_specific_medoids[ctr]);
15201589
}
1521-
cur_list_size = (uint32_t) filter_specific_medoids.size();
1590+
cur_list_size = (uint32_t)filter_specific_medoids.size();
15221591
}
15231592

15241593
uint32_t cmps = 0;
@@ -1535,10 +1604,10 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
15351604
std::vector<std::pair<uint32_t, std::pair<uint32_t, uint32_t *>>> cached_nhoods;
15361605
cached_nhoods.reserve(2 * beam_width);
15371606

1538-
//if we are doing multi-filter search we don't want to restrict the number of IOs
1539-
//at present. Must revisit this decision later.
1607+
// if we are doing multi-filter search we don't want to restrict the number of IOs
1608+
// at present. Must revisit this decision later.
15401609
uint32_t max_ios_for_query = use_filters || (io_limit == 0) ? std::numeric_limits<uint32_t>::max() : io_limit;
1541-
const std::vector<LabelT>& label_ids = filter_labels; //avoid renaming.
1610+
const std::vector<LabelT> &label_ids = filter_labels; // avoid renaming.
15421611
std::vector<LabelT> lbl_vec;
15431612

15441613
retset.sort();
@@ -1554,7 +1623,6 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
15541623
// find new beam
15551624
uint32_t num_seen = 0;
15561625

1557-
15581626
for (const auto &lbl : label_ids)
15591627
{ // assuming that number of OR labels is
15601628
// less than max frontier size allowed
@@ -1583,7 +1651,8 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
15831651
retset[lbl_marker].expanded = true;
15841652
if (this->_count_visited_nodes)
15851653
{
1586-
reinterpret_cast<std::atomic<uint32_t> &>(this->_node_visit_counter[retset[lbl_marker].id].second)
1654+
reinterpret_cast<std::atomic<uint32_t> &>(
1655+
this->_node_visit_counter[retset[lbl_marker].id].second)
15871656
.fetch_add(1);
15881657
}
15891658
break;
@@ -1686,7 +1755,6 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
16861755
full_retset.push_back(Neighbor((unsigned)cached_nhood.first, cur_expanded_dist));
16871756
}
16881757

1689-
16901758
uint64_t nnbrs = cached_nhood.second.first;
16911759
uint32_t *node_nbrs = cached_nhood.second.second;
16921760

@@ -1768,7 +1836,7 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
17681836
{
17691837
full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist));
17701838
}
1771-
1839+
17721840
uint32_t *node_nbrs = (node_buf + 1);
17731841
// compute node_nbrs <-> query dist in PQ space
17741842
cpu_timer.reset();

0 commit comments

Comments
 (0)