first commit
This commit is contained in:
parent
5c5d88c92f
commit
eb4f62c56d
41 changed files with 3851 additions and 19 deletions
304
docker_svc/agent/app/main.py
Normal file
304
docker_svc/agent/app/main.py
Normal file
|
@ -0,0 +1,304 @@
|
|||
from fastapi import FastAPI, File, UploadFile, HTTPException, Request, BackgroundTasks
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from typing import Literal, List
|
||||
from pydantic import BaseModel
|
||||
from langdetect import DetectorFactory
|
||||
from qdrant_client import QdrantClient
|
||||
import os
|
||||
from typing import List
|
||||
import uuid
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from libs.check_medical import is_medical_query
|
||||
import libs.manage_languages as manage_languages
|
||||
import libs.qdrant_helper as qdrant_helper
|
||||
from libs.models import ChatMessage, ChatRequest
|
||||
import libs.prompt_helper as prompt_helper
|
||||
from libs.log_prompts import log_prompt_to_db
|
||||
|
||||
# Set seed for reproducibility of language detection
|
||||
DetectorFactory.seed = 0
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
max_request_body_size=100 * 1024 * 1024 # 100MB
|
||||
)
|
||||
|
||||
# Get CORS origins from environment or use default
|
||||
cors_origins = os.getenv("CORS_ORIGINS", "http://localhost:3000,http://127.0.0.1:3000").split(",")
|
||||
|
||||
# Add CORS middleware with proper configuration
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["Content-Type", "X-Content-Type-Options"],
|
||||
max_age=600, # 10 minutes for preflight cache
|
||||
)
|
||||
|
||||
# Load custom OpenAPI schema
|
||||
def load_custom_openapi():
|
||||
with open("openapi.json", "r") as f:
|
||||
custom_openapi = yaml.safe_load(f)
|
||||
default_openapi = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
openapi_version=app.openapi_version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
default_openapi["info"] = custom_openapi.get("info", default_openapi["info"])
|
||||
default_openapi["paths"].update(custom_openapi.get("paths", {}))
|
||||
return default_openapi
|
||||
|
||||
app.openapi = load_custom_openapi
|
||||
|
||||
with open("prompts.yaml", "r") as f:
|
||||
prompts = yaml.safe_load(f)
|
||||
SYSTEM_PROMPT_TEMPLATE = prompts["system_prompt"]
|
||||
LANGUAGE_PROMPTS = prompts["languages"]
|
||||
PERSONALITY_PROMPTS = prompts["personalities"]
|
||||
|
||||
# Configuration of models and services using .env variables
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "https://ollama.kube-ext.isc.heia-fr.ch")
|
||||
logger.info(f"Starting application with OLLAMA_BASE_URL: {OLLAMA_BASE_URL}")
|
||||
|
||||
# Embedding model using Ollama
|
||||
embed_model = OllamaEmbedding(
|
||||
model_name=os.getenv("EMBED_MODEL_NAME", "mxbai-embed-large"),
|
||||
base_url=OLLAMA_BASE_URL,
|
||||
request_timeout=os.getenv("TIMEOUT_REQUEST_EMBED", 20.0)
|
||||
)
|
||||
logger.info("OllamaEmbedding initialized with model: " + os.getenv("EMBED_MODEL_NAME", "mxbai-embed-large"))
|
||||
|
||||
# Direct inference model
|
||||
llm = Ollama(
|
||||
model=os.getenv("LLM_MODEL_NAME", "llama3"),
|
||||
base_url=OLLAMA_BASE_URL,
|
||||
temperature=float(os.getenv("TEMPERATURE", "0.7")),
|
||||
request_timeout=os.getenv("TIMEOUT_REQUEST_CHAT_DIRECT", 30.0)
|
||||
)
|
||||
logger.info(f"Ollama LLM initialized with model: {llm.model} "
|
||||
f"with temperature: {llm.temperature}")
|
||||
|
||||
# Reasoning model
|
||||
llm_reasoning = Ollama(
|
||||
model=os.getenv("LLM_MODEL_NAME_THINKING", "deepseek-r1:14b"),
|
||||
base_url=OLLAMA_BASE_URL,
|
||||
temperature=float(os.getenv("TEMPERATURE", "0.7")),
|
||||
request_timeout=os.getenv("TIMEOUT_REQUEST_CHAT_REASON", 60.0)
|
||||
)
|
||||
logger.info(f"Ollama reasoning LLM initialized with model: {llm_reasoning.model} "
|
||||
f"with temperature: {llm_reasoning.temperature}")
|
||||
|
||||
# Qdrant configuration
|
||||
qdrant_client = QdrantClient(
|
||||
host=os.getenv("QDRANT_HOST", "localhost"),
|
||||
port=int(os.getenv("QDRANT_PORT", "6333"))
|
||||
)
|
||||
collection_name = os.getenv("COLLECTION_NAME", "default_collection")
|
||||
vector_size = int(os.getenv("VECTOR_SIZE", "1024"))
|
||||
logger.info(f"Qdrant client initialized with host: {os.getenv('QDRANT_HOST')} and collection: {collection_name}")
|
||||
|
||||
# Ensure Qdrant collection exists
|
||||
qdrant_helper.ensure_collection_exists(qdrant_client, collection_name, vector_size)
|
||||
|
||||
# Endpoint to upload PDFs
|
||||
@app.post("/upload")
|
||||
async def upload_pdfs(files: List[UploadFile] = File(...)):
|
||||
logger.info("Received upload request")
|
||||
try:
|
||||
uploaded_files_count = len(files)
|
||||
logger.debug(f"Number of files to upload: {uploaded_files_count}")
|
||||
|
||||
for file in files:
|
||||
file_id = str(uuid.uuid4())
|
||||
file_path = f"./pdfs/{file_id}.pdf"
|
||||
logger.debug(f"Processing file: {file.filename}, saving as {file_path}")
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(await file.read())
|
||||
logger.debug(f"File {file.filename} saved successfully")
|
||||
|
||||
documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
|
||||
logger.debug(f"Loaded {len(documents)} documents from {file.filename}")
|
||||
|
||||
qdrant_helper.index_documents(qdrant_client, collection_name, embed_model, documents)
|
||||
|
||||
return {"message": f"{uploaded_files_count} files processed and indexed successfully"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in upload endpoint: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}")
|
||||
|
||||
# Chat endpoint with language, temperature, and reasoning support
|
||||
@app.post("/chat")
|
||||
async def chat_inference(chat_request: ChatRequest, http_request: Request, background_tasks: BackgroundTasks):
|
||||
logger.info("Received chat request")
|
||||
try:
|
||||
if not chat_request.messages:
|
||||
logger.warning("No messages provided in the request")
|
||||
raise HTTPException(status_code=400, detail="No messages provided")
|
||||
|
||||
# Log the complete request object to inspect its contents
|
||||
logger.debug(f"Complete chat request object: {chat_request.dict()}")
|
||||
|
||||
logger.debug(f"Request messages: {chat_request.messages}")
|
||||
logger.debug(f"Requested language: {chat_request.language}")
|
||||
logger.debug(f"Requested temperature: {chat_request.temperature}")
|
||||
logger.debug(f"Requested reasoning: {chat_request.reasoning}")
|
||||
logger.debug(f"Requested streaming: {chat_request.stream}")
|
||||
logger.debug(f"Requested personality: {chat_request.personality}")
|
||||
|
||||
# Validate language
|
||||
manage_languages.validate_language(chat_request.language)
|
||||
|
||||
# Log più dettagliato della personalità
|
||||
logger.info(f"Processing request with personality: {chat_request.personality}")
|
||||
|
||||
# Validate personality
|
||||
if chat_request.personality not in ["cool", "cynical", "supportive"]:
|
||||
logger.warning(f"Invalid personality: {chat_request.personality}, using 'supportive' as default")
|
||||
chat_request.personality = "supportive"
|
||||
|
||||
# Validate temperature
|
||||
if not (0 < chat_request.temperature < 1):
|
||||
raise HTTPException(status_code=400, detail="Temperature must be between 0 and 1 (exclusive)")
|
||||
|
||||
# Prepare message data
|
||||
current_message = chat_request.messages[-1].content.lower()
|
||||
history = chat_request.messages[:-1]
|
||||
logger.debug(f"Current user message: {current_message}")
|
||||
logger.debug(f"Message history: {history}")
|
||||
|
||||
# Prepare full conversation history as a concatenated string
|
||||
conversation_history = "\n".join([f"{msg.role}: {msg.content}" for msg in chat_request.messages])
|
||||
logger.debug(f"Full conversation history: {conversation_history}")
|
||||
|
||||
# Detect language if "auto"
|
||||
if chat_request.language == "auto":
|
||||
chat_request.language = manage_languages.detect_language(current_message)
|
||||
logger.info(f"Detected language using inference: {chat_request.language}")
|
||||
|
||||
# Check if the query is medical-related
|
||||
is_medical = is_medical_query(current_message)
|
||||
logger.debug(f"Is medical-related query? {is_medical}")
|
||||
|
||||
# Select LLM and set temperature
|
||||
selected_llm = prompt_helper.select_llm(llm, llm_reasoning, chat_request.reasoning)
|
||||
selected_llm.temperature = chat_request.temperature
|
||||
logger.info(f"Using LLM model: {selected_llm.model} with temperature: {selected_llm.temperature}")
|
||||
|
||||
# Retrieve documents from Qdrant
|
||||
retrieved_docs = qdrant_helper.retrieve_documents(qdrant_client, collection_name, embed_model, current_message)
|
||||
|
||||
# Format system prompt with personality - verifico passaggio corretto
|
||||
system_message_content = prompt_helper.format_system_prompt(
|
||||
SYSTEM_PROMPT_TEMPLATE,
|
||||
LANGUAGE_PROMPTS,
|
||||
chat_request.language,
|
||||
retrieved_docs,
|
||||
is_medical,
|
||||
chat_request.personality, # Confermo passaggio personalità
|
||||
PERSONALITY_PROMPTS # Confermo passaggio dizionario personalità
|
||||
)
|
||||
|
||||
# Decidiamo se utilizzare lo streaming o la risposta sincrona
|
||||
if chat_request.stream:
|
||||
# Streaming response
|
||||
logger.info("Using streaming response")
|
||||
|
||||
async def generate():
|
||||
full_response = ""
|
||||
async for content in prompt_helper.perform_inference_streaming(
|
||||
selected_llm,
|
||||
system_message_content,
|
||||
history,
|
||||
chat_request.messages[-1].content
|
||||
):
|
||||
if content:
|
||||
full_response += content
|
||||
# Formato SSE standard con \n\n alla fine per delimitare gli eventi
|
||||
yield f"data: {json.dumps({'content': content, 'full': full_response})}\n\n"
|
||||
|
||||
# Log the full conversation and response
|
||||
background_tasks.add_task(
|
||||
log_prompt_to_db,
|
||||
None, # TODO: User ID not available yet
|
||||
http_request.client.host, # Client's IP address
|
||||
conversation_history, # Full conversation history
|
||||
full_response # AI-generated response
|
||||
)
|
||||
|
||||
# Signal the end of the stream con formato SSE consistente
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream; charset=utf-8",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Type": "text/event-stream; charset=utf-8",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Non-streaming response
|
||||
logger.info("Using non-streaming response")
|
||||
|
||||
response_content = prompt_helper.perform_inference(
|
||||
selected_llm,
|
||||
system_message_content,
|
||||
history,
|
||||
chat_request.messages[-1].content,
|
||||
stream=False
|
||||
)
|
||||
|
||||
# Log the full conversation and response in the background
|
||||
background_tasks.add_task(
|
||||
log_prompt_to_db,
|
||||
None, # TODO: User ID not available yet
|
||||
http_request.client.host, # Client's IP address
|
||||
conversation_history, # Full conversation history
|
||||
response_content # AI-generated response
|
||||
)
|
||||
|
||||
return {"response": response_content}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat inference: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}")
|
||||
|
||||
# Add new DELETE endpoint to clear all documents
|
||||
@app.delete("/docs")
|
||||
async def delete_all_docs():
|
||||
logger.info("Received request to delete all documents")
|
||||
try:
|
||||
qdrant_helper.delete_all_documents(qdrant_client, collection_name, vector_size)
|
||||
return {"message": "All documents have been deleted from the database"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in delete endpoint: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting documents: {str(e)}")
|
Loading…
Add table
Add a link
Reference in a new issue