481 lines
18 KiB
Python
481 lines
18 KiB
Python
"""
|
|
Main prompt service that orchestrates prompt generation and management.
|
|
"""
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import datetime
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging import setup_logging
|
|
from app.services.data_service import DataService
|
|
from app.services.ai_service import AIService
|
|
from app.models.prompt import (
|
|
PromptResponse,
|
|
PoolStatsResponse,
|
|
HistoryStatsResponse,
|
|
FeedbackWord,
|
|
FeedbackHistoryItem
|
|
)
|
|
|
|
logger = setup_logging()
|
|
|
|
|
|
class PromptService:
|
|
"""Main service for prompt generation and management."""
|
|
|
|
def __init__(self):
|
|
"""Initialize prompt service with dependencies."""
|
|
self.data_service = DataService()
|
|
self.ai_service = AIService()
|
|
|
|
# Load settings from config file
|
|
self.settings_config = {}
|
|
|
|
# Cache for loaded data
|
|
self._prompts_historic_cache = None
|
|
self._prompts_pool_cache = None
|
|
self._feedback_words_cache = None
|
|
self._feedback_historic_cache = None
|
|
self._prompt_template_cache = None
|
|
self._feedback_template_cache = None
|
|
|
|
async def _load_settings_config(self):
|
|
"""Load settings from config file if not already loaded."""
|
|
if not self.settings_config:
|
|
self.settings_config = await self.data_service.load_settings_config()
|
|
|
|
async def _get_setting(self, key: str, default: Any) -> Any:
|
|
"""Get setting value, preferring config file over environment."""
|
|
await self._load_settings_config()
|
|
return self.settings_config.get(key, default)
|
|
|
|
# Data loading methods with caching
|
|
async def get_prompts_historic(self) -> List[Dict[str, str]]:
|
|
"""Get historic prompts with caching."""
|
|
if self._prompts_historic_cache is None:
|
|
self._prompts_historic_cache = await self.data_service.load_prompts_historic()
|
|
return self._prompts_historic_cache
|
|
|
|
async def get_prompts_pool(self) -> List[str]:
|
|
"""Get prompt pool with caching."""
|
|
if self._prompts_pool_cache is None:
|
|
self._prompts_pool_cache = await self.data_service.load_prompts_pool()
|
|
return self._prompts_pool_cache
|
|
|
|
async def get_feedback_historic(self) -> List[Dict[str, Any]]:
|
|
"""Get historic feedback words with caching."""
|
|
if self._feedback_historic_cache is None:
|
|
self._feedback_historic_cache = await self.data_service.load_feedback_historic()
|
|
return self._feedback_historic_cache
|
|
|
|
async def get_feedback_queued_words(self) -> List[Dict[str, Any]]:
|
|
"""Get queued feedback words (positions 0-5) for user weighting."""
|
|
feedback_historic = await self.get_feedback_historic()
|
|
return feedback_historic[:6] if len(feedback_historic) >= 6 else feedback_historic
|
|
|
|
async def get_feedback_active_words(self) -> List[Dict[str, Any]]:
|
|
"""Get active feedback words (positions 6-11) for prompt generation."""
|
|
feedback_historic = await self.get_feedback_historic()
|
|
if len(feedback_historic) >= 12:
|
|
return feedback_historic[6:12]
|
|
elif len(feedback_historic) > 6:
|
|
return feedback_historic[6:]
|
|
else:
|
|
return []
|
|
|
|
async def get_prompt_template(self) -> str:
|
|
"""Get prompt template with caching."""
|
|
if self._prompt_template_cache is None:
|
|
self._prompt_template_cache = await self.data_service.load_prompt_template()
|
|
return self._prompt_template_cache
|
|
|
|
async def get_feedback_template(self) -> str:
|
|
"""Get feedback template with caching."""
|
|
if self._feedback_template_cache is None:
|
|
self._feedback_template_cache = await self.data_service.load_feedback_template()
|
|
return self._feedback_template_cache
|
|
|
|
# Core prompt operations
|
|
async def draw_prompts_from_pool(self, count: Optional[int] = None) -> List[str]:
|
|
"""
|
|
Draw prompts from the pool.
|
|
|
|
Args:
|
|
count: Number of prompts to draw
|
|
|
|
Returns:
|
|
List of drawn prompts
|
|
"""
|
|
if count is None:
|
|
count = await self._get_setting('num_prompts', settings.NUM_PROMPTS_PER_SESSION)
|
|
|
|
pool = await self.get_prompts_pool()
|
|
|
|
if len(pool) < count:
|
|
raise ValueError(
|
|
f"Pool only has {len(pool)} prompts, requested {count}. "
|
|
f"Use fill-pool endpoint to add more prompts."
|
|
)
|
|
|
|
# Draw prompts from the beginning of the pool
|
|
drawn_prompts = pool[:count]
|
|
remaining_pool = pool[count:]
|
|
|
|
# Update cache and save
|
|
self._prompts_pool_cache = remaining_pool
|
|
await self.data_service.save_prompts_pool(remaining_pool)
|
|
|
|
logger.info(f"Drew {len(drawn_prompts)} prompts from pool, {len(remaining_pool)} remaining")
|
|
return drawn_prompts
|
|
|
|
async def fill_pool_to_target(self) -> int:
|
|
"""
|
|
Fill the prompt pool to target volume.
|
|
|
|
Returns:
|
|
Number of prompts added
|
|
"""
|
|
target_volume = await self._get_setting('cached_pool_volume', settings.CACHED_POOL_VOLUME)
|
|
current_pool = await self.get_prompts_pool()
|
|
current_size = len(current_pool)
|
|
|
|
if current_size >= target_volume:
|
|
logger.info(f"Pool already at target volume: {current_size}/{target_volume}")
|
|
return 0
|
|
|
|
prompts_needed = target_volume - current_size
|
|
logger.info(f"Generating {prompts_needed} prompts to fill pool")
|
|
|
|
# Generate prompts
|
|
new_prompts = await self.generate_prompts(
|
|
count=prompts_needed,
|
|
use_history=True,
|
|
use_feedback=True
|
|
)
|
|
|
|
if not new_prompts:
|
|
logger.error("Failed to generate prompts for pool")
|
|
return 0
|
|
|
|
# Add to pool
|
|
updated_pool = current_pool + new_prompts
|
|
self._prompts_pool_cache = updated_pool
|
|
await self.data_service.save_prompts_pool(updated_pool)
|
|
|
|
added_count = len(new_prompts)
|
|
logger.info(f"Added {added_count} prompts to pool, new size: {len(updated_pool)}")
|
|
return added_count
|
|
|
|
async def generate_prompts(
|
|
self,
|
|
count: Optional[int] = None,
|
|
use_history: bool = True,
|
|
use_feedback: bool = True
|
|
) -> List[str]:
|
|
"""
|
|
Generate new prompts using AI.
|
|
|
|
Args:
|
|
count: Number of prompts to generate
|
|
use_history: Whether to use historic prompts as context
|
|
use_feedback: Whether to use feedback words as context
|
|
|
|
Returns:
|
|
List of generated prompts
|
|
"""
|
|
if count is None:
|
|
count = await self._get_setting('num_prompts', settings.NUM_PROMPTS_PER_SESSION)
|
|
|
|
min_length = await self._get_setting('min_length', settings.MIN_PROMPT_LENGTH)
|
|
max_length = await self._get_setting('max_length', settings.MAX_PROMPT_LENGTH)
|
|
|
|
# Load templates and data
|
|
prompt_template = await self.get_prompt_template()
|
|
if not prompt_template:
|
|
raise ValueError("Prompt template not found")
|
|
|
|
historic_prompts = await self.get_prompts_historic() if use_history else []
|
|
feedback_words = await self.get_feedback_active_words() if use_feedback else None
|
|
|
|
# Filter out feedback words with weight 0
|
|
if feedback_words:
|
|
feedback_words = [
|
|
word for word in feedback_words
|
|
if word.get("weight", 3) != 0 # Default weight is 3 if not specified
|
|
]
|
|
# If all words have weight 0, set to None
|
|
if not feedback_words:
|
|
feedback_words = None
|
|
|
|
# Generate prompts using AI
|
|
new_prompts = await self.ai_service.generate_prompts(
|
|
prompt_template=prompt_template,
|
|
historic_prompts=historic_prompts,
|
|
feedback_words=feedback_words,
|
|
count=count,
|
|
min_length=min_length,
|
|
max_length=max_length
|
|
)
|
|
|
|
return new_prompts
|
|
|
|
async def add_prompt_to_history(self, prompt_text: str) -> str:
|
|
"""
|
|
Add a prompt to the historic prompts cyclic buffer.
|
|
|
|
Args:
|
|
prompt_text: Prompt text to add
|
|
|
|
Returns:
|
|
Position key of the added prompt (e.g., "prompt00")
|
|
"""
|
|
historic_prompts = await self.get_prompts_historic()
|
|
|
|
# Create the new prompt object
|
|
new_prompt = {"prompt00": prompt_text}
|
|
|
|
# Shift all existing prompts down by one position
|
|
updated_prompts = [new_prompt]
|
|
|
|
# Add all existing prompts, shifting their numbers down by one
|
|
for i, prompt_dict in enumerate(historic_prompts):
|
|
if i >= settings.HISTORY_BUFFER_SIZE - 1: # Keep only HISTORY_BUFFER_SIZE prompts
|
|
break
|
|
|
|
# Get the prompt text
|
|
prompt_key = list(prompt_dict.keys())[0]
|
|
prompt_text = prompt_dict[prompt_key]
|
|
|
|
# Create prompt with new number (shifted down by one)
|
|
new_prompt_key = f"prompt{i+1:02d}"
|
|
updated_prompts.append({new_prompt_key: prompt_text})
|
|
|
|
# Update cache and save
|
|
self._prompts_historic_cache = updated_prompts
|
|
await self.data_service.save_prompts_historic(updated_prompts)
|
|
|
|
logger.info(f"Added prompt to history as prompt00, history size: {len(updated_prompts)}")
|
|
return "prompt00"
|
|
|
|
# Statistics methods
|
|
async def get_pool_stats(self) -> PoolStatsResponse:
|
|
"""Get statistics about the prompt pool."""
|
|
pool = await self.get_prompts_pool()
|
|
total_prompts = len(pool)
|
|
|
|
prompts_per_session = await self._get_setting('num_prompts', settings.NUM_PROMPTS_PER_SESSION)
|
|
target_pool_size = await self._get_setting('cached_pool_volume', settings.CACHED_POOL_VOLUME)
|
|
|
|
available_sessions = total_prompts // prompts_per_session if prompts_per_session > 0 else 0
|
|
needs_refill = total_prompts < target_pool_size
|
|
|
|
return PoolStatsResponse(
|
|
total_prompts=total_prompts,
|
|
prompts_per_session=prompts_per_session,
|
|
target_pool_size=target_pool_size,
|
|
available_sessions=available_sessions,
|
|
needs_refill=needs_refill
|
|
)
|
|
|
|
async def get_history_stats(self) -> HistoryStatsResponse:
|
|
"""Get statistics about prompt history."""
|
|
historic_prompts = await self.get_prompts_historic()
|
|
total_prompts = len(historic_prompts)
|
|
|
|
history_capacity = settings.HISTORY_BUFFER_SIZE
|
|
available_slots = max(0, history_capacity - total_prompts)
|
|
is_full = total_prompts >= history_capacity
|
|
|
|
return HistoryStatsResponse(
|
|
total_prompts=total_prompts,
|
|
history_capacity=history_capacity,
|
|
available_slots=available_slots,
|
|
is_full=is_full
|
|
)
|
|
|
|
async def get_prompt_history(self, limit: Optional[int] = None) -> List[PromptResponse]:
|
|
"""
|
|
Get prompt history.
|
|
|
|
Args:
|
|
limit: Maximum number of history items to return
|
|
|
|
Returns:
|
|
List of historical prompts
|
|
"""
|
|
historic_prompts = await self.get_prompts_historic()
|
|
|
|
if limit is not None:
|
|
historic_prompts = historic_prompts[:limit]
|
|
|
|
prompts = []
|
|
for i, prompt_dict in enumerate(historic_prompts):
|
|
prompt_key = list(prompt_dict.keys())[0]
|
|
prompt_text = prompt_dict[prompt_key]
|
|
|
|
prompts.append(PromptResponse(
|
|
key=prompt_key,
|
|
text=prompt_text,
|
|
position=i
|
|
))
|
|
|
|
return prompts
|
|
|
|
# Feedback operations
|
|
async def generate_theme_feedback_words(self) -> List[str]:
|
|
"""Generate 6 theme feedback words using AI."""
|
|
feedback_template = await self.get_feedback_template()
|
|
if not feedback_template:
|
|
raise ValueError("Feedback template not found")
|
|
|
|
historic_prompts = await self.get_prompts_historic()
|
|
if not historic_prompts:
|
|
raise ValueError("No historic prompts available for feedback analysis")
|
|
|
|
queued_feedback_words = await self.get_feedback_queued_words()
|
|
historic_feedback_words = await self.get_feedback_historic()
|
|
|
|
theme_words = await self.ai_service.generate_theme_feedback_words(
|
|
feedback_template=feedback_template,
|
|
historic_prompts=historic_prompts,
|
|
queued_feedback_words=queued_feedback_words,
|
|
historic_feedback_words=historic_feedback_words
|
|
)
|
|
|
|
return theme_words
|
|
|
|
async def update_feedback_words(self, ratings: Dict[str, int]) -> List[FeedbackWord]:
|
|
"""
|
|
Update feedback words with new ratings.
|
|
|
|
Args:
|
|
ratings: Dictionary of word to rating (0-6)
|
|
|
|
Returns:
|
|
Updated feedback words
|
|
"""
|
|
if len(ratings) != 6:
|
|
raise ValueError(f"Expected 6 ratings, got {len(ratings)}")
|
|
|
|
# Get current feedback historic
|
|
feedback_historic = await self.get_feedback_historic()
|
|
|
|
# Update weights for queued words (positions 0-5)
|
|
for i, (word, rating) in enumerate(ratings.items()):
|
|
if not 0 <= rating <= 6:
|
|
raise ValueError(f"Rating for '{word}' must be between 0 and 6, got {rating}")
|
|
|
|
if i < len(feedback_historic):
|
|
# Get the existing item and its key
|
|
existing_item = feedback_historic[i]
|
|
# Find the feedback key (not "weight")
|
|
existing_keys = [k for k in existing_item.keys() if k != "weight"]
|
|
if existing_keys:
|
|
existing_key = existing_keys[0]
|
|
else:
|
|
# Fallback to generating a key
|
|
existing_key = f"feedback{i:02d}"
|
|
|
|
# Update the item with existing key, same word, new weight
|
|
feedback_historic[i] = {
|
|
existing_key: word,
|
|
"weight": rating
|
|
}
|
|
else:
|
|
# If we don't have enough items, add a new one
|
|
feedback_key = f"feedback{i:02d}"
|
|
feedback_historic.append({
|
|
feedback_key: word,
|
|
"weight": rating
|
|
})
|
|
|
|
# Update cache and save
|
|
self._feedback_historic_cache = feedback_historic
|
|
await self.data_service.save_feedback_historic(feedback_historic)
|
|
|
|
# Generate new feedback words and insert at position 0
|
|
await self._generate_and_insert_new_feedback_words(feedback_historic)
|
|
|
|
# Get updated queued words for response
|
|
updated_queued_words = feedback_historic[:6] if len(feedback_historic) >= 6 else feedback_historic
|
|
|
|
# Convert to FeedbackWord models
|
|
feedback_words = []
|
|
for i, item in enumerate(updated_queued_words):
|
|
key = list(item.keys())[0]
|
|
word = item[key]
|
|
weight = item.get("weight", 3) # Default weight is 3
|
|
feedback_words.append(FeedbackWord(key=key, word=word, weight=weight))
|
|
|
|
logger.info(f"Updated feedback words with {len(feedback_words)} items")
|
|
return feedback_words
|
|
|
|
async def _generate_and_insert_new_feedback_words(self, feedback_historic: List[Dict[str, Any]]) -> None:
|
|
"""Generate new feedback words and insert at position 0."""
|
|
try:
|
|
# Generate 6 new feedback words
|
|
new_words = await self.generate_theme_feedback_words()
|
|
|
|
if len(new_words) != 6:
|
|
logger.warning(f"Expected 6 new feedback words, got {len(new_words)}. Not inserting.")
|
|
return
|
|
|
|
# Create new feedback items with default weight of 3
|
|
new_feedback_items = []
|
|
for i, word in enumerate(new_words):
|
|
# Generate unique key based on position in buffer
|
|
# New items will be at positions 0-5, so use those indices
|
|
feedback_key = f"feedback{i:02d}"
|
|
new_feedback_items.append({
|
|
feedback_key: word,
|
|
"weight": 3 # Default weight
|
|
})
|
|
|
|
# Insert new words at position 0
|
|
# Keep only FEEDBACK_HISTORY_SIZE items total
|
|
updated_feedback_historic = new_feedback_items + feedback_historic
|
|
if len(updated_feedback_historic) > settings.FEEDBACK_HISTORY_SIZE:
|
|
updated_feedback_historic = updated_feedback_historic[:settings.FEEDBACK_HISTORY_SIZE]
|
|
|
|
# Re-key all items to ensure unique keys
|
|
for i, item in enumerate(updated_feedback_historic):
|
|
# Get the word and weight from the current item
|
|
# Each item has structure: {"feedbackXX": "word", "weight": N}
|
|
old_key = list(item.keys())[0]
|
|
if old_key == "weight":
|
|
# Handle edge case where weight might be first key
|
|
continue
|
|
word = item[old_key]
|
|
weight = item.get("weight", 3)
|
|
|
|
# Create new key based on position
|
|
new_key = f"feedback{i:02d}"
|
|
|
|
# Replace the item with new structure
|
|
updated_feedback_historic[i] = {
|
|
new_key: word,
|
|
"weight": weight
|
|
}
|
|
|
|
# Update cache and save
|
|
self._feedback_historic_cache = updated_feedback_historic
|
|
await self.data_service.save_feedback_historic(updated_feedback_historic)
|
|
|
|
logger.info(f"Inserted 6 new feedback words at position 0, history size: {len(updated_feedback_historic)}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating and inserting new feedback words: {e}")
|
|
raise
|
|
|
|
# Utility methods for API endpoints
|
|
def get_pool_size(self) -> int:
|
|
"""Get current pool size (synchronous for API endpoints)."""
|
|
if self._prompts_pool_cache is None:
|
|
raise RuntimeError("Pool cache not initialized")
|
|
return len(self._prompts_pool_cache)
|
|
|
|
def get_target_volume(self) -> int:
|
|
"""Get target pool volume (synchronous for API endpoints)."""
|
|
return settings.CACHED_POOL_VOLUME
|
|
|