diff --git a/AnyBuildLogs/latest.txt b/AnyBuildLogs/latest.txt index 3e11e6ab6..04de5bf1c 100644 --- a/AnyBuildLogs/latest.txt +++ b/AnyBuildLogs/latest.txt @@ -1 +1 @@ -20260116-140907-2cca51ce \ No newline at end of file +20260129-192840-e7c5749c \ No newline at end of file diff --git a/include/defaults.h b/include/defaults.h index eb646a494..946bc3c7b 100644 --- a/include/defaults.h +++ b/include/defaults.h @@ -37,6 +37,8 @@ const bool NUM_DIVERSE_BUILD = 1; const bool REORDER_INDEX = false; const uint32_t REORDER_DIM = 0; +const bool ATTRIBUTE_DIVERSITY = false; +const float ATTR_DIST_THRESHOLD = 0.2f; } // namespace defaults } // namespace diskann diff --git a/include/disk_utils.h b/include/disk_utils.h index def83a8a5..2c168ffd2 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -82,7 +82,12 @@ DISKANN_DLLEXPORT int build_merged_vamana_index(std::string base_file, diskann:: uint32_t num_threads, bool use_filters = false, bool use_integer_labels = false, const std::string &label_file = std::string(""), const std::string &labels_to_medoids_file = std::string(""), - const std::string &universal_label = "", const uint32_t Lf = 0); + const std::string &universal_label = "", const uint32_t Lf = 0, + uint32_t universal_label_num = 0, + const char* seller_file_path = nullptr, + uint32_t num_diverse_build = 1, + const char* attribute_file_path = nullptr, + float attr_dist_threshold = 0.2f); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth(std::unique_ptr> &_pFlashIndex, @@ -101,7 +106,9 @@ DISKANN_DLLEXPORT int build_disk_index( const uint32_t Lf = 0, const char* reorderDataFilePath = nullptr, const char* sellerFilePath = nullptr, - uint32_t num_diverse_build = 1); // default is empty string for no universal label + uint32_t num_diverse_build = 1, + const char* attributeFilePath = nullptr, + float attr_dist_threshold = 0.2f); // default is empty string for no universal label template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file, diff --git a/include/index.h b/include/index.h index 25255db1e..32bb31209 100644 --- a/include/index.h +++ b/include/index.h @@ -267,8 +267,13 @@ template clas // determines navigating node of the graph by calculating medoid of datafopt uint32_t calculate_entry_point(); - void parse_label_file(const std::string &label_file, size_t &num_pts_labels, size_t& total_labels); - void parse_seller_file(const std::string& label_file, size_t& num_pts_labels); + template + void parse_integer_string_file(const std::string &file_path, size_t &num_points, size_t& total_values, + std::vector>& location_to_values, + tsl::robin_set* unique_values = nullptr, + bool sort_values = true); + void parse_seller_file(const std::string& label_file, size_t& num_pts_labels, + std::vector& location_to_seller, uint32_t& num_unique_sellers); void convert_pts_label_to_bitmask(std::vector>& pts_to_labels, simple_bitmask_buf& bitmask_buf, size_t num_labels); @@ -357,6 +362,8 @@ template clas void initialize_query_scratch(uint32_t num_threads, uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t bitmask_size = 0); + double attribute_distance(const std::vector &a, const std::vector &b); + // Do not call without acquiring appropriate locks // call public member functions save and load to invoke these. DISKANN_DLLEXPORT size_t save_graph(std::string filename); @@ -426,6 +433,10 @@ template clas std::vector _location_to_seller; uint32_t _num_unique_sellers = 0; std::string _seller_file; + bool _attribute_diversity = false; + float _attr_dist_threshold = 0.2f; + std::string _attribute_file; + std::vector> _location_to_attributes; bool _use_universal_label = false; LabelT _universal_label = 0; diff --git a/include/parameters.h b/include/parameters.h index ae13621d8..0d53cf5f7 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -32,13 +32,18 @@ class IndexWriteParameters const bool diverse_index; const std::string seller_file; const uint32_t num_diverse_build; + const bool attribute_diversity; + const std::string attribute_file; + const float attr_dist_threshold; IndexWriteParameters(const uint32_t search_list_size, const uint32_t max_degree, const bool saturate_graph, const uint32_t max_occlusion_size, const float alpha, const uint32_t num_threads, - const uint32_t filter_list_size, bool diverse_index, const std::string& seller_file, uint32_t num_diverse_build) + const uint32_t filter_list_size, bool diverse_index, const std::string& seller_file, uint32_t num_diverse_build, + bool attribute_diversity, const std::string& attribute_file, float attr_dist_threshold) : search_list_size(search_list_size), max_degree(max_degree), saturate_graph(saturate_graph), max_occlusion_size(max_occlusion_size), alpha(alpha), num_threads(num_threads), - filter_list_size(filter_list_size), diverse_index(diverse_index), seller_file(seller_file), num_diverse_build(num_diverse_build) + filter_list_size(filter_list_size), diverse_index(diverse_index), seller_file(seller_file), num_diverse_build(num_diverse_build), + attribute_diversity(attribute_diversity), attribute_file(attribute_file), attr_dist_threshold(attr_dist_threshold) { } @@ -100,6 +105,24 @@ class IndexWriteParametersBuilder return *this; } + IndexWriteParametersBuilder& with_attribute_diversity(const bool attribute_diversity) + { + _attribute_diversity = attribute_diversity; + return *this; + } + + IndexWriteParametersBuilder& with_attr_dist_threshold(const float attr_dist_threshold) + { + _attr_dist_threshold = attr_dist_threshold; + return *this; + } + + IndexWriteParametersBuilder& with_attribute_file(const std::string attribute_file) + { + _attribute_file = attribute_file; + return *this; + } + IndexWriteParametersBuilder &with_alpha(const float alpha) { _alpha = alpha; @@ -121,13 +144,16 @@ class IndexWriteParametersBuilder IndexWriteParameters build() const { return IndexWriteParameters(_search_list_size, _max_degree, _saturate_graph, _max_occlusion_size, _alpha, - _num_threads, _filter_list_size, _diverse_index, _seller_file, _num_diverse_build); + _num_threads, _filter_list_size, _diverse_index, _seller_file, _num_diverse_build, + _attribute_diversity, _attribute_file, _attr_dist_threshold); } IndexWriteParametersBuilder(const IndexWriteParameters &wp) : _search_list_size(wp.search_list_size), _max_degree(wp.max_degree), _max_occlusion_size(wp.max_occlusion_size), _saturate_graph(wp.saturate_graph), _alpha(wp.alpha), - _filter_list_size(wp.filter_list_size) + _num_threads(wp.num_threads), _filter_list_size(wp.filter_list_size), _diverse_index(wp.diverse_index), + _seller_file(wp.seller_file), _num_diverse_build(wp.num_diverse_build), _attribute_diversity(wp.attribute_diversity), + _attribute_file(wp.attribute_file), _attr_dist_threshold(wp.attr_dist_threshold) { } IndexWriteParametersBuilder(const IndexWriteParametersBuilder &) = delete; @@ -143,7 +169,10 @@ class IndexWriteParametersBuilder uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE}; bool _diverse_index{ defaults::DIVERSE_INDEX }; std::string _seller_file{ defaults::EMPTY_STRING }; - uint32_t _num_diverse_build{ defaults::NUM_DIVERSE_BUILD }; + uint32_t _num_diverse_build{ defaults::NUM_DIVERSE_BUILD }; + bool _attribute_diversity{ defaults::ATTRIBUTE_DIVERSITY }; + std::string _attribute_file{ defaults::EMPTY_STRING }; + float _attr_dist_threshold{ defaults::ATTR_DIST_THRESHOLD }; }; struct IndexLoadParams diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 2fe2d2bb8..797c699a0 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -630,9 +630,11 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, const std::string &labels_to_medoids_file, const std::string &universal_label, - const uint32_t Lf, uint32_t universal_label_num = 0, - const char* seller_file_path = nullptr, - uint32_t num_diverse_build = 1) + const uint32_t Lf, uint32_t universal_label_num, + const char* seller_file_path, + uint32_t num_diverse_build, + const char* attribute_file_path, + float attr_dist_threshold) { size_t base_num, base_dim; diskann::get_bin_metadata(base_file, base_num, base_dim); @@ -650,6 +652,13 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr { is_diverse_index = true; } + + bool is_attribute_diverse = false; + if (attribute_file_path != nullptr && !std::string(attribute_file_path).empty()) + { + is_attribute_diverse = true; + } + diskann::IndexWriteParameters paras = diskann::IndexWriteParametersBuilder(L, R) .with_filter_list_size(Lf) .with_saturate_graph(!use_filters) @@ -657,6 +666,9 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr .with_diverse_index(is_diverse_index) .with_seller_file(seller_file_path) .with_num_diverse_build(num_diverse_build) + .with_attribute_diversity(is_attribute_diverse) + .with_attribute_file(attribute_file_path) + .with_attr_dist_threshold(attr_dist_threshold) .build(); using TagT = uint32_t; diskann::Index _index(compareMetric, base_dim, base_num, @@ -1122,7 +1134,9 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, - uint32_t num_diverse_build) + uint32_t num_diverse_build, + const char* attributeFilePath, + float attr_dist_threshold) { std::stringstream parser; parser << std::string(indexBuildParameters); @@ -1352,8 +1366,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const diskann::build_merged_vamana_index(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, indexing_ram_budget, mem_index_path, medoids_path, centroids_path, build_pq_bytes, use_opq, num_threads, use_filters, use_integer_labels, labels_file_to_use, - labels_to_medoids_path, universal_label, Lf, universal_label_id, - sellerFilePath, num_diverse_build); + labels_to_medoids_path, universal_label, Lf, universal_label_id, + sellerFilePath, num_diverse_build, attributeFilePath, attr_dist_threshold); diskann::cout << timer.elapsed_seconds_for_step("building merged vamana index") << std::endl; timer.reset(); @@ -1492,7 +1506,9 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *da const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, - uint32_t num_diverse_build); + uint32_t num_diverse_build, + const char* attributeFilePath, + float attr_dist_threshold); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, @@ -1501,7 +1517,9 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *d const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, - uint32_t num_diverse_build); + uint32_t num_diverse_build, + const char* attributeFilePath, + float attr_dist_threshold); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, @@ -1510,7 +1528,9 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *dat const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, - uint32_t num_diverse_build); + uint32_t num_diverse_build, + const char* attributeFilePath, + float attr_dist_threshold); // LabelT = uint16 template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, @@ -1520,7 +1540,9 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *da const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, - uint32_t num_diverse_build); + uint32_t num_diverse_build, + const char* attributeFilePath, + float attr_dist_threshold); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, @@ -1529,7 +1551,9 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *d const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, - uint32_t num_diverse_build); + uint32_t num_diverse_build, + const char* attributeFilePath, + float attr_dist_threshold); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, @@ -1538,37 +1562,51 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *dat const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, - uint32_t num_diverse_build); + uint32_t num_diverse_build, + const char* attributeFilePath, + float attr_dist_threshold); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + uint32_t universal_label_num, const char* seller_file_path, uint32_t num_diverse_build, + const char* attribute_file_path, float attr_dist_threshold); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + uint32_t universal_label_num, const char* seller_file_path, uint32_t num_diverse_build, + const char* attribute_file_path, float attr_dist_threshold); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + uint32_t universal_label_num, const char* seller_file_path, uint32_t num_diverse_build, + const char* attribute_file_path, float attr_dist_threshold); // Label=16_t template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + uint32_t universal_label_num, const char* seller_file_path, uint32_t num_diverse_build, + const char* attribute_file_path, float attr_dist_threshold); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + uint32_t universal_label_num, const char* seller_file_path, uint32_t num_diverse_build, + const char* attribute_file_path, float attr_dist_threshold); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + uint32_t universal_label_num, const char* seller_file_path, uint32_t num_diverse_build, + const char* attribute_file_path, float attr_dist_threshold); }; // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index e2fc0cacd..9cc50791f 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -113,6 +113,10 @@ Index::Index(const IndexConfig &index_config, std::shared_ptrseller_file; _num_diverse_build = index_config.index_write_params->num_diverse_build; + _attribute_diversity = index_config.index_write_params->attribute_diversity; + _attribute_file = index_config.index_write_params->attribute_file; + _attr_dist_threshold = index_config.index_write_params->attr_dist_threshold; + if (index_config.index_search_params != nullptr) { std::uint32_t default_queue_size = (std::max)(_indexingQueueSize, _filterIndexingQueueSize); @@ -687,7 +691,7 @@ void Index::load(const IndexLoadParams & load_params) else if (file_exists(old_index_seller_file)) { uint64_t nrows_seller_file; - parse_seller_file(old_index_seller_file, nrows_seller_file); + parse_seller_file(old_index_seller_file, nrows_seller_file, _location_to_seller, _num_unique_sellers); if (nrows_seller_file != data_file_num_pts) { std::stringstream stream; @@ -1294,16 +1298,32 @@ void Index::occlude_list(const uint32_t location, std::vector cur_alpha) { - if (!_diverse_index - || blockers[cur_index].size() >= _num_diverse_build) + if (!_diverse_index && !_attribute_diversity) { continue; } - auto iter_seller = _location_to_seller[iter->id]; - if (blockers[cur_index].find(iter_seller) != blockers[cur_index].end()) + + if (blockers[cur_index].size() >= _num_diverse_build) { continue; } + + if (_diverse_index) + { + auto iter_seller = _location_to_seller[iter->id]; + if (blockers[cur_index].find(iter_seller) != blockers[cur_index].end()) + { + continue; + } + } + + if (_attribute_diversity) + { + if (blockers[cur_index].find(std::numeric_limits::max()) != blockers[cur_index].end()) + { + continue; + } + } } // Set the entry to float::max so that is not considered again @@ -1326,17 +1346,24 @@ void Index::occlude_list(const uint32_t location, std::vector alpha) { - if (!_diverse_index - || blockers[t].size() >= _num_diverse_build) + if(!_diverse_index && !_attribute_diversity) { continue; } - auto iter2_seller = _location_to_seller[iter2->id]; - if (blockers[t].find(iter2_seller) != blockers[t].end()) + if (blockers[t].size() >= _num_diverse_build) { continue; } + + if (_diverse_index) + { + auto iter2_seller = _location_to_seller[iter2->id]; + if (blockers[t].find(iter2_seller) != blockers[t].end()) + { + continue; + } + } } bool prune_allowed = true; @@ -1368,10 +1395,28 @@ void Index::occlude_list(const uint32_t location, std::vector::max() : std::max(occlude_factor[t], iter2->distance / djk); - if (_diverse_index && (iter2->distance / djk) > cur_alpha) + if (djk == 0 || (iter2->distance / djk) > cur_alpha) { - auto iter_seller = _location_to_seller[iter->id]; - blockers[t].insert(iter_seller); + if (_diverse_index) + { + auto iter_seller = _location_to_seller[iter->id]; + blockers[t].insert(iter_seller); + } + + if (_attribute_diversity) + { + double attr_dist = attribute_distance(_location_to_attributes[iter2->id], _location_to_attributes[iter->id]); + + //std::cout << "Attribute distance between " << iter2->id << " and " << iter->id << " is " << attr_dist << std::endl; + if (attr_dist > _attr_dist_threshold) + { // attribute distance threshold + blockers[t].insert(iter->id); + } + else + { + blockers[t].insert(std::numeric_limits::max()); + } + } } } else if (_dist_metric == diskann::Metric::INNER_PRODUCT) @@ -1751,6 +1796,37 @@ void Index::set_start_points_at_random(T radius, uint32_t rando set_start_points(points_data.data(), points_data.size()); } +template +double Index::attribute_distance(const std::vector &a, const std::vector &b) +{ + // Use the smaller size to avoid out-of-bounds issues + size_t n = std::min(a.size(), b.size()); + if(n == 0) + return 0.0; + int matches = 0, counts = 0; + + for(size_t i = 0; i < n; ++i){ + + + if(a[i] == 0 || b[i] == 0) + continue; + counts++; + if(a[i] == b[i]){ + + matches++; + + } + } + + if (counts == 0) + { + return 0.0; + } + + double similarity = static_cast(matches) / counts; + return 1.0 - similarity; +} + template void Index::build_with_data_populated(const std::vector &tags) { @@ -1778,10 +1854,19 @@ void Index::build_with_data_populated(const std::vector & if (_diverse_index) { uint64_t nrows; - parse_seller_file(_seller_file, nrows); + parse_seller_file(_seller_file, nrows, _location_to_seller, _num_unique_sellers); std::cout << "Parsed seller file with " << nrows << " rows" << std::endl; } + if( _attribute_diversity) { + uint64_t nrows = 0, total_attributes = 0; + + parse_integer_string_file(_attribute_file, + nrows, total_attributes, _location_to_attributes, nullptr, false); + + std::cout << "Parsed attribute file with " << nrows << " rows" << std::endl; + } + uint32_t index_R = _indexingRange; uint32_t num_threads_index = _indexingThreads; uint32_t index_L = _indexingQueueSize; @@ -2116,14 +2201,18 @@ bool Index::is_label_valid(const std::string& raw_label) const } template -void Index::parse_label_file(const std::string &label_file, size_t &num_points, size_t& total_labels) +template +void Index::parse_integer_string_file(const std::string &file_path, size_t &num_points, size_t& total_values, + std::vector>& location_to_values, + tsl::robin_set* unique_values, + bool sort_values) { - // Format of Label txt file: filters with comma separators + // Format of file: integer values with comma separators per line - std::ifstream infile(label_file); + std::ifstream infile(file_path); if (infile.fail()) { - throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + throw diskann::ANNException(std::string("Failed to open file ") + file_path, -1); } std::string line, token; @@ -2135,41 +2224,50 @@ void Index::parse_label_file(const std::string &label_file, siz } if (_dynamic_index) { - _location_to_labels.resize(_max_points, std::vector()); + location_to_values.resize(_max_points, std::vector()); } else { - _location_to_labels.resize(line_cnt, std::vector()); + location_to_values.resize(line_cnt, std::vector()); } infile.clear(); infile.seekg(0, std::ios::beg); line_cnt = 0; - total_labels = 0; + total_values = 0; while (std::getline(infile, line)) { std::istringstream iss(line); - std::vector lbls(0); + std::vector vals(0); getline(iss, token, '\t'); std::istringstream new_iss(token); while (getline(new_iss, token, ',')) { token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = (LabelT)std::stoul(token); - lbls.push_back(token_as_num); - _labels.insert(token_as_num); + ValueT token_as_num = (ValueT)std::stoul(token); + vals.push_back(token_as_num); + if (unique_values != nullptr) + { + unique_values->insert(token_as_num); + } } - std::sort(lbls.begin(), lbls.end()); - _location_to_labels[line_cnt] = lbls; + if (sort_values) + { + std::sort(vals.begin(), vals.end()); + } + location_to_values[line_cnt] = vals; line_cnt++; - total_labels += lbls.size(); + total_values += vals.size(); } num_points = (size_t)line_cnt; - diskann::cout << "Identified " << _labels.size() << " distinct label(s)" << std::endl; + if (unique_values != nullptr) + { + diskann::cout << "Identified " << unique_values->size() << " distinct value(s)" << std::endl; + } } template @@ -2214,7 +2312,8 @@ void Index::set_universal_label(const LabelT &label) } template -void Index::parse_seller_file(const std::string& label_file, size_t& num_points) +void Index::parse_seller_file(const std::string& label_file, size_t& num_points, + std::vector& location_to_seller, uint32_t& num_unique_sellers) { // Format of Label txt file: filters with comma separators @@ -2231,7 +2330,7 @@ void Index::parse_seller_file(const std::string& label_file, si { line_cnt++; } - _location_to_seller.resize(line_cnt); + location_to_seller.resize(line_cnt); infile.clear(); infile.seekg(0, std::ios::beg); @@ -2252,10 +2351,10 @@ void Index::parse_seller_file(const std::string& label_file, si sellers.insert(seller); } - _location_to_seller[line_cnt] = seller; + location_to_seller[line_cnt] = seller; line_cnt++; } - _num_unique_sellers = static_cast(sellers.size()); + num_unique_sellers = static_cast(sellers.size()); num_points = (size_t)line_cnt; diskann::cout << "Identified " << sellers.size() << " distinct seller(s) across " << num_points << " points." << std::endl; } @@ -2335,8 +2434,8 @@ void Index::build_filtered_index(const char *filename, const st size_t num_points_labels = 0; size_t total_labels = 0; - parse_label_file(label_file, - num_points_labels, total_labels); // determines medoid for each label and identifies + parse_integer_string_file(label_file, + num_points_labels, total_labels, _location_to_labels, &_labels); // determines medoid for each label and identifies // the points to label mapping if (!_use_integer_labels)