Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions include/svs/core/data/simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,12 @@ class GenericSerializer {
}

template <typename T, lib::LazyInvocable<size_t, size_t> F>
static lib::lazy_result_t<F, size_t, size_t>
load(const lib::ContextFreeLoadTable& table, std::istream& is, const F& lazy) {
static lib::lazy_result_t<F, size_t, size_t> load(
const lib::ContextFreeLoadTable& table,
const lib::detail::Deserializer& deserializer,
std::istream& is,
const F& lazy
) {
auto datatype = lib::load_at<DataType>(table, "eltype");
if (datatype != datatype_v<T>) {
throw ANNEXCEPTION(
Expand All @@ -151,6 +155,10 @@ class GenericSerializer {
size_t num_vectors = lib::load_at<size_t>(table, "num_vectors");
size_t dims = lib::load_at<size_t>(table, "dims");

deserializer.read_name(is);
deserializer.read_size(is);
deserializer.read_binary<io::v1::Header>(is);

return io::load_dataset(is, lazy, num_vectors, dims);
}
};
Expand Down Expand Up @@ -474,13 +482,14 @@ class SimpleData {

static SimpleData load(
const lib::ContextFreeLoadTable& table,
const lib::detail::Deserializer& deserializer,
std::istream& is,
const allocator_type& allocator = {}
)
requires(!is_view)
{
return GenericSerializer::load<T>(
table, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) {
table, deserializer, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) {
return SimpleData(n_elements, n_dimensions, allocator);
})
);
Expand Down Expand Up @@ -879,11 +888,15 @@ class SimpleData<T, Extent, Blocked<Alloc>> {

static SimpleData load(
const lib::ContextFreeLoadTable& table,
const lib::detail::Deserializer& deserializer,
std::istream& is,
const Blocked<Alloc>& allocator = {}
) {
return GenericSerializer::load<T>(
table, is, lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) {
table,
deserializer,
is,
lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) {
return SimpleData(n_elements, n_dimensions, allocator);
})
);
Expand Down
66 changes: 48 additions & 18 deletions include/svs/core/translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,27 +324,36 @@ class IDTranslator {
"external_to_internal_translation";
static constexpr lib::Version save_version = lib::Version(0, 0, 0);

lib::SaveTable save(const lib::SaveContext& ctx) const {
auto filename = ctx.generate_name("id_translation", "binary");
// Save the translations to a file.
auto stream = lib::open_write(filename);
for (auto i = begin(), iend = end(); i != iend; ++i) {
// N.B.: Apparently `std::pair` of integers is not trivially copyable ...
lib::write_binary(stream, i->first);
lib::write_binary(stream, i->second);
}
lib::SaveTable save_table() const {
return lib::SaveTable(
serialization_schema,
save_version,
{{"kind", kind},
{"num_points", lib::save(size())},
{"external_id_type", lib::save(datatype_v<external_id_type>)},
{"internal_id_type", lib::save(datatype_v<internal_id_type>)},
{"filename", lib::save(filename.filename())}}
{"internal_id_type", lib::save(datatype_v<internal_id_type>)}}
);
}

static IDTranslator load(const lib::LoadTable& table) {
void save(std::ostream& os) const {
for (auto i = begin(), iend = end(); i != iend; ++i) {
// N.B.: Apparently `std::pair` of integers is not trivially copyable ...
lib::write_binary(os, i->first);
lib::write_binary(os, i->second);
}
}

lib::SaveTable save(const lib::SaveContext& ctx) const {
auto filename = ctx.generate_name("id_translation", "binary");
// Save the translations to a file.
auto os = lib::open_write(filename);
save(os);
auto table = save_table();
table.insert("filename", lib::save(filename.filename()));
return table;
}

static void validate(const lib::ContextFreeLoadTable& table) {
if (kind != lib::load_at<std::string>(table, "kind")) {
throw ANNEXCEPTION("Mismatched kind!");
}
Expand All @@ -357,21 +366,42 @@ class IDTranslator {
if (internal_id_name != lib::load_at<std::string>(table, "internal_id_type")) {
throw ANNEXCEPTION("Mismatched internal id types!");
}
}

// Now that we've more-or-less validated the metadata, time to start loading
// the points.
static IDTranslator load(const lib::ContextFreeLoadTable& table, std::istream& is) {
auto num_points = lib::load_at<size_t>(table, "num_points");

auto translator = IDTranslator{};
auto resolved = table.resolve_at("filename");
auto stream = lib::open_read(resolved);
for (size_t i = 0; i < num_points; ++i) {
auto external_id = lib::read_binary<external_id_type>(stream);
auto internal_id = lib::read_binary<internal_id_type>(stream);
auto external_id = lib::read_binary<external_id_type>(is);
auto internal_id = lib::read_binary<internal_id_type>(is);
translator.insert_translation(external_id, internal_id);
}
return translator;
}

static IDTranslator load(
const lib::ContextFreeLoadTable& table,
const lib::detail::Deserializer& deserializer,
std::istream& is
) {
IDTranslator::validate(table);
deserializer.read_name(is);
deserializer.read_size(is);

return IDTranslator::load(table, is);
}

static IDTranslator load(const lib::LoadTable& table) {
IDTranslator::validate(table);

// Now that we've more-or-less validated the metadata, time to start loading
// the points.
auto resolved = table.resolve_at("filename");
auto is = lib::open_read(resolved);
return IDTranslator::load(table, is);
}

private:
template <class Begin, class End, class Map, class Modifier = lib::identity>
void check(
Expand Down
68 changes: 68 additions & 0 deletions include/svs/index/flat/dynamic_flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "svs/lib/invoke.h"
#include "svs/lib/misc.h"
#include "svs/lib/preprocessor.h"
#include "svs/lib/stream.h"
#include "svs/lib/threads.h"

namespace svs::index::flat {
Expand Down Expand Up @@ -403,6 +404,26 @@ template <typename Data, typename Dist> class DynamicFlatIndex {
// Save the dataset in the separate data directory
lib::save_to_disk(data_, data_directory);
}

void save(std::ostream& os) {
compact();

lib::begin_serialization(os);
// Save data structures and translation to config directory
lib::SaveTable save_table = lib::SaveTable(
"dynamic_flat_config",
save_version,
{
{"name", name()},
{"translation", lib::detail::exit_hook(translator_.save_table())},
}
);
lib::save_to_stream(save_table, os);
translator_.save(os);

lib::save_to_stream(data_, os);
}

constexpr std::string_view name() const { return "dynamic flat index"; }

///// Thread Pool Management
Expand Down Expand Up @@ -767,4 +788,51 @@ auto auto_dynamic_assemble(
);
}

template <typename LazyDataLoader, typename Distance, typename ThreadPoolProto>
auto auto_dynamic_assemble(
const lib::detail::Deserializer& deserializer,
std::istream& is,
LazyDataLoader&& data_loader,
Distance distance,
ThreadPoolProto threadpool_proto,
// Set this to `true` to use the identity map for ID translation.
// This allows us to read files generated by the static index construction routines
// to easily benchmark the static versus dynamic implementation.
//
// This is an internal API and should not be considered officially supported nor stable.
bool SVS_UNUSED(debug_load_from_static) = false,
svs::logging::logger_ptr logger = svs::logging::get()
) {
auto table = lib::detail::begin_deserialization(deserializer, is);
auto translator = IDTranslator::load(
table.template cast<toml::table>().at("translation").template cast<toml::table>(),
deserializer,
is
);

// // Load the dataset
auto threadpool = threads::as_threadpool(std::move(threadpool_proto));

auto data = svs::detail::dispatch_load(data_loader(), threadpool);

// // Load the ID translator from the config directory
auto datasize = data.size();

// Validate the translator
auto translator_size = translator.size();
if (translator_size != datasize) {
throw ANNEXCEPTION(
"Translator has {} IDs but should have {}", translator_size, datasize
);
}

return DynamicFlatIndex(
std::move(data),
std::move(translator),
std::move(distance),
std::move(threadpool),
std::move(logger)
);
}

} // namespace svs::index::flat
5 changes: 4 additions & 1 deletion include/svs/index/flat/flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,10 @@ class FlatIndex {
lib::save_to_disk(data_, data_directory);
}

void save(std::ostream& os) const { lib::save_to_stream(data_, os); }
void save(std::ostream& os) const {
lib::begin_serialization(os);
lib::save_to_stream(data_, os);
}
};

///
Expand Down
92 changes: 63 additions & 29 deletions include/svs/lib/saveload/load.h
Original file line number Diff line number Diff line change
Expand Up @@ -833,39 +833,62 @@ inline SerializedObject begin_deserialization(const std::filesystem::path& fullp
std::move(table), lib::LoadContext{fullpath.parent_path(), version}};
}

inline ContextFreeSerializedObject begin_deserialization(std::istream& stream) {
lib::StreamArchiver::size_type magic = 0;
lib::StreamArchiver::read_size(stream, magic);
if (magic == lib::DirectoryArchiver::magic_number) {
// Backward compatibility mode for older versions:
// Previously, SVS serialized models using an intermediate file,
// so some dummy information was added to the stream.
lib::StreamArchiver::size_type num_files = 0;
lib::StreamArchiver::read_size(stream, num_files);

std::string file_name;
lib::StreamArchiver::read_name(stream, file_name);
} else if (magic != lib::StreamArchiver::magic_number) {
throw ANNEXCEPTION("Invalid magic number in stream deserialization!");
class Deserializer {
enum SerializationScheme { native, legacy };
SerializationScheme scheme_;

explicit Deserializer(const SerializationScheme& scheme)
: scheme_(scheme) {}

public:
static Deserializer build(std::istream& stream) {
lib::StreamArchiver::size_type magic = 0;
lib::StreamArchiver::read_size(stream, magic);
if (magic == lib::StreamArchiver::magic_number) {
return Deserializer(SerializationScheme::native);
} else if (magic == lib::DirectoryArchiver::magic_number) {
// Backward compatibility mode for older versions:
// Previously, SVS serialized models using an intermediate file,
// so some dummy information was added to the stream.
lib::StreamArchiver::size_type num_files = 0;
lib::StreamArchiver::read_size(stream, num_files);

return Deserializer(SerializationScheme::legacy);
} else {
throw ANNEXCEPTION("Invalid magic number in stream deserialization!");
}
}

void read_name(std::istream& stream) const {
if (scheme_ == SerializationScheme::legacy) {
std::string file_name;
lib::StreamArchiver::read_name(stream, file_name);
}
}

void read_size(std::istream& stream) const {
if (scheme_ == SerializationScheme::legacy) {
lib::StreamArchiver::size_type file_size = 0;
lib::StreamArchiver::read_size(stream, file_size);
}
}

template <typename T> void read_binary(std::istream& stream) const {
if (scheme_ == SerializationScheme::legacy) {
lib::read_binary<T>(stream);
}
}
};

inline ContextFreeSerializedObject
begin_deserialization(const Deserializer& deserializer, std::istream& stream) {
deserializer.read_name(stream);
if (!stream) {
throw ANNEXCEPTION("Error reading from stream!");
}

auto table = lib::StreamArchiver::read_table(stream);

if (magic == lib::DirectoryArchiver::magic_number) {
// Backward compatibility mode for older versions:
// Previously, SVS serialized models using an intermediate file,
// so some dummy information was added to the stream.
std::string file_name;
lib::StreamArchiver::read_name(stream, file_name);

lib::StreamArchiver::size_type file_size = 0;
lib::StreamArchiver::read_size(stream, file_size);
lib::read_binary<io::v1::Header>(stream);
}
return ContextFreeSerializedObject{std::move(table)};
}

Expand Down Expand Up @@ -920,17 +943,28 @@ T load_from_disk(const std::filesystem::path& path, Args&&... args) {

///// load_from_stream
template <typename T, typename... Args>
T load_from_stream(const Loader<T>& loader, std::istream& stream, Args&&... args) {
T load_from_stream(
const Loader<T>& loader,
const detail::Deserializer& deserializer,
std::istream& stream,
Args&&... args
) {
// At this point, we will try the saving/loading framework to load the object.
// Here we go!
return lib::load(
loader, detail::begin_deserialization(stream), stream, SVS_FWD(args)...
loader,
detail::begin_deserialization(deserializer, stream),
deserializer,
stream,
SVS_FWD(args)...
);
}

template <typename T, typename... Args>
T load_from_stream(std::istream& stream, Args&&... args) {
return lib::load_from_stream(Loader<T>(), stream, SVS_FWD(args)...);
T load_from_stream(
const detail::Deserializer& deserializer, std::istream& stream, Args&&... args
) {
return lib::load_from_stream(Loader<T>(), deserializer, stream, SVS_FWD(args)...);
}

///// load_from_file
Expand Down
Loading
Loading