353 lines
13 KiB
Python
353 lines
13 KiB
Python
"""
|
|
AI service for handling OpenAI/DeepSeek API calls.
|
|
"""
|
|
|
|
import json
|
|
from typing import List, Dict, Any, Optional
|
|
from openai import OpenAI, AsyncOpenAI
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging import setup_logging
|
|
|
|
logger = setup_logging()
|
|
|
|
|
|
class AIService:
|
|
"""Service for handling AI API calls."""
|
|
|
|
def __init__(self):
|
|
"""Initialize AI service."""
|
|
api_key = settings.DEEPSEEK_API_KEY or settings.OPENAI_API_KEY
|
|
if not api_key:
|
|
raise ValueError("No API key found. Set DEEPSEEK_API_KEY or OPENAI_API_KEY in environment.")
|
|
|
|
self.client = AsyncOpenAI(
|
|
api_key=api_key,
|
|
base_url=settings.API_BASE_URL
|
|
)
|
|
self.model = settings.MODEL
|
|
|
|
def _clean_ai_response(self, response_content: str) -> str:
|
|
"""
|
|
Clean up AI response content to handle common formatting issues.
|
|
|
|
Handles:
|
|
1. Leading/trailing backticks (```json ... ```)
|
|
2. Leading "json" string on its own line
|
|
3. Extra whitespace and newlines
|
|
"""
|
|
content = response_content.strip()
|
|
|
|
# Remove leading/trailing backticks (```json ... ```)
|
|
if content.startswith('```'):
|
|
lines = content.split('\n')
|
|
if len(lines) > 1:
|
|
first_line = lines[0].strip()
|
|
if 'json' in first_line.lower() or first_line == '```':
|
|
content = '\n'.join(lines[1:])
|
|
|
|
# Remove trailing backticks if present
|
|
if content.endswith('```'):
|
|
content = content[:-3].rstrip()
|
|
|
|
# Remove leading "json" string on its own line (case-insensitive)
|
|
lines = content.split('\n')
|
|
if len(lines) > 0:
|
|
first_line = lines[0].strip().lower()
|
|
if first_line == 'json':
|
|
content = '\n'.join(lines[1:])
|
|
|
|
# Also handle the case where "json" might be at the beginning of the first line
|
|
content = content.strip()
|
|
if content.lower().startswith('json\n'):
|
|
content = content[4:].strip()
|
|
|
|
return content.strip()
|
|
|
|
async def generate_prompts(
|
|
self,
|
|
prompt_template: str,
|
|
historic_prompts: List[Dict[str, str]],
|
|
feedback_words: Optional[List[Dict[str, Any]]] = None,
|
|
count: Optional[int] = None,
|
|
min_length: Optional[int] = None,
|
|
max_length: Optional[int] = None
|
|
) -> List[str]:
|
|
"""
|
|
Generate journal prompts using AI.
|
|
|
|
Args:
|
|
prompt_template: Base prompt template
|
|
historic_prompts: List of historic prompts for context
|
|
feedback_words: List of feedback words with weights
|
|
count: Number of prompts to generate
|
|
min_length: Minimum prompt length
|
|
max_length: Maximum prompt length
|
|
|
|
Returns:
|
|
List of generated prompts
|
|
"""
|
|
if count is None:
|
|
count = settings.NUM_PROMPTS_PER_SESSION
|
|
if min_length is None:
|
|
min_length = settings.MIN_PROMPT_LENGTH
|
|
if max_length is None:
|
|
max_length = settings.MAX_PROMPT_LENGTH
|
|
|
|
# Prepare the full prompt
|
|
full_prompt = self._prepare_prompt(
|
|
prompt_template,
|
|
historic_prompts,
|
|
feedback_words,
|
|
count,
|
|
min_length,
|
|
max_length
|
|
)
|
|
|
|
logger.info(f"Generating {count} prompts with AI")
|
|
|
|
try:
|
|
# Call the AI API
|
|
response = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "You are a creative writing assistant that generates journal prompts. Always respond with valid JSON."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": full_prompt
|
|
}
|
|
],
|
|
temperature=0.7,
|
|
max_tokens=2000
|
|
)
|
|
|
|
response_content = response.choices[0].message.content
|
|
logger.debug(f"AI response received: {len(response_content)} characters")
|
|
|
|
# Parse the response
|
|
prompts = self._parse_prompt_response(response_content, count)
|
|
logger.info(f"Successfully parsed {len(prompts)} prompts from AI response")
|
|
|
|
return prompts
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calling AI API: {e}")
|
|
logger.debug(f"Full prompt sent to API: {full_prompt[:500]}...")
|
|
raise
|
|
|
|
def _prepare_prompt(
|
|
self,
|
|
template: str,
|
|
historic_prompts: List[Dict[str, str]],
|
|
feedback_words: Optional[List[Dict[str, Any]]],
|
|
count: int,
|
|
min_length: int,
|
|
max_length: int
|
|
) -> str:
|
|
"""Prepare the full prompt with all context."""
|
|
# Add the instruction for the specific number of prompts
|
|
prompt_instruction = f"Please generate {count} writing prompts, each between {min_length} and {max_length} characters."
|
|
|
|
# Start with template and instruction
|
|
full_prompt = f"{template}\n\n{prompt_instruction}"
|
|
|
|
# Add historic prompts if available
|
|
if historic_prompts:
|
|
historic_context = json.dumps(historic_prompts, indent=2)
|
|
full_prompt = f"{full_prompt}\n\nPrevious prompts:\n{historic_context}"
|
|
|
|
# Add feedback words if available
|
|
if feedback_words:
|
|
feedback_context = json.dumps(feedback_words, indent=2)
|
|
full_prompt = f"{full_prompt}\n\nFeedback words:\n{feedback_context}"
|
|
|
|
return full_prompt
|
|
|
|
def _parse_prompt_response(self, response_content: str, expected_count: int) -> List[str]:
|
|
"""Parse AI response to extract prompts."""
|
|
cleaned_content = self._clean_ai_response(response_content)
|
|
|
|
try:
|
|
data = json.loads(cleaned_content)
|
|
|
|
if isinstance(data, list):
|
|
if len(data) >= expected_count:
|
|
return data[:expected_count]
|
|
else:
|
|
logger.warning(f"AI returned {len(data)} prompts, expected {expected_count}")
|
|
return data
|
|
elif isinstance(data, dict):
|
|
logger.warning("AI returned dictionary format, expected list format")
|
|
prompts = []
|
|
for i in range(expected_count):
|
|
key = f"newprompt{i}"
|
|
if key in data:
|
|
prompts.append(data[key])
|
|
return prompts
|
|
else:
|
|
logger.warning(f"AI returned unexpected data type: {type(data)}")
|
|
return []
|
|
|
|
except json.JSONDecodeError:
|
|
logger.warning("AI response is not valid JSON, attempting to extract prompts...")
|
|
return self._extract_prompts_from_text(response_content, expected_count)
|
|
|
|
def _extract_prompts_from_text(self, text: str, expected_count: int) -> List[str]:
|
|
"""Extract prompts from plain text response."""
|
|
lines = text.strip().split('\n')
|
|
prompts = []
|
|
|
|
for line in lines[:expected_count]:
|
|
line = line.strip()
|
|
if line and len(line) > 50: # Reasonable minimum length for a prompt
|
|
prompts.append(line)
|
|
|
|
return prompts
|
|
|
|
async def generate_theme_feedback_words(
|
|
self,
|
|
feedback_template: str,
|
|
historic_prompts: List[Dict[str, str]],
|
|
queued_feedback_words: Optional[List[Dict[str, Any]]] = None,
|
|
historic_feedback_words: Optional[List[Dict[str, Any]]] = None
|
|
) -> List[str]:
|
|
"""
|
|
Generate theme feedback words using AI.
|
|
|
|
Args:
|
|
feedback_template: Feedback analysis template
|
|
historic_prompts: List of historic prompts for context
|
|
queued_feedback_words: Queued feedback words with weights (positions 0-5)
|
|
historic_feedback_words: Historic feedback words with weights (all positions)
|
|
|
|
Returns:
|
|
List of 6 theme words
|
|
"""
|
|
# Prepare the full prompt
|
|
full_prompt = self._prepare_feedback_prompt(
|
|
feedback_template,
|
|
historic_prompts,
|
|
queued_feedback_words,
|
|
historic_feedback_words
|
|
)
|
|
|
|
logger.info("Generating theme feedback words with AI")
|
|
|
|
try:
|
|
# Call the AI API
|
|
response = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "You are a creative writing assistant that analyzes writing prompts. Always respond with valid JSON."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": full_prompt
|
|
}
|
|
],
|
|
temperature=0.7,
|
|
max_tokens=1000
|
|
)
|
|
|
|
response_content = response.choices[0].message.content
|
|
logger.debug(f"AI feedback response received: {len(response_content)} characters")
|
|
|
|
# Parse the response
|
|
theme_words = self._parse_feedback_response(response_content)
|
|
logger.info(f"Successfully parsed {len(theme_words)} theme words from AI response")
|
|
|
|
if len(theme_words) != 6:
|
|
logger.warning(f"Expected 6 theme words, got {len(theme_words)}")
|
|
|
|
return theme_words
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calling AI API for feedback: {e}")
|
|
logger.debug(f"Full feedback prompt sent to API: {full_prompt[:500]}...")
|
|
raise
|
|
|
|
def _prepare_feedback_prompt(
|
|
self,
|
|
template: str,
|
|
historic_prompts: List[Dict[str, str]],
|
|
queued_feedback_words: Optional[List[Dict[str, Any]]],
|
|
historic_feedback_words: Optional[List[Dict[str, Any]]]
|
|
) -> str:
|
|
"""Prepare the full feedback prompt."""
|
|
if not historic_prompts:
|
|
raise ValueError("No historic prompts available for feedback analysis")
|
|
|
|
full_prompt = f"{template}\n\nPrevious prompts:\n{json.dumps(historic_prompts, indent=2)}"
|
|
|
|
# Add queued feedback words if available (these have user-adjusted weights)
|
|
if queued_feedback_words:
|
|
# Extract just the words and weights for clarity
|
|
queued_words_with_weights = []
|
|
for item in queued_feedback_words:
|
|
key = list(item.keys())[0]
|
|
word = item[key]
|
|
weight = item.get("weight", 3)
|
|
queued_words_with_weights.append({"word": word, "weight": weight})
|
|
|
|
feedback_context = json.dumps(queued_words_with_weights, indent=2)
|
|
full_prompt = f"{full_prompt}\n\nQueued feedback themes (with user-adjusted weights):\n{feedback_context}"
|
|
|
|
# Add historic feedback words if available (these may have weights too)
|
|
if historic_feedback_words:
|
|
# Extract just the words for historic context
|
|
historic_words = []
|
|
for item in historic_feedback_words:
|
|
key = list(item.keys())[0]
|
|
word = item[key]
|
|
historic_words.append(word)
|
|
|
|
feedback_historic_context = json.dumps(historic_words, indent=2)
|
|
full_prompt = f"{full_prompt}\n\nHistoric feedback themes (just words):\n{feedback_historic_context}"
|
|
|
|
return full_prompt
|
|
|
|
def _parse_feedback_response(self, response_content: str) -> List[str]:
|
|
"""Parse AI response to extract theme words."""
|
|
cleaned_content = self._clean_ai_response(response_content)
|
|
|
|
try:
|
|
data = json.loads(cleaned_content)
|
|
|
|
if isinstance(data, list):
|
|
theme_words = []
|
|
for word in data:
|
|
if isinstance(word, str):
|
|
theme_words.append(word.lower().strip())
|
|
else:
|
|
theme_words.append(str(word).lower().strip())
|
|
return theme_words
|
|
else:
|
|
logger.warning(f"AI returned unexpected data type for feedback: {type(data)}")
|
|
return []
|
|
|
|
except json.JSONDecodeError:
|
|
logger.warning("AI feedback response is not valid JSON, attempting to extract theme words...")
|
|
return self._extract_theme_words_from_text(response_content)
|
|
|
|
def _extract_theme_words_from_text(self, text: str) -> List[str]:
|
|
"""Extract theme words from plain text response."""
|
|
lines = text.strip().split('\n')
|
|
theme_words = []
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
if line and len(line) < 50: # Theme words should be short
|
|
words = [w.lower().strip('.,;:!?()[]{}\"\'') for w in line.split()]
|
|
theme_words.extend(words)
|
|
|
|
if len(theme_words) >= 6:
|
|
break
|
|
|
|
return theme_words[:6]
|
|
|