197 lines
5.6 KiB
Python
197 lines
5.6 KiB
Python
"""
|
|
Prompt-related API endpoints.
|
|
"""
|
|
|
|
from typing import List, Optional
|
|
from fastapi import APIRouter, HTTPException, Depends, status
|
|
from pydantic import BaseModel
|
|
|
|
from app.services.prompt_service import PromptService
|
|
from app.models.prompt import PromptResponse, PoolStatsResponse, HistoryStatsResponse
|
|
|
|
# Create router
|
|
router = APIRouter()
|
|
|
|
# Response models
|
|
class DrawPromptsResponse(BaseModel):
|
|
"""Response model for drawing prompts."""
|
|
prompts: List[str]
|
|
count: int
|
|
remaining_in_pool: int
|
|
|
|
class FillPoolResponse(BaseModel):
|
|
"""Response model for filling prompt pool."""
|
|
added: int
|
|
total_in_pool: int
|
|
target_volume: int
|
|
|
|
class SelectPromptRequest(BaseModel):
|
|
"""Request model for selecting a prompt."""
|
|
prompt_text: str
|
|
|
|
class SelectPromptResponse(BaseModel):
|
|
"""Response model for selecting a prompt."""
|
|
selected_prompt: str
|
|
position_in_history: str # e.g., "prompt00"
|
|
history_size: int
|
|
|
|
# Service dependency
|
|
async def get_prompt_service() -> PromptService:
|
|
"""Dependency to get PromptService instance."""
|
|
return PromptService()
|
|
|
|
@router.get("/draw", response_model=DrawPromptsResponse)
|
|
async def draw_prompts(
|
|
count: Optional[int] = None,
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""
|
|
Draw prompts from the pool.
|
|
|
|
Args:
|
|
count: Number of prompts to draw (defaults to settings.NUM_PROMPTS_PER_SESSION)
|
|
prompt_service: PromptService instance
|
|
|
|
Returns:
|
|
List of prompts drawn from pool
|
|
"""
|
|
try:
|
|
prompts = await prompt_service.draw_prompts_from_pool(count)
|
|
pool_size = prompt_service.get_pool_size()
|
|
|
|
return DrawPromptsResponse(
|
|
prompts=prompts,
|
|
count=len(prompts),
|
|
remaining_in_pool=pool_size
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Error drawing prompts: {str(e)}"
|
|
)
|
|
|
|
@router.post("/fill-pool", response_model=FillPoolResponse)
|
|
async def fill_prompt_pool(
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""
|
|
Fill the prompt pool to target volume using AI.
|
|
|
|
Returns:
|
|
Information about added prompts
|
|
"""
|
|
try:
|
|
added_count = await prompt_service.fill_pool_to_target()
|
|
pool_size = prompt_service.get_pool_size()
|
|
target_volume = prompt_service.get_target_volume()
|
|
|
|
return FillPoolResponse(
|
|
added=added_count,
|
|
total_in_pool=pool_size,
|
|
target_volume=target_volume
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Error filling prompt pool: {str(e)}"
|
|
)
|
|
|
|
@router.get("/stats", response_model=PoolStatsResponse)
|
|
async def get_pool_stats(
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""
|
|
Get statistics about the prompt pool.
|
|
|
|
Returns:
|
|
Pool statistics
|
|
"""
|
|
try:
|
|
return await prompt_service.get_pool_stats()
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Error getting pool stats: {str(e)}"
|
|
)
|
|
|
|
@router.get("/history/stats", response_model=HistoryStatsResponse)
|
|
async def get_history_stats(
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""
|
|
Get statistics about prompt history.
|
|
|
|
Returns:
|
|
History statistics
|
|
"""
|
|
try:
|
|
return await prompt_service.get_history_stats()
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Error getting history stats: {str(e)}"
|
|
)
|
|
|
|
@router.get("/history", response_model=List[PromptResponse])
|
|
async def get_prompt_history(
|
|
limit: Optional[int] = None,
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""
|
|
Get prompt history.
|
|
|
|
Args:
|
|
limit: Maximum number of history items to return
|
|
|
|
Returns:
|
|
List of historical prompts
|
|
"""
|
|
try:
|
|
return await prompt_service.get_prompt_history(limit)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Error getting prompt history: {str(e)}"
|
|
)
|
|
|
|
@router.post("/select", response_model=SelectPromptResponse)
|
|
async def select_prompt(
|
|
request: SelectPromptRequest,
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""
|
|
Select a prompt to add to history.
|
|
|
|
Adds the provided prompt text to the historic prompts cyclic buffer.
|
|
The prompt will be added at position 0 (most recent), shifting existing prompts down.
|
|
|
|
Args:
|
|
request: SelectPromptRequest containing the prompt text
|
|
|
|
Returns:
|
|
Confirmation of prompt selection with position in history
|
|
"""
|
|
try:
|
|
# Add the prompt to history
|
|
position_key = await prompt_service.add_prompt_to_history(request.prompt_text)
|
|
|
|
# Get updated history stats
|
|
history_stats = await prompt_service.get_history_stats()
|
|
|
|
return SelectPromptResponse(
|
|
selected_prompt=request.prompt_text,
|
|
position_in_history=position_key,
|
|
history_size=history_stats.total_prompts
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Error selecting prompt: {str(e)}"
|
|
)
|
|
|