-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathhea_documentation_QA_bot.py
More file actions
136 lines (116 loc) · 5.16 KB
/
hea_documentation_QA_bot.py
File metadata and controls
136 lines (116 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import lancedb
from lancedb.schema import vector
from sentence_transformers import SentenceTransformer
import ssl
import certifi
import httpx
import openai
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
LANCEDB_PATH = "data/hea_lancedb"
TABLE_NAME = "hea"
def initialize_azure_client(division="ts", region="eastus2", api_version="2024-10-21"):
openai_endpoints = {
'ts': {
'eastus':'https://air-ts-eastus.openai.azure.com/',
'eastus2':'https://air-ts-eastus2.openai.azure.com/',
'northcentralus':'https://air-poc-northcentralus.openai.azure.com/',
},
'ps': {
'eastus':'https://air-ps-eastus.openai.azure.com/',
'eastus2':'https://air-ps-eastus2.openai.azure.com/',
'northcentralus':'https://air-poc-northcentralus.openai.azure.com/'
},
}
openai_endpoint = openai_endpoints[division][region]
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
ctx = ssl.create_default_context(cafile=os.environ.get('REQUESTS_CA_BUNDLE', certifi.where()))
httpx_client = httpx.Client(verify=ctx)
openai_client = openai.AzureOpenAI(
api_version=api_version,
azure_endpoint=openai_endpoint,
azure_ad_token_provider=token_provider,
http_client=httpx_client
)
return openai_client
# --- RAG Components ---
class QnAPipeline:
def __init__(self):
# Initialize the embedding model
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Connect to LanceDB and the document chunks table
try:
db = lancedb.connect(LANCEDB_PATH)
self.table = db.open_table(TABLE_NAME)
except Exception as e:
raise FileNotFoundError(f"LanceDB table not found at {LANCEDB_PATH}/{TABLE_NAME}. Please run the ingestion script first. Error: {e}")
self.client = initialize_azure_client()
def search_knowledge_base(self, query: str, top_k: int = 5):
"""
Embeds a query and searches the LanceDB table for the most relevant chunks.
"""
# Embed the user's query
query_vector = self.embedding_model.encode(query).tolist()
# Search the LanceDB table using the vector
# .to_list() retrieves the search results as a Python list of dictionaries
search_results = (
self.table
.search(query_vector)
.limit(top_k)
.to_list()
)
return search_results
def generate_response(self, user_question: str, context: list):
"""
Constructs a prompt with retrieved context and generates a response using Azure OpenAI.
"""
# Format the context for the LLM
context_str = "\n".join([f"Source: {c['source_uri']}\nContent: {c['text']}" for c in context])
# Define the system message to guide the LLM's behavior
system_message = (
"You are a helpful assistant that answers questions based on the provided context. "
"Only use the information from the documents provided. "
"If the answer is not in the context, say 'I cannot answer this question based on the provided documents.' "
"Please cite the source document(s) for your answer."
)
# Send the prompt to the Azure OpenAI client
response = self.client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": f"Context: {context_str}\n\nQuestion: {user_question}"}
],
temperature=0.7,
max_tokens=500
)
return response.choices[0].message.content
def run_qa_loop(self):
"""
Runs the interactive Q&A loop.
"""
print("Welcome to the LIAS Q&A System! Type 'quit' to exit.")
while True:
user_question = input("\nAsk a question: ")
if user_question.lower() == 'quit':
break
try:
# 1. Retrieve relevant chunks
relevant_chunks = self.search_knowledge_base(user_question)
if not relevant_chunks:
print("I couldn't find any relevant information for that question.")
continue
# 2. Generate a response with the retrieved context
answer = self.generate_response(user_question, relevant_chunks)
# 3. Print the final answer
print(f"\nAI Answer: {answer}")
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
try:
qa_system = QnAPipeline()
qa_system.run_qa_loop()
except FileNotFoundError as e:
print(f"Error: {e}")
print("Please run your document ingestion script first to create the LanceDB table.")
except Exception as e:
print(f"An unexpected error occurred during setup: {e}")