304 lines
12 KiB
Python
304 lines
12 KiB
Python
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Request, BackgroundTasks
|
||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||
|
from fastapi.middleware.cors import CORSMiddleware
|
||
|
from fastapi.openapi.utils import get_openapi
|
||
|
from llama_index.core import SimpleDirectoryReader
|
||
|
from llama_index.embeddings.ollama import OllamaEmbedding
|
||
|
from llama_index.llms.ollama import Ollama
|
||
|
from typing import Literal, List
|
||
|
from pydantic import BaseModel
|
||
|
from langdetect import DetectorFactory
|
||
|
from qdrant_client import QdrantClient
|
||
|
import os
|
||
|
from typing import List
|
||
|
import uuid
|
||
|
import yaml
|
||
|
from dotenv import load_dotenv
|
||
|
import logging
|
||
|
import asyncio
|
||
|
import json
|
||
|
|
||
|
from libs.check_medical import is_medical_query
|
||
|
import libs.manage_languages as manage_languages
|
||
|
import libs.qdrant_helper as qdrant_helper
|
||
|
from libs.models import ChatMessage, ChatRequest
|
||
|
import libs.prompt_helper as prompt_helper
|
||
|
from libs.log_prompts import log_prompt_to_db
|
||
|
|
||
|
# Set seed for reproducibility of language detection
|
||
|
DetectorFactory.seed = 0
|
||
|
|
||
|
# Configure logging
|
||
|
logging.basicConfig(
|
||
|
level=logging.DEBUG,
|
||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||
|
)
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
# Load environment variables from .env file
|
||
|
load_dotenv()
|
||
|
|
||
|
# Initialize FastAPI app
|
||
|
app = FastAPI(
|
||
|
docs_url="/docs",
|
||
|
redoc_url="/redoc",
|
||
|
max_request_body_size=100 * 1024 * 1024 # 100MB
|
||
|
)
|
||
|
|
||
|
# Get CORS origins from environment or use default
|
||
|
cors_origins = os.getenv("CORS_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(",")
|
||
|
|
||
|
# Add CORS middleware with proper configuration
|
||
|
app.add_middleware(
|
||
|
CORSMiddleware,
|
||
|
allow_origins=cors_origins,
|
||
|
allow_credentials=True,
|
||
|
allow_methods=["*"],
|
||
|
allow_headers=["*"],
|
||
|
expose_headers=["Content-Type", "X-Content-Type-Options"],
|
||
|
max_age=600, # 10 minutes for preflight cache
|
||
|
)
|
||
|
|
||
|
# Load custom OpenAPI schema
|
||
|
def load_custom_openapi():
|
||
|
with open("openapi.json", "r") as f:
|
||
|
custom_openapi = yaml.safe_load(f)
|
||
|
default_openapi = get_openapi(
|
||
|
title=app.title,
|
||
|
version=app.version,
|
||
|
openapi_version=app.openapi_version,
|
||
|
description=app.description,
|
||
|
routes=app.routes,
|
||
|
)
|
||
|
default_openapi["info"] = custom_openapi.get("info", default_openapi["info"])
|
||
|
default_openapi["paths"].update(custom_openapi.get("paths", {}))
|
||
|
return default_openapi
|
||
|
|
||
|
app.openapi = load_custom_openapi
|
||
|
|
||
|
with open("prompts.yaml", "r") as f:
|
||
|
prompts = yaml.safe_load(f)
|
||
|
SYSTEM_PROMPT_TEMPLATE = prompts["system_prompt"]
|
||
|
LANGUAGE_PROMPTS = prompts["languages"]
|
||
|
PERSONALITY_PROMPTS = prompts["personalities"]
|
||
|
|
||
|
# Configuration of models and services using .env variables
|
||
|
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "https://ollama.kube-ext.isc.heia-fr.ch")
|
||
|
logger.info(f"Starting application with OLLAMA_BASE_URL: {OLLAMA_BASE_URL}")
|
||
|
|
||
|
# Embedding model using Ollama
|
||
|
embed_model = OllamaEmbedding(
|
||
|
model_name=os.getenv("EMBED_MODEL_NAME", "mxbai-embed-large"),
|
||
|
base_url=OLLAMA_BASE_URL,
|
||
|
request_timeout=os.getenv("TIMEOUT_REQUEST_EMBED", 20.0)
|
||
|
)
|
||
|
logger.info("OllamaEmbedding initialized with model: " + os.getenv("EMBED_MODEL_NAME", "mxbai-embed-large"))
|
||
|
|
||
|
# Direct inference model
|
||
|
llm = Ollama(
|
||
|
model=os.getenv("LLM_MODEL_NAME", "llama3"),
|
||
|
base_url=OLLAMA_BASE_URL,
|
||
|
temperature=float(os.getenv("TEMPERATURE", "0.7")),
|
||
|
request_timeout=os.getenv("TIMEOUT_REQUEST_CHAT_DIRECT", 30.0)
|
||
|
)
|
||
|
logger.info(f"Ollama LLM initialized with model: {llm.model} "
|
||
|
f"with temperature: {llm.temperature}")
|
||
|
|
||
|
# Reasoning model
|
||
|
llm_reasoning = Ollama(
|
||
|
model=os.getenv("LLM_MODEL_NAME_THINKING", "deepseek-r1:14b"),
|
||
|
base_url=OLLAMA_BASE_URL,
|
||
|
temperature=float(os.getenv("TEMPERATURE", "0.7")),
|
||
|
request_timeout=os.getenv("TIMEOUT_REQUEST_CHAT_REASON", 60.0)
|
||
|
)
|
||
|
logger.info(f"Ollama reasoning LLM initialized with model: {llm_reasoning.model} "
|
||
|
f"with temperature: {llm_reasoning.temperature}")
|
||
|
|
||
|
# Qdrant configuration
|
||
|
qdrant_client = QdrantClient(
|
||
|
host=os.getenv("QDRANT_HOST", "localhost"),
|
||
|
port=int(os.getenv("QDRANT_PORT", "6333"))
|
||
|
)
|
||
|
collection_name = os.getenv("COLLECTION_NAME", "default_collection")
|
||
|
vector_size = int(os.getenv("VECTOR_SIZE", "1024"))
|
||
|
logger.info(f"Qdrant client initialized with host: {os.getenv('QDRANT_HOST')} and collection: {collection_name}")
|
||
|
|
||
|
# Ensure Qdrant collection exists
|
||
|
qdrant_helper.ensure_collection_exists(qdrant_client, collection_name, vector_size)
|
||
|
|
||
|
# Endpoint to upload PDFs
|
||
|
@app.post("/upload")
|
||
|
async def upload_pdfs(files: List[UploadFile] = File(...)):
|
||
|
logger.info("Received upload request")
|
||
|
try:
|
||
|
uploaded_files_count = len(files)
|
||
|
logger.debug(f"Number of files to upload: {uploaded_files_count}")
|
||
|
|
||
|
for file in files:
|
||
|
file_id = str(uuid.uuid4())
|
||
|
file_path = f"./pdfs/{file_id}.pdf"
|
||
|
logger.debug(f"Processing file: {file.filename}, saving as {file_path}")
|
||
|
|
||
|
with open(file_path, "wb") as f:
|
||
|
f.write(await file.read())
|
||
|
logger.debug(f"File {file.filename} saved successfully")
|
||
|
|
||
|
documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
|
||
|
logger.debug(f"Loaded {len(documents)} documents from {file.filename}")
|
||
|
|
||
|
qdrant_helper.index_documents(qdrant_client, collection_name, embed_model, documents)
|
||
|
|
||
|
return {"message": f"{uploaded_files_count} files processed and indexed successfully"}
|
||
|
except Exception as e:
|
||
|
logger.error(f"Error in upload endpoint: {str(e)}", exc_info=True)
|
||
|
raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}")
|
||
|
|
||
|
# Chat endpoint with language, temperature, and reasoning support
|
||
|
@app.post("/chat")
|
||
|
async def chat_inference(chat_request: ChatRequest, http_request: Request, background_tasks: BackgroundTasks):
|
||
|
logger.info("Received chat request")
|
||
|
try:
|
||
|
if not chat_request.messages:
|
||
|
logger.warning("No messages provided in the request")
|
||
|
raise HTTPException(status_code=400, detail="No messages provided")
|
||
|
|
||
|
# Log the complete request object to inspect its contents
|
||
|
logger.debug(f"Complete chat request object: {chat_request.dict()}")
|
||
|
|
||
|
logger.debug(f"Request messages: {chat_request.messages}")
|
||
|
logger.debug(f"Requested language: {chat_request.language}")
|
||
|
logger.debug(f"Requested temperature: {chat_request.temperature}")
|
||
|
logger.debug(f"Requested reasoning: {chat_request.reasoning}")
|
||
|
logger.debug(f"Requested streaming: {chat_request.stream}")
|
||
|
logger.debug(f"Requested personality: {chat_request.personality}")
|
||
|
|
||
|
# Validate language
|
||
|
manage_languages.validate_language(chat_request.language)
|
||
|
|
||
|
# Log più dettagliato della personalità
|
||
|
logger.info(f"Processing request with personality: {chat_request.personality}")
|
||
|
|
||
|
# Validate personality
|
||
|
if chat_request.personality not in ["cool", "cynical", "supportive"]:
|
||
|
logger.warning(f"Invalid personality: {chat_request.personality}, using 'supportive' as default")
|
||
|
chat_request.personality = "supportive"
|
||
|
|
||
|
# Validate temperature
|
||
|
if not (0 < chat_request.temperature < 1):
|
||
|
raise HTTPException(status_code=400, detail="Temperature must be between 0 and 1 (exclusive)")
|
||
|
|
||
|
# Prepare message data
|
||
|
current_message = chat_request.messages[-1].content.lower()
|
||
|
history = chat_request.messages[:-1]
|
||
|
logger.debug(f"Current user message: {current_message}")
|
||
|
logger.debug(f"Message history: {history}")
|
||
|
|
||
|
# Prepare full conversation history as a concatenated string
|
||
|
conversation_history = "\n".join([f"{msg.role}: {msg.content}" for msg in chat_request.messages])
|
||
|
logger.debug(f"Full conversation history: {conversation_history}")
|
||
|
|
||
|
# Detect language if "auto"
|
||
|
if chat_request.language == "auto":
|
||
|
chat_request.language = manage_languages.detect_language(current_message)
|
||
|
logger.info(f"Detected language using inference: {chat_request.language}")
|
||
|
|
||
|
# Check if the query is medical-related
|
||
|
is_medical = is_medical_query(current_message)
|
||
|
logger.debug(f"Is medical-related query? {is_medical}")
|
||
|
|
||
|
# Select LLM and set temperature
|
||
|
selected_llm = prompt_helper.select_llm(llm, llm_reasoning, chat_request.reasoning)
|
||
|
selected_llm.temperature = chat_request.temperature
|
||
|
logger.info(f"Using LLM model: {selected_llm.model} with temperature: {selected_llm.temperature}")
|
||
|
|
||
|
# Retrieve documents from Qdrant
|
||
|
retrieved_docs = qdrant_helper.retrieve_documents(qdrant_client, collection_name, embed_model, current_message)
|
||
|
|
||
|
# Format system prompt with personality - verifico passaggio corretto
|
||
|
system_message_content = prompt_helper.format_system_prompt(
|
||
|
SYSTEM_PROMPT_TEMPLATE,
|
||
|
LANGUAGE_PROMPTS,
|
||
|
chat_request.language,
|
||
|
retrieved_docs,
|
||
|
is_medical,
|
||
|
chat_request.personality, # Confermo passaggio personalità
|
||
|
PERSONALITY_PROMPTS # Confermo passaggio dizionario personalità
|
||
|
)
|
||
|
|
||
|
# Decidiamo se utilizzare lo streaming o la risposta sincrona
|
||
|
if chat_request.stream:
|
||
|
# Streaming response
|
||
|
logger.info("Using streaming response")
|
||
|
|
||
|
async def generate():
|
||
|
full_response = ""
|
||
|
async for content in prompt_helper.perform_inference_streaming(
|
||
|
selected_llm,
|
||
|
system_message_content,
|
||
|
history,
|
||
|
chat_request.messages[-1].content
|
||
|
):
|
||
|
if content:
|
||
|
full_response += content
|
||
|
# Formato SSE standard con \n\n alla fine per delimitare gli eventi
|
||
|
yield f"data: {json.dumps({'content': content, 'full': full_response})}\n\n"
|
||
|
|
||
|
# Log the full conversation and response
|
||
|
background_tasks.add_task(
|
||
|
log_prompt_to_db,
|
||
|
None, # TODO: User ID not available yet
|
||
|
http_request.client.host, # Client's IP address
|
||
|
conversation_history, # Full conversation history
|
||
|
full_response # AI-generated response
|
||
|
)
|
||
|
|
||
|
# Signal the end of the stream con formato SSE consistente
|
||
|
yield f"data: {json.dumps({'done': True})}\n\n"
|
||
|
|
||
|
return StreamingResponse(
|
||
|
generate(),
|
||
|
media_type="text/event-stream; charset=utf-8",
|
||
|
headers={
|
||
|
"Cache-Control": "no-cache",
|
||
|
"Connection": "keep-alive",
|
||
|
"X-Accel-Buffering": "no",
|
||
|
"Content-Type": "text/event-stream; charset=utf-8",
|
||
|
}
|
||
|
)
|
||
|
else:
|
||
|
# Non-streaming response
|
||
|
logger.info("Using non-streaming response")
|
||
|
|
||
|
response_content = prompt_helper.perform_inference(
|
||
|
selected_llm,
|
||
|
system_message_content,
|
||
|
history,
|
||
|
chat_request.messages[-1].content,
|
||
|
stream=False
|
||
|
)
|
||
|
|
||
|
# Log the full conversation and response in the background
|
||
|
background_tasks.add_task(
|
||
|
log_prompt_to_db,
|
||
|
None, # TODO: User ID not available yet
|
||
|
http_request.client.host, # Client's IP address
|
||
|
conversation_history, # Full conversation history
|
||
|
response_content # AI-generated response
|
||
|
)
|
||
|
|
||
|
return {"response": response_content}
|
||
|
|
||
|
except Exception as e:
|
||
|
logger.error(f"Error in chat inference: {str(e)}", exc_info=True)
|
||
|
raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}")
|
||
|
|
||
|
# Add new DELETE endpoint to clear all documents
|
||
|
@app.delete("/docs")
|
||
|
async def delete_all_docs():
|
||
|
logger.info("Received request to delete all documents")
|
||
|
try:
|
||
|
qdrant_helper.delete_all_documents(qdrant_client, collection_name, vector_size)
|
||
|
return {"message": "All documents have been deleted from the database"}
|
||
|
except Exception as e:
|
||
|
logger.error(f"Error in delete endpoint: {str(e)}", exc_info=True)
|
||
|
raise HTTPException(status_code=500, detail=f"Error deleting documents: {str(e)}")
|