first commit
This commit is contained in:
parent
5c5d88c92f
commit
eb4f62c56d
41 changed files with 3851 additions and 19 deletions
192
docker_svc/agent/app/libs/prompt_helper.py
Normal file
192
docker_svc/agent/app/libs/prompt_helper.py
Normal file
|
@ -0,0 +1,192 @@
|
|||
from llama_index.core.base.llms.types import ChatMessage as LlamaChatMessage
|
||||
import logging
|
||||
from libs.models import ChatMessage
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
import httpx
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def format_system_prompt(system_prompt_template: str, language_prompts: dict, language: str,
|
||||
retrieved_docs: str, is_medical: bool, personality: str = "supportive",
|
||||
personality_prompts: dict = {}) -> str:
|
||||
"""Formatta il prompt di sistema con il contenuto specifico della lingua, personalità e i documenti recuperati."""
|
||||
language_prompt = language_prompts[language]["prompt"]
|
||||
language_disclaimer = language_prompts[language]["disclaimer"]
|
||||
language_constraint = "" if language == "auto" else language_prompts[language]["constraint"]
|
||||
|
||||
# Miglioro il log e la gestione della personalità
|
||||
if personality not in personality_prompts:
|
||||
logger.warning(f"Personality '{personality}' not found in prompts, using default empty prompt")
|
||||
personality_prompt = ""
|
||||
else:
|
||||
personality_prompt = personality_prompts[personality]["prompt"]
|
||||
logger.info(f"Using '{personality}' personality: {personality_prompts[personality]['description'][:50]}...")
|
||||
|
||||
logger.info(f"Formatting system prompt with language {language}, personality {personality}")
|
||||
system_message_content = system_prompt_template.format(
|
||||
language_prompt=language_prompt,
|
||||
context=retrieved_docs,
|
||||
language_disclaimer=language_disclaimer if is_medical else "",
|
||||
personality_prompt=personality_prompt,
|
||||
language_constraint=language_constraint
|
||||
)
|
||||
logger.debug(f"System message content: {system_message_content[:200]}...")
|
||||
return system_message_content
|
||||
|
||||
async def perform_inference_streaming(
|
||||
llm,
|
||||
system_message: str,
|
||||
history: List[Dict],
|
||||
current_message: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream inference results from Ollama API"""
|
||||
base_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
|
||||
|
||||
# Prepare messages for Ollama API
|
||||
messages = []
|
||||
|
||||
# Add system message
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": system_message
|
||||
})
|
||||
|
||||
# Add history
|
||||
for msg in history:
|
||||
messages.append({
|
||||
"role": "user" if msg.role == "user" else "assistant",
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
# Add current user message
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": current_message
|
||||
})
|
||||
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
"model": llm.model,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"options": {
|
||||
"temperature": llm.temperature
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(f"Sending streaming request to Ollama API: {base_url}/api/chat")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream("POST", f"{base_url}/api/chat", json=payload, timeout=60.0) as response:
|
||||
if response.status_code != 200:
|
||||
error_detail = await response.aread()
|
||||
logger.error(f"Error from Ollama API: {response.status_code}, {error_detail}")
|
||||
yield f"Error: Failed to get response from language model (Status {response.status_code})"
|
||||
return
|
||||
|
||||
# Variable to accumulate the full response
|
||||
full_response = ""
|
||||
|
||||
# Process the streaming response
|
||||
async for chunk in response.aiter_text():
|
||||
if not chunk.strip():
|
||||
continue
|
||||
|
||||
# Each chunk might contain one JSON object
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
# Process message content if available
|
||||
if 'message' in data and 'content' in data['message']:
|
||||
content = data['message']['content']
|
||||
full_response += content
|
||||
yield content
|
||||
|
||||
# Check if this is the final message with done flag
|
||||
if data.get('done', False):
|
||||
logger.debug("Streaming response completed")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse streaming response: {e}, chunk: {chunk}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during streaming inference: {str(e)}")
|
||||
yield f"Error: {str(e)}"
|
||||
|
||||
# Return empty string at the end to signal completion
|
||||
yield ""
|
||||
|
||||
def perform_inference(
|
||||
llm,
|
||||
system_message: str,
|
||||
history: List[Dict],
|
||||
current_message: str,
|
||||
stream: bool = False
|
||||
) -> str:
|
||||
"""Perform inference with the given LLM."""
|
||||
if stream:
|
||||
# This will be handled by the streaming endpoint
|
||||
raise ValueError("Streaming not supported in synchronous inference")
|
||||
|
||||
# Prepare messages for the API
|
||||
messages = []
|
||||
|
||||
# Add system message
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": system_message
|
||||
})
|
||||
|
||||
# Add history
|
||||
for msg in history:
|
||||
messages.append({
|
||||
"role": "user" if msg.role == "user" else "assistant",
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
# Add current user message
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": current_message
|
||||
})
|
||||
|
||||
# For non-streaming, we'll use the httpx client directly to call Ollama API
|
||||
base_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
|
||||
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
"model": llm.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": llm.temperature
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(f"Sending non-streaming request to Ollama API: {base_url}/api/chat")
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(f"{base_url}/api/chat", json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Error from Ollama API: {response.status_code}, {response.text}")
|
||||
return f"Error: Failed to get response from language model (Status {response.status_code})"
|
||||
|
||||
data = response.json()
|
||||
if 'message' in data and 'content' in data['message']:
|
||||
return data['message']['content']
|
||||
else:
|
||||
logger.error(f"Unexpected response format: {data}")
|
||||
return "Error: Unexpected response format from language model"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during non-streaming inference: {str(e)}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def select_llm(llm, llm_reasoning, reasoning: bool):
|
||||
"""Select the LLM model based on the reasoning flag."""
|
||||
selected_llm = llm_reasoning if reasoning else llm
|
||||
return selected_llm
|
Loading…
Add table
Add a link
Reference in a new issue