Skip to content

Commit 78f5313

Browse files
authored
Merge pull request #46 from epsilla-cloud/hybrid_search
Hybrid search
2 parents c91a9b4 + 1b71ead commit 78f5313

File tree

6 files changed

+259
-1
lines changed

6 files changed

+259
-1
lines changed

pyepsilla/cloud/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pprint
88
import socket
99
from typing import Optional, Union
10+
from ..utils.search_engine import SearchEngine
1011

1112
import requests
1213
import sentry_sdk
@@ -311,3 +312,6 @@ def get(
311312
body = res.json()
312313
res.close()
313314
return status_code, body
315+
316+
def as_search_engine(self):
317+
return SearchEngine(self)

pyepsilla/enterprise/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pprint
88
import socket
99
from typing import Optional, Union
10+
from ..utils.search_engine import SearchEngine
1011

1112
import requests
1213
import sentry_sdk
@@ -360,3 +361,6 @@ def get(
360361
body = res.json()
361362
res.close()
362363
return status_code, body
364+
365+
def as_search_engine(self):
366+
return SearchEngine(self)

pyepsilla/utils/__init__.py

Whitespace-only changes.

pyepsilla/utils/search_engine.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
#!/usr/bin/env python
2+
# -*- coding:utf-8 -*-
3+
from __future__ import annotations
4+
5+
import datetime
6+
import json
7+
import socket
8+
import time
9+
from typing import Optional, Union
10+
11+
class VectorRetriever:
12+
def __init__(
13+
self,
14+
db_client,
15+
table_name: str,
16+
primary_key_field: str,
17+
query_index: str = None,
18+
query_field: str = None,
19+
query_vector: Union[list, dict] = None,
20+
response_fields: list = None,
21+
limit: int = 2,
22+
filter: str = ""
23+
):
24+
self._db_client = db_client
25+
self._table_name = table_name
26+
self._primary_key_field = primary_key_field
27+
self._query_index = query_index
28+
self._query_field = query_field
29+
self._query_vector = query_vector
30+
self._response_fields = response_fields
31+
self._limit = limit
32+
self._filter = filter
33+
34+
def retrieve(self, query: str) -> list[dict]:
35+
# Query vectors from the table
36+
status_code, response = self._db_client.query(
37+
table_name=self._table_name,
38+
query_text=query,
39+
query_index=self._query_index,
40+
query_field=self._query_field,
41+
query_vector=self._query_vector,
42+
response_fields=self._response_fields,
43+
limit=self._limit,
44+
filter=self._filter,
45+
with_distance=True,
46+
)
47+
if status_code != 200:
48+
error_msg = response["message"] if "message" in response else "Unknown error"
49+
raise Exception(f"Failed to retrieve data from table {self._table_name}: {error_msg}")
50+
# Add @id from the table to each record based on the primary_key_field
51+
for record in response["result"]:
52+
# Raise exception if the primary_key_field is not found in the record
53+
if self._primary_key_field not in record:
54+
raise Exception(f"Primary key field {self._primary_key_field} not found in the response from table {self._table_name}")
55+
record["@id"] = record[self._primary_key_field]
56+
return response["result"]
57+
58+
class Reranker:
59+
def rerank(self, candidates: list[list[any]]) -> list[any]:
60+
pass
61+
62+
class RRFReRanker(Reranker):
63+
def __init__(self, weights: list[float] = None, k = 50, limit = None):
64+
self._weights = weights
65+
self._k = k
66+
self._limit = limit
67+
68+
def rerank(self, candidates: list[list[any]]) -> list[any]:
69+
# Use candidate["@distance"] of each candidate to rerank
70+
# Initialize weights if not provided
71+
if not self._weights:
72+
self._weights = [1] * len(candidates)
73+
74+
# Calculate RRF scores for each candidate
75+
rrf_scores = {}
76+
for i, candidate_list in enumerate(candidates):
77+
weight = self._weights[i]
78+
for rank, candidate in enumerate(candidate_list, start=1):
79+
# Calculate RRF score for this candidate in this list
80+
rrf_score = weight / (self._k + rank)
81+
# Aggregate scores if candidate appears in multiple lists
82+
if candidate["@id"] in rrf_scores:
83+
rrf_scores[candidate["@id"]]["score"] += rrf_score
84+
else:
85+
rrf_scores[candidate["@id"]] = {"candidate": candidate, "score": rrf_score}
86+
87+
# Sort candidates based on aggregated RRF score
88+
sorted_candidates = sorted(rrf_scores.values(), key=lambda x: x["score"], reverse=True)
89+
90+
# Apply the limit to the final list if specified
91+
if self._limit is not None:
92+
sorted_candidates = sorted_candidates[:self._limit]
93+
94+
# Return only the candidate information, discarding the scores
95+
return [item["candidate"] for item in sorted_candidates]
96+
97+
class RelativeScoreFusionReranker(Reranker):
98+
def __init__(self, limit: int = None):
99+
self._limit = limit
100+
101+
def normalize_distances(self, candidates: list[dict]) -> list[dict]:
102+
# Extract all distances
103+
distances = [candidate["@distance"] for candidate in candidates]
104+
105+
if len(distances) < 2 or max(distances) == min(distances):
106+
return [{'candidate': candidate, 'score': 1} for candidate in candidates]
107+
108+
min_distance, max_distance = min(distances), max(distances)
109+
110+
# Normalize distances: (distance - min_distance) / (max_distance - min_distance)
111+
normalized_candidates = []
112+
for candidate in candidates:
113+
normalized_score = (candidate["@distance"] - min_distance) / (max_distance - min_distance)
114+
normalized_candidates.append({'candidate': candidate, 'score': 1 - normalized_score})
115+
116+
return normalized_candidates
117+
118+
def rerank(self, candidates: list[list[dict]]) -> list[dict]:
119+
normalized_lists = [self.normalize_distances(candidate_list) for candidate_list in candidates]
120+
121+
# Aggregate normalized scores across lists
122+
aggregated_scores = {}
123+
for candidate_list in normalized_lists:
124+
for item in candidate_list:
125+
candidate_id = item['candidate']['@id']
126+
if candidate_id in aggregated_scores:
127+
aggregated_scores[candidate_id]['score'] += item['score']
128+
else:
129+
aggregated_scores[candidate_id] = item
130+
131+
# Sort candidates based on aggregated score
132+
sorted_candidates = sorted(aggregated_scores.values(), key=lambda x: x['score'], reverse=True)
133+
134+
# Apply the limit to the final list if specified
135+
if self._limit is not None:
136+
sorted_candidates = sorted_candidates[:self._limit]
137+
138+
# Return only the candidate information, discarding the scores
139+
return [item['candidate'] for item in sorted_candidates]
140+
141+
class DistributionBasedScoreFusionReranker(Reranker):
142+
def __init__(self, scale_ranges: list[list[float]] = [], limit: int = None):
143+
self._limit = limit
144+
self._scale_ranges = scale_ranges
145+
146+
def normalize_distances(self, scale: list[float], candidates: list[dict]) -> list[dict]:
147+
# Normalize distances: (distance - min_distance) / (max_distance - min_distance)
148+
normalized_candidates = []
149+
for candidate in candidates:
150+
normalized_score = max(candidate["@distance"] - scale[0], 0) / (scale[1] - scale[0])
151+
normalized_candidates.append({'candidate': candidate, 'score': 1 - min(1, normalized_score)})
152+
153+
return normalized_candidates
154+
155+
def rerank(self, candidates: list[list[dict]]) -> list[dict]:
156+
normalized_lists = [self.normalize_distances(self._scale_ranges[i], candidate_list) for i, candidate_list in enumerate(candidates)]
157+
158+
# Aggregate normalized scores across lists
159+
aggregated_scores = {}
160+
for candidate_list in normalized_lists:
161+
for item in candidate_list:
162+
candidate_id = item['candidate']['@id']
163+
if candidate_id in aggregated_scores:
164+
aggregated_scores[candidate_id]['score'] += item['score']
165+
else:
166+
aggregated_scores[candidate_id] = item
167+
168+
# Sort candidates based on aggregated score
169+
sorted_candidates = sorted(aggregated_scores.values(), key=lambda x: x['score'], reverse=True)
170+
171+
# Apply the limit to the final list if specified
172+
if self._limit is not None:
173+
sorted_candidates = sorted_candidates[:self._limit]
174+
175+
# Return only the candidate information, discarding the scores
176+
return [item['candidate'] for item in sorted_candidates]
177+
178+
class SearchEngine:
179+
def __init__(
180+
self,
181+
db_client,
182+
):
183+
self._db_client = db_client
184+
self._retrievers = []
185+
self._reranker: Reranker = None
186+
187+
def add_retriever(
188+
self,
189+
table_name: str,
190+
primary_key_field: str = "ID",
191+
query_index: str = None,
192+
query_field: str = None,
193+
query_vector: Union[list, dict] = None,
194+
response_fields: list = None,
195+
limit: int = 2,
196+
filter: str = ""
197+
) -> SearchEngine:
198+
self._reranker = None
199+
self._retrievers.append(
200+
VectorRetriever(
201+
db_client=self._db_client,
202+
table_name=table_name,
203+
primary_key_field=primary_key_field,
204+
query_index=query_index,
205+
query_field=query_field,
206+
query_vector=query_vector,
207+
response_fields=response_fields,
208+
limit=limit,
209+
filter=filter
210+
)
211+
)
212+
return self
213+
214+
def set_reranker(self, type: str="rrf", weights: list[float] = None, scale_ranges: list[list[int]] = [], k = 50, limit = None):
215+
if type == "rrf" or type == "reciprocal_rank_fusion":
216+
if weights is not None and len(self._retrievers) != len(weights):
217+
raise Exception("The length of weights should be equal to the number of retrievers")
218+
self._reranker = RRFReRanker(weights=weights, k=k, limit=limit)
219+
elif type == "rsf" or type == "relative_score_fusion":
220+
self._reranker = RelativeScoreFusionReranker(limit=limit)
221+
elif type == "dbsf" or type == "distribution_based_score_fusion":
222+
if len(scale_ranges) != len(self._retrievers):
223+
raise Exception("The length of scale_ranges should be equal to the number of retrievers")
224+
self._reranker = DistributionBasedScoreFusionReranker(scale_ranges, limit=limit)
225+
else:
226+
raise Exception("Invalid reranker type: " + type)
227+
return self
228+
229+
def search(self, query: str) -> list[dict]:
230+
# If no retriever is added, return error
231+
if not self._retrievers:
232+
raise Exception("No retriever added to the search engine")
233+
# If more than one retrievers are added, must set a reranker
234+
if len(self._retrievers) > 1 and not self._reranker:
235+
raise Exception("More than one retriever added to the search engine, but no reranker is set")
236+
237+
# Use ThreadPoolExecutor to run retrievers concurrently
238+
candidates = []
239+
for retriever in self._retrievers:
240+
candidates.append(retriever.retrieve(query))
241+
242+
# Rerank candidates if reranker is set
243+
if self._reranker:
244+
candidates = self._reranker.rerank(candidates)
245+
246+
return candidates

pyepsilla/vectordb/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import socket
88
import time
99
from typing import Optional, Union
10+
from ..utils.search_engine import SearchEngine
1011

1112
import requests
1213
import sentry_sdk
@@ -354,3 +355,6 @@ def drop_db(self, db_name: str):
354355
body = res.json()
355356
res.close()
356357
return status_code, body
358+
359+
def as_search_engine(self):
360+
return SearchEngine(self)

pyepsilla/vectordb/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.3"
1+
__version__ = "0.3.4"

0 commit comments

Comments
 (0)