diff --git a/be/src/olap/collection_statistics.cpp b/be/src/olap/collection_statistics.cpp index 94130a1e6a74f6..7a548ed0a14bca 100644 --- a/be/src/olap/collection_statistics.cpp +++ b/be/src/olap/collection_statistics.cpp @@ -19,6 +19,7 @@ #include #include +#include #include "common/exception.h" #include "olap/rowset/rowset.h" @@ -108,94 +109,15 @@ Status CollectionStatistics::collect( return Status::OK(); } -vectorized::VSlotRef* find_slot_ref(const vectorized::VExprSPtr& expr) { - if (!expr) return nullptr; - auto cur = vectorized::VExpr::expr_without_cast(expr); - if (cur->node_type() == TExprNodeType::SLOT_REF) { - return static_cast(cur.get()); - } - for (auto& ch : cur->children()) { - if (auto* s = find_slot_ref(ch)) return s; - } - return nullptr; -} - -Status handle_match_pred(RuntimeState* state, const TabletSchemaSPtr& tablet_schema, - const vectorized::VExprSPtr& expr, - std::unordered_map* collect_infos) { - auto* left_slot_ref = find_slot_ref(expr->children()[0]); - if (left_slot_ref == nullptr) { - return Status::Error( - "Index statistics collection failed: Cannot find slot reference in match predicate " - "left expression"); - } - auto* right_literal = static_cast(expr->children()[1].get()); - DCHECK(right_literal != nullptr); - - const auto* sd = state->desc_tbl().get_slot_descriptor(left_slot_ref->slot_id()); - if (sd == nullptr) { - return Status::Error( - "Index statistics collection failed: Cannot find slot descriptor for slot_id={}", - left_slot_ref->slot_id()); - } - int32_t col_idx = tablet_schema->field_index(left_slot_ref->column_name()); - if (col_idx == -1) { - return Status::Error( - "Index statistics collection failed: Cannot find column index for column={}", - left_slot_ref->column_name()); - } - - const auto& column = tablet_schema->column(col_idx); - auto index_metas = tablet_schema->inverted_indexs(sd->col_unique_id(), column.suffix_path()); -#ifndef BE_TEST - if (index_metas.empty()) { - return Status::Error( - "Index statistics collection failed: Score query is not supported without inverted " - "index for column={}", - left_slot_ref->column_name()); - } -#endif - - auto format_options = vectorized::DataTypeSerDe::get_default_format_options(); - format_options.timezone = &state->timezone_obj(); - for (const auto* index_meta : index_metas) { - if (!InvertedIndexAnalyzer::should_analyzer(index_meta->properties())) { - continue; - } - if (!segment_v2::IndexReaderHelper::is_need_similarity_score(expr->op(), index_meta)) { - continue; - } - - auto term_infos = InvertedIndexAnalyzer::get_analyse_result( - right_literal->value(format_options), index_meta->properties()); - if (term_infos.empty()) { - LOG(WARNING) << "Index statistics collection: no terms extracted from literal value, " - << "col_unique_id=" << index_meta->col_unique_ids()[0]; - continue; - } - - std::string field_name = std::to_string(index_meta->col_unique_ids()[0]); - if (!column.suffix_path().empty()) { - field_name += "." + column.suffix_path(); - } - std::wstring ws_field_name = StringHelper::to_wstring(field_name); - auto iter = collect_infos->find(ws_field_name); - if (iter == collect_infos->end()) { - CollectInfo collect_info; - collect_info.term_infos.insert(term_infos.begin(), term_infos.end()); - collect_info.index_meta = index_meta; - (*collect_infos)[ws_field_name] = std::move(collect_info); - } else { - iter->second.term_infos.insert(term_infos.begin(), term_infos.end()); - } - } - return Status::OK(); -} - Status CollectionStatistics::extract_collect_info( RuntimeState* state, const vectorized::VExprContextSPtrs& common_expr_ctxs_push_down, - const TabletSchemaSPtr& tablet_schema, - std::unordered_map* collect_infos) { + const TabletSchemaSPtr& tablet_schema, CollectInfoMap* collect_infos) { + DCHECK(collect_infos != nullptr); + + std::unordered_map collectors; + collectors[TExprNodeType::MATCH_PRED] = std::make_unique(); + collectors[TExprNodeType::SEARCH_EXPR] = std::make_unique(); + for (const auto& root_expr_ctx : common_expr_ctxs_push_down) { const auto& root_expr = root_expr_ctx->root(); if (root_expr == nullptr) { @@ -206,27 +128,35 @@ Status CollectionStatistics::extract_collect_info( stack.emplace(root_expr); while (!stack.empty()) { - const auto& expr = stack.top(); + auto expr = stack.top(); stack.pop(); - if (expr->node_type() == TExprNodeType::MATCH_PRED) { - RETURN_IF_ERROR(handle_match_pred(state, tablet_schema, expr, collect_infos)); + if (!expr) { + continue; + } + + auto collector_it = collectors.find(expr->node_type()); + if (collector_it != collectors.end()) { + RETURN_IF_ERROR( + collector_it->second->collect(state, tablet_schema, expr, collect_infos)); } const auto& children = expr->children(); - for (int32_t i = static_cast(children.size()) - 1; i >= 0; --i) { - if (!children[i]->children().empty()) { - stack.emplace(children[i]); - } + for (const auto& child : children) { + stack.push(child); } } } + + LOG(INFO) << "Extracted collect info for " << collect_infos->size() << " fields"; + return Status::OK(); } -Status CollectionStatistics::process_segment( - const RowsetSharedPtr& rowset, int32_t seg_id, const TabletSchema* tablet_schema, - const std::unordered_map& collect_infos, io::IOContext* io_ctx) { +Status CollectionStatistics::process_segment(const RowsetSharedPtr& rowset, int32_t seg_id, + const TabletSchema* tablet_schema, + const CollectInfoMap& collect_infos, + io::IOContext* io_ctx) { auto seg_path = DORIS_TRY(rowset->segment_path(seg_id)); auto rowset_meta = rowset->rowset_meta(); @@ -238,36 +168,42 @@ Status CollectionStatistics::process_segment( RETURN_IF_ERROR(idx_file_reader->init(config::inverted_index_read_buffer_size, io_ctx)); int32_t total_seg_num_docs = 0; + for (const auto& [ws_field_name, collect_info] : collect_infos) { + lucene::search::IndexSearcher* index_searcher = nullptr; + lucene::index::IndexReader* index_reader = nullptr; + #ifdef BE_TEST auto compound_reader = DORIS_TRY(idx_file_reader->open(collect_info.index_meta, io_ctx)); auto* reader = lucene::index::IndexReader::open(compound_reader.get()); - auto index_searcher = std::make_shared(reader, true); - - auto* index_reader = index_searcher->getReader(); + auto searcher_ptr = std::make_shared(reader, true); + index_searcher = searcher_ptr.get(); + index_reader = index_searcher->getReader(); #else InvertedIndexCacheHandle inverted_index_cache_handle; auto index_file_key = idx_file_reader->get_index_file_cache_key(collect_info.index_meta); InvertedIndexSearcherCache::CacheKey searcher_cache_key(index_file_key); + if (!InvertedIndexSearcherCache::instance()->lookup(searcher_cache_key, &inverted_index_cache_handle)) { auto compound_reader = DORIS_TRY(idx_file_reader->open(collect_info.index_meta, io_ctx)); auto* reader = lucene::index::IndexReader::open(compound_reader.get()); size_t reader_size = reader->getTermInfosRAMUsed(); - auto index_searcher = std::make_shared(reader, true); + auto searcher_ptr = std::make_shared(reader, true); auto* cache_value = new InvertedIndexSearcherCache::CacheValue( - std::move(index_searcher), reader_size, UnixMillis()); + std::move(searcher_ptr), reader_size, UnixMillis()); InvertedIndexSearcherCache::instance()->insert(searcher_cache_key, cache_value, &inverted_index_cache_handle); } auto searcher_variant = inverted_index_cache_handle.get_index_searcher(); - auto index_searcher = std::get(searcher_variant); - auto* index_reader = index_searcher->getReader(); + auto index_searcher_ptr = std::get(searcher_variant); + index_searcher = index_searcher_ptr.get(); + index_reader = index_searcher->getReader(); #endif - total_seg_num_docs = std::max(total_seg_num_docs, index_reader->maxDoc()); + _total_num_tokens[ws_field_name] += index_reader->sumTotalTermFreq(ws_field_name.c_str()).value_or(0); @@ -277,7 +213,9 @@ Status CollectionStatistics::process_segment( _term_doc_freqs[ws_field_name][iter->term()] += iter->doc_freq(); } } + _total_num_docs += total_seg_num_docs; + return Status::OK(); } diff --git a/be/src/olap/collection_statistics.h b/be/src/olap/collection_statistics.h index 9fdd3ddde30650..14a081cb534fb1 100644 --- a/be/src/olap/collection_statistics.h +++ b/be/src/olap/collection_statistics.h @@ -22,6 +22,7 @@ #include "common/be_mock_util.h" #include "olap/olap_common.h" +#include "olap/predicate_collector.h" #include "olap/rowset/segment_v2/inverted_index/query/query_info.h" #include "runtime/runtime_state.h" #include "vec/exprs/vexpr_fwd.h" @@ -44,18 +45,6 @@ class TabletIndex; class TabletSchema; using TabletSchemaSPtr = std::shared_ptr; -struct TermInfoComparer { - bool operator()(const segment_v2::TermInfo& lhs, const segment_v2::TermInfo& rhs) const { - return lhs.term < rhs.term; - } -}; - -class CollectInfo { -public: - std::set term_infos; - const TabletIndex* index_meta = nullptr; -}; - class CollectionStatistics { public: CollectionStatistics() = default; @@ -74,10 +63,9 @@ class CollectionStatistics { Status extract_collect_info(RuntimeState* state, const vectorized::VExprContextSPtrs& common_expr_ctxs_push_down, const TabletSchemaSPtr& tablet_schema, - std::unordered_map* collect_infos); + CollectInfoMap* collect_infos); Status process_segment(const RowsetSharedPtr& rowset, int32_t seg_id, - const TabletSchema* tablet_schema, - const std::unordered_map& collect_infos, + const TabletSchema* tablet_schema, const CollectInfoMap& collect_infos, io::IOContext* io_ctx); uint64_t get_term_doc_freq_by_col(const std::wstring& lucene_col_name, @@ -95,6 +83,7 @@ class CollectionStatistics { MOCK_DEFINE(friend class BM25SimilarityTest;) MOCK_DEFINE(friend class CollectionStatisticsTest;) MOCK_DEFINE(friend class BooleanQueryTest;) + MOCK_DEFINE(friend class OccurBooleanQueryTest;) }; using CollectionStatisticsPtr = std::shared_ptr; diff --git a/be/src/olap/predicate_collector.cpp b/be/src/olap/predicate_collector.cpp new file mode 100644 index 00000000000000..1cdcecf1716db9 --- /dev/null +++ b/be/src/olap/predicate_collector.cpp @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/predicate_collector.h" + +#include + +#include "gen_cpp/Exprs_types.h" +#include "olap/rowset/segment_v2/index_reader_helper.h" +#include "olap/rowset/segment_v2/inverted_index/analyzer/analyzer.h" +#include "olap/rowset/segment_v2/inverted_index/util/string_helper.h" +#include "olap/tablet_schema.h" +#include "vec/exprs/vexpr.h" +#include "vec/exprs/vexpr_context.h" +#include "vec/exprs/vliteral.h" +#include "vec/exprs/vsearch.h" +#include "vec/exprs/vslot_ref.h" + +namespace doris { + +using namespace segment_v2; +using namespace vectorized; + +vectorized::VSlotRef* PredicateCollector::find_slot_ref(const vectorized::VExprSPtr& expr) const { + if (!expr) { + return nullptr; + } + + auto cur = VExpr::expr_without_cast(expr); + if (cur->node_type() == TExprNodeType::SLOT_REF) { + return static_cast(cur.get()); + } + + for (const auto& ch : cur->children()) { + if (auto* s = find_slot_ref(ch)) { + return s; + } + } + + return nullptr; +} + +std::string PredicateCollector::build_field_name(int32_t col_unique_id, + const std::string& suffix_path) const { + std::string field_name = std::to_string(col_unique_id); + if (!suffix_path.empty()) { + field_name += "." + suffix_path; + } + return field_name; +} + +Status MatchPredicateCollector::collect(RuntimeState* state, const TabletSchemaSPtr& tablet_schema, + const vectorized::VExprSPtr& expr, + CollectInfoMap* collect_infos) { + DCHECK(collect_infos != nullptr); + + auto* left_slot_ref = find_slot_ref(expr->children()[0]); + if (left_slot_ref == nullptr) { + return Status::Error( + "Index statistics collection failed: Cannot find slot reference in match predicate " + "left expression"); + } + + auto* right_literal = static_cast(expr->children()[1].get()); + DCHECK(right_literal != nullptr); + + const auto* sd = state->desc_tbl().get_slot_descriptor(left_slot_ref->slot_id()); + if (sd == nullptr) { + return Status::Error( + "Index statistics collection failed: Cannot find slot descriptor for slot_id={}", + left_slot_ref->slot_id()); + } + + int32_t col_idx = tablet_schema->field_index(left_slot_ref->column_name()); + if (col_idx == -1) { + return Status::Error( + "Index statistics collection failed: Cannot find column index for column={}", + left_slot_ref->column_name()); + } + + const auto& column = tablet_schema->column(col_idx); + auto index_metas = tablet_schema->inverted_indexs(sd->col_unique_id(), column.suffix_path()); + +#ifndef BE_TEST + if (index_metas.empty()) { + return Status::Error( + "Index statistics collection failed: Score query is not supported without inverted " + "index for column={}", + left_slot_ref->column_name()); + } +#endif + + for (const auto* index_meta : index_metas) { + if (!InvertedIndexAnalyzer::should_analyzer(index_meta->properties())) { + continue; + } + + if (!IndexReaderHelper::is_need_similarity_score(expr->op(), index_meta)) { + continue; + } + + auto options = DataTypeSerDe::get_default_format_options(); + options.timezone = &state->timezone_obj(); + auto term_infos = InvertedIndexAnalyzer::get_analyse_result(right_literal->value(options), + index_meta->properties()); + + std::string field_name = + build_field_name(index_meta->col_unique_ids()[0], column.suffix_path()); + std::wstring ws_field_name = StringHelper::to_wstring(field_name); + + auto iter = collect_infos->find(ws_field_name); + if (iter == collect_infos->end()) { + CollectInfo collect_info; + collect_info.term_infos.insert(term_infos.begin(), term_infos.end()); + collect_info.index_meta = index_meta; + (*collect_infos)[ws_field_name] = std::move(collect_info); + } else { + iter->second.term_infos.insert(term_infos.begin(), term_infos.end()); + } + } + + return Status::OK(); +} + +Status SearchPredicateCollector::collect(RuntimeState* state, const TabletSchemaSPtr& tablet_schema, + const vectorized::VExprSPtr& expr, + CollectInfoMap* collect_infos) { + DCHECK(collect_infos != nullptr); + + auto* search_expr = dynamic_cast(expr.get()); + if (search_expr == nullptr) { + return Status::InternalError("SearchPredicateCollector: expr is not VSearchExpr type"); + } + + const TSearchParam& search_param = search_expr->get_search_param(); + + RETURN_IF_ERROR(collect_from_clause(search_param.root, state, tablet_schema, collect_infos)); + + return Status::OK(); +} + +Status SearchPredicateCollector::collect_from_clause(const TSearchClause& clause, + RuntimeState* state, + const TabletSchemaSPtr& tablet_schema, + CollectInfoMap* collect_infos) { + const std::string& clause_type = clause.clause_type; + ClauseTypeCategory category = get_clause_type_category(clause_type); + + if (category == ClauseTypeCategory::COMPOUND) { + if (clause.__isset.children) { + for (const auto& child_clause : clause.children) { + RETURN_IF_ERROR( + collect_from_clause(child_clause, state, tablet_schema, collect_infos)); + } + } + return Status::OK(); + } + + return collect_from_leaf(clause, state, tablet_schema, collect_infos); +} + +Status SearchPredicateCollector::collect_from_leaf(const TSearchClause& clause, RuntimeState* state, + const TabletSchemaSPtr& tablet_schema, + CollectInfoMap* collect_infos) { + if (!clause.__isset.field_name || !clause.__isset.value) { + return Status::InvalidArgument("Search clause missing field_name or value"); + } + + const std::string& field_name = clause.field_name; + const std::string& value = clause.value; + const std::string& clause_type = clause.clause_type; + + if (!is_score_query_type(clause_type)) { + return Status::OK(); + } + + int32_t col_idx = tablet_schema->field_index(field_name); + if (col_idx == -1) { + return Status::OK(); + } + + const auto& column = tablet_schema->column(col_idx); + + auto index_metas = tablet_schema->inverted_indexs(column.unique_id(), column.suffix_path()); + if (index_metas.empty()) { + return Status::OK(); + } + + ClauseTypeCategory category = get_clause_type_category(clause_type); + for (const auto* index_meta : index_metas) { + std::set term_infos; + + if (category == ClauseTypeCategory::TOKENIZED) { + if (InvertedIndexAnalyzer::should_analyzer(index_meta->properties())) { + auto analyzed_terms = + InvertedIndexAnalyzer::get_analyse_result(value, index_meta->properties()); + term_infos.insert(analyzed_terms.begin(), analyzed_terms.end()); + } else { + term_infos.insert(TermInfo(value)); + } + } else if (category == ClauseTypeCategory::NON_TOKENIZED) { + if (clause_type == "TERM" && + InvertedIndexAnalyzer::should_analyzer(index_meta->properties())) { + auto analyzed_terms = + InvertedIndexAnalyzer::get_analyse_result(value, index_meta->properties()); + term_infos.insert(analyzed_terms.begin(), analyzed_terms.end()); + } else { + term_infos.insert(TermInfo(value)); + } + } + + std::string lucene_field_name = + build_field_name(index_meta->col_unique_ids()[0], column.suffix_path()); + std::wstring ws_field_name = StringHelper::to_wstring(lucene_field_name); + + auto iter = collect_infos->find(ws_field_name); + if (iter == collect_infos->end()) { + CollectInfo collect_info; + collect_info.term_infos = std::move(term_infos); + collect_info.index_meta = index_meta; + (*collect_infos)[ws_field_name] = std::move(collect_info); + } else { + iter->second.term_infos.insert(term_infos.begin(), term_infos.end()); + } + } + + return Status::OK(); +} + +bool SearchPredicateCollector::is_score_query_type(const std::string& clause_type) const { + return clause_type == "TERM" || clause_type == "EXACT" || clause_type == "PHRASE" || + clause_type == "MATCH" || clause_type == "ANY" || clause_type == "ALL"; +} + +SearchPredicateCollector::ClauseTypeCategory SearchPredicateCollector::get_clause_type_category( + const std::string& clause_type) const { + if (clause_type == "AND" || clause_type == "OR" || clause_type == "NOT" || + clause_type == "OCCUR_BOOLEAN") { + return ClauseTypeCategory::COMPOUND; + } else if (clause_type == "TERM" || clause_type == "EXACT") { + return ClauseTypeCategory::NON_TOKENIZED; + } else if (clause_type == "PHRASE" || clause_type == "MATCH" || clause_type == "ANY" || + clause_type == "ALL") { + return ClauseTypeCategory::TOKENIZED; + } else { + LOG(WARNING) << "Unknown clause type '" << clause_type + << "', defaulting to NON_TOKENIZED category"; + return ClauseTypeCategory::NON_TOKENIZED; + } +} + +} // namespace doris \ No newline at end of file diff --git a/be/src/olap/predicate_collector.h b/be/src/olap/predicate_collector.h new file mode 100644 index 00000000000000..8e0557eb7f394a --- /dev/null +++ b/be/src/olap/predicate_collector.h @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "common/status.h" +#include "gen_cpp/Exprs_types.h" +#include "olap/rowset/segment_v2/inverted_index/query/query_info.h" +#include "runtime/runtime_state.h" +#include "vec/exprs/vexpr_fwd.h" + +namespace doris { + +namespace vectorized { +class VSlotRef; +} // namespace vectorized + +class TabletIndex; +class TabletSchema; +using TabletSchemaSPtr = std::shared_ptr; + +struct TermInfoComparer { + bool operator()(const segment_v2::TermInfo& lhs, const segment_v2::TermInfo& rhs) const { + return lhs.term < rhs.term; + } +}; + +struct CollectInfo { + std::set term_infos; + const TabletIndex* index_meta = nullptr; +}; +using CollectInfoMap = std::unordered_map; + +class PredicateCollector { +public: + virtual ~PredicateCollector() = default; + + virtual Status collect(RuntimeState* state, const TabletSchemaSPtr& tablet_schema, + const vectorized::VExprSPtr& expr, CollectInfoMap* collect_infos) = 0; + +protected: + vectorized::VSlotRef* find_slot_ref(const vectorized::VExprSPtr& expr) const; + std::string build_field_name(int32_t col_unique_id, const std::string& suffix_path) const; +}; + +class MatchPredicateCollector : public PredicateCollector { +public: + Status collect(RuntimeState* state, const TabletSchemaSPtr& tablet_schema, + const vectorized::VExprSPtr& expr, CollectInfoMap* collect_infos) override; +}; + +class SearchPredicateCollector : public PredicateCollector { +public: + Status collect(RuntimeState* state, const TabletSchemaSPtr& tablet_schema, + const vectorized::VExprSPtr& expr, CollectInfoMap* collect_infos) override; + +private: + enum class ClauseTypeCategory { NON_TOKENIZED, TOKENIZED, COMPOUND }; + + Status collect_from_clause(const TSearchClause& clause, RuntimeState* state, + const TabletSchemaSPtr& tablet_schema, + CollectInfoMap* collect_infos); + Status collect_from_leaf(const TSearchClause& clause, RuntimeState* state, + const TabletSchemaSPtr& tablet_schema, CollectInfoMap* collect_infos); + bool is_score_query_type(const std::string& clause_type) const; + ClauseTypeCategory get_clause_type_category(const std::string& clause_type) const; +}; + +using PredicateCollectorPtr = std::unique_ptr; + +} // namespace doris \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/index_iterator.h b/be/src/olap/rowset/segment_v2/index_iterator.h index b97069f4089c57..c12c74c71d6c0c 100644 --- a/be/src/olap/rowset/segment_v2/index_iterator.h +++ b/be/src/olap/rowset/segment_v2/index_iterator.h @@ -57,6 +57,7 @@ class IndexIterator { virtual Result has_null() = 0; void set_context(const IndexQueryContextPtr& context) { _context = context; } + IndexQueryContextPtr get_context() const { return _context; } protected: IndexQueryContextPtr _context = nullptr; diff --git a/be/src/olap/rowset/segment_v2/index_query_context.h b/be/src/olap/rowset/segment_v2/index_query_context.h index fdd02504dcc337..811ddece4fa2b2 100644 --- a/be/src/olap/rowset/segment_v2/index_query_context.h +++ b/be/src/olap/rowset/segment_v2/index_query_context.h @@ -30,6 +30,9 @@ struct IndexQueryContext { CollectionStatisticsPtr collection_statistics; CollectionSimilarityPtr collection_similarity; + + size_t query_limit = 0; + bool is_asc = false; }; using IndexQueryContextPtr = std::shared_ptr; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.cpp index d2318165671078..acbeb53075cc3a 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.cpp +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.cpp @@ -45,6 +45,12 @@ OccurBooleanWeight::OccurBooleanWeight( template ScorerPtr OccurBooleanWeight::scorer(const QueryExecutionContext& context) { + return scorer(context, {}); +} + +template +ScorerPtr OccurBooleanWeight::scorer(const QueryExecutionContext& context, + const std::string& binding_key) { if (_sub_weights.empty()) { return std::make_shared(); } @@ -53,27 +59,28 @@ ScorerPtr OccurBooleanWeight::scorer(const QueryExecutionCont if (occur == Occur::MUST_NOT) { return std::make_shared(); } - return weight->scorer(context); + return weight->scorer(context, binding_key); } _max_doc = context.segment_num_rows; if (_enable_scoring) { - auto specialized = complex_scorer(context, _score_combiner); + auto specialized = complex_scorer(context, _score_combiner, binding_key); return into_box_scorer(std::move(specialized), _score_combiner); } else { auto combiner = std::make_shared(); - auto specialized = complex_scorer(context, combiner); + auto specialized = complex_scorer(context, combiner, binding_key); return into_box_scorer(std::move(specialized), combiner); } } template std::unordered_map> -OccurBooleanWeight::per_occur_scorers(const QueryExecutionContext& context) { +OccurBooleanWeight::per_occur_scorers(const QueryExecutionContext& context, + const std::string& binding_key) { std::unordered_map> result; for (size_t i = 0; i < _sub_weights.size(); ++i) { const auto& [occur, weight] = _sub_weights[i]; - const auto& binding_key = _binding_keys[i]; - auto sub_scorer = weight->scorer(context, binding_key); + const auto& key = _binding_keys[i].empty() ? binding_key : _binding_keys[i]; + auto sub_scorer = weight->scorer(context, key); if (sub_scorer) { result[occur].push_back(std::move(sub_scorer)); } @@ -217,8 +224,8 @@ SpecializedScorer OccurBooleanWeight::build_positive_opt( template template SpecializedScorer OccurBooleanWeight::complex_scorer( - const QueryExecutionContext& context, CombinerT combiner) { - auto scorers_by_occur = per_occur_scorers(context); + const QueryExecutionContext& context, CombinerT combiner, const std::string& binding_key) { + auto scorers_by_occur = per_occur_scorers(context, binding_key); auto must_scorers = std::move(scorers_by_occur[Occur::MUST]); auto should_scorers = std::move(scorers_by_occur[Occur::SHOULD]); auto must_not_scorers = std::move(scorers_by_occur[Occur::MUST_NOT]); diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h index 70c43f25a50fb5..acb2d67939625c 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h @@ -22,6 +22,7 @@ #include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/wand/block_wand.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" namespace doris::segment_v2::inverted_index::query_v2 { @@ -51,14 +52,21 @@ class OccurBooleanWeight : public Weight { ~OccurBooleanWeight() override = default; ScorerPtr scorer(const QueryExecutionContext& context) override; + ScorerPtr scorer(const QueryExecutionContext& context, const std::string& binding_key) override; + + void for_each_pruning(const QueryExecutionContext& context, float threshold, + PruningCallback callback) override; + void for_each_pruning(const QueryExecutionContext& context, const std::string& binding_key, + float threshold, PruningCallback callback) override; private: std::unordered_map> per_occur_scorers( - const QueryExecutionContext& context); + const QueryExecutionContext& context, const std::string& binding_key = {}); AllAndEmptyScorerCounts remove_and_count_all_and_empty_scorers(std::vector& scorers); template - SpecializedScorer complex_scorer(const QueryExecutionContext& context, CombinerT combiner); + SpecializedScorer complex_scorer(const QueryExecutionContext& context, CombinerT combiner, + const std::string& binding_key = {}); template std::optional build_should_opt(std::vector& must_scorers, @@ -100,4 +108,35 @@ class OccurBooleanWeight : public Weight { uint32_t _max_doc = 0; }; +template +void OccurBooleanWeight::for_each_pruning(const QueryExecutionContext& context, + float threshold, + PruningCallback callback) { + for_each_pruning(context, {}, threshold, std::move(callback)); +} + +template +void OccurBooleanWeight::for_each_pruning(const QueryExecutionContext& context, + const std::string& binding_key, + float threshold, + PruningCallback callback) { + if (_sub_weights.empty()) { + return; + } + + _max_doc = context.segment_num_rows; + auto specialized = complex_scorer(context, _score_combiner, binding_key); + + std::visit( + [&](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + block_wand(std::move(arg), threshold, std::move(callback)); + } else { + for_each_pruning_scorer(std::move(arg), threshold, std::move(callback)); + } + }, + std::move(specialized)); +} + } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/doc_set_collector.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/doc_set_collector.cpp new file mode 100644 index 00000000000000..49ac22bca105a7 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/doc_set_collector.cpp @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/collect/doc_set_collector.h" + +#include "olap/rowset/segment_v2/inverted_index/query_v2/collect/multi_segment_util.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +void collect_single_scorer(const WeightPtr& weight, const QueryExecutionContext& context, + const std::string& binding_key, + const std::shared_ptr& roaring, + const CollectionSimilarityPtr& similarity, bool enable_scoring) { + auto scorer = weight->scorer(context, binding_key); + if (!scorer) { + return; + } + + uint32_t doc = scorer->doc(); + while (doc != TERMINATED) { + roaring->add(doc); + if (enable_scoring && similarity) { + similarity->collect(doc, scorer->score()); + } + doc = scorer->advance(); + } +} + +void collect_multi_segment_doc_set(const WeightPtr& weight, const QueryExecutionContext& context, + const std::string& binding_key, + const std::shared_ptr& roaring, + const CollectionSimilarityPtr& similarity, bool enable_scoring) { + if (context.readers.empty()) { + collect_single_scorer(weight, context, binding_key, roaring, similarity, enable_scoring); + return; + } + + for_each_index_segment(context, binding_key, + [&](const QueryExecutionContext& seg_ctx, uint32_t doc_base) { + auto scorer = weight->scorer(seg_ctx, binding_key); + if (!scorer) { + return; + } + + uint32_t doc = scorer->doc(); + while (doc != TERMINATED) { + uint32_t global_doc = doc + doc_base; + roaring->add(global_doc); + if (enable_scoring && similarity) { + similarity->collect(global_doc, scorer->score()); + } + doc = scorer->advance(); + } + }); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/doc_set_collector.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/doc_set_collector.h new file mode 100644 index 00000000000000..06189f6ccb3a5e --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/doc_set_collector.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "olap/collection_similarity.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +void collect_multi_segment_doc_set(const WeightPtr& weight, const QueryExecutionContext& context, + const std::string& binding_key, + const std::shared_ptr& roaring, + const CollectionSimilarityPtr& similarity, bool enable_scoring); + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/multi_segment_util.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/multi_segment_util.h new file mode 100644 index 00000000000000..7eb0a9e10aadd6 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/multi_segment_util.h @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wshadow-field" +#pragma clang diagnostic ignored "-Woverloaded-virtual" +#pragma clang diagnostic ignored "-Winconsistent-missing-override" +#pragma clang diagnostic ignored "-Wreorder-ctor" +#pragma clang diagnostic ignored "-Wshorten-64-to-32" +#elif defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Woverloaded-virtual" +#endif +#include "CLucene.h" +#include "CLucene/index/_MultiSegmentReader.h" +#ifdef __clang__ +#pragma clang diagnostic pop +#elif defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +namespace doris::segment_v2::inverted_index::query_v2 { + +inline QueryExecutionContext create_segment_context(lucene::index::IndexReader* seg_reader, + const QueryExecutionContext& original_ctx, + const std::string& binding_key) { + QueryExecutionContext seg_ctx; + seg_ctx.segment_num_rows = seg_reader->numDocs(); + + auto reader_ptr = std::shared_ptr( + seg_reader, [](lucene::index::IndexReader*) {}); + seg_ctx.readers.push_back(reader_ptr); + + if (!binding_key.empty()) { + seg_ctx.reader_bindings[binding_key] = reader_ptr; + } + + seg_ctx.binding_fields = original_ctx.binding_fields; + seg_ctx.null_resolver = original_ctx.null_resolver; + + return seg_ctx; +} + +template +void for_each_index_segment(const QueryExecutionContext& context, const std::string& binding_key, + SegmentCallback&& callback) { + auto* reader = context.readers.empty() ? nullptr : context.readers.front().get(); + if (!reader) { + // No reader available (e.g., AllQuery/MatchAllDocsQuery which doesn't resolve fields). + // Fall back to using the original context directly, as AllScorer only needs segment_num_rows. + if (context.segment_num_rows > 0) { + callback(context, 0); + } + return; + } + + auto* multi_reader = dynamic_cast(reader); + if (multi_reader == nullptr) { + callback(context, 0); + return; + } + + const auto* sub_readers = multi_reader->getSubReaders(); + const auto* starts = multi_reader->getStarts(); + + if (!sub_readers || sub_readers->length == 0) { + return; + } + + for (size_t i = 0; i < sub_readers->length; ++i) { + auto* seg_reader = (*sub_readers)[i]; + auto seg_base = static_cast(starts[i]); + QueryExecutionContext seg_ctx = create_segment_context(seg_reader, context, binding_key); + callback(seg_ctx, seg_base); + } +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/top_k_collector.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/top_k_collector.cpp new file mode 100644 index 00000000000000..b7bb378f938266 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/top_k_collector.cpp @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/collect/top_k_collector.h" + +#include "olap/rowset/segment_v2/inverted_index/query_v2/collect/multi_segment_util.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +void collect_multi_segment_top_k(const WeightPtr& weight, const QueryExecutionContext& context, + const std::string& binding_key, size_t k, + const std::shared_ptr& roaring, + const CollectionSimilarityPtr& similarity, bool use_wand) { + TopKCollector final_collector(k); + + for_each_index_segment( + context, binding_key, [&](const QueryExecutionContext& seg_ctx, uint32_t seg_base) { + float initial_threshold = final_collector.threshold(); + + TopKCollector seg_collector(k); + auto callback = [&seg_collector](uint32_t doc_id, float score) -> float { + return seg_collector.collect(doc_id, score); + }; + + if (use_wand) { + weight->for_each_pruning(seg_ctx, binding_key, initial_threshold, callback); + } else { + auto scorer = weight->scorer(seg_ctx, binding_key); + if (scorer) { + Weight::for_each_pruning_scorer(scorer, initial_threshold, callback); + } + } + + for (const auto& doc : seg_collector.into_sorted_vec()) { + final_collector.collect(doc.doc_id + seg_base, doc.score); + } + }); + + for (const auto& doc : final_collector.into_sorted_vec()) { + roaring->add(doc.doc_id); + if (similarity) { + similarity->collect(doc.doc_id, doc.score); + } + } +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/top_k_collector.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/top_k_collector.h new file mode 100644 index 00000000000000..bde9ec5d2f0be4 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/collect/top_k_collector.h @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "olap/collection_similarity.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +struct ScoredDoc { + ScoredDoc() = default; + ScoredDoc(uint32_t doc, float s) : doc_id(doc), score(s) {} + + uint32_t doc_id = 0; + float score = 0.0F; +}; + +struct ScoredDocByScoreDesc { + bool operator()(const ScoredDoc& a, const ScoredDoc& b) const { + return a.score > b.score || (a.score == b.score && a.doc_id < b.doc_id); + } +}; + +class TopKCollector { +public: + static constexpr size_t kMaxK = 10000; + + explicit TopKCollector(size_t k) : _k(std::clamp(k, size_t(1), kMaxK)) { + if (k > kMaxK) { + LOG(WARNING) << "TopKCollector: requested k=" << k << " exceeds maximum " << kMaxK + << ", truncated to " << kMaxK; + } + _buffer.reserve(_k * 2); + } + + float collect(uint32_t doc_id, float score) { + if (score < _threshold) { + return _threshold; + } + _buffer.emplace_back(doc_id, score); + if (_buffer.size() == _buffer.capacity()) { + _truncate(); + } else if (_buffer.size() == _k) { + _update_threshold_at_capacity(); + } + return _threshold; + } + + float threshold() const { return _threshold; } + size_t size() const { return std::min(_buffer.size(), _k); } + + [[nodiscard]] std::vector into_sorted_vec() { + if (_buffer.size() > _k) { + _truncate(); + } + std::ranges::sort(_buffer, ScoredDocByScoreDesc {}); + return std::move(_buffer); + } + +private: + void _truncate() { + std::ranges::nth_element(_buffer, _buffer.begin() + _k, ScoredDocByScoreDesc {}); + _buffer.resize(_k); + _update_threshold_at_capacity(); + } + + void _update_threshold_at_capacity() { + auto it = std::ranges::max_element(_buffer, ScoredDocByScoreDesc {}); + _threshold = it->score; + } + + size_t _k; + float _threshold = -std::numeric_limits::infinity(); + std::vector _buffer; +}; + +void collect_multi_segment_top_k(const WeightPtr& weight, const QueryExecutionContext& context, + const std::string& binding_key, size_t k, + const std::shared_ptr& roaring, + const CollectionSimilarityPtr& similarity, bool use_wand = true); + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/composite_reader.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/composite_reader.h index 543f8211ff3358..5b789df055504e 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/composite_reader.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/composite_reader.h @@ -17,8 +17,20 @@ #pragma once +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Woverloaded-virtual" +#elif defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Woverloaded-virtual" +#endif #include #include +#ifdef __clang__ +#pragma clang diagnostic pop +#elif defined(__GNUC__) +#pragma GCC diagnostic pop +#endif #include #include diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/match_all_docs_scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/match_all_docs_scorer.h index 1260086acdc651..23f15b31624d71 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/match_all_docs_scorer.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/match_all_docs_scorer.h @@ -21,7 +21,19 @@ #include #include +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Woverloaded-virtual" +#elif defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Woverloaded-virtual" +#endif #include "CLucene.h" // IWYU pragma: keep +#ifdef __clang__ +#pragma clang diagnostic pop +#elif defined(__GNUC__) +#pragma GCC diagnostic pop +#endif #include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" namespace doris::segment_v2::inverted_index::query_v2 { diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_prefix_query/phrase_prefix_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_prefix_query/phrase_prefix_weight.h index c306cbc4e21e9b..7f8b42ba6ac3be 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_prefix_query/phrase_prefix_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_prefix_query/phrase_prefix_weight.h @@ -65,7 +65,7 @@ class PhrasePrefixWeight : public Weight { std::vector> all_postings; for (const auto& [offset, term] : _phrase_terms) { auto posting = create_position_posting(reader.get(), _field, term, _enable_scoring, - _context->io_ctx); + _similarity, _context->io_ctx); if (!posting) { return std::make_shared(); } @@ -81,7 +81,7 @@ class PhrasePrefixWeight : public Weight { std::vector suffix_postings; for (const auto& term : expanded_terms) { auto posting = create_position_posting(reader.get(), _field, term, _enable_scoring, - _context->io_ctx); + _similarity, _context->io_ctx); if (posting) { suffix_postings.emplace_back(std::move(posting)); } diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/multi_phrase_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/multi_phrase_weight.h index e75c59e3607799..9807831b39ef1a 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/multi_phrase_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/multi_phrase_weight.h @@ -66,7 +66,7 @@ class MultiPhraseWeight : public Weight { if (term_info.is_single_term()) { auto posting = create_position_posting(reader.get(), _field, term_info.get_single_term(), - _enable_scoring, _context->io_ctx); + _enable_scoring, _similarity, _context->io_ctx); if (posting) { if (posting->size_hint() > SPARSE_TERM_DOC_THRESHOLD) { auto loaded_posting = LoadedPostings::load(*posting); @@ -81,8 +81,9 @@ class MultiPhraseWeight : public Weight { const auto& terms = term_info.get_multi_terms(); std::vector postings; for (const auto& term : terms) { - auto posting = create_position_posting(reader.get(), _field, term, - _enable_scoring, _context->io_ctx); + auto posting = + create_position_posting(reader.get(), _field, term, _enable_scoring, + _similarity, _context->io_ctx); if (posting) { if (posting->size_hint() <= SPARSE_TERM_DOC_THRESHOLD) { postings.push_back(LoadedPostings::load(*posting)); diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h index 75457aafef451a..30964f51a96dfb 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h @@ -62,7 +62,7 @@ class PhraseWeight : public Weight { size_t offset = term_info.position; auto posting = create_position_posting(reader.get(), _field, term_info.get_single_term(), - _enable_scoring, _context->io_ctx); + _enable_scoring, _similarity, _context->io_ctx); if (posting) { term_postings_list.emplace_back(offset, std::move(posting)); } else { diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/prefix_query/prefix_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/prefix_query/prefix_weight.h index 7f24557cf28dd7..f0c3903087985a 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/prefix_query/prefix_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/prefix_query/prefix_weight.h @@ -17,9 +17,21 @@ #pragma once +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Woverloaded-virtual" +#elif defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Woverloaded-virtual" +#endif #include #include #include +#ifdef __clang__ +#pragma clang diagnostic pop +#elif defined(__GNUC__) +#pragma GCC diagnostic pop +#endif #include "olap/rowset/segment_v2/index_query_context.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_scorer.h" @@ -135,7 +147,8 @@ class PrefixWeight : public Weight { auto term_wstr = StringHelper::to_wstring(term); auto t = make_term_ptr(_field.c_str(), term_wstr.c_str()); auto iter = make_term_doc_ptr(reader.get(), t.get(), _enable_scoring, _context->io_ctx); - auto segment_postings = make_segment_postings(std::move(iter), _enable_scoring); + auto segment_postings = + make_segment_postings(std::move(iter), _enable_scoring, nullptr); uint32_t doc = segment_postings->doc(); while (doc != TERMINATED) { diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.cpp index cb71e10daa463c..db6f3524584f96 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.cpp +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.cpp @@ -17,8 +17,20 @@ #include "regexp_weight.h" +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Woverloaded-virtual" +#elif defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Woverloaded-virtual" +#endif #include #include +#ifdef __clang__ +#pragma clang diagnostic pop +#elif defined(__GNUC__) +#pragma GCC diagnostic pop +#endif #include #include @@ -41,7 +53,9 @@ RegexpWeight::RegexpWeight(IndexQueryContextPtr context, std::wstring field, std _pattern(std::move(pattern)), _enable_scoring(enable_scoring), _nullable(nullable) { - // _max_expansions = _context->runtime_state->query_options().inverted_index_max_expansions; + if (_context->runtime_state) { + _max_expansions = _context->runtime_state->query_options().inverted_index_max_expansions; + } } ScorerPtr RegexpWeight::scorer(const QueryExecutionContext& context, @@ -91,13 +105,11 @@ ScorerPtr RegexpWeight::regexp_scorer(const QueryExecutionContext& context, return std::make_shared(); } + auto reader = lookup_reader(_field, context, binding_key); auto doc_bitset = std::make_shared(); for (const auto& term : matching_terms) { - auto t = make_term_ptr(_field.c_str(), term.c_str()); - auto reader = lookup_reader(_field, context, binding_key); - auto iter = make_term_doc_ptr(reader.get(), t.get(), _enable_scoring, _context->io_ctx); - auto segment_postings = make_segment_postings(std::move(iter), _enable_scoring); - + auto segment_postings = + create_term_posting(reader.get(), _field, term, false, nullptr, _context->io_ctx); uint32_t doc = segment_postings->doc(); while (doc != TERMINATED) { doc_bitset->add(doc); diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.h index f9959ff0d8ce3c..70c8263969a569 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.h @@ -46,7 +46,7 @@ class RegexpWeight : public Weight { std::wstring _field; std::string _pattern; - bool _enable_scoring = false; + [[maybe_unused]] bool _enable_scoring = false; bool _nullable = true; // Set to 0 to disable limit (ES has no default limit for prefix queries) // The limit prevents collecting too many terms, but can cause incorrect results diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/scorer.h index 794acde282a75e..14a76a884018a4 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/scorer.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/scorer.h @@ -71,5 +71,6 @@ class EmptyScorer : public Scorer { float score() override { return 0.0F; } }; +using EmptyScorerPtr = std::shared_ptr; } // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h index 455723ba28ffb4..5be919609e3059 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h @@ -17,14 +17,19 @@ #pragma once +#include #include #include "CLucene/index/DocRange.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h" +#include "olap/rowset/segment_v2/inverted_index/similarity/similarity.h" #include "olap/rowset/segment_v2/inverted_index_common.h" namespace doris::segment_v2::inverted_index::query_v2 { +using doris::segment_v2::Similarity; +using doris::segment_v2::SimilarityPtr; + class Postings : public DocSet { public: Postings() = default; @@ -40,20 +45,25 @@ class Postings : public DocSet { using PostingsPtr = std::shared_ptr; -class SegmentPostings final : public Postings { +class SegmentPostings : public Postings { public: using IterVariant = std::variant; - explicit SegmentPostings(TermDocsPtr iter, bool enable_scoring = false) - : _iter(std::move(iter)), _enable_scoring(enable_scoring) { + explicit SegmentPostings(TermDocsPtr iter, bool enable_scoring, SimilarityPtr similarity) + : _iter(std::move(iter)), + _enable_scoring(enable_scoring), + _similarity(std::move(similarity)) { if (auto* p = std::get_if(&_iter)) { _raw_iter = p->get(); } _init_doc(); } - explicit SegmentPostings(TermPositionsPtr iter, bool enable_scoring = false) - : _iter(std::move(iter)), _enable_scoring(enable_scoring), _has_positions(true) { + explicit SegmentPostings(TermPositionsPtr iter, bool enable_scoring, SimilarityPtr similarity) + : _iter(std::move(iter)), + _enable_scoring(enable_scoring), + _has_positions(true), + _similarity(std::move(similarity)) { if (auto* p = std::get_if(&_iter)) { _raw_iter = p->get(); } @@ -155,14 +165,63 @@ class SegmentPostings final : public Postings { bool scoring_enabled() const { return _enable_scoring; } + int64_t block_id() const { return _block_id; } + + void seek_block(uint32_t target_doc) { + if (target_doc <= _doc) { + return; + } + if (_raw_iter->skipToBlock(target_doc)) { + _block_max_score_cache = -1.0F; + _cursor = 0; + _block.doc_many_size_ = 0; + } + } + + uint32_t last_doc_in_block() const { + int32_t last_doc = _raw_iter->getLastDocInBlock(); + if (last_doc == -1 || last_doc == 0x7FFFFFFFL) { + return TERMINATED; + } + return static_cast(last_doc); + } + + float block_max_score() { + if (!_enable_scoring || !_similarity) { + return std::numeric_limits::max(); + } + if (_block_max_score_cache >= 0.0F) { + return _block_max_score_cache; + } + int32_t max_block_freq = _raw_iter->getMaxBlockFreq(); + int32_t max_block_norm = _raw_iter->getMaxBlockNorm(); + if (max_block_freq >= 0 && max_block_norm >= 0) { + _block_max_score_cache = _similarity->score(static_cast(max_block_freq), + static_cast(max_block_norm)); + return _block_max_score_cache; + } + return _similarity->max_score(); + } + + float max_score() const { + if (!_enable_scoring || !_similarity) { + return std::numeric_limits::max(); + } + return _similarity->max_score(); + } + + int32_t max_block_freq() const { return _raw_iter->getMaxBlockFreq(); } + int32_t max_block_norm() const { return _raw_iter->getMaxBlockNorm(); } + private: bool _refill() { - _block.need_positions = _has_positions; - if (!_raw_iter->readRange(&_block)) { + if (!_raw_iter->readBlock(&_block)) { return false; } _cursor = 0; _prox_cursor = 0; + _block_max_score_cache = -1.0F; + _block_id++; return _block.doc_many_size_ > 0; } @@ -187,17 +246,20 @@ class SegmentPostings final : public Postings { DocRange _block; uint32_t _cursor = 0; uint32_t _prox_cursor = 0; + mutable float _block_max_score_cache = -1.0F; + mutable int64_t _block_id = 0; + SimilarityPtr _similarity; }; - using SegmentPostingsPtr = std::shared_ptr; -inline SegmentPostingsPtr make_segment_postings(TermDocsPtr iter, bool enable_scoring = false) { - return std::make_shared(std::move(iter), enable_scoring); +inline SegmentPostingsPtr make_segment_postings(TermDocsPtr iter, bool enable_scoring, + SimilarityPtr similarity) { + return std::make_shared(std::move(iter), enable_scoring, similarity); } -inline SegmentPostingsPtr make_segment_postings(TermPositionsPtr iter, - bool enable_scoring = false) { - return std::make_shared(std::move(iter), enable_scoring); +inline SegmentPostingsPtr make_segment_postings(TermPositionsPtr iter, bool enable_scoring, + SimilarityPtr similarity) { + return std::make_shared(std::move(iter), enable_scoring, similarity); } } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h index e67621c2ab6679..c63377f2683bd8 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h @@ -45,6 +45,11 @@ class TermScorer final : public Scorer { uint32_t freq() const override { return _segment_postings->freq(); } uint32_t norm() const override { return _segment_postings->norm(); } + void seek_block(uint32_t target) { _segment_postings->seek_block(target); } + uint32_t last_doc_in_block() const { return _segment_postings->last_doc_in_block(); } + float block_max_score() const { return _segment_postings->block_max_score(); } + float max_score() const { return _segment_postings->max_score(); } + float score() override { return _similarity->score(freq(), norm()); } bool has_null_bitmap(const NullBitmapResolver* resolver = nullptr) override { diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h index 893467f3845603..37a470de7fda78 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h @@ -17,15 +17,22 @@ #pragma once +#include + #include "olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/wand/block_wand.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" #include "olap/rowset/segment_v2/inverted_index/similarity/similarity.h" namespace doris::segment_v2::inverted_index::query_v2 { +using TermOrEmptyScorer = std::variant; + class TermWeight : public Weight { public: + using Weight::for_each_pruning; + TermWeight(IndexQueryContextPtr context, std::wstring field, std::wstring term, SimilarityPtr similarity, bool enable_scoring) : _context(std::move(context)), @@ -36,25 +43,43 @@ class TermWeight : public Weight { ~TermWeight() override = default; ScorerPtr scorer(const QueryExecutionContext& ctx, const std::string& binding_key) override { + auto result = specialized_scorer(ctx, binding_key); + return std::visit([](auto&& sc) -> ScorerPtr { return sc; }, result); + } + + template + void for_each_pruning(const QueryExecutionContext& context, const std::string& binding_key, + float threshold, Callback&& callback) { + auto result = specialized_scorer(context, binding_key); + std::visit( + [&](auto&& sc) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + block_wand_single_scorer(std::move(sc), threshold, + std::forward(callback)); + } + }, + std::move(result)); + } + +private: + TermOrEmptyScorer specialized_scorer(const QueryExecutionContext& ctx, + const std::string& binding_key) { auto reader = lookup_reader(_field, ctx, binding_key); auto logical_field = logical_field_or_fallback(ctx, binding_key, _field); - if (!reader) { return std::make_shared(); } - auto t = make_term_ptr(_field.c_str(), _term.c_str()); - auto iter = make_term_doc_ptr(reader.get(), t.get(), _enable_scoring, _context->io_ctx); - if (iter) { - return std::make_shared( - make_segment_postings(std::move(iter), _enable_scoring), _similarity, - logical_field); + SegmentPostingsPtr segment_postings; + segment_postings = create_term_posting(reader.get(), _field, _term, _enable_scoring, + _similarity, _context->io_ctx); + if (segment_postings) { + return std::make_shared(segment_postings, _similarity, logical_field); } - return std::make_shared(); } -private: IndexQueryContextPtr _context; std::wstring _field; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wand/block_wand.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wand/block_wand.h new file mode 100644 index 00000000000000..efa1a171a6852b --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wand/block_wand.h @@ -0,0 +1,286 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +class BlockWand { +public: + template + static void execute(TermScorerPtr scorer, float threshold, Callback&& callback) { + uint32_t doc = scorer->doc(); + while (doc != TERMINATED) { + while (scorer->block_max_score() < threshold) { + uint32_t last_doc_in_block = scorer->last_doc_in_block(); + if (last_doc_in_block == TERMINATED) { + return; + } + doc = last_doc_in_block + 1; + scorer->seek_block(doc); + } + + doc = scorer->seek(doc); + if (doc == TERMINATED) { + break; + } + + while (true) { + float score = scorer->score(); + if (score > threshold) { + threshold = callback(doc, score); + } + if (doc == scorer->last_doc_in_block()) { + break; + } + doc = scorer->advance(); + if (doc == TERMINATED) { + return; + } + } + doc++; + scorer->seek_block(doc); + } + } + + template + static void execute(std::vector scorers, float threshold, Callback&& callback) { + if (scorers.empty()) { + return; + } + + if (scorers.size() == 1) { + execute(std::move(scorers[0]), threshold, std::forward(callback)); + return; + } + + std::vector wrappers; + wrappers.reserve(scorers.size()); + for (auto& s : scorers) { + if (s->doc() != TERMINATED) { + wrappers.emplace_back(std::move(s)); + } + } + + std::sort(wrappers.begin(), wrappers.end(), + [](const ScorerWrapper& a, const ScorerWrapper& b) { return a.doc() < b.doc(); }); + + while (true) { + auto result = find_pivot_doc(wrappers, threshold); + if (result.pivot_doc == TERMINATED) { + break; + } + auto [before_pivot_len, pivot_len, pivot_doc] = result; + + assert(std::ranges::is_sorted(wrappers, + [](const ScorerWrapper& a, const ScorerWrapper& b) { + return a.doc() < b.doc(); + })); + assert(pivot_doc != TERMINATED); + assert(before_pivot_len < pivot_len); + + float block_max_score_upperbound = 0.0F; + for (size_t i = 0; i < pivot_len; ++i) { + wrappers[i].seek_block(pivot_doc); + block_max_score_upperbound += wrappers[i].block_max_score(); + } + + if (block_max_score_upperbound <= threshold) { + block_max_was_too_low_advance_one_scorer(wrappers, pivot_len); + continue; + } + + if (!align_scorers(wrappers, pivot_doc, before_pivot_len)) { + continue; + } + + float score = 0.0F; + for (size_t i = 0; i < pivot_len; ++i) { + score += wrappers[i].score(); + } + + if (score > threshold) { + threshold = callback(pivot_doc, score); + } + + advance_all_scorers_on_pivot(wrappers, pivot_len); + } + } + +private: + class ScorerWrapper { + public: + explicit ScorerWrapper(TermScorerPtr scorer) + : _scorer(std::move(scorer)), _max_score(_scorer->max_score()) {} + + uint32_t doc() const { return _scorer->doc(); } + uint32_t advance() { return _scorer->advance(); } + uint32_t seek(uint32_t target) { return _scorer->seek(target); } + float score() { return _scorer->score(); } + + void seek_block(uint32_t target) { _scorer->seek_block(target); } + uint32_t last_doc_in_block() const { return _scorer->last_doc_in_block(); } + float block_max_score() const { return _scorer->block_max_score(); } + float max_score() const { return _max_score; } + + private: + TermScorerPtr _scorer; + float _max_score; + }; + + struct PivotResult { + size_t before_pivot_len; + size_t pivot_len; + uint32_t pivot_doc; + }; + + static PivotResult find_pivot_doc(std::vector& scorers, float threshold) { + float max_score = 0.0F; + size_t before_pivot_len = 0; + uint32_t pivot_doc = TERMINATED; + + while (before_pivot_len < scorers.size()) { + max_score += scorers[before_pivot_len].max_score(); + if (max_score > threshold) { + pivot_doc = scorers[before_pivot_len].doc(); + break; + } + before_pivot_len++; + } + + if (pivot_doc == TERMINATED) { + return PivotResult {.before_pivot_len = 0, .pivot_len = 0, .pivot_doc = TERMINATED}; + } + + size_t pivot_len = before_pivot_len + 1; + while (pivot_len < scorers.size() && scorers[pivot_len].doc() == pivot_doc) { + pivot_len++; + } + + return PivotResult {.before_pivot_len = before_pivot_len, + .pivot_len = pivot_len, + .pivot_doc = pivot_doc}; + } + + static void restore_ordering(std::vector& scorers, size_t ord) { + uint32_t doc = scorers[ord].doc(); + while (ord + 1 < scorers.size() && doc > scorers[ord + 1].doc()) { + std::swap(scorers[ord], scorers[ord + 1]); + ord++; + } + assert(std::ranges::is_sorted(scorers, [](const ScorerWrapper& a, const ScorerWrapper& b) { + return a.doc() < b.doc(); + })); + } + + static void block_max_was_too_low_advance_one_scorer(std::vector& scorers, + size_t pivot_len) { + assert(std::ranges::is_sorted(scorers, [](const ScorerWrapper& a, const ScorerWrapper& b) { + return a.doc() < b.doc(); + })); + + size_t scorer_to_seek = pivot_len - 1; + float global_max_score = scorers[scorer_to_seek].max_score(); + uint32_t doc_to_seek_after = scorers[scorer_to_seek].last_doc_in_block(); + for (size_t i = pivot_len - 1; i > 0; --i) { + size_t scorer_ord = i - 1; + const auto& scorer = scorers[scorer_ord]; + doc_to_seek_after = std::min(doc_to_seek_after, scorer.last_doc_in_block()); + if (scorer.max_score() > global_max_score) { + global_max_score = scorer.max_score(); + scorer_to_seek = scorer_ord; + } + } + if (doc_to_seek_after != TERMINATED) { + doc_to_seek_after++; + } + for (size_t i = pivot_len; i < scorers.size(); ++i) { + const auto& scorer = scorers[i]; + doc_to_seek_after = std::min(doc_to_seek_after, scorer.doc()); + } + scorers[scorer_to_seek].seek(doc_to_seek_after); + restore_ordering(scorers, scorer_to_seek); + + assert(std::ranges::is_sorted(scorers, [](const ScorerWrapper& a, const ScorerWrapper& b) { + return a.doc() < b.doc(); + })); + } + + static bool align_scorers(std::vector& scorers, uint32_t pivot_doc, + size_t before_pivot_len) { + for (size_t i = before_pivot_len; i > 0; --i) { + size_t idx = i - 1; + uint32_t new_doc = scorers[idx].seek(pivot_doc); + if (new_doc != pivot_doc) { + if (new_doc == TERMINATED) { + std::swap(scorers[idx], scorers.back()); + scorers.pop_back(); + if (scorers.empty()) { + return false; + } + } + // Full re-sort to guarantee invariant after swap-with-back, + // consistent with advance_all_scorers_on_pivot approach. + std::ranges::sort(scorers, [](const ScorerWrapper& a, const ScorerWrapper& b) { + return a.doc() < b.doc(); + }); + return false; + } + } + return true; + } + + static void advance_all_scorers_on_pivot(std::vector& scorers, + size_t pivot_len) { + for (size_t i = 0; i < pivot_len; ++i) { + scorers[i].advance(); + } + + size_t i = 0; + while (i < scorers.size()) { + if (scorers[i].doc() == TERMINATED) { + std::swap(scorers[i], scorers.back()); + scorers.pop_back(); + } else { + i++; + } + } + + std::ranges::sort(scorers, [](const ScorerWrapper& a, const ScorerWrapper& b) { + return a.doc() < b.doc(); + }); + } +}; + +template +inline void block_wand_single_scorer(TermScorerPtr scorer, float threshold, Callback&& callback) { + BlockWand::execute(std::move(scorer), threshold, std::forward(callback)); +} + +template +inline void block_wand(std::vector scorers, float threshold, Callback&& callback) { + BlockWand::execute(std::move(scorers), threshold, std::forward(callback)); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h index 2b53284dbf6de1..a8d511beb1f64c 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include #include @@ -54,6 +55,8 @@ struct QueryExecutionContext { class Weight { public: + using PruningCallback = std::function; + Weight() = default; virtual ~Weight() = default; @@ -63,6 +66,34 @@ class Weight { return scorer(context); } + virtual void for_each_pruning(const QueryExecutionContext& context, float threshold, + PruningCallback callback) { + auto sc = scorer(context); + if (!sc) { + return; + } + for_each_pruning_scorer(sc, threshold, std::move(callback)); + } + + virtual void for_each_pruning(const QueryExecutionContext& context, + const std::string& binding_key, float threshold, + PruningCallback callback) { + (void)binding_key; + for_each_pruning(context, threshold, std::move(callback)); + } + + static void for_each_pruning_scorer(const ScorerPtr& scorer, float threshold, + PruningCallback callback) { + int32_t doc = scorer->doc(); + while (doc != TERMINATED) { + float score = scorer->score(); + if (score > threshold) { + threshold = callback(doc, score); + } + doc = scorer->advance(); + } + } + protected: const FieldBindingContext* get_field_binding(const QueryExecutionContext& ctx, const std::string& binding_key) const { @@ -108,27 +139,36 @@ class Weight { SegmentPostingsPtr create_term_posting(lucene::index::IndexReader* reader, const std::wstring& field, const std::string& term, - bool enable_scoring, const io::IOContext* io_ctx) const { - auto term_wstr = StringHelper::to_wstring(term); - auto t = make_term_ptr(field.c_str(), term_wstr.c_str()); + bool enable_scoring, const SimilarityPtr& similarity, + const io::IOContext* io_ctx) const { + return create_term_posting(reader, field, StringHelper::to_wstring(term), enable_scoring, + similarity, io_ctx); + } + + SegmentPostingsPtr create_term_posting(lucene::index::IndexReader* reader, + const std::wstring& field, const std::wstring& term, + bool enable_scoring, const SimilarityPtr& similarity, + const io::IOContext* io_ctx) const { + auto t = make_term_ptr(field.c_str(), term.c_str()); auto iter = make_term_doc_ptr(reader, t.get(), enable_scoring, io_ctx); - if (iter) { - return make_segment_postings(std::move(iter), enable_scoring); - } - return nullptr; + return iter ? make_segment_postings(std::move(iter), enable_scoring, similarity) : nullptr; } SegmentPostingsPtr create_position_posting(lucene::index::IndexReader* reader, const std::wstring& field, const std::string& term, - bool enable_scoring, + bool enable_scoring, const SimilarityPtr& similarity, const io::IOContext* io_ctx) const { - auto term_wstr = StringHelper::to_wstring(term); - auto t = make_term_ptr(field.c_str(), term_wstr.c_str()); + return create_position_posting(reader, field, StringHelper::to_wstring(term), + enable_scoring, similarity, io_ctx); + } + + SegmentPostingsPtr create_position_posting(lucene::index::IndexReader* reader, + const std::wstring& field, const std::wstring& term, + bool enable_scoring, const SimilarityPtr& similarity, + const io::IOContext* io_ctx) const { + auto t = make_term_ptr(field.c_str(), term.c_str()); auto iter = make_term_positions_ptr(reader, t.get(), enable_scoring, io_ctx); - if (iter) { - return make_segment_postings(std::move(iter), enable_scoring); - } - return nullptr; + return iter ? make_segment_postings(std::move(iter), enable_scoring, similarity) : nullptr; } }; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_weight.h index 22de6e3b16dd6f..b22f99cb28171c 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_weight.h @@ -41,7 +41,7 @@ class WildcardWeight : public Weight { ScorerPtr scorer(const QueryExecutionContext& ctx, const std::string& binding_key) override { std::string regex_pattern = wildcard_to_regex(_pattern); auto regexp_weight = std::make_shared( - _context, std::move(_field), std::move(regex_pattern), _enable_scoring, _nullable); + _context, _field, std::move(regex_pattern), _enable_scoring, _nullable); return regexp_weight->scorer(ctx, binding_key); } diff --git a/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.cpp b/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.cpp index 88865e140a625d..d518ffbf8475da 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.cpp +++ b/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.cpp @@ -17,6 +17,8 @@ #include "bm25_similarity.h" +#include + namespace doris::segment_v2 { #include "common/compile_check_begin.h" @@ -83,6 +85,13 @@ float BM25Similarity::score(float freq, int64_t encoded_norm) { return _weight - _weight / (1.0F + freq * norm_inverse); } +float BM25Similarity::max_score() { + // 2013265944 = byte4_to_int(int_to_byte4(MAX_INT32)) from Lucene's SmallFloat encoding, + // representing the maximum possible term frequency. Combined with norm=255 (shortest + // document length), this yields the theoretical upper-bound BM25 score for this term. + return score(static_cast(2013265944), 255); +} + int32_t BM25Similarity::number_of_leading_zeros(uint64_t value) { if (value == 0) { return 64; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h b/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h index 68a1f3a90db00a..00abfa144671da 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h @@ -42,6 +42,7 @@ class BM25Similarity : public Similarity { const std::vector& terms) override; float score(float freq, int64_t encoded_norm) override; + float max_score() override; static uint8_t int_to_byte4(int32_t i); static int32_t byte4_to_int(uint8_t b); diff --git a/be/src/olap/rowset/segment_v2/inverted_index/similarity/similarity.h b/be/src/olap/rowset/segment_v2/inverted_index/similarity/similarity.h index 744254d410b71e..59b676e8409ffd 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/similarity/similarity.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/similarity/similarity.h @@ -36,6 +36,7 @@ class Similarity { const std::vector& terms) = 0; virtual float score(float freq, int64_t encoded_norm) = 0; + virtual float max_score() = 0; }; using SimilarityPtr = std::shared_ptr; diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp b/be/src/olap/rowset/segment_v2/segment_iterator.cpp index fb6e5fa171a458..3b4b5047365617 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp @@ -1398,6 +1398,8 @@ Status SegmentIterator::_init_index_iterators() { if (_score_runtime) { _index_query_context->collection_statistics = _opts.collection_statistics; _index_query_context->collection_similarity = std::make_shared(); + _index_query_context->query_limit = _score_runtime->get_limit(); + _index_query_context->is_asc = _score_runtime->is_asc(); } // Inverted index iterators @@ -2889,6 +2891,7 @@ Status SegmentIterator::_construct_compound_expr_context() { auto inverted_index_context = std::make_shared( _schema->column_ids(), _index_iterators, _storage_name_and_type, _common_expr_index_exec_status, _score_runtime, _segment.get(), iter_opts); + inverted_index_context->set_index_query_context(_index_query_context); for (const auto& expr_ctx : _opts.common_expr_ctxs_push_down) { vectorized::VExprContextSPtr context; // _ann_range_search_runtime will do deep copy. diff --git a/be/src/vec/exprs/vexpr_context.h b/be/src/vec/exprs/vexpr_context.h index e7c7c5ebdee9f1..39943c121f6be7 100644 --- a/be/src/vec/exprs/vexpr_context.h +++ b/be/src/vec/exprs/vexpr_context.h @@ -49,6 +49,7 @@ class RuntimeState; namespace doris::segment_v2 { class Segment; class ColumnIterator; +class Segment; } // namespace doris::segment_v2 namespace doris::vectorized { @@ -64,8 +65,8 @@ class IndexExecContext { const std::vector& storage_name_and_type_vec, std::unordered_map>& common_expr_index_status, - ScoreRuntimeSPtr score_runtime, segment_v2::Segment* segment, - const segment_v2::ColumnIteratorOptions& column_iter_opts) + ScoreRuntimeSPtr score_runtime, segment_v2::Segment* segment = nullptr, + const segment_v2::ColumnIteratorOptions& column_iter_opts = {}) : _col_ids(col_ids), _index_iterators(index_iterators), _storage_name_and_type(storage_name_and_type_vec), @@ -203,6 +204,20 @@ class IndexExecContext { return iter->second.get(); } + void set_index_query_context(segment_v2::IndexQueryContextPtr index_query_context) { + _index_query_context = index_query_context; + } + + const segment_v2::IndexQueryContextPtr& get_index_query_context() const { + return _index_query_context; + } + + segment_v2::Segment* get_segment() const { return _segment; } + + const segment_v2::ColumnIteratorOptions& get_column_iter_opts() const { + return _column_iter_opts; + } + private: // A reference to a vector of column IDs for the current expression's output columns. const std::vector& _col_ids; @@ -231,6 +246,7 @@ class IndexExecContext { segment_v2::Segment* _segment = nullptr; // Ref segment_v2::ColumnIteratorOptions _column_iter_opts; + segment_v2::IndexQueryContextPtr _index_query_context; }; class VExprContext { diff --git a/be/src/vec/exprs/vsearch.cpp b/be/src/vec/exprs/vsearch.cpp index c76e5c647611aa..7f4b17c885663c 100644 --- a/be/src/vec/exprs/vsearch.cpp +++ b/be/src/vec/exprs/vsearch.cpp @@ -258,11 +258,14 @@ Status VSearchExpr::evaluate_inverted_index(VExprContext* context, uint32_t segm return Status::OK(); } + auto index_query_context = index_context->get_index_query_context(); + auto function = std::make_shared(); auto result_bitmap = InvertedIndexResultBitmap(); auto status = function->evaluate_inverted_index_with_search_param( _search_param, bundle.field_types, bundle.iterators, segment_num_rows, result_bitmap, - _enable_cache, index_context.get(), bundle.field_name_to_column_id); + _enable_cache, index_context.get(), bundle.field_name_to_column_id, + index_query_context); if (!status.ok()) { LOG(WARNING) << "VSearchExpr: Function evaluation failed: " << status.to_string(); diff --git a/be/src/vec/functions/function_search.cpp b/be/src/vec/functions/function_search.cpp index 2afa1a7c30927e..c0f82150df2659 100644 --- a/be/src/vec/functions/function_search.cpp +++ b/be/src/vec/functions/function_search.cpp @@ -41,6 +41,8 @@ #include "olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_query.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/collect/doc_set_collector.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/collect/top_k_collector.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/multi_phrase_query.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_query.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_query.h" @@ -379,7 +381,8 @@ Status FunctionSearch::evaluate_inverted_index_with_search_param( std::unordered_map iterators, uint32_t num_rows, InvertedIndexResultBitmap& bitmap_result, bool enable_cache, const IndexExecContext* index_exec_ctx, - const std::unordered_map& field_name_to_column_id) const { + const std::unordered_map& field_name_to_column_id, + const std::shared_ptr& index_query_context) const { const bool is_nested_query = search_param.root.clause_type == "NESTED"; if (is_nested_query && !is_nested_group_search_supported()) { return Status::NotSupported( @@ -433,9 +436,14 @@ Status FunctionSearch::evaluate_inverted_index_with_search_param( } } - auto context = std::make_shared(); - context->collection_statistics = std::make_shared(); - context->collection_similarity = std::make_shared(); + std::shared_ptr context; + if (index_query_context) { + context = index_query_context; + } else { + context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + } // NESTED() queries evaluate predicates on the flattened "element space" of a nested group. // For VARIANT nested groups, the indexed lucene field (stored_field_name) uses: @@ -554,43 +562,52 @@ Status FunctionSearch::evaluate_inverted_index_with_search_param( query_v2::QueryExecutionContext exec_ctx = build_query_execution_context(num_rows, resolver, &null_resolver); - auto weight = root_query->weight(false); - if (!weight) { - LOG(WARNING) << "search: Failed to build query weight"; - bitmap_result = InvertedIndexResultBitmap(std::make_shared(), - std::make_shared()); - return Status::OK(); + bool enable_scoring = false; + bool is_asc = false; + size_t top_k = 0; + if (index_query_context) { + enable_scoring = index_query_context->collection_similarity != nullptr; + is_asc = index_query_context->is_asc; + top_k = index_query_context->query_limit; } - auto scorer = weight->scorer(exec_ctx, root_binding_key); - if (!scorer) { - LOG(WARNING) << "search: Failed to build scorer"; + auto weight = root_query->weight(enable_scoring); + if (!weight) { + LOG(WARNING) << "search: Failed to build query weight"; bitmap_result = InvertedIndexResultBitmap(std::make_shared(), std::make_shared()); return Status::OK(); } std::shared_ptr roaring = std::make_shared(); - uint32_t doc = scorer->doc(); - uint32_t matched_docs = 0; - while (doc != query_v2::TERMINATED) { - roaring->add(doc); - ++matched_docs; - doc = scorer->advance(); + if (enable_scoring && !is_asc && top_k > 0) { + bool use_wand = index_query_context->runtime_state != nullptr && + index_query_context->runtime_state->query_options() + .enable_inverted_index_wand_query; + query_v2::collect_multi_segment_top_k(weight, exec_ctx, root_binding_key, top_k, roaring, + index_query_context->collection_similarity, use_wand); + } else { + query_v2::collect_multi_segment_doc_set( + weight, exec_ctx, root_binding_key, roaring, + index_query_context ? index_query_context->collection_similarity : nullptr, + enable_scoring); } - VLOG_DEBUG << "search: Query completed, matched " << matched_docs << " documents"; + VLOG_DEBUG << "search: Query completed, matched " << roaring->cardinality() << " documents"; // Extract NULL bitmap from three-valued logic scorer // The scorer correctly computes which documents evaluate to NULL based on query logic // For example: TRUE OR NULL = TRUE (not NULL), FALSE OR NULL = NULL std::shared_ptr null_bitmap = std::make_shared(); - if (scorer->has_null_bitmap(exec_ctx.null_resolver)) { - const auto* bitmap = scorer->get_null_bitmap(exec_ctx.null_resolver); - if (bitmap != nullptr) { - *null_bitmap = *bitmap; - VLOG_TRACE << "search: Extracted NULL bitmap with " << null_bitmap->cardinality() - << " NULL documents"; + if (exec_ctx.null_resolver) { + auto scorer = weight->scorer(exec_ctx, root_binding_key); + if (scorer && scorer->has_null_bitmap(exec_ctx.null_resolver)) { + const auto* bitmap = scorer->get_null_bitmap(exec_ctx.null_resolver); + if (bitmap != nullptr) { + *null_bitmap = *bitmap; + VLOG_TRACE << "search: Extracted NULL bitmap with " << null_bitmap->cardinality() + << " NULL documents"; + } } } diff --git a/be/src/vec/functions/function_search.h b/be/src/vec/functions/function_search.h index 8ac521a4c29c92..a05bff0a6739f0 100644 --- a/be/src/vec/functions/function_search.h +++ b/be/src/vec/functions/function_search.h @@ -178,7 +178,8 @@ class FunctionSearch : public IFunction { std::unordered_map iterators, uint32_t num_rows, InvertedIndexResultBitmap& bitmap_result, bool enable_cache, const IndexExecContext* index_exec_ctx, - const std::unordered_map& field_name_to_column_id) const; + const std::unordered_map& field_name_to_column_id, + const std::shared_ptr& index_query_context = nullptr) const; Status evaluate_nested_query( const TSearchParam& search_param, const TSearchClause& nested_clause, diff --git a/be/test/olap/collection_statistics_test.cpp b/be/test/olap/collection_statistics_test.cpp index f6f4f85a01c946..c1c61b0c3c602e 100644 --- a/be/test/olap/collection_statistics_test.cpp +++ b/be/test/olap/collection_statistics_test.cpp @@ -619,36 +619,6 @@ TEST_F(CollectionStatisticsTest, CollectWithDoubleCastWrappedSlotRef) { EXPECT_TRUE(status.ok()) << status.msg(); } -TEST_F(CollectionStatisticsTest, FindSlotRefHandlesNullDirectCastAndNested) { - // null - vectorized::VExprSPtr null_expr; - EXPECT_EQ(find_slot_ref(null_expr), nullptr); - - // direct SLOT_REF - auto slot_ref_direct = - std::make_shared("content", SlotId(1)); - EXPECT_EQ(find_slot_ref(slot_ref_direct), - static_cast(slot_ref_direct.get())); - - // CAST(SLOT_REF) - auto slot_ref_cast = - std::make_shared("content", SlotId(1)); - auto cast_expr = std::make_shared(TExprNodeType::CAST_EXPR); - cast_expr->_children.push_back(slot_ref_cast); - EXPECT_EQ(find_slot_ref(cast_expr), static_cast(slot_ref_cast.get())); - - // BINARY_PRED(CAST(SLOT_REF), literal) - auto slot_ref_nested = - std::make_shared("content", SlotId(1)); - auto inner_cast = std::make_shared(TExprNodeType::CAST_EXPR); - inner_cast->_children.push_back(slot_ref_nested); - auto lit = std::make_shared("x"); - auto bin = std::make_shared(TExprNodeType::BINARY_PRED); - bin->_children.push_back(inner_cast); - bin->_children.push_back(lit); - EXPECT_EQ(find_slot_ref(bin), static_cast(slot_ref_nested.get())); -} - TEST(TermInfoComparerTest, OrdersByTermAndDedups) { using doris::TermInfoComparer; using doris::segment_v2::TermInfo; diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query/query_helper_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query/query_helper_test.cpp index 6e17de9eacb8a6..e2962178558133 100644 --- a/be/test/olap/rowset/segment_v2/inverted_index/query/query_helper_test.cpp +++ b/be/test/olap/rowset/segment_v2/inverted_index/query/query_helper_test.cpp @@ -46,6 +46,8 @@ class MockSimilarity : public doris::segment_v2::Similarity { MOCK_FUNCTION float score(float freq, int64_t encoded_norm) override { return _score_value; } + MOCK_FUNCTION float max_score() override { return std::numeric_limits::max(); } + private: float _score_value; }; diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/occur_boolean_query_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/occur_boolean_query_test.cpp index 69322ffa5e10e9..7c74de7fb5715a 100644 --- a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/occur_boolean_query_test.cpp +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/occur_boolean_query_test.cpp @@ -25,12 +25,15 @@ #include #include +#include "olap/rowset/segment_v2/inverted_index/analyzer/custom_analyzer.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/all_query/all_query.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/query.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" +#include "olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h" namespace doris::segment_v2::inverted_index::query_v2 { namespace { diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings_test.cpp index 3febf6ec106583..5f5f6c9e66568f 100644 --- a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings_test.cpp +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings_test.cpp @@ -46,7 +46,9 @@ class MockTermDocs : public lucene::index::TermDocs { int32_t read(int32_t*, int32_t*, int32_t) override { return 0; } int32_t read(int32_t*, int32_t*, int32_t*, int32_t) override { return 0; } - bool readRange(DocRange* docRange) override { + bool readRange(DocRange* docRange) override { return _fillDocRange(docRange); } + bool readBlock(DocRange* docRange) override { return _fillDocRange(docRange); } + bool _fillDocRange(DocRange* docRange) { if (_read_done || _docs.empty()) { return false; } @@ -62,7 +64,7 @@ class MockTermDocs : public lucene::index::TermDocs { } bool skipTo(const int32_t target) override { return false; } - void skipToBlock(const int32_t target) override {} + bool skipToBlock(const int32_t target) override { return false; } void close() override {} lucene::index::TermPositions* __asTermPositions() override { return nullptr; } @@ -105,7 +107,9 @@ class MockTermPositions : public lucene::index::TermPositions { int32_t read(int32_t*, int32_t*, int32_t) override { return 0; } int32_t read(int32_t*, int32_t*, int32_t*, int32_t) override { return 0; } - bool readRange(DocRange* docRange) override { + bool readRange(DocRange* docRange) override { return _fillDocRange(docRange); } + bool readBlock(DocRange* docRange) override { return _fillDocRange(docRange); } + bool _fillDocRange(DocRange* docRange) { if (_read_done || _docs.empty()) { return false; } @@ -121,7 +125,7 @@ class MockTermPositions : public lucene::index::TermPositions { } bool skipTo(const int32_t target) override { return false; } - void skipToBlock(const int32_t target) override {} + bool skipToBlock(const int32_t target) override { return false; } void close() override {} lucene::index::TermPositions* __asTermPositions() override { return this; } @@ -173,7 +177,7 @@ TEST_F(SegmentPostingsTest, test_postings_positions_with_offset) { TEST_F(SegmentPostingsTest, test_segment_postings_base_constructor_next_true) { TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3)); - SegmentPostings base(std::move(ptr), true); + SegmentPostings base(std::move(ptr), true, nullptr); EXPECT_EQ(base.doc(), 1); EXPECT_EQ(base.size_hint(), 3); @@ -183,21 +187,21 @@ TEST_F(SegmentPostingsTest, test_segment_postings_base_constructor_next_true) { TEST_F(SegmentPostingsTest, test_segment_postings_base_constructor_next_false) { TermDocsPtr ptr(new MockTermDocs({}, {}, {}, 0)); - SegmentPostings base(std::move(ptr)); + SegmentPostings base(std::move(ptr), true, nullptr); EXPECT_EQ(base.doc(), TERMINATED); } TEST_F(SegmentPostingsTest, test_segment_postings_base_constructor_doc_terminate) { TermDocsPtr ptr(new MockTermDocs({TERMINATED}, {1}, {1}, 1)); - SegmentPostings base(std::move(ptr)); + SegmentPostings base(std::move(ptr), true, nullptr); EXPECT_EQ(base.doc(), TERMINATED); } TEST_F(SegmentPostingsTest, test_segment_postings_base_advance_success) { TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3)); - SegmentPostings base(std::move(ptr)); + SegmentPostings base(std::move(ptr), true, nullptr); EXPECT_EQ(base.doc(), 1); EXPECT_EQ(base.advance(), 3); @@ -206,14 +210,14 @@ TEST_F(SegmentPostingsTest, test_segment_postings_base_advance_success) { TEST_F(SegmentPostingsTest, test_segment_postings_base_advance_end) { TermDocsPtr ptr(new MockTermDocs({1}, {2}, {1}, 1)); - SegmentPostings base(std::move(ptr)); + SegmentPostings base(std::move(ptr), true, nullptr); EXPECT_EQ(base.advance(), TERMINATED); } TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_target_le_doc) { TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3)); - SegmentPostings base(std::move(ptr)); + SegmentPostings base(std::move(ptr), true, nullptr); EXPECT_EQ(base.seek(0), 1); EXPECT_EQ(base.seek(1), 1); @@ -221,21 +225,21 @@ TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_target_le_doc) { TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_in_block_success) { TermDocsPtr ptr(new MockTermDocs({1, 3, 5, 7}, {2, 4, 6, 8}, {1, 1, 1, 1}, 4)); - SegmentPostings base(std::move(ptr)); + SegmentPostings base(std::move(ptr), true, nullptr); EXPECT_EQ(base.seek(5), 5); } TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_fail) { TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3)); - SegmentPostings base(std::move(ptr)); + SegmentPostings base(std::move(ptr), true, nullptr); EXPECT_EQ(base.seek(10), TERMINATED); } TEST_F(SegmentPostingsTest, test_segment_postings_base_append_positions_exception) { TermDocsPtr ptr(new MockTermDocs({1}, {2}, {1}, 1)); - SegmentPostings base(std::move(ptr)); + SegmentPostings base(std::move(ptr), true, nullptr); std::vector output; EXPECT_THROW(base.append_positions_with_offset(0, output), Exception); @@ -243,7 +247,7 @@ TEST_F(SegmentPostingsTest, test_segment_postings_base_append_positions_exceptio TEST_F(SegmentPostingsTest, test_segment_postings_termdocs) { TermDocsPtr ptr(new MockTermDocs({1, 3}, {2, 4}, {1, 1}, 2)); - SegmentPostings postings(std::move(ptr)); + SegmentPostings postings(std::move(ptr), true, nullptr); EXPECT_EQ(postings.doc(), 1); EXPECT_EQ(postings.size_hint(), 2); @@ -252,16 +256,14 @@ TEST_F(SegmentPostingsTest, test_segment_postings_termdocs) { TEST_F(SegmentPostingsTest, test_segment_postings_termpositions) { TermPositionsPtr ptr( new MockTermPositions({1, 3}, {2, 3}, {1, 1}, {{10, 20}, {30, 40, 50}}, 2)); - SegmentPostings postings(std::move(ptr), true); - - EXPECT_EQ(postings.doc(), 1); + SegmentPostings postings(std::move(ptr), true, nullptr); EXPECT_EQ(postings.freq(), 2); } TEST_F(SegmentPostingsTest, test_segment_postings_termpositions_append_positions) { TermPositionsPtr ptr( new MockTermPositions({1, 3}, {2, 3}, {1, 1}, {{10, 20}, {30, 40, 50}}, 2)); - SegmentPostings postings(std::move(ptr), true); + SegmentPostings postings(std::move(ptr), true, nullptr); std::vector output = {999}; postings.append_positions_with_offset(100, output); @@ -274,7 +276,7 @@ TEST_F(SegmentPostingsTest, test_segment_postings_termpositions_append_positions TEST_F(SegmentPostingsTest, test_no_score_segment_posting) { TermDocsPtr ptr(new MockTermDocs({1, 3}, {5, 7}, {10, 20}, 2)); - SegmentPostings posting(std::move(ptr)); + SegmentPostings posting(std::move(ptr), false, nullptr); EXPECT_EQ(posting.doc(), 1); EXPECT_EQ(posting.freq(), 1); diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/top_k_collector_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/top_k_collector_test.cpp new file mode 100644 index 00000000000000..a153b2c1bced5a --- /dev/null +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/top_k_collector_test.cpp @@ -0,0 +1,490 @@ + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/collect/top_k_collector.h" + +#include + +#include +#include +#include + +namespace doris::segment_v2::inverted_index::query_v2 { + +TEST(TopKCollectorTest, TestTieBreaking) { + { + TopKCollector collector(1); + + collector.collect(100, 5.0); + ASSERT_EQ(collector.size(), 1); + ASSERT_EQ(collector.threshold(), 5.0); + + collector.collect(99, 5.0); + + auto result = collector.into_sorted_vec(); + ASSERT_EQ(result.size(), 1); + EXPECT_EQ(result[0].doc_id, 99); + EXPECT_EQ(result[0].score, 5.0); + } + + { + TopKCollector collector(2); + + collector.collect(100, 5.0); + collector.collect(101, 5.0); + + collector.collect(99, 5.0); + + auto result = collector.into_sorted_vec(); + ASSERT_EQ(result.size(), 2); + EXPECT_EQ(result[0].doc_id, 99); + EXPECT_EQ(result[1].doc_id, 100); + } +} + +TEST(TopKCollectorTest, TestBasicCollection) { + TopKCollector collector(3); + + collector.collect(1, 1.0); + collector.collect(2, 2.0); + collector.collect(3, 3.0); + collector.collect(4, 4.0); + + auto result = collector.into_sorted_vec(); + ASSERT_EQ(result.size(), 3); + + EXPECT_EQ(result[0].doc_id, 4); + EXPECT_EQ(result[0].score, 4.0); + + EXPECT_EQ(result[1].doc_id, 3); + EXPECT_EQ(result[1].score, 3.0); + + EXPECT_EQ(result[2].doc_id, 2); + EXPECT_EQ(result[2].score, 2.0); +} + +TEST(TopKCollectorTest, TestThresholdPruning) { + TopKCollector collector(2); + + collector.collect(1, 5.0); + collector.collect(2, 6.0); + EXPECT_EQ(collector.threshold(), 5.0); + + float new_threshold = collector.collect(3, 4.0); + EXPECT_EQ(new_threshold, 5.0); + + new_threshold = collector.collect(4, 7.0); + EXPECT_EQ(new_threshold, 5.0); + + auto result = collector.into_sorted_vec(); + ASSERT_EQ(result.size(), 2); + EXPECT_EQ(result[0].doc_id, 4); + EXPECT_EQ(result[1].doc_id, 2); +} + +TEST(TopKCollectorTest, TestK1) { + TopKCollector collector(1); + + collector.collect(10, 1.0); + EXPECT_EQ(collector.threshold(), 1.0); + + collector.collect(20, 0.5); + collector.collect(30, 2.0); + EXPECT_EQ(collector.threshold(), 2.0); + + auto result = collector.into_sorted_vec(); + ASSERT_EQ(result.size(), 1); + EXPECT_EQ(result[0].doc_id, 30); +} + +TEST(TopKCollectorTest, TestLargeK) { + TopKCollector collector(100); + + for (uint32_t i = 0; i < 50; i++) { + collector.collect(i, static_cast(i)); + } + + EXPECT_EQ(collector.size(), 50); + EXPECT_EQ(collector.threshold(), -std::numeric_limits::infinity()); + + for (uint32_t i = 50; i < 100; i++) { + collector.collect(i, static_cast(i)); + } + EXPECT_EQ(collector.threshold(), 0.0); + + for (uint32_t i = 100; i < 150; i++) { + collector.collect(i, static_cast(i)); + } + + auto result = collector.into_sorted_vec(); + ASSERT_EQ(result.size(), 100); + EXPECT_EQ(result[0].doc_id, 149); + EXPECT_EQ(result[99].doc_id, 50); +} + +TEST(TopKCollectorTest, TestBufferTruncation) { + TopKCollector collector(3); + + collector.collect(1, 1.0); + collector.collect(2, 2.0); + collector.collect(3, 3.0); + collector.collect(4, 4.0); + collector.collect(5, 5.0); + collector.collect(6, 6.0); + + auto result = collector.into_sorted_vec(); + ASSERT_EQ(result.size(), 3); + EXPECT_EQ(result[0].score, 6.0); + EXPECT_EQ(result[1].score, 5.0); + EXPECT_EQ(result[2].score, 4.0); +} + +TEST(TopKCollectorTest, TestEmptyCollector) { + TopKCollector collector(5); + + auto result = collector.into_sorted_vec(); + EXPECT_TRUE(result.empty()); +} + +TEST(TopKCollectorTest, TestFewerThanK) { + TopKCollector collector(10); + + collector.collect(1, 3.0); + collector.collect(2, 1.0); + collector.collect(3, 2.0); + + auto result = collector.into_sorted_vec(); + ASSERT_EQ(result.size(), 3); + EXPECT_EQ(result[0].doc_id, 1); + EXPECT_EQ(result[1].doc_id, 3); + EXPECT_EQ(result[2].doc_id, 2); +} + +TEST(TopKCollectorTest, TestNegativeScores) { + TopKCollector collector(3); + + collector.collect(1, -1.0); + collector.collect(2, -2.0); + collector.collect(3, -0.5); + collector.collect(4, -3.0); + + auto result = collector.into_sorted_vec(); + ASSERT_EQ(result.size(), 3); + EXPECT_EQ(result[0].doc_id, 3); + EXPECT_EQ(result[1].doc_id, 1); + EXPECT_EQ(result[2].doc_id, 2); +} + +TEST(TopKCollectorTest, TestAllSameScore) { + TopKCollector collector(3); + + collector.collect(5, 1.0); + collector.collect(3, 1.0); + collector.collect(7, 1.0); + collector.collect(1, 1.0); + + auto result = collector.into_sorted_vec(); + ASSERT_EQ(result.size(), 3); + EXPECT_EQ(result[0].doc_id, 1); + EXPECT_EQ(result[1].doc_id, 3); + EXPECT_EQ(result[2].doc_id, 5); +} + +std::vector compute_expected_topk(std::vector& docs, size_t k) { + std::sort(docs.begin(), docs.end(), ScoredDocByScoreDesc {}); + docs.resize(std::min(docs.size(), k)); + return docs; +} + +TEST(TopKCollectorTest, StressRandomScores1M) { + constexpr size_t N = 1000000; + constexpr size_t K = 100; + + std::mt19937 rng(42); + std::uniform_real_distribution dist(0.0f, 1000.0f); + + TopKCollector collector(K); + std::vector all_docs; + all_docs.reserve(N); + + for (uint32_t i = 0; i < N; i++) { + float score = dist(rng); + collector.collect(i, score); + all_docs.emplace_back(i, score); + } + + auto result = collector.into_sorted_vec(); + auto expected = compute_expected_topk(all_docs, K); + + ASSERT_EQ(result.size(), K); + for (size_t i = 0; i < K; i++) { + EXPECT_EQ(result[i].doc_id, expected[i].doc_id) << "Mismatch at position " << i; + EXPECT_FLOAT_EQ(result[i].score, expected[i].score) << "Mismatch at position " << i; + } +} + +TEST(TopKCollectorTest, StressAscendingOrder500K) { + constexpr size_t N = 500000; + constexpr size_t K = 1000; + + TopKCollector collector(K); + std::vector all_docs; + all_docs.reserve(N); + + for (uint32_t i = 0; i < N; i++) { + float score = static_cast(i); + collector.collect(i, score); + all_docs.emplace_back(i, score); + } + + auto result = collector.into_sorted_vec(); + auto expected = compute_expected_topk(all_docs, K); + + ASSERT_EQ(result.size(), K); + EXPECT_EQ(result[0].doc_id, N - 1); + EXPECT_EQ(result[K - 1].doc_id, N - K); + + for (size_t i = 0; i < K; i++) { + EXPECT_EQ(result[i].doc_id, expected[i].doc_id); + } +} + +TEST(TopKCollectorTest, StressDescendingOrder500K) { + constexpr size_t N = 500000; + constexpr size_t K = 1000; + + TopKCollector collector(K); + std::vector all_docs; + all_docs.reserve(N); + + for (uint32_t i = 0; i < N; i++) { + float score = static_cast(N - i); + collector.collect(i, score); + all_docs.emplace_back(i, score); + } + + auto result = collector.into_sorted_vec(); + auto expected = compute_expected_topk(all_docs, K); + + ASSERT_EQ(result.size(), K); + EXPECT_EQ(result[0].doc_id, 0); + EXPECT_EQ(result[K - 1].doc_id, K - 1); + + for (size_t i = 0; i < K; i++) { + EXPECT_EQ(result[i].doc_id, expected[i].doc_id); + } +} + +TEST(TopKCollectorTest, StressManyDuplicateScores) { + constexpr size_t N = 100000; + constexpr size_t K = 500; + constexpr int NUM_DISTINCT_SCORES = 100; + + std::mt19937 rng(123); + std::uniform_int_distribution score_dist(0, NUM_DISTINCT_SCORES - 1); + + TopKCollector collector(K); + std::vector all_docs; + all_docs.reserve(N); + + for (uint32_t i = 0; i < N; i++) { + float score = static_cast(score_dist(rng)); + collector.collect(i, score); + all_docs.emplace_back(i, score); + } + + auto result = collector.into_sorted_vec(); + auto expected = compute_expected_topk(all_docs, K); + + ASSERT_EQ(result.size(), K); + for (size_t i = 0; i < K; i++) { + EXPECT_EQ(result[i].doc_id, expected[i].doc_id) << "Mismatch at position " << i; + EXPECT_FLOAT_EQ(result[i].score, expected[i].score); + } +} + +TEST(TopKCollectorTest, StressAllSameScore) { + constexpr size_t N = 50000; + constexpr size_t K = 1000; + constexpr float SCORE = 42.0f; + + std::mt19937 rng(456); + std::vector doc_ids(N); + std::iota(doc_ids.begin(), doc_ids.end(), 0); + std::shuffle(doc_ids.begin(), doc_ids.end(), rng); + + TopKCollector collector(K); + for (uint32_t doc_id : doc_ids) { + collector.collect(doc_id, SCORE); + } + + auto result = collector.into_sorted_vec(); + ASSERT_EQ(result.size(), K); + + for (size_t i = 0; i < K; i++) { + EXPECT_EQ(result[i].doc_id, i) << "Expected doc_id " << i << " at position " << i; + EXPECT_FLOAT_EQ(result[i].score, SCORE); + } +} + +TEST(TopKCollectorTest, StressMultipleTruncations) { + constexpr size_t K = 100; + constexpr size_t N = K * 50; + + std::mt19937 rng(789); + std::uniform_real_distribution dist(0.0f, 10000.0f); + + TopKCollector collector(K); + std::vector all_docs; + all_docs.reserve(N); + + for (uint32_t i = 0; i < N; i++) { + float score = dist(rng); + collector.collect(i, score); + all_docs.emplace_back(i, score); + } + + auto result = collector.into_sorted_vec(); + auto expected = compute_expected_topk(all_docs, K); + + ASSERT_EQ(result.size(), K); + for (size_t i = 0; i < K; i++) { + EXPECT_EQ(result[i].doc_id, expected[i].doc_id); + EXPECT_FLOAT_EQ(result[i].score, expected[i].score); + } +} + +TEST(TopKCollectorTest, StressZipfDistribution) { + constexpr size_t N = 500000; + constexpr size_t K = 100; + + std::mt19937 rng(999); + + TopKCollector collector(K); + std::vector all_docs; + all_docs.reserve(N); + + for (uint32_t i = 0; i < N; i++) { + float base_score = 1.0f / (static_cast(i % 10000) + 1.0f); + float noise = static_cast(rng() % 1000) / 1000000.0f; + float score = base_score + noise; + + collector.collect(i, score); + all_docs.emplace_back(i, score); + } + + auto result = collector.into_sorted_vec(); + auto expected = compute_expected_topk(all_docs, K); + + ASSERT_EQ(result.size(), K); + for (size_t i = 0; i < K; i++) { + EXPECT_EQ(result[i].doc_id, expected[i].doc_id) << "Mismatch at position " << i; + } +} + +TEST(TopKCollectorTest, StressSmallKLargeN) { + constexpr size_t N = 1000000; + constexpr size_t K = 10; + + std::mt19937 rng(111); + std::uniform_real_distribution dist(0.0f, 1.0f); + + TopKCollector collector(K); + std::vector all_docs; + all_docs.reserve(N); + + for (uint32_t i = 0; i < N; i++) { + float score = dist(rng); + collector.collect(i, score); + all_docs.emplace_back(i, score); + } + + auto result = collector.into_sorted_vec(); + auto expected = compute_expected_topk(all_docs, K); + + ASSERT_EQ(result.size(), K); + for (size_t i = 0; i < K; i++) { + EXPECT_EQ(result[i].doc_id, expected[i].doc_id); + EXPECT_FLOAT_EQ(result[i].score, expected[i].score); + } +} + +TEST(TopKCollectorTest, StressBimodalDistribution) { + constexpr size_t N = 200000; + constexpr size_t K = 500; + + std::mt19937 rng(222); + + TopKCollector collector(K); + std::vector all_docs; + all_docs.reserve(N); + + for (uint32_t i = 0; i < N; i++) { + float score; + if (i % 2 == 0) { + score = static_cast(rng() % 1000) / 100.0f; + } else { + score = 90.0f + static_cast(rng() % 1000) / 100.0f; + } + collector.collect(i, score); + all_docs.emplace_back(i, score); + } + + auto result = collector.into_sorted_vec(); + auto expected = compute_expected_topk(all_docs, K); + + ASSERT_EQ(result.size(), K); + for (size_t i = 0; i < K; i++) { + EXPECT_EQ(result[i].doc_id, expected[i].doc_id); + } + + for (size_t i = 0; i < K; i++) { + EXPECT_EQ(result[i].doc_id % 2, 1) << "Expected odd doc_id at position " << i; + } +} + +TEST(TopKCollectorTest, StressThresholdBoundary) { + constexpr size_t K = 100; + constexpr size_t N = 10000; + constexpr float BASE_SCORE = 50.0f; + + TopKCollector collector(K); + std::vector all_docs; + all_docs.reserve(N); + + for (uint32_t i = 0; i < K; i++) { + collector.collect(i, BASE_SCORE); + all_docs.emplace_back(i, BASE_SCORE); + } + + for (uint32_t i = K; i < N; i++) { + float score = (i % 2 == 0) ? BASE_SCORE : BASE_SCORE + 0.001f; + collector.collect(i, score); + all_docs.emplace_back(i, score); + } + + auto result = collector.into_sorted_vec(); + auto expected = compute_expected_topk(all_docs, K); + + ASSERT_EQ(result.size(), K); + for (size_t i = 0; i < K; i++) { + EXPECT_EQ(result[i].doc_id, expected[i].doc_id) << "Mismatch at position " << i; + } +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/union_postings_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/union_postings_test.cpp index 73b384e1ad04bf..12c46712d5c50a 100644 --- a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/union_postings_test.cpp +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/union_postings_test.cpp @@ -56,7 +56,9 @@ class MockTermPositionsForUnion : public lucene::index::TermPositions { int32_t read(int32_t*, int32_t*, int32_t) override { return 0; } int32_t read(int32_t*, int32_t*, int32_t*, int32_t) override { return 0; } - bool readRange(DocRange* docRange) override { + bool readRange(DocRange* docRange) override { return _fillDocRange(docRange); } + bool readBlock(DocRange* docRange) override { return _fillDocRange(docRange); } + bool _fillDocRange(DocRange* docRange) { if (_read_done || _docs.empty()) { return false; } @@ -72,7 +74,7 @@ class MockTermPositionsForUnion : public lucene::index::TermPositions { } bool skipTo(const int32_t target) override { return false; } - void skipToBlock(const int32_t target) override {} + bool skipToBlock(const int32_t target) override { return false; } void close() override {} lucene::index::TermPositions* __asTermPositions() override { return this; } lucene::index::TermDocs* __asTermDocs() override { return this; } @@ -105,7 +107,7 @@ static SegmentPostingsPtr make_pos_postings(std::vector docs, std::vec int32_t df = static_cast(docs.size()); TermPositionsPtr ptr(new MockTermPositionsForUnion(std::move(docs), std::move(freqs), std::move(norms), std::move(positions), df)); - return std::make_shared(std::move(ptr), true); + return std::make_shared(std::move(ptr), true, nullptr); } class UnionPostingsTest : public testing::Test {}; diff --git a/contrib/clucene b/contrib/clucene index 8b57674e9d7876..c51b5cc9adc638 160000 --- a/contrib/clucene +++ b/contrib/clucene @@ -1 +1 @@ -Subproject commit 8b57674e9d78769b10aa0c1441cd12671a394745 +Subproject commit c51b5cc9adc63817ad8322f617c75737ece7288d diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java index 7073febac4dfc6..d24a6018438810 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.Match; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SearchExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.scalar.Score; import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; @@ -116,12 +117,13 @@ private Plan pushDown( return null; } - // 2. Requirement: WHERE clause must contain a MATCH function. - boolean hasMatchPredicate = filter.getConjuncts().stream() - .anyMatch(conjunct -> !conjunct.collect(e -> e instanceof Match).isEmpty()); - if (!hasMatchPredicate) { + // 2. Requirement: WHERE clause must contain a MATCH or SEARCH function. + boolean hasMatchOrSearchPredicate = filter.getConjuncts().stream() + .anyMatch(conjunct -> !conjunct.collect( + e -> e instanceof Match || e instanceof SearchExpression).isEmpty()); + if (!hasMatchOrSearchPredicate) { throw new AnalysisException( - "WHERE clause must contain at least one MATCH function" + "WHERE clause must contain at least one MATCH or SEARCH function" + " for score() push down optimization"); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 240a8d1acb49fb..083bad941f1b34 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -616,6 +616,7 @@ public class SessionVariable implements Serializable, Writable { // used for cross-platform (x86/arm) inverted index compatibility // may removed in the future public static final String INVERTED_INDEX_COMPATIBLE_READ = "inverted_index_compatible_read"; + public static final String ENABLE_INVERTED_INDEX_WAND_QUERY = "enable_inverted_index_wand_query"; public static final String AUTO_ANALYZE_START_TIME = "auto_analyze_start_time"; @@ -857,6 +858,10 @@ public class SessionVariable implements Serializable, Writable { + "proportion as hot values, up to HOT_VALUE_COLLECT_COUNT."}) public int hotValueCollectCount = 10; // Select the values that account for at least 10% of the column + @VariableMgr.VarAttr(name = ENABLE_INVERTED_INDEX_WAND_QUERY, + description = {"是否开启倒排索引WAND查询优化", "Whether to enable inverted index WAND query optimization"}) + public boolean enableInvertedIndexWandQuery = true; + public void setHotValueCollectCount(int count) { this.hotValueCollectCount = count; @@ -5157,6 +5162,7 @@ public TQueryOptions toThrift() { tResult.setInvertedIndexSkipThreshold(invertedIndexSkipThreshold); tResult.setInvertedIndexCompatibleRead(invertedIndexCompatibleRead); + tResult.setEnableInvertedIndexWandQuery(enableInvertedIndexWandQuery); tResult.setEnableParallelScan(enableParallelScan); tResult.setEnableLeftSemiDirectReturnOpt(enableLeftSemiDirectReturnOpt); diff --git a/gensrc/thrift/PaloInternalService.thrift b/gensrc/thrift/PaloInternalService.thrift index 3f1c5feedb978c..6dfcd8d6502b2e 100644 --- a/gensrc/thrift/PaloInternalService.thrift +++ b/gensrc/thrift/PaloInternalService.thrift @@ -438,6 +438,8 @@ struct TQueryOptions { 185: optional bool enable_parquet_file_page_cache = true; + 203: optional bool enable_inverted_index_wand_query = true; + // For cloud, to control if the content would be written into file cache // In write path, to control if the content would be written into file cache. // In read path, read from file cache or remote storage when execute query. diff --git a/regression-test/suites/inverted_index_p0/test_bm25_score.groovy b/regression-test/suites/inverted_index_p0/test_bm25_score.groovy index cdbec2579229d3..2686011e89e3b2 100644 --- a/regression-test/suites/inverted_index_p0/test_bm25_score.groovy +++ b/regression-test/suites/inverted_index_p0/test_bm25_score.groovy @@ -139,7 +139,7 @@ suite("test_bm25_score", "p0") { test { sql """ select score() as score from test_bm25_score where request = 'button.03.gif' order by score() limit 10; """ - exception "WHERE clause must contain at least one MATCH function for score() push down optimization" + exception "WHERE clause must contain at least one MATCH or SEARCH function for score() push down optimization" } test {