medichaid/docker_svc/agent/app/main.py
2025-07-10 01:43:01 +02:00

304 lines
No EOL
12 KiB
Python

from fastapi import FastAPI, File, UploadFile, HTTPException, Request, BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from llama_index.core import SimpleDirectoryReader
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
from typing import Literal, List
from pydantic import BaseModel
from langdetect import DetectorFactory
from qdrant_client import QdrantClient
import os
from typing import List
import uuid
import yaml
from dotenv import load_dotenv
import logging
import asyncio
import json
from libs.check_medical import is_medical_query
import libs.manage_languages as manage_languages
import libs.qdrant_helper as qdrant_helper
from libs.models import ChatMessage, ChatRequest
import libs.prompt_helper as prompt_helper
from libs.log_prompts import log_prompt_to_db
# Set seed for reproducibility of language detection
DetectorFactory.seed = 0
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Load environment variables from .env file
load_dotenv()
# Initialize FastAPI app
app = FastAPI(
docs_url="/docs",
redoc_url="/redoc",
max_request_body_size=100 * 1024 * 1024 # 100MB
)
# Get CORS origins from environment or use default
cors_origins = os.getenv("CORS_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(",")
# Add CORS middleware with proper configuration
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["Content-Type", "X-Content-Type-Options"],
max_age=600, # 10 minutes for preflight cache
)
# Load custom OpenAPI schema
def load_custom_openapi():
with open("openapi.json", "r") as f:
custom_openapi = yaml.safe_load(f)
default_openapi = get_openapi(
title=app.title,
version=app.version,
openapi_version=app.openapi_version,
description=app.description,
routes=app.routes,
)
default_openapi["info"] = custom_openapi.get("info", default_openapi["info"])
default_openapi["paths"].update(custom_openapi.get("paths", {}))
return default_openapi
app.openapi = load_custom_openapi
with open("prompts.yaml", "r") as f:
prompts = yaml.safe_load(f)
SYSTEM_PROMPT_TEMPLATE = prompts["system_prompt"]
LANGUAGE_PROMPTS = prompts["languages"]
PERSONALITY_PROMPTS = prompts["personalities"]
# Configuration of models and services using .env variables
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "https://ollama.kube-ext.isc.heia-fr.ch")
logger.info(f"Starting application with OLLAMA_BASE_URL: {OLLAMA_BASE_URL}")
# Embedding model using Ollama
embed_model = OllamaEmbedding(
model_name=os.getenv("EMBED_MODEL_NAME", "mxbai-embed-large"),
base_url=OLLAMA_BASE_URL,
request_timeout=os.getenv("TIMEOUT_REQUEST_EMBED", 20.0)
)
logger.info("OllamaEmbedding initialized with model: " + os.getenv("EMBED_MODEL_NAME", "mxbai-embed-large"))
# Direct inference model
llm = Ollama(
model=os.getenv("LLM_MODEL_NAME", "llama3"),
base_url=OLLAMA_BASE_URL,
temperature=float(os.getenv("TEMPERATURE", "0.7")),
request_timeout=os.getenv("TIMEOUT_REQUEST_CHAT_DIRECT", 30.0)
)
logger.info(f"Ollama LLM initialized with model: {llm.model} "
f"with temperature: {llm.temperature}")
# Reasoning model
llm_reasoning = Ollama(
model=os.getenv("LLM_MODEL_NAME_THINKING", "deepseek-r1:14b"),
base_url=OLLAMA_BASE_URL,
temperature=float(os.getenv("TEMPERATURE", "0.7")),
request_timeout=os.getenv("TIMEOUT_REQUEST_CHAT_REASON", 60.0)
)
logger.info(f"Ollama reasoning LLM initialized with model: {llm_reasoning.model} "
f"with temperature: {llm_reasoning.temperature}")
# Qdrant configuration
qdrant_client = QdrantClient(
host=os.getenv("QDRANT_HOST", "localhost"),
port=int(os.getenv("QDRANT_PORT", "6333"))
)
collection_name = os.getenv("COLLECTION_NAME", "default_collection")
vector_size = int(os.getenv("VECTOR_SIZE", "1024"))
logger.info(f"Qdrant client initialized with host: {os.getenv('QDRANT_HOST')} and collection: {collection_name}")
# Ensure Qdrant collection exists
qdrant_helper.ensure_collection_exists(qdrant_client, collection_name, vector_size)
# Endpoint to upload PDFs
@app.post("/upload")
async def upload_pdfs(files: List[UploadFile] = File(...)):
logger.info("Received upload request")
try:
uploaded_files_count = len(files)
logger.debug(f"Number of files to upload: {uploaded_files_count}")
for file in files:
file_id = str(uuid.uuid4())
file_path = f"./pdfs/{file_id}.pdf"
logger.debug(f"Processing file: {file.filename}, saving as {file_path}")
with open(file_path, "wb") as f:
f.write(await file.read())
logger.debug(f"File {file.filename} saved successfully")
documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
logger.debug(f"Loaded {len(documents)} documents from {file.filename}")
qdrant_helper.index_documents(qdrant_client, collection_name, embed_model, documents)
return {"message": f"{uploaded_files_count} files processed and indexed successfully"}
except Exception as e:
logger.error(f"Error in upload endpoint: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}")
# Chat endpoint with language, temperature, and reasoning support
@app.post("/chat")
async def chat_inference(chat_request: ChatRequest, http_request: Request, background_tasks: BackgroundTasks):
logger.info("Received chat request")
try:
if not chat_request.messages:
logger.warning("No messages provided in the request")
raise HTTPException(status_code=400, detail="No messages provided")
# Log the complete request object to inspect its contents
logger.debug(f"Complete chat request object: {chat_request.dict()}")
logger.debug(f"Request messages: {chat_request.messages}")
logger.debug(f"Requested language: {chat_request.language}")
logger.debug(f"Requested temperature: {chat_request.temperature}")
logger.debug(f"Requested reasoning: {chat_request.reasoning}")
logger.debug(f"Requested streaming: {chat_request.stream}")
logger.debug(f"Requested personality: {chat_request.personality}")
# Validate language
manage_languages.validate_language(chat_request.language)
# Log più dettagliato della personalità
logger.info(f"Processing request with personality: {chat_request.personality}")
# Validate personality
if chat_request.personality not in ["cool", "cynical", "supportive"]:
logger.warning(f"Invalid personality: {chat_request.personality}, using 'supportive' as default")
chat_request.personality = "supportive"
# Validate temperature
if not (0 < chat_request.temperature < 1):
raise HTTPException(status_code=400, detail="Temperature must be between 0 and 1 (exclusive)")
# Prepare message data
current_message = chat_request.messages[-1].content.lower()
history = chat_request.messages[:-1]
logger.debug(f"Current user message: {current_message}")
logger.debug(f"Message history: {history}")
# Prepare full conversation history as a concatenated string
conversation_history = "\n".join([f"{msg.role}: {msg.content}" for msg in chat_request.messages])
logger.debug(f"Full conversation history: {conversation_history}")
# Detect language if "auto"
if chat_request.language == "auto":
chat_request.language = manage_languages.detect_language(current_message)
logger.info(f"Detected language using inference: {chat_request.language}")
# Check if the query is medical-related
is_medical = is_medical_query(current_message)
logger.debug(f"Is medical-related query? {is_medical}")
# Select LLM and set temperature
selected_llm = prompt_helper.select_llm(llm, llm_reasoning, chat_request.reasoning)
selected_llm.temperature = chat_request.temperature
logger.info(f"Using LLM model: {selected_llm.model} with temperature: {selected_llm.temperature}")
# Retrieve documents from Qdrant
retrieved_docs = qdrant_helper.retrieve_documents(qdrant_client, collection_name, embed_model, current_message)
# Format system prompt with personality - verifico passaggio corretto
system_message_content = prompt_helper.format_system_prompt(
SYSTEM_PROMPT_TEMPLATE,
LANGUAGE_PROMPTS,
chat_request.language,
retrieved_docs,
is_medical,
chat_request.personality, # Confermo passaggio personalità
PERSONALITY_PROMPTS # Confermo passaggio dizionario personalità
)
# Decidiamo se utilizzare lo streaming o la risposta sincrona
if chat_request.stream:
# Streaming response
logger.info("Using streaming response")
async def generate():
full_response = ""
async for content in prompt_helper.perform_inference_streaming(
selected_llm,
system_message_content,
history,
chat_request.messages[-1].content
):
if content:
full_response += content
# Formato SSE standard con \n\n alla fine per delimitare gli eventi
yield f"data: {json.dumps({'content': content, 'full': full_response})}\n\n"
# Log the full conversation and response
background_tasks.add_task(
log_prompt_to_db,
None, # TODO: User ID not available yet
http_request.client.host, # Client's IP address
conversation_history, # Full conversation history
full_response # AI-generated response
)
# Signal the end of the stream con formato SSE consistente
yield f"data: {json.dumps({'done': True})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream; charset=utf-8",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Type": "text/event-stream; charset=utf-8",
}
)
else:
# Non-streaming response
logger.info("Using non-streaming response")
response_content = prompt_helper.perform_inference(
selected_llm,
system_message_content,
history,
chat_request.messages[-1].content,
stream=False
)
# Log the full conversation and response in the background
background_tasks.add_task(
log_prompt_to_db,
None, # TODO: User ID not available yet
http_request.client.host, # Client's IP address
conversation_history, # Full conversation history
response_content # AI-generated response
)
return {"response": response_content}
except Exception as e:
logger.error(f"Error in chat inference: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}")
# Add new DELETE endpoint to clear all documents
@app.delete("/docs")
async def delete_all_docs():
logger.info("Received request to delete all documents")
try:
qdrant_helper.delete_all_documents(qdrant_client, collection_name, vector_size)
return {"message": "All documents have been deleted from the database"}
except Exception as e:
logger.error(f"Error in delete endpoint: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error deleting documents: {str(e)}")