first commit
This commit is contained in:
parent
5c5d88c92f
commit
eb4f62c56d
41 changed files with 3851 additions and 19 deletions
36
docker_svc/agent/app/libs/check_medical.py
Normal file
36
docker_svc/agent/app/libs/check_medical.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
def is_medical_query(message: str) -> bool:
|
||||
"""
|
||||
Check if the user message contains medical keywords. This function is case-insensitive.
|
||||
|
||||
:param message: The user message or any string to check.
|
||||
:return: True if the message contains medical keywords, False otherwise.
|
||||
"""
|
||||
medical_keywords = [
|
||||
"health",
|
||||
"doctor",
|
||||
"medicine",
|
||||
"disease",
|
||||
"symptom",
|
||||
"treatment",
|
||||
"salute",
|
||||
"medico",
|
||||
"malattia",
|
||||
"sintomo",
|
||||
"cura",
|
||||
"sanità",
|
||||
"santé",
|
||||
"médecin",
|
||||
"médicament",
|
||||
"maladie",
|
||||
"symptôme",
|
||||
"traitement",
|
||||
"gesundheit",
|
||||
"arzt",
|
||||
"medizin",
|
||||
"krankheit",
|
||||
"symptom",
|
||||
"behandlung",
|
||||
]
|
||||
|
||||
message_lower = message.lower()
|
||||
return any(keyword in message_lower for keyword in medical_keywords)
|
43
docker_svc/agent/app/libs/log_prompts.py
Normal file
43
docker_svc/agent/app/libs/log_prompts.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import os
|
||||
from mysql.connector import connect, Error
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def log_prompt_to_db(userid: str | None, ip: str, prompt: str, answer: str):
|
||||
"""
|
||||
Logs the user's prompt and the corresponding response to the database.
|
||||
|
||||
Args:
|
||||
userid (str | None): User ID (optional, can be None).
|
||||
ip (str): Client's IP address.
|
||||
prompt (str): Full conversation history provided by the user.
|
||||
answer (str): Response generated by the AI.
|
||||
"""
|
||||
try:
|
||||
# Connect to the database using environment variables
|
||||
connection = connect(
|
||||
host=os.getenv("DB_HOST"),
|
||||
port=int(os.getenv("DB_PORT", "3306")),
|
||||
user=os.getenv("DB_USER"),
|
||||
password=os.getenv("DB_PASSWORD"),
|
||||
database=os.getenv("DB_NAME")
|
||||
)
|
||||
cursor = connection.cursor()
|
||||
|
||||
# SQL query to insert data
|
||||
query = """
|
||||
INSERT INTO user_prompts (userid, ip, prompt, answer)
|
||||
VALUES (%s, %s, %s, %s)
|
||||
"""
|
||||
values = (userid, ip, prompt, answer)
|
||||
cursor.execute(query, values)
|
||||
|
||||
# Commit the transaction and close resources
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
connection.close()
|
||||
|
||||
except Error as e:
|
||||
logger.error(f"Error logging prompt to database: {e}")
|
36
docker_svc/agent/app/libs/manage_languages.py
Normal file
36
docker_svc/agent/app/libs/manage_languages.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# libs/manage_languages.py
|
||||
|
||||
from langdetect import detect
|
||||
from fastapi import HTTPException
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def validate_language(language: str) -> None:
|
||||
"""Validate the language parameter. Throws an HTTPException if the language is invalid."""
|
||||
valid_languages = {"french", "italian", "english", "german", "auto"}
|
||||
if language not in valid_languages:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid language. Must be one of: french, italian, english, german, or auto"
|
||||
)
|
||||
|
||||
def detect_language(current_message: str) -> str:
|
||||
"""Detect the language of the current message. Defaults to French if detection fails."""
|
||||
try:
|
||||
detected_lang = detect(current_message)
|
||||
if detected_lang == "fr":
|
||||
language = "french"
|
||||
elif detected_lang == "it":
|
||||
language = "italian"
|
||||
elif detected_lang == "en":
|
||||
language = "english"
|
||||
elif detected_lang == "de":
|
||||
language = "german"
|
||||
else:
|
||||
language = "french"
|
||||
logger.info(f"Detected language: {language}")
|
||||
return language
|
||||
except Exception as e:
|
||||
logger.error(f"Language detection failed: {str(e)}")
|
||||
return "french"
|
14
docker_svc/agent/app/libs/models.py
Normal file
14
docker_svc/agent/app/libs/models.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from typing import List, Optional, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "coach"]
|
||||
content: str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
language: str = "auto"
|
||||
temperature: float = 0.7
|
||||
reasoning: bool = False
|
||||
stream: bool = True
|
||||
personality: str = "supportive"
|
192
docker_svc/agent/app/libs/prompt_helper.py
Normal file
192
docker_svc/agent/app/libs/prompt_helper.py
Normal file
|
@ -0,0 +1,192 @@
|
|||
from llama_index.core.base.llms.types import ChatMessage as LlamaChatMessage
|
||||
import logging
|
||||
from libs.models import ChatMessage
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
import httpx
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def format_system_prompt(system_prompt_template: str, language_prompts: dict, language: str,
|
||||
retrieved_docs: str, is_medical: bool, personality: str = "supportive",
|
||||
personality_prompts: dict = {}) -> str:
|
||||
"""Formatta il prompt di sistema con il contenuto specifico della lingua, personalità e i documenti recuperati."""
|
||||
language_prompt = language_prompts[language]["prompt"]
|
||||
language_disclaimer = language_prompts[language]["disclaimer"]
|
||||
language_constraint = "" if language == "auto" else language_prompts[language]["constraint"]
|
||||
|
||||
# Miglioro il log e la gestione della personalità
|
||||
if personality not in personality_prompts:
|
||||
logger.warning(f"Personality '{personality}' not found in prompts, using default empty prompt")
|
||||
personality_prompt = ""
|
||||
else:
|
||||
personality_prompt = personality_prompts[personality]["prompt"]
|
||||
logger.info(f"Using '{personality}' personality: {personality_prompts[personality]['description'][:50]}...")
|
||||
|
||||
logger.info(f"Formatting system prompt with language {language}, personality {personality}")
|
||||
system_message_content = system_prompt_template.format(
|
||||
language_prompt=language_prompt,
|
||||
context=retrieved_docs,
|
||||
language_disclaimer=language_disclaimer if is_medical else "",
|
||||
personality_prompt=personality_prompt,
|
||||
language_constraint=language_constraint
|
||||
)
|
||||
logger.debug(f"System message content: {system_message_content[:200]}...")
|
||||
return system_message_content
|
||||
|
||||
async def perform_inference_streaming(
|
||||
llm,
|
||||
system_message: str,
|
||||
history: List[Dict],
|
||||
current_message: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream inference results from Ollama API"""
|
||||
base_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
|
||||
|
||||
# Prepare messages for Ollama API
|
||||
messages = []
|
||||
|
||||
# Add system message
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": system_message
|
||||
})
|
||||
|
||||
# Add history
|
||||
for msg in history:
|
||||
messages.append({
|
||||
"role": "user" if msg.role == "user" else "assistant",
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
# Add current user message
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": current_message
|
||||
})
|
||||
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
"model": llm.model,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"options": {
|
||||
"temperature": llm.temperature
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(f"Sending streaming request to Ollama API: {base_url}/api/chat")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream("POST", f"{base_url}/api/chat", json=payload, timeout=60.0) as response:
|
||||
if response.status_code != 200:
|
||||
error_detail = await response.aread()
|
||||
logger.error(f"Error from Ollama API: {response.status_code}, {error_detail}")
|
||||
yield f"Error: Failed to get response from language model (Status {response.status_code})"
|
||||
return
|
||||
|
||||
# Variable to accumulate the full response
|
||||
full_response = ""
|
||||
|
||||
# Process the streaming response
|
||||
async for chunk in response.aiter_text():
|
||||
if not chunk.strip():
|
||||
continue
|
||||
|
||||
# Each chunk might contain one JSON object
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
# Process message content if available
|
||||
if 'message' in data and 'content' in data['message']:
|
||||
content = data['message']['content']
|
||||
full_response += content
|
||||
yield content
|
||||
|
||||
# Check if this is the final message with done flag
|
||||
if data.get('done', False):
|
||||
logger.debug("Streaming response completed")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse streaming response: {e}, chunk: {chunk}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during streaming inference: {str(e)}")
|
||||
yield f"Error: {str(e)}"
|
||||
|
||||
# Return empty string at the end to signal completion
|
||||
yield ""
|
||||
|
||||
def perform_inference(
|
||||
llm,
|
||||
system_message: str,
|
||||
history: List[Dict],
|
||||
current_message: str,
|
||||
stream: bool = False
|
||||
) -> str:
|
||||
"""Perform inference with the given LLM."""
|
||||
if stream:
|
||||
# This will be handled by the streaming endpoint
|
||||
raise ValueError("Streaming not supported in synchronous inference")
|
||||
|
||||
# Prepare messages for the API
|
||||
messages = []
|
||||
|
||||
# Add system message
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": system_message
|
||||
})
|
||||
|
||||
# Add history
|
||||
for msg in history:
|
||||
messages.append({
|
||||
"role": "user" if msg.role == "user" else "assistant",
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
# Add current user message
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": current_message
|
||||
})
|
||||
|
||||
# For non-streaming, we'll use the httpx client directly to call Ollama API
|
||||
base_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
|
||||
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
"model": llm.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": llm.temperature
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(f"Sending non-streaming request to Ollama API: {base_url}/api/chat")
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(f"{base_url}/api/chat", json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Error from Ollama API: {response.status_code}, {response.text}")
|
||||
return f"Error: Failed to get response from language model (Status {response.status_code})"
|
||||
|
||||
data = response.json()
|
||||
if 'message' in data and 'content' in data['message']:
|
||||
return data['message']['content']
|
||||
else:
|
||||
logger.error(f"Unexpected response format: {data}")
|
||||
return "Error: Unexpected response format from language model"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during non-streaming inference: {str(e)}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def select_llm(llm, llm_reasoning, reasoning: bool):
|
||||
"""Select the LLM model based on the reasoning flag."""
|
||||
selected_llm = llm_reasoning if reasoning else llm
|
||||
return selected_llm
|
79
docker_svc/agent/app/libs/qdrant_helper.py
Normal file
79
docker_svc/agent/app/libs/qdrant_helper.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
from llama_index.core import VectorStoreIndex, StorageContext
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.models import Distance, VectorParams
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def ensure_collection_exists(qdrant_client: QdrantClient, collection_name: str, vector_size: int) -> None:
|
||||
"""Verify that the Qdrant collection exists, and create it if it does not."""
|
||||
try:
|
||||
if not qdrant_client.collection_exists(collection_name):
|
||||
qdrant_client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE)
|
||||
)
|
||||
logger.info(f"Created Qdrant collection '{collection_name}' with vector size {vector_size}")
|
||||
else:
|
||||
logger.info(f"Qdrant collection '{collection_name}' already exists")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure Qdrant collection exists: {str(e)}")
|
||||
raise
|
||||
|
||||
def retrieve_documents(qdrant_client: QdrantClient, collection_name: str, embed_model, current_message: str) -> str:
|
||||
"""Get the relevant documents from Qdrant based on the current message."""
|
||||
logger.info("Initializing Qdrant vector store")
|
||||
vector_store = QdrantVectorStore(
|
||||
client=qdrant_client,
|
||||
collection_name=collection_name,
|
||||
embed_model=embed_model
|
||||
)
|
||||
logger.info("Building vector store index")
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
vector_store=vector_store,
|
||||
embed_model=embed_model
|
||||
)
|
||||
logger.info("Retrieving documents")
|
||||
retriever = index.as_retriever()
|
||||
retrieved_nodes = retriever.retrieve(current_message)
|
||||
retrieved_docs = "\n\n".join([node.text for node in retrieved_nodes])
|
||||
logger.debug(f"Retrieved documents (first 200 chars): {retrieved_docs[:200]}...")
|
||||
return retrieved_docs
|
||||
|
||||
def index_documents(qdrant_client: QdrantClient, collection_name: str, embed_model, documents) -> None:
|
||||
"""Index the provided documents into the Qdrant collection."""
|
||||
vector_store = QdrantVectorStore(
|
||||
client=qdrant_client,
|
||||
collection_name=collection_name,
|
||||
embed_model=embed_model
|
||||
)
|
||||
logger.info(f"Indexing documents into Qdrant collection '{collection_name}'")
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model
|
||||
)
|
||||
logger.info("Successfully indexed documents")
|
||||
|
||||
def delete_all_documents(qdrant_client: QdrantClient, collection_name: str, vector_size: int) -> None:
|
||||
"""Delete all vectors from the Qdrant collection by recreating it."""
|
||||
try:
|
||||
# Check if collection exists
|
||||
if qdrant_client.collection_exists(collection_name):
|
||||
# Delete the collection
|
||||
qdrant_client.delete_collection(collection_name=collection_name)
|
||||
logger.info(f"Deleted Qdrant collection '{collection_name}'")
|
||||
|
||||
# Recreate the empty collection with the same parameters
|
||||
qdrant_client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE)
|
||||
)
|
||||
logger.info(f"Recreated empty Qdrant collection '{collection_name}'")
|
||||
else:
|
||||
logger.warning(f"Qdrant collection '{collection_name}' does not exist, nothing to delete")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete Qdrant collection: {str(e)}")
|
||||
raise
|
Loading…
Add table
Add a link
Reference in a new issue