from fastapi import FastAPI, File, UploadFile, HTTPException, Request, BackgroundTasks from fastapi.responses import JSONResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi from llama_index.core import SimpleDirectoryReader from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.llms.ollama import Ollama from typing import Literal, List from pydantic import BaseModel from langdetect import DetectorFactory from qdrant_client import QdrantClient import os from typing import List import uuid import yaml from dotenv import load_dotenv import logging import asyncio import json from libs.check_medical import is_medical_query import libs.manage_languages as manage_languages import libs.qdrant_helper as qdrant_helper from libs.models import ChatMessage, ChatRequest import libs.prompt_helper as prompt_helper from libs.log_prompts import log_prompt_to_db # Set seed for reproducibility of language detection DetectorFactory.seed = 0 # Configure logging logging.basicConfig( level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Load environment variables from .env file load_dotenv() # Initialize FastAPI app app = FastAPI( docs_url="/docs", redoc_url="/redoc", max_request_body_size=100 * 1024 * 1024 # 100MB ) # Get CORS origins from environment or use default cors_origins = os.getenv("CORS_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(",") # Add CORS middleware with proper configuration app.add_middleware( CORSMiddleware, allow_origins=cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["Content-Type", "X-Content-Type-Options"], max_age=600, # 10 minutes for preflight cache ) # Load custom OpenAPI schema def load_custom_openapi(): with open("openapi.json", "r") as f: custom_openapi = yaml.safe_load(f) default_openapi = get_openapi( title=app.title, version=app.version, openapi_version=app.openapi_version, description=app.description, routes=app.routes, ) default_openapi["info"] = custom_openapi.get("info", default_openapi["info"]) default_openapi["paths"].update(custom_openapi.get("paths", {})) return default_openapi app.openapi = load_custom_openapi with open("prompts.yaml", "r") as f: prompts = yaml.safe_load(f) SYSTEM_PROMPT_TEMPLATE = prompts["system_prompt"] LANGUAGE_PROMPTS = prompts["languages"] PERSONALITY_PROMPTS = prompts["personalities"] # Configuration of models and services using .env variables OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "https://ollama.kube-ext.isc.heia-fr.ch") logger.info(f"Starting application with OLLAMA_BASE_URL: {OLLAMA_BASE_URL}") # Embedding model using Ollama embed_model = OllamaEmbedding( model_name=os.getenv("EMBED_MODEL_NAME", "mxbai-embed-large"), base_url=OLLAMA_BASE_URL, request_timeout=os.getenv("TIMEOUT_REQUEST_EMBED", 20.0) ) logger.info("OllamaEmbedding initialized with model: " + os.getenv("EMBED_MODEL_NAME", "mxbai-embed-large")) # Direct inference model llm = Ollama( model=os.getenv("LLM_MODEL_NAME", "llama3"), base_url=OLLAMA_BASE_URL, temperature=float(os.getenv("TEMPERATURE", "0.7")), request_timeout=os.getenv("TIMEOUT_REQUEST_CHAT_DIRECT", 30.0) ) logger.info(f"Ollama LLM initialized with model: {llm.model} " f"with temperature: {llm.temperature}") # Reasoning model llm_reasoning = Ollama( model=os.getenv("LLM_MODEL_NAME_THINKING", "deepseek-r1:14b"), base_url=OLLAMA_BASE_URL, temperature=float(os.getenv("TEMPERATURE", "0.7")), request_timeout=os.getenv("TIMEOUT_REQUEST_CHAT_REASON", 60.0) ) logger.info(f"Ollama reasoning LLM initialized with model: {llm_reasoning.model} " f"with temperature: {llm_reasoning.temperature}") # Qdrant configuration qdrant_client = QdrantClient( host=os.getenv("QDRANT_HOST", "localhost"), port=int(os.getenv("QDRANT_PORT", "6333")) ) collection_name = os.getenv("COLLECTION_NAME", "default_collection") vector_size = int(os.getenv("VECTOR_SIZE", "1024")) logger.info(f"Qdrant client initialized with host: {os.getenv('QDRANT_HOST')} and collection: {collection_name}") # Ensure Qdrant collection exists qdrant_helper.ensure_collection_exists(qdrant_client, collection_name, vector_size) # Endpoint to upload PDFs @app.post("/upload") async def upload_pdfs(files: List[UploadFile] = File(...)): logger.info("Received upload request") try: uploaded_files_count = len(files) logger.debug(f"Number of files to upload: {uploaded_files_count}") for file in files: file_id = str(uuid.uuid4()) file_path = f"./pdfs/{file_id}.pdf" logger.debug(f"Processing file: {file.filename}, saving as {file_path}") with open(file_path, "wb") as f: f.write(await file.read()) logger.debug(f"File {file.filename} saved successfully") documents = SimpleDirectoryReader(input_files=[file_path]).load_data() logger.debug(f"Loaded {len(documents)} documents from {file.filename}") qdrant_helper.index_documents(qdrant_client, collection_name, embed_model, documents) return {"message": f"{uploaded_files_count} files processed and indexed successfully"} except Exception as e: logger.error(f"Error in upload endpoint: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}") # Chat endpoint with language, temperature, and reasoning support @app.post("/chat") async def chat_inference(chat_request: ChatRequest, http_request: Request, background_tasks: BackgroundTasks): logger.info("Received chat request") try: if not chat_request.messages: logger.warning("No messages provided in the request") raise HTTPException(status_code=400, detail="No messages provided") # Log the complete request object to inspect its contents logger.debug(f"Complete chat request object: {chat_request.dict()}") logger.debug(f"Request messages: {chat_request.messages}") logger.debug(f"Requested language: {chat_request.language}") logger.debug(f"Requested temperature: {chat_request.temperature}") logger.debug(f"Requested reasoning: {chat_request.reasoning}") logger.debug(f"Requested streaming: {chat_request.stream}") logger.debug(f"Requested personality: {chat_request.personality}") # Validate language manage_languages.validate_language(chat_request.language) # Log più dettagliato della personalità logger.info(f"Processing request with personality: {chat_request.personality}") # Validate personality if chat_request.personality not in ["cool", "cynical", "supportive"]: logger.warning(f"Invalid personality: {chat_request.personality}, using 'supportive' as default") chat_request.personality = "supportive" # Validate temperature if not (0 < chat_request.temperature < 1): raise HTTPException(status_code=400, detail="Temperature must be between 0 and 1 (exclusive)") # Prepare message data current_message = chat_request.messages[-1].content.lower() history = chat_request.messages[:-1] logger.debug(f"Current user message: {current_message}") logger.debug(f"Message history: {history}") # Prepare full conversation history as a concatenated string conversation_history = "\n".join([f"{msg.role}: {msg.content}" for msg in chat_request.messages]) logger.debug(f"Full conversation history: {conversation_history}") # Detect language if "auto" if chat_request.language == "auto": chat_request.language = manage_languages.detect_language(current_message) logger.info(f"Detected language using inference: {chat_request.language}") # Check if the query is medical-related is_medical = is_medical_query(current_message) logger.debug(f"Is medical-related query? {is_medical}") # Select LLM and set temperature selected_llm = prompt_helper.select_llm(llm, llm_reasoning, chat_request.reasoning) selected_llm.temperature = chat_request.temperature logger.info(f"Using LLM model: {selected_llm.model} with temperature: {selected_llm.temperature}") # Retrieve documents from Qdrant retrieved_docs = qdrant_helper.retrieve_documents(qdrant_client, collection_name, embed_model, current_message) # Format system prompt with personality - verifico passaggio corretto system_message_content = prompt_helper.format_system_prompt( SYSTEM_PROMPT_TEMPLATE, LANGUAGE_PROMPTS, chat_request.language, retrieved_docs, is_medical, chat_request.personality, # Confermo passaggio personalità PERSONALITY_PROMPTS # Confermo passaggio dizionario personalità ) # Decidiamo se utilizzare lo streaming o la risposta sincrona if chat_request.stream: # Streaming response logger.info("Using streaming response") async def generate(): full_response = "" async for content in prompt_helper.perform_inference_streaming( selected_llm, system_message_content, history, chat_request.messages[-1].content ): if content: full_response += content # Formato SSE standard con \n\n alla fine per delimitare gli eventi yield f"data: {json.dumps({'content': content, 'full': full_response})}\n\n" # Log the full conversation and response background_tasks.add_task( log_prompt_to_db, None, # TODO: User ID not available yet http_request.client.host, # Client's IP address conversation_history, # Full conversation history full_response # AI-generated response ) # Signal the end of the stream con formato SSE consistente yield f"data: {json.dumps({'done': True})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream; charset=utf-8", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", "Content-Type": "text/event-stream; charset=utf-8", } ) else: # Non-streaming response logger.info("Using non-streaming response") response_content = prompt_helper.perform_inference( selected_llm, system_message_content, history, chat_request.messages[-1].content, stream=False ) # Log the full conversation and response in the background background_tasks.add_task( log_prompt_to_db, None, # TODO: User ID not available yet http_request.client.host, # Client's IP address conversation_history, # Full conversation history response_content # AI-generated response ) return {"response": response_content} except Exception as e: logger.error(f"Error in chat inference: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}") # Add new DELETE endpoint to clear all documents @app.delete("/docs") async def delete_all_docs(): logger.info("Received request to delete all documents") try: qdrant_helper.delete_all_documents(qdrant_client, collection_name, vector_size) return {"message": "All documents have been deleted from the database"} except Exception as e: logger.error(f"Error in delete endpoint: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Error deleting documents: {str(e)}")