192 lines
7 KiB
Python
192 lines
7 KiB
Python
|
from llama_index.core.base.llms.types import ChatMessage as LlamaChatMessage
|
||
|
import logging
|
||
|
from libs.models import ChatMessage
|
||
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||
|
import httpx
|
||
|
import json
|
||
|
import os
|
||
|
import asyncio
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
def format_system_prompt(system_prompt_template: str, language_prompts: dict, language: str,
|
||
|
retrieved_docs: str, is_medical: bool, personality: str = "supportive",
|
||
|
personality_prompts: dict = {}) -> str:
|
||
|
"""Formatta il prompt di sistema con il contenuto specifico della lingua, personalità e i documenti recuperati."""
|
||
|
language_prompt = language_prompts[language]["prompt"]
|
||
|
language_disclaimer = language_prompts[language]["disclaimer"]
|
||
|
language_constraint = "" if language == "auto" else language_prompts[language]["constraint"]
|
||
|
|
||
|
# Miglioro il log e la gestione della personalità
|
||
|
if personality not in personality_prompts:
|
||
|
logger.warning(f"Personality '{personality}' not found in prompts, using default empty prompt")
|
||
|
personality_prompt = ""
|
||
|
else:
|
||
|
personality_prompt = personality_prompts[personality]["prompt"]
|
||
|
logger.info(f"Using '{personality}' personality: {personality_prompts[personality]['description'][:50]}...")
|
||
|
|
||
|
logger.info(f"Formatting system prompt with language {language}, personality {personality}")
|
||
|
system_message_content = system_prompt_template.format(
|
||
|
language_prompt=language_prompt,
|
||
|
context=retrieved_docs,
|
||
|
language_disclaimer=language_disclaimer if is_medical else "",
|
||
|
personality_prompt=personality_prompt,
|
||
|
language_constraint=language_constraint
|
||
|
)
|
||
|
logger.debug(f"System message content: {system_message_content[:200]}...")
|
||
|
return system_message_content
|
||
|
|
||
|
async def perform_inference_streaming(
|
||
|
llm,
|
||
|
system_message: str,
|
||
|
history: List[Dict],
|
||
|
current_message: str
|
||
|
) -> AsyncGenerator[str, None]:
|
||
|
"""Stream inference results from Ollama API"""
|
||
|
base_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
|
||
|
|
||
|
# Prepare messages for Ollama API
|
||
|
messages = []
|
||
|
|
||
|
# Add system message
|
||
|
messages.append({
|
||
|
"role": "system",
|
||
|
"content": system_message
|
||
|
})
|
||
|
|
||
|
# Add history
|
||
|
for msg in history:
|
||
|
messages.append({
|
||
|
"role": "user" if msg.role == "user" else "assistant",
|
||
|
"content": msg.content
|
||
|
})
|
||
|
|
||
|
# Add current user message
|
||
|
messages.append({
|
||
|
"role": "user",
|
||
|
"content": current_message
|
||
|
})
|
||
|
|
||
|
# Prepare request payload
|
||
|
payload = {
|
||
|
"model": llm.model,
|
||
|
"messages": messages,
|
||
|
"stream": True,
|
||
|
"options": {
|
||
|
"temperature": llm.temperature
|
||
|
}
|
||
|
}
|
||
|
|
||
|
logger.debug(f"Sending streaming request to Ollama API: {base_url}/api/chat")
|
||
|
|
||
|
try:
|
||
|
async with httpx.AsyncClient() as client:
|
||
|
async with client.stream("POST", f"{base_url}/api/chat", json=payload, timeout=60.0) as response:
|
||
|
if response.status_code != 200:
|
||
|
error_detail = await response.aread()
|
||
|
logger.error(f"Error from Ollama API: {response.status_code}, {error_detail}")
|
||
|
yield f"Error: Failed to get response from language model (Status {response.status_code})"
|
||
|
return
|
||
|
|
||
|
# Variable to accumulate the full response
|
||
|
full_response = ""
|
||
|
|
||
|
# Process the streaming response
|
||
|
async for chunk in response.aiter_text():
|
||
|
if not chunk.strip():
|
||
|
continue
|
||
|
|
||
|
# Each chunk might contain one JSON object
|
||
|
try:
|
||
|
data = json.loads(chunk)
|
||
|
# Process message content if available
|
||
|
if 'message' in data and 'content' in data['message']:
|
||
|
content = data['message']['content']
|
||
|
full_response += content
|
||
|
yield content
|
||
|
|
||
|
# Check if this is the final message with done flag
|
||
|
if data.get('done', False):
|
||
|
logger.debug("Streaming response completed")
|
||
|
except json.JSONDecodeError as e:
|
||
|
logger.error(f"Failed to parse streaming response: {e}, chunk: {chunk}")
|
||
|
|
||
|
except Exception as e:
|
||
|
logger.error(f"Error during streaming inference: {str(e)}")
|
||
|
yield f"Error: {str(e)}"
|
||
|
|
||
|
# Return empty string at the end to signal completion
|
||
|
yield ""
|
||
|
|
||
|
def perform_inference(
|
||
|
llm,
|
||
|
system_message: str,
|
||
|
history: List[Dict],
|
||
|
current_message: str,
|
||
|
stream: bool = False
|
||
|
) -> str:
|
||
|
"""Perform inference with the given LLM."""
|
||
|
if stream:
|
||
|
# This will be handled by the streaming endpoint
|
||
|
raise ValueError("Streaming not supported in synchronous inference")
|
||
|
|
||
|
# Prepare messages for the API
|
||
|
messages = []
|
||
|
|
||
|
# Add system message
|
||
|
messages.append({
|
||
|
"role": "system",
|
||
|
"content": system_message
|
||
|
})
|
||
|
|
||
|
# Add history
|
||
|
for msg in history:
|
||
|
messages.append({
|
||
|
"role": "user" if msg.role == "user" else "assistant",
|
||
|
"content": msg.content
|
||
|
})
|
||
|
|
||
|
# Add current user message
|
||
|
messages.append({
|
||
|
"role": "user",
|
||
|
"content": current_message
|
||
|
})
|
||
|
|
||
|
# For non-streaming, we'll use the httpx client directly to call Ollama API
|
||
|
base_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
|
||
|
|
||
|
# Prepare request payload
|
||
|
payload = {
|
||
|
"model": llm.model,
|
||
|
"messages": messages,
|
||
|
"stream": False,
|
||
|
"options": {
|
||
|
"temperature": llm.temperature
|
||
|
}
|
||
|
}
|
||
|
|
||
|
logger.debug(f"Sending non-streaming request to Ollama API: {base_url}/api/chat")
|
||
|
|
||
|
try:
|
||
|
with httpx.Client(timeout=60.0) as client:
|
||
|
response = client.post(f"{base_url}/api/chat", json=payload)
|
||
|
|
||
|
if response.status_code != 200:
|
||
|
logger.error(f"Error from Ollama API: {response.status_code}, {response.text}")
|
||
|
return f"Error: Failed to get response from language model (Status {response.status_code})"
|
||
|
|
||
|
data = response.json()
|
||
|
if 'message' in data and 'content' in data['message']:
|
||
|
return data['message']['content']
|
||
|
else:
|
||
|
logger.error(f"Unexpected response format: {data}")
|
||
|
return "Error: Unexpected response format from language model"
|
||
|
|
||
|
except Exception as e:
|
||
|
logger.error(f"Error during non-streaming inference: {str(e)}")
|
||
|
return f"Error: {str(e)}"
|
||
|
|
||
|
def select_llm(llm, llm_reasoning, reasoning: bool):
|
||
|
"""Select the LLM model based on the reasoning flag."""
|
||
|
selected_llm = llm_reasoning if reasoning else llm
|
||
|
return selected_llm
|