non-building checkpoint 1
This commit is contained in:
337
backend/app/services/ai_service.py
Normal file
337
backend/app/services/ai_service.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
AI service for handling OpenAI/DeepSeek API calls.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
class AIService:
|
||||
"""Service for handling AI API calls."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize AI service."""
|
||||
api_key = settings.DEEPSEEK_API_KEY or settings.OPENAI_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("No API key found. Set DEEPSEEK_API_KEY or OPENAI_API_KEY in environment.")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=settings.API_BASE_URL
|
||||
)
|
||||
self.model = settings.MODEL
|
||||
|
||||
def _clean_ai_response(self, response_content: str) -> str:
|
||||
"""
|
||||
Clean up AI response content to handle common formatting issues.
|
||||
|
||||
Handles:
|
||||
1. Leading/trailing backticks (```json ... ```)
|
||||
2. Leading "json" string on its own line
|
||||
3. Extra whitespace and newlines
|
||||
"""
|
||||
content = response_content.strip()
|
||||
|
||||
# Remove leading/trailing backticks (```json ... ```)
|
||||
if content.startswith('```'):
|
||||
lines = content.split('\n')
|
||||
if len(lines) > 1:
|
||||
first_line = lines[0].strip()
|
||||
if 'json' in first_line.lower() or first_line == '```':
|
||||
content = '\n'.join(lines[1:])
|
||||
|
||||
# Remove trailing backticks if present
|
||||
if content.endswith('```'):
|
||||
content = content[:-3].rstrip()
|
||||
|
||||
# Remove leading "json" string on its own line (case-insensitive)
|
||||
lines = content.split('\n')
|
||||
if len(lines) > 0:
|
||||
first_line = lines[0].strip().lower()
|
||||
if first_line == 'json':
|
||||
content = '\n'.join(lines[1:])
|
||||
|
||||
# Also handle the case where "json" might be at the beginning of the first line
|
||||
content = content.strip()
|
||||
if content.lower().startswith('json\n'):
|
||||
content = content[4:].strip()
|
||||
|
||||
return content.strip()
|
||||
|
||||
async def generate_prompts(
|
||||
self,
|
||||
prompt_template: str,
|
||||
historic_prompts: List[Dict[str, str]],
|
||||
feedback_words: Optional[List[Dict[str, Any]]] = None,
|
||||
count: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate journal prompts using AI.
|
||||
|
||||
Args:
|
||||
prompt_template: Base prompt template
|
||||
historic_prompts: List of historic prompts for context
|
||||
feedback_words: List of feedback words with weights
|
||||
count: Number of prompts to generate
|
||||
min_length: Minimum prompt length
|
||||
max_length: Maximum prompt length
|
||||
|
||||
Returns:
|
||||
List of generated prompts
|
||||
"""
|
||||
if count is None:
|
||||
count = settings.NUM_PROMPTS_PER_SESSION
|
||||
if min_length is None:
|
||||
min_length = settings.MIN_PROMPT_LENGTH
|
||||
if max_length is None:
|
||||
max_length = settings.MAX_PROMPT_LENGTH
|
||||
|
||||
# Prepare the full prompt
|
||||
full_prompt = self._prepare_prompt(
|
||||
prompt_template,
|
||||
historic_prompts,
|
||||
feedback_words,
|
||||
count,
|
||||
min_length,
|
||||
max_length
|
||||
)
|
||||
|
||||
logger.info(f"Generating {count} prompts with AI")
|
||||
|
||||
try:
|
||||
# Call the AI API
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a creative writing assistant that generates journal prompts. Always respond with valid JSON."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": full_prompt
|
||||
}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
)
|
||||
|
||||
response_content = response.choices[0].message.content
|
||||
logger.debug(f"AI response received: {len(response_content)} characters")
|
||||
|
||||
# Parse the response
|
||||
prompts = self._parse_prompt_response(response_content, count)
|
||||
logger.info(f"Successfully parsed {len(prompts)} prompts from AI response")
|
||||
|
||||
return prompts
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling AI API: {e}")
|
||||
logger.debug(f"Full prompt sent to API: {full_prompt[:500]}...")
|
||||
raise
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
template: str,
|
||||
historic_prompts: List[Dict[str, str]],
|
||||
feedback_words: Optional[List[Dict[str, Any]]],
|
||||
count: int,
|
||||
min_length: int,
|
||||
max_length: int
|
||||
) -> str:
|
||||
"""Prepare the full prompt with all context."""
|
||||
# Add the instruction for the specific number of prompts
|
||||
prompt_instruction = f"Please generate {count} writing prompts, each between {min_length} and {max_length} characters."
|
||||
|
||||
# Start with template and instruction
|
||||
full_prompt = f"{template}\n\n{prompt_instruction}"
|
||||
|
||||
# Add historic prompts if available
|
||||
if historic_prompts:
|
||||
historic_context = json.dumps(historic_prompts, indent=2)
|
||||
full_prompt = f"{full_prompt}\n\nPrevious prompts:\n{historic_context}"
|
||||
|
||||
# Add feedback words if available
|
||||
if feedback_words:
|
||||
feedback_context = json.dumps(feedback_words, indent=2)
|
||||
full_prompt = f"{full_prompt}\n\nFeedback words:\n{feedback_context}"
|
||||
|
||||
return full_prompt
|
||||
|
||||
def _parse_prompt_response(self, response_content: str, expected_count: int) -> List[str]:
|
||||
"""Parse AI response to extract prompts."""
|
||||
cleaned_content = self._clean_ai_response(response_content)
|
||||
|
||||
try:
|
||||
data = json.loads(cleaned_content)
|
||||
|
||||
if isinstance(data, list):
|
||||
if len(data) >= expected_count:
|
||||
return data[:expected_count]
|
||||
else:
|
||||
logger.warning(f"AI returned {len(data)} prompts, expected {expected_count}")
|
||||
return data
|
||||
elif isinstance(data, dict):
|
||||
logger.warning("AI returned dictionary format, expected list format")
|
||||
prompts = []
|
||||
for i in range(expected_count):
|
||||
key = f"newprompt{i}"
|
||||
if key in data:
|
||||
prompts.append(data[key])
|
||||
return prompts
|
||||
else:
|
||||
logger.warning(f"AI returned unexpected data type: {type(data)}")
|
||||
return []
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("AI response is not valid JSON, attempting to extract prompts...")
|
||||
return self._extract_prompts_from_text(response_content, expected_count)
|
||||
|
||||
def _extract_prompts_from_text(self, text: str, expected_count: int) -> List[str]:
|
||||
"""Extract prompts from plain text response."""
|
||||
lines = text.strip().split('\n')
|
||||
prompts = []
|
||||
|
||||
for line in lines[:expected_count]:
|
||||
line = line.strip()
|
||||
if line and len(line) > 50: # Reasonable minimum length for a prompt
|
||||
prompts.append(line)
|
||||
|
||||
return prompts
|
||||
|
||||
async def generate_theme_feedback_words(
|
||||
self,
|
||||
feedback_template: str,
|
||||
historic_prompts: List[Dict[str, str]],
|
||||
current_feedback_words: Optional[List[Dict[str, Any]]] = None,
|
||||
historic_feedback_words: Optional[List[Dict[str, str]]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate theme feedback words using AI.
|
||||
|
||||
Args:
|
||||
feedback_template: Feedback analysis template
|
||||
historic_prompts: List of historic prompts for context
|
||||
current_feedback_words: Current feedback words with weights
|
||||
historic_feedback_words: Historic feedback words (just words)
|
||||
|
||||
Returns:
|
||||
List of 6 theme words
|
||||
"""
|
||||
# Prepare the full prompt
|
||||
full_prompt = self._prepare_feedback_prompt(
|
||||
feedback_template,
|
||||
historic_prompts,
|
||||
current_feedback_words,
|
||||
historic_feedback_words
|
||||
)
|
||||
|
||||
logger.info("Generating theme feedback words with AI")
|
||||
|
||||
try:
|
||||
# Call the AI API
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a creative writing assistant that analyzes writing prompts. Always respond with valid JSON."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": full_prompt
|
||||
}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
response_content = response.choices[0].message.content
|
||||
logger.debug(f"AI feedback response received: {len(response_content)} characters")
|
||||
|
||||
# Parse the response
|
||||
theme_words = self._parse_feedback_response(response_content)
|
||||
logger.info(f"Successfully parsed {len(theme_words)} theme words from AI response")
|
||||
|
||||
if len(theme_words) != 6:
|
||||
logger.warning(f"Expected 6 theme words, got {len(theme_words)}")
|
||||
|
||||
return theme_words
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling AI API for feedback: {e}")
|
||||
logger.debug(f"Full feedback prompt sent to API: {full_prompt[:500]}...")
|
||||
raise
|
||||
|
||||
def _prepare_feedback_prompt(
|
||||
self,
|
||||
template: str,
|
||||
historic_prompts: List[Dict[str, str]],
|
||||
current_feedback_words: Optional[List[Dict[str, Any]]],
|
||||
historic_feedback_words: Optional[List[Dict[str, str]]]
|
||||
) -> str:
|
||||
"""Prepare the full feedback prompt."""
|
||||
if not historic_prompts:
|
||||
raise ValueError("No historic prompts available for feedback analysis")
|
||||
|
||||
full_prompt = f"{template}\n\nPrevious prompts:\n{json.dumps(historic_prompts, indent=2)}"
|
||||
|
||||
# Add current feedback words if available
|
||||
if current_feedback_words:
|
||||
feedback_context = json.dumps(current_feedback_words, indent=2)
|
||||
full_prompt = f"{full_prompt}\n\nCurrent feedback themes (with weights):\n{feedback_context}"
|
||||
|
||||
# Add historic feedback words if available
|
||||
if historic_feedback_words:
|
||||
feedback_historic_context = json.dumps(historic_feedback_words, indent=2)
|
||||
full_prompt = f"{full_prompt}\n\nHistoric feedback themes (just words):\n{feedback_historic_context}"
|
||||
|
||||
return full_prompt
|
||||
|
||||
def _parse_feedback_response(self, response_content: str) -> List[str]:
|
||||
"""Parse AI response to extract theme words."""
|
||||
cleaned_content = self._clean_ai_response(response_content)
|
||||
|
||||
try:
|
||||
data = json.loads(cleaned_content)
|
||||
|
||||
if isinstance(data, list):
|
||||
theme_words = []
|
||||
for word in data:
|
||||
if isinstance(word, str):
|
||||
theme_words.append(word.lower().strip())
|
||||
else:
|
||||
theme_words.append(str(word).lower().strip())
|
||||
return theme_words
|
||||
else:
|
||||
logger.warning(f"AI returned unexpected data type for feedback: {type(data)}")
|
||||
return []
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("AI feedback response is not valid JSON, attempting to extract theme words...")
|
||||
return self._extract_theme_words_from_text(response_content)
|
||||
|
||||
def _extract_theme_words_from_text(self, text: str) -> List[str]:
|
||||
"""Extract theme words from plain text response."""
|
||||
lines = text.strip().split('\n')
|
||||
theme_words = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and len(line) < 50: # Theme words should be short
|
||||
words = [w.lower().strip('.,;:!?()[]{}\"\'') for w in line.split()]
|
||||
theme_words.extend(words)
|
||||
|
||||
if len(theme_words) >= 6:
|
||||
break
|
||||
|
||||
return theme_words[:6]
|
||||
|
||||
187
backend/app/services/data_service.py
Normal file
187
backend/app/services/data_service.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Data service for handling JSON file operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import aiofiles
|
||||
from typing import Any, List, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
class DataService:
|
||||
"""Service for handling data persistence in JSON files."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize data service."""
|
||||
self.data_dir = Path(settings.DATA_DIR)
|
||||
self.data_dir.mkdir(exist_ok=True)
|
||||
|
||||
def _get_file_path(self, filename: str) -> Path:
|
||||
"""Get full path for a data file."""
|
||||
return self.data_dir / filename
|
||||
|
||||
async def load_json(self, filename: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Load JSON data from file.
|
||||
|
||||
Args:
|
||||
filename: Name of the JSON file
|
||||
default: Default value if file doesn't exist or is invalid
|
||||
|
||||
Returns:
|
||||
Loaded data or default value
|
||||
"""
|
||||
file_path = self._get_file_path(filename)
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f"File {filename} not found, returning default")
|
||||
return default if default is not None else []
|
||||
|
||||
try:
|
||||
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = await f.read()
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error decoding JSON from {filename}: {e}")
|
||||
return default if default is not None else []
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {filename}: {e}")
|
||||
return default if default is not None else []
|
||||
|
||||
async def save_json(self, filename: str, data: Any) -> bool:
|
||||
"""
|
||||
Save data to JSON file.
|
||||
|
||||
Args:
|
||||
filename: Name of the JSON file
|
||||
data: Data to save
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
file_path = self._get_file_path(filename)
|
||||
|
||||
try:
|
||||
# Create backup of existing file if it exists
|
||||
if file_path.exists():
|
||||
backup_path = file_path.with_suffix('.json.bak')
|
||||
async with aiofiles.open(file_path, 'r', encoding='utf-8') as src:
|
||||
async with aiofiles.open(backup_path, 'w', encoding='utf-8') as dst:
|
||||
await dst.write(await src.read())
|
||||
|
||||
# Save new data
|
||||
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
|
||||
await f.write(json.dumps(data, indent=2, ensure_ascii=False))
|
||||
|
||||
logger.info(f"Saved data to {filename}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving {filename}: {e}")
|
||||
return False
|
||||
|
||||
async def load_prompts_historic(self) -> List[Dict[str, str]]:
|
||||
"""Load historic prompts from JSON file."""
|
||||
return await self.load_json(
|
||||
settings.PROMPTS_HISTORIC_FILE,
|
||||
default=[]
|
||||
)
|
||||
|
||||
async def save_prompts_historic(self, prompts: List[Dict[str, str]]) -> bool:
|
||||
"""Save historic prompts to JSON file."""
|
||||
return await self.save_json(settings.PROMPTS_HISTORIC_FILE, prompts)
|
||||
|
||||
async def load_prompts_pool(self) -> List[str]:
|
||||
"""Load prompt pool from JSON file."""
|
||||
return await self.load_json(
|
||||
settings.PROMPTS_POOL_FILE,
|
||||
default=[]
|
||||
)
|
||||
|
||||
async def save_prompts_pool(self, prompts: List[str]) -> bool:
|
||||
"""Save prompt pool to JSON file."""
|
||||
return await self.save_json(settings.PROMPTS_POOL_FILE, prompts)
|
||||
|
||||
async def load_feedback_words(self) -> List[Dict[str, Any]]:
|
||||
"""Load feedback words from JSON file."""
|
||||
return await self.load_json(
|
||||
settings.FEEDBACK_WORDS_FILE,
|
||||
default=[]
|
||||
)
|
||||
|
||||
async def save_feedback_words(self, feedback_words: List[Dict[str, Any]]) -> bool:
|
||||
"""Save feedback words to JSON file."""
|
||||
return await self.save_json(settings.FEEDBACK_WORDS_FILE, feedback_words)
|
||||
|
||||
async def load_feedback_historic(self) -> List[Dict[str, str]]:
|
||||
"""Load historic feedback words from JSON file."""
|
||||
return await self.load_json(
|
||||
settings.FEEDBACK_HISTORIC_FILE,
|
||||
default=[]
|
||||
)
|
||||
|
||||
async def save_feedback_historic(self, feedback_words: List[Dict[str, str]]) -> bool:
|
||||
"""Save historic feedback words to JSON file."""
|
||||
return await self.save_json(settings.FEEDBACK_HISTORIC_FILE, feedback_words)
|
||||
|
||||
async def load_prompt_template(self) -> str:
|
||||
"""Load prompt template from file."""
|
||||
template_path = Path(settings.PROMPT_TEMPLATE_PATH)
|
||||
if not template_path.exists():
|
||||
logger.error(f"Prompt template not found at {template_path}")
|
||||
return ""
|
||||
|
||||
try:
|
||||
async with aiofiles.open(template_path, 'r', encoding='utf-8') as f:
|
||||
return await f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading prompt template: {e}")
|
||||
return ""
|
||||
|
||||
async def load_feedback_template(self) -> str:
|
||||
"""Load feedback template from file."""
|
||||
template_path = Path(settings.FEEDBACK_TEMPLATE_PATH)
|
||||
if not template_path.exists():
|
||||
logger.error(f"Feedback template not found at {template_path}")
|
||||
return ""
|
||||
|
||||
try:
|
||||
async with aiofiles.open(template_path, 'r', encoding='utf-8') as f:
|
||||
return await f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading feedback template: {e}")
|
||||
return ""
|
||||
|
||||
async def load_settings_config(self) -> Dict[str, Any]:
|
||||
"""Load settings from config file."""
|
||||
config_path = Path(settings.SETTINGS_CONFIG_PATH)
|
||||
if not config_path.exists():
|
||||
logger.warning(f"Settings config not found at {config_path}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
import configparser
|
||||
config = configparser.ConfigParser()
|
||||
config.read(config_path)
|
||||
|
||||
settings_dict = {}
|
||||
if 'prompts' in config:
|
||||
prompts_section = config['prompts']
|
||||
settings_dict['min_length'] = int(prompts_section.get('min_length', settings.MIN_PROMPT_LENGTH))
|
||||
settings_dict['max_length'] = int(prompts_section.get('max_length', settings.MAX_PROMPT_LENGTH))
|
||||
settings_dict['num_prompts'] = int(prompts_section.get('num_prompts', settings.NUM_PROMPTS_PER_SESSION))
|
||||
|
||||
if 'prefetch' in config:
|
||||
prefetch_section = config['prefetch']
|
||||
settings_dict['cached_pool_volume'] = int(prefetch_section.get('cached_pool_volume', settings.CACHED_POOL_VOLUME))
|
||||
|
||||
return settings_dict
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading settings config: {e}")
|
||||
return {}
|
||||
|
||||
416
backend/app/services/prompt_service.py
Normal file
416
backend/app/services/prompt_service.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
Main prompt service that orchestrates prompt generation and management.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import setup_logging
|
||||
from app.services.data_service import DataService
|
||||
from app.services.ai_service import AIService
|
||||
from app.models.prompt import (
|
||||
PromptResponse,
|
||||
PoolStatsResponse,
|
||||
HistoryStatsResponse,
|
||||
FeedbackWord,
|
||||
FeedbackHistoryItem
|
||||
)
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
class PromptService:
|
||||
"""Main service for prompt generation and management."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize prompt service with dependencies."""
|
||||
self.data_service = DataService()
|
||||
self.ai_service = AIService()
|
||||
|
||||
# Load settings from config file
|
||||
self.settings_config = {}
|
||||
|
||||
# Cache for loaded data
|
||||
self._prompts_historic_cache = None
|
||||
self._prompts_pool_cache = None
|
||||
self._feedback_words_cache = None
|
||||
self._feedback_historic_cache = None
|
||||
self._prompt_template_cache = None
|
||||
self._feedback_template_cache = None
|
||||
|
||||
async def _load_settings_config(self):
|
||||
"""Load settings from config file if not already loaded."""
|
||||
if not self.settings_config:
|
||||
self.settings_config = await self.data_service.load_settings_config()
|
||||
|
||||
async def _get_setting(self, key: str, default: Any) -> Any:
|
||||
"""Get setting value, preferring config file over environment."""
|
||||
await self._load_settings_config()
|
||||
return self.settings_config.get(key, default)
|
||||
|
||||
# Data loading methods with caching
|
||||
async def get_prompts_historic(self) -> List[Dict[str, str]]:
|
||||
"""Get historic prompts with caching."""
|
||||
if self._prompts_historic_cache is None:
|
||||
self._prompts_historic_cache = await self.data_service.load_prompts_historic()
|
||||
return self._prompts_historic_cache
|
||||
|
||||
async def get_prompts_pool(self) -> List[str]:
|
||||
"""Get prompt pool with caching."""
|
||||
if self._prompts_pool_cache is None:
|
||||
self._prompts_pool_cache = await self.data_service.load_prompts_pool()
|
||||
return self._prompts_pool_cache
|
||||
|
||||
async def get_feedback_words(self) -> List[Dict[str, Any]]:
|
||||
"""Get feedback words with caching."""
|
||||
if self._feedback_words_cache is None:
|
||||
self._feedback_words_cache = await self.data_service.load_feedback_words()
|
||||
return self._feedback_words_cache
|
||||
|
||||
async def get_feedback_historic(self) -> List[Dict[str, str]]:
|
||||
"""Get historic feedback words with caching."""
|
||||
if self._feedback_historic_cache is None:
|
||||
self._feedback_historic_cache = await self.data_service.load_feedback_historic()
|
||||
return self._feedback_historic_cache
|
||||
|
||||
async def get_prompt_template(self) -> str:
|
||||
"""Get prompt template with caching."""
|
||||
if self._prompt_template_cache is None:
|
||||
self._prompt_template_cache = await self.data_service.load_prompt_template()
|
||||
return self._prompt_template_cache
|
||||
|
||||
async def get_feedback_template(self) -> str:
|
||||
"""Get feedback template with caching."""
|
||||
if self._feedback_template_cache is None:
|
||||
self._feedback_template_cache = await self.data_service.load_feedback_template()
|
||||
return self._feedback_template_cache
|
||||
|
||||
# Core prompt operations
|
||||
async def draw_prompts_from_pool(self, count: Optional[int] = None) -> List[str]:
|
||||
"""
|
||||
Draw prompts from the pool.
|
||||
|
||||
Args:
|
||||
count: Number of prompts to draw
|
||||
|
||||
Returns:
|
||||
List of drawn prompts
|
||||
"""
|
||||
if count is None:
|
||||
count = await self._get_setting('num_prompts', settings.NUM_PROMPTS_PER_SESSION)
|
||||
|
||||
pool = await self.get_prompts_pool()
|
||||
|
||||
if len(pool) < count:
|
||||
raise ValueError(
|
||||
f"Pool only has {len(pool)} prompts, requested {count}. "
|
||||
f"Use fill-pool endpoint to add more prompts."
|
||||
)
|
||||
|
||||
# Draw prompts from the beginning of the pool
|
||||
drawn_prompts = pool[:count]
|
||||
remaining_pool = pool[count:]
|
||||
|
||||
# Update cache and save
|
||||
self._prompts_pool_cache = remaining_pool
|
||||
await self.data_service.save_prompts_pool(remaining_pool)
|
||||
|
||||
logger.info(f"Drew {len(drawn_prompts)} prompts from pool, {len(remaining_pool)} remaining")
|
||||
return drawn_prompts
|
||||
|
||||
async def fill_pool_to_target(self) -> int:
|
||||
"""
|
||||
Fill the prompt pool to target volume.
|
||||
|
||||
Returns:
|
||||
Number of prompts added
|
||||
"""
|
||||
target_volume = await self._get_setting('cached_pool_volume', settings.CACHED_POOL_VOLUME)
|
||||
current_pool = await self.get_prompts_pool()
|
||||
current_size = len(current_pool)
|
||||
|
||||
if current_size >= target_volume:
|
||||
logger.info(f"Pool already at target volume: {current_size}/{target_volume}")
|
||||
return 0
|
||||
|
||||
prompts_needed = target_volume - current_size
|
||||
logger.info(f"Generating {prompts_needed} prompts to fill pool")
|
||||
|
||||
# Generate prompts
|
||||
new_prompts = await self.generate_prompts(
|
||||
count=prompts_needed,
|
||||
use_history=True,
|
||||
use_feedback=True
|
||||
)
|
||||
|
||||
if not new_prompts:
|
||||
logger.error("Failed to generate prompts for pool")
|
||||
return 0
|
||||
|
||||
# Add to pool
|
||||
updated_pool = current_pool + new_prompts
|
||||
self._prompts_pool_cache = updated_pool
|
||||
await self.data_service.save_prompts_pool(updated_pool)
|
||||
|
||||
added_count = len(new_prompts)
|
||||
logger.info(f"Added {added_count} prompts to pool, new size: {len(updated_pool)}")
|
||||
return added_count
|
||||
|
||||
async def generate_prompts(
|
||||
self,
|
||||
count: Optional[int] = None,
|
||||
use_history: bool = True,
|
||||
use_feedback: bool = True
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate new prompts using AI.
|
||||
|
||||
Args:
|
||||
count: Number of prompts to generate
|
||||
use_history: Whether to use historic prompts as context
|
||||
use_feedback: Whether to use feedback words as context
|
||||
|
||||
Returns:
|
||||
List of generated prompts
|
||||
"""
|
||||
if count is None:
|
||||
count = await self._get_setting('num_prompts', settings.NUM_PROMPTS_PER_SESSION)
|
||||
|
||||
min_length = await self._get_setting('min_length', settings.MIN_PROMPT_LENGTH)
|
||||
max_length = await self._get_setting('max_length', settings.MAX_PROMPT_LENGTH)
|
||||
|
||||
# Load templates and data
|
||||
prompt_template = await self.get_prompt_template()
|
||||
if not prompt_template:
|
||||
raise ValueError("Prompt template not found")
|
||||
|
||||
historic_prompts = await self.get_prompts_historic() if use_history else []
|
||||
feedback_words = await self.get_feedback_words() if use_feedback else None
|
||||
|
||||
# Generate prompts using AI
|
||||
new_prompts = await self.ai_service.generate_prompts(
|
||||
prompt_template=prompt_template,
|
||||
historic_prompts=historic_prompts,
|
||||
feedback_words=feedback_words,
|
||||
count=count,
|
||||
min_length=min_length,
|
||||
max_length=max_length
|
||||
)
|
||||
|
||||
return new_prompts
|
||||
|
||||
async def add_prompt_to_history(self, prompt_text: str) -> str:
|
||||
"""
|
||||
Add a prompt to the historic prompts cyclic buffer.
|
||||
|
||||
Args:
|
||||
prompt_text: Prompt text to add
|
||||
|
||||
Returns:
|
||||
Position key of the added prompt (e.g., "prompt00")
|
||||
"""
|
||||
historic_prompts = await self.get_prompts_historic()
|
||||
|
||||
# Create the new prompt object
|
||||
new_prompt = {"prompt00": prompt_text}
|
||||
|
||||
# Shift all existing prompts down by one position
|
||||
updated_prompts = [new_prompt]
|
||||
|
||||
# Add all existing prompts, shifting their numbers down by one
|
||||
for i, prompt_dict in enumerate(historic_prompts):
|
||||
if i >= settings.HISTORY_BUFFER_SIZE - 1: # Keep only HISTORY_BUFFER_SIZE prompts
|
||||
break
|
||||
|
||||
# Get the prompt text
|
||||
prompt_key = list(prompt_dict.keys())[0]
|
||||
prompt_text = prompt_dict[prompt_key]
|
||||
|
||||
# Create prompt with new number (shifted down by one)
|
||||
new_prompt_key = f"prompt{i+1:02d}"
|
||||
updated_prompts.append({new_prompt_key: prompt_text})
|
||||
|
||||
# Update cache and save
|
||||
self._prompts_historic_cache = updated_prompts
|
||||
await self.data_service.save_prompts_historic(updated_prompts)
|
||||
|
||||
logger.info(f"Added prompt to history as prompt00, history size: {len(updated_prompts)}")
|
||||
return "prompt00"
|
||||
|
||||
# Statistics methods
|
||||
async def get_pool_stats(self) -> PoolStatsResponse:
|
||||
"""Get statistics about the prompt pool."""
|
||||
pool = await self.get_prompts_pool()
|
||||
total_prompts = len(pool)
|
||||
|
||||
prompts_per_session = await self._get_setting('num_prompts', settings.NUM_PROMPTS_PER_SESSION)
|
||||
target_pool_size = await self._get_setting('cached_pool_volume', settings.CACHED_POOL_VOLUME)
|
||||
|
||||
available_sessions = total_prompts // prompts_per_session if prompts_per_session > 0 else 0
|
||||
needs_refill = total_prompts < target_pool_size
|
||||
|
||||
return PoolStatsResponse(
|
||||
total_prompts=total_prompts,
|
||||
prompts_per_session=prompts_per_session,
|
||||
target_pool_size=target_pool_size,
|
||||
available_sessions=available_sessions,
|
||||
needs_refill=needs_refill
|
||||
)
|
||||
|
||||
async def get_history_stats(self) -> HistoryStatsResponse:
|
||||
"""Get statistics about prompt history."""
|
||||
historic_prompts = await self.get_prompts_historic()
|
||||
total_prompts = len(historic_prompts)
|
||||
|
||||
history_capacity = settings.HISTORY_BUFFER_SIZE
|
||||
available_slots = max(0, history_capacity - total_prompts)
|
||||
is_full = total_prompts >= history_capacity
|
||||
|
||||
return HistoryStatsResponse(
|
||||
total_prompts=total_prompts,
|
||||
history_capacity=history_capacity,
|
||||
available_slots=available_slots,
|
||||
is_full=is_full
|
||||
)
|
||||
|
||||
async def get_prompt_history(self, limit: Optional[int] = None) -> List[PromptResponse]:
|
||||
"""
|
||||
Get prompt history.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of history items to return
|
||||
|
||||
Returns:
|
||||
List of historical prompts
|
||||
"""
|
||||
historic_prompts = await self.get_prompts_historic()
|
||||
|
||||
if limit is not None:
|
||||
historic_prompts = historic_prompts[:limit]
|
||||
|
||||
prompts = []
|
||||
for i, prompt_dict in enumerate(historic_prompts):
|
||||
prompt_key = list(prompt_dict.keys())[0]
|
||||
prompt_text = prompt_dict[prompt_key]
|
||||
|
||||
prompts.append(PromptResponse(
|
||||
key=prompt_key,
|
||||
text=prompt_text,
|
||||
position=i
|
||||
))
|
||||
|
||||
return prompts
|
||||
|
||||
# Feedback operations
|
||||
async def generate_theme_feedback_words(self) -> List[str]:
|
||||
"""Generate 6 theme feedback words using AI."""
|
||||
feedback_template = await self.get_feedback_template()
|
||||
if not feedback_template:
|
||||
raise ValueError("Feedback template not found")
|
||||
|
||||
historic_prompts = await self.get_prompts_historic()
|
||||
if not historic_prompts:
|
||||
raise ValueError("No historic prompts available for feedback analysis")
|
||||
|
||||
current_feedback_words = await self.get_feedback_words()
|
||||
historic_feedback_words = await self.get_feedback_historic()
|
||||
|
||||
theme_words = await self.ai_service.generate_theme_feedback_words(
|
||||
feedback_template=feedback_template,
|
||||
historic_prompts=historic_prompts,
|
||||
current_feedback_words=current_feedback_words,
|
||||
historic_feedback_words=historic_feedback_words
|
||||
)
|
||||
|
||||
return theme_words
|
||||
|
||||
async def update_feedback_words(self, ratings: Dict[str, int]) -> List[FeedbackWord]:
|
||||
"""
|
||||
Update feedback words with new ratings.
|
||||
|
||||
Args:
|
||||
ratings: Dictionary of word to rating (0-6)
|
||||
|
||||
Returns:
|
||||
Updated feedback words
|
||||
"""
|
||||
if len(ratings) != 6:
|
||||
raise ValueError(f"Expected 6 ratings, got {len(ratings)}")
|
||||
|
||||
feedback_items = []
|
||||
for i, (word, rating) in enumerate(ratings.items()):
|
||||
if not 0 <= rating <= 6:
|
||||
raise ValueError(f"Rating for '{word}' must be between 0 and 6, got {rating}")
|
||||
|
||||
feedback_key = f"feedback{i:02d}"
|
||||
feedback_items.append({
|
||||
feedback_key: word,
|
||||
"weight": rating
|
||||
})
|
||||
|
||||
# Update cache and save
|
||||
self._feedback_words_cache = feedback_items
|
||||
await self.data_service.save_feedback_words(feedback_items)
|
||||
|
||||
# Also add to historic feedback
|
||||
await self._add_feedback_words_to_history(feedback_items)
|
||||
|
||||
# Convert to FeedbackWord models
|
||||
feedback_words = []
|
||||
for item in feedback_items:
|
||||
key = list(item.keys())[0]
|
||||
word = item[key]
|
||||
weight = item["weight"]
|
||||
feedback_words.append(FeedbackWord(key=key, word=word, weight=weight))
|
||||
|
||||
logger.info(f"Updated feedback words with {len(feedback_words)} items")
|
||||
return feedback_words
|
||||
|
||||
async def _add_feedback_words_to_history(self, feedback_items: List[Dict[str, Any]]) -> None:
|
||||
"""Add feedback words to historic buffer."""
|
||||
historic_feedback = await self.get_feedback_historic()
|
||||
|
||||
# Extract just the words from current feedback
|
||||
new_feedback_words = []
|
||||
for i, item in enumerate(feedback_items):
|
||||
feedback_key = f"feedback{i:02d}"
|
||||
if feedback_key in item:
|
||||
word = item[feedback_key]
|
||||
new_feedback_words.append({feedback_key: word})
|
||||
|
||||
if len(new_feedback_words) != 6:
|
||||
logger.warning(f"Expected 6 feedback words, got {len(new_feedback_words)}. Not adding to history.")
|
||||
return
|
||||
|
||||
# Shift all existing feedback words down by 6 positions
|
||||
updated_feedback_historic = new_feedback_words
|
||||
|
||||
# Add all existing feedback words, shifting their numbers down by 6
|
||||
for i, feedback_dict in enumerate(historic_feedback):
|
||||
if i >= settings.FEEDBACK_HISTORY_SIZE - 6: # Keep only FEEDBACK_HISTORY_SIZE items
|
||||
break
|
||||
|
||||
feedback_key = list(feedback_dict.keys())[0]
|
||||
word = feedback_dict[feedback_key]
|
||||
|
||||
new_feedback_key = f"feedback{i+6:02d}"
|
||||
updated_feedback_historic.append({new_feedback_key: word})
|
||||
|
||||
# Update cache and save
|
||||
self._feedback_historic_cache = updated_feedback_historic
|
||||
await self.data_service.save_feedback_historic(updated_feedback_historic)
|
||||
|
||||
logger.info(f"Added 6 feedback words to history, history size: {len(updated_feedback_historic)}")
|
||||
|
||||
# Utility methods for API endpoints
|
||||
def get_pool_size(self) -> int:
|
||||
"""Get current pool size (synchronous for API endpoints)."""
|
||||
if self._prompts_pool_cache is None:
|
||||
raise RuntimeError("Pool cache not initialized")
|
||||
return len(self._prompts_pool_cache)
|
||||
|
||||
def get_target_volume(self) -> int:
|
||||
"""Get target pool volume (synchronous for API endpoints)."""
|
||||
return settings.CACHED_POOL_VOLUME
|
||||
|
||||
Reference in New Issue
Block a user