diff --git a/backend/Dockerfile b/backend/Dockerfile index 44c53f0365..180229dcfd 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -40,4 +40,12 @@ COPY ./app /app/app RUN --mount=type=cache,target=/root/.cache/uv \ uv sync -CMD ["fastapi", "run", "--workers", "4", "app/main.py"] +# Copy the entrypoint script +COPY ./docker-entrypoint.sh /app/docker-entrypoint.sh +RUN chmod +x /app/docker-entrypoint.sh + +# Remove the old CMD +# CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "${PORT:-8000}", "--workers", "4"] + +# Set the entrypoint +ENTRYPOINT ["/app/docker-entrypoint.sh"] diff --git a/backend/app/alembic/env.py b/backend/app/alembic/env.py index 7f29c04680..797efdb6af 100755 --- a/backend/app/alembic/env.py +++ b/backend/app/alembic/env.py @@ -1,9 +1,16 @@ import os +import sys +from pathlib import Path from logging.config import fileConfig from alembic import context from sqlalchemy import engine_from_config, pool +# Add project root to sys.path +# Assumes env.py is in backend/app/alembic/ +project_root = Path(__file__).parents[2] +sys.path.insert(0, str(project_root)) + # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config diff --git a/backend/app/alembic/versions/47b3823eff09_add_on_delete_cascade_to_message_.py b/backend/app/alembic/versions/47b3823eff09_add_on_delete_cascade_to_message_.py new file mode 100644 index 0000000000..e3ff947e83 --- /dev/null +++ b/backend/app/alembic/versions/47b3823eff09_add_on_delete_cascade_to_message_.py @@ -0,0 +1,55 @@ +"""Add ON DELETE CASCADE to message.conversation_id + +Revision ID: 47b3823eff09 +Revises: cba8d126b9ac +Create Date: 2025-04-19 06:25:27.717138 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = '47b3823eff09' +down_revision = 'cba8d126b9ac' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # Manually added to apply ON DELETE CASCADE + op.drop_constraint( + 'message_conversation_id_fkey', # Default constraint name might vary, check DB if needed + 'message', + type_='foreignkey' + ) + op.create_foreign_key( + 'message_conversation_id_fkey', + 'message', + 'conversation', + ['conversation_id'], + ['id'], + ondelete='CASCADE' + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # Manually added to revert ON DELETE CASCADE + op.drop_constraint( + 'message_conversation_id_fkey', + 'message', + type_='foreignkey' + ) + op.create_foreign_key( + 'message_conversation_id_fkey', + 'message', + 'conversation', + ['conversation_id'], + ['id'] + # No ondelete here + ) + # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/cba8d126b9ac_add_character_conversation_message_.py b/backend/app/alembic/versions/cba8d126b9ac_add_character_conversation_message_.py new file mode 100644 index 0000000000..9608d62aa6 --- /dev/null +++ b/backend/app/alembic/versions/cba8d126b9ac_add_character_conversation_message_.py @@ -0,0 +1,61 @@ +"""Add Character, Conversation, Message models + +Revision ID: cba8d126b9ac +Revises: 1a31ce608336 +Create Date: 2025-04-19 06:13:38.566914 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = 'cba8d126b9ac' +down_revision = '1a31ce608336' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('character', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=True), + sa.Column('image_url', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True), + sa.Column('greeting_message', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('status', sa.Enum('PENDING', 'APPROVED', 'REJECTED', name='characterstatus'), nullable=False), + sa.Column('creator_id', sa.Uuid(), nullable=False), + sa.ForeignKeyConstraint(['creator_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_character_name'), 'character', ['name'], unique=False) + op.create_table('conversation', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('character_id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['character_id'], ['character.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('message', + sa.Column('content', sqlmodel.sql.sqltypes.AutoString(length=5000), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('conversation_id', sa.Uuid(), nullable=False), + sa.Column('sender', sa.Enum('USER', 'AI', name='messagesender'), nullable=False), + sa.Column('timestamp', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['conversation_id'], ['conversation.id'], ), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('message') + op.drop_table('conversation') + op.drop_index(op.f('ix_character_name'), table_name='character') + op.drop_table('character') + # ### end Alembic commands ### diff --git a/backend/app/api/main.py b/backend/app/api/main.py index eac18c8e8f..fc5cac2ae6 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from app.api.routes import items, login, private, users, utils +from app.api.routes import items, login, private, users, utils, characters, admin_characters, conversations from app.core.config import settings api_router = APIRouter() @@ -8,6 +8,9 @@ api_router.include_router(users.router) api_router.include_router(utils.router) api_router.include_router(items.router) +api_router.include_router(characters.router) +api_router.include_router(admin_characters.router) +api_router.include_router(conversations.router) if settings.ENVIRONMENT == "local": diff --git a/backend/app/api/routes/admin_characters.py b/backend/app/api/routes/admin_characters.py new file mode 100644 index 0000000000..70b5724e72 --- /dev/null +++ b/backend/app/api/routes/admin_characters.py @@ -0,0 +1,113 @@ +# Placeholder for admin character management routes + +import uuid +from typing import Any + +from fastapi import APIRouter, HTTPException, Depends +from sqlmodel import Session, select, func + +from app.api.deps import SessionDep, CurrentUser, get_current_active_superuser +from app import crud +from app.models import ( + Character, CharacterUpdate, CharacterPublic, CharactersPublic, CharacterStatus, Message +) + +router = APIRouter(prefix="/admin/characters", tags=["admin-characters"], + dependencies=[Depends(get_current_active_superuser)]) + + +@router.get("/", response_model=CharactersPublic) +def list_all_characters( + session: SessionDep, skip: int = 0, limit: int = 100, status: CharacterStatus | None = None +) -> Any: + """ + Retrieve all characters (admin view). + Optionally filter by status. + """ + count = crud.characters.get_characters_count(session=session, status=status) + characters = crud.characters.get_characters( + session=session, skip=skip, limit=limit, status=status + ) + return CharactersPublic(data=characters, count=count) + + +@router.get("/pending", response_model=CharactersPublic) +def list_pending_characters( + session: SessionDep, skip: int = 0, limit: int = 100 +) -> Any: + """ + Retrieve characters pending approval. + """ + count = crud.characters.get_characters_count(session=session, status=CharacterStatus.PENDING) + characters = crud.characters.get_characters( + session=session, skip=skip, limit=limit, status=CharacterStatus.PENDING + ) + return CharactersPublic(data=characters, count=count) + + +@router.patch("/{id}/approve", response_model=CharacterPublic) +def approve_character(session: SessionDep, id: uuid.UUID) -> Any: + """ + Approve a character submission. + """ + db_character = crud.characters.get_character(session=session, character_id=id) + if not db_character: + raise HTTPException(status_code=404, detail="Character not found") + if db_character.status != CharacterStatus.PENDING: + raise HTTPException(status_code=400, detail="Character is not pending approval") + + character_update = CharacterUpdate(status=CharacterStatus.APPROVED) + character = crud.characters.update_character( + session=session, db_character=db_character, character_in=character_update + ) + return character + + +@router.patch("/{id}/reject", response_model=CharacterPublic) +def reject_character(session: SessionDep, id: uuid.UUID) -> Any: + """ + Reject a character submission. + """ + db_character = crud.characters.get_character(session=session, character_id=id) + if not db_character: + raise HTTPException(status_code=404, detail="Character not found") + if db_character.status != CharacterStatus.PENDING: + raise HTTPException(status_code=400, detail="Character is not pending approval") + + character_update = CharacterUpdate(status=CharacterStatus.REJECTED) + character = crud.characters.update_character( + session=session, db_character=db_character, character_in=character_update + ) + return character + + +@router.put("/{id}", response_model=CharacterPublic) +def update_character_admin( + *, session: SessionDep, id: uuid.UUID, character_in: CharacterUpdate +) -> Any: + """ + Update any character (admin only). + """ + db_character = crud.characters.get_character(session=session, character_id=id) + if not db_character: + raise HTTPException(status_code=404, detail="Character not found") + + character = crud.characters.update_character( + session=session, db_character=db_character, character_in=character_in + ) + return character + + +@router.delete("/{id}") +def delete_character_admin( + session: SessionDep, id: uuid.UUID +) -> Message: + """ + Delete a character (admin only). + """ + db_character = crud.characters.get_character(session=session, character_id=id) + if not db_character: + raise HTTPException(status_code=404, detail="Character not found") + + crud.characters.delete_character(session=session, db_character=db_character) + return Message(message="Character deleted successfully") \ No newline at end of file diff --git a/backend/app/api/routes/characters.py b/backend/app/api/routes/characters.py new file mode 100644 index 0000000000..552db8e26d --- /dev/null +++ b/backend/app/api/routes/characters.py @@ -0,0 +1,78 @@ +# Placeholder for character submission and listing routes + +import uuid +from typing import Any + +from fastapi import APIRouter, HTTPException, Depends +from sqlmodel import Session, select, func + +from app.api.deps import SessionDep, CurrentUser +from app import crud +from app.models import ( + Character, CharacterCreate, CharacterPublic, CharactersPublic, CharacterStatus, Message +) + +router = APIRouter(prefix="/characters", tags=["characters"]) + + +@router.get("/", response_model=CharactersPublic) +def list_approved_characters( + session: SessionDep, skip: int = 0, limit: int = 100 +) -> Any: + """ + Retrieve approved characters. + """ + count = crud.characters.get_characters_count( + session=session, status=CharacterStatus.APPROVED + ) + characters = crud.characters.get_characters( + session=session, skip=skip, limit=limit, status=CharacterStatus.APPROVED + ) + return CharactersPublic(data=characters, count=count) + + +@router.get("/my-submissions", response_model=CharactersPublic) +def list_my_character_submissions( + session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100 +) -> Any: + """ + Retrieve characters submitted by the current user. + Includes characters with any status (pending, approved, rejected). + """ + count = crud.characters.get_characters_count( + session=session, creator_id=current_user.id + ) + characters = crud.characters.get_characters( + session=session, creator_id=current_user.id, skip=skip, limit=limit + ) + return CharactersPublic(data=characters, count=count) + + +@router.get("/{id}", response_model=CharacterPublic) +def get_approved_character( + session: SessionDep, id: uuid.UUID +) -> Any: + """ + Get a specific approved character by ID. + """ + character = crud.characters.get_character(session=session, character_id=id) + if not character: + raise HTTPException(status_code=404, detail="Character not found") + if character.status != CharacterStatus.APPROVED: + # Hide non-approved characters from this public endpoint + raise HTTPException(status_code=404, detail="Character not found") # Or 403 Forbidden + return character + + +@router.post("/submit", response_model=CharacterPublic, status_code=201) +def submit_character( + *, session: SessionDep, current_user: CurrentUser, character_in: CharacterCreate +) -> Any: + """ + Submit a new character for review. + Status defaults to 'pending'. + """ + character = crud.characters.create_character( + session=session, character_create=character_in, creator_id=current_user.id + ) + return character \ No newline at end of file diff --git a/backend/app/api/routes/conversations.py b/backend/app/api/routes/conversations.py new file mode 100644 index 0000000000..aabccecc44 --- /dev/null +++ b/backend/app/api/routes/conversations.py @@ -0,0 +1,165 @@ +# Placeholder for conversation management routes + +import uuid +from typing import Any + +from fastapi import APIRouter, HTTPException, Depends +from sqlmodel import Session + +from app.api.deps import SessionDep, CurrentUser +from app import crud +from app.models import ( + Conversation, ConversationCreate, ConversationPublic, ConversationsPublic, + Message, MessageCreate, MessagePublic, MessagesPublic, MessageSender, + CharacterStatus, Character # Import Character +) +# Import AI service +from app.services import ai_service + +router = APIRouter(prefix="/conversations", tags=["conversations"]) + + +@router.post("/", response_model=ConversationPublic, status_code=201) +def start_conversation( + *, session: SessionDep, current_user: CurrentUser, conversation_in: ConversationCreate +) -> Any: + """ + Start a new conversation with an approved character. + """ + # Check if character exists and is approved + character = crud.characters.get_character( + session=session, character_id=conversation_in.character_id + ) + if not character or character.status != CharacterStatus.APPROVED: + raise HTTPException(status_code=404, detail="Approved character not found") + + try: + conversation = crud.conversations.create_conversation( + session=session, conversation_create=conversation_in, user_id=current_user.id + ) + except ValueError as e: + # Catch potential errors from CRUD (like character not found again, just in case) + raise HTTPException(status_code=404, detail=str(e)) + + # Optionally: Add the character's greeting message as the first AI message + if character.greeting_message: + crud.conversations.create_message( + session=session, + message_create=MessageCreate(content=character.greeting_message), + conversation_id=conversation.id, + sender=MessageSender.AI + ) + session.refresh(conversation) # Refresh to potentially load the new message relationship + + return conversation + + +@router.get("/", response_model=ConversationsPublic) +def list_my_conversations( + session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100 +) -> Any: + """ + Retrieve conversations for the current user. + """ + count = crud.conversations.get_user_conversations_count( + session=session, user_id=current_user.id + ) + conversations = crud.conversations.get_user_conversations( + session=session, user_id=current_user.id, skip=skip, limit=limit + ) + return ConversationsPublic(data=conversations, count=count) + + +@router.get("/{conversation_id}/messages", response_model=MessagesPublic) +def get_conversation_messages_route( + session: SessionDep, current_user: CurrentUser, conversation_id: uuid.UUID, skip: int = 0, limit: int = 100 +) -> Any: + """ + Retrieve messages for a specific conversation owned by the current user. + """ + conversation = crud.conversations.get_conversation( + session=session, conversation_id=conversation_id + ) + if not conversation: + raise HTTPException(status_code=404, detail="Conversation not found") + if conversation.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to view these messages") + + count = crud.conversations.get_conversation_messages_count( + session=session, conversation_id=conversation_id + ) + messages = crud.conversations.get_conversation_messages( + session=session, conversation_id=conversation_id, skip=skip, limit=limit + ) + return MessagesPublic(data=messages, count=count) + + +@router.post("/{conversation_id}/messages", response_model=MessagePublic) +def send_message( + *, session: SessionDep, current_user: CurrentUser, conversation_id: uuid.UUID, message_in: MessageCreate +) -> Any: + """ + Send a message from the user to a conversation and get an AI response. + """ + conversation = crud.conversations.get_conversation( + session=session, conversation_id=conversation_id + ) + if not conversation: + raise HTTPException(status_code=404, detail="Conversation not found") + if conversation.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to send messages to this conversation") + if not conversation.character: # Ensure character relationship is loaded or handle if None + # This might require adjusting how conversations are fetched or created if lazy loading isn't setup + # For now, assume it exists if conversation exists + raise HTTPException(status_code=500, detail="Character details missing for conversation") + + # 1. Save the user's message + user_message = crud.conversations.create_message( + session=session, + message_create=message_in, + conversation_id=conversation_id, + sender=MessageSender.USER + ) + + # 2. Prepare context and get AI response + # Get recent messages (including the one just sent by the user) + # Adjust limit as needed for AI context window + message_history = crud.conversations.get_conversation_messages( + session=session, conversation_id=conversation_id, limit=20 + ) + + ai_response_text = ai_service.get_ai_response( + character=conversation.character, + history=message_history + ) + + # 3. Save the AI's message + ai_message = crud.conversations.create_message( + session=session, + message_create=MessageCreate(content=ai_response_text), + conversation_id=conversation_id, + sender=MessageSender.AI + ) + + # Return the AI's response message + return ai_message + + +@router.delete("/{conversation_id}", status_code=204) +def delete_conversation_route( + session: SessionDep, current_user: CurrentUser, conversation_id: uuid.UUID +) -> None: + """ + Delete a conversation owned by the current user. + """ + conversation = crud.conversations.get_conversation( + session=session, conversation_id=conversation_id + ) + if not conversation: + # Idempotent delete: if not found, act as if deleted + return None + if conversation.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to delete this conversation") + + crud.conversations.delete_conversation(session=session, db_conversation=conversation) + return None # No content response \ No newline at end of file diff --git a/backend/app/core/config.py b/backend/app/core/config.py index d58e03c87d..a195c1cf00 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -51,23 +51,46 @@ def all_cors_origins(self) -> list[str]: PROJECT_NAME: str SENTRY_DSN: HttpUrl | None = None - POSTGRES_SERVER: str - POSTGRES_PORT: int = 5432 - POSTGRES_USER: str - POSTGRES_PASSWORD: str = "" - POSTGRES_DB: str = "" + + # Add DATABASE_URL as an optional field + DATABASE_URL: PostgresDsn | None = None + + # Make individual PG fields optional + POSTGRES_SERVER: str | None = None + POSTGRES_PORT: int | None = 5432 + POSTGRES_USER: str | None = None + POSTGRES_PASSWORD: str | None = None + POSTGRES_DB: str | None = None @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: - return MultiHostUrl.build( - scheme="postgresql+psycopg", - username=self.POSTGRES_USER, - password=self.POSTGRES_PASSWORD, - host=self.POSTGRES_SERVER, - port=self.POSTGRES_PORT, - path=self.POSTGRES_DB, - ) + if self.DATABASE_URL: + print("Using DATABASE_URL from environment") + return self.DATABASE_URL + elif ( + self.POSTGRES_SERVER + and self.POSTGRES_USER + # Password can be None or empty + and self.POSTGRES_DB + and self.POSTGRES_PORT + ): + print("Using individual POSTGRES variables from environment") + # Ensure password is a string, even if None was provided + pg_password = self.POSTGRES_PASSWORD or "" + return MultiHostUrl.build( + scheme="postgresql+psycopg", + username=self.POSTGRES_USER, + password=pg_password, + host=self.POSTGRES_SERVER, + port=self.POSTGRES_PORT, + path=self.POSTGRES_DB, + ) + else: + raise ValueError( + "Database configuration error: Set DATABASE_URL or all individual " + "POSTGRES_SERVER, POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_DB, POSTGRES_PORT variables." + ) SMTP_TLS: bool = True SMTP_SSL: bool = False @@ -109,7 +132,9 @@ def _check_default_secret(self, var_name: str, value: str | None) -> None: @model_validator(mode="after") def _enforce_non_default_secrets(self) -> Self: self._check_default_secret("SECRET_KEY", self.SECRET_KEY) - self._check_default_secret("POSTGRES_PASSWORD", self.POSTGRES_PASSWORD) + # Only check PG password if individual variables are used + if not self.DATABASE_URL and self.POSTGRES_PASSWORD: + self._check_default_secret("POSTGRES_PASSWORD", self.POSTGRES_PASSWORD) self._check_default_secret( "FIRST_SUPERUSER_PASSWORD", self.FIRST_SUPERUSER_PASSWORD ) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py new file mode 100644 index 0000000000..1f254ed39f --- /dev/null +++ b/backend/app/crud/__init__.py @@ -0,0 +1,24 @@ +from . import characters +from . import conversations + +# Import functions from the original crud file (now named base.py) +from .base import ( + create_user, + update_user, + get_user_by_email, + authenticate, + create_item, +) + +# Now, when other modules do 'from app import crud', +# they can access crud.create_user, crud.authenticate, +# crud.characters.create_character, etc. + +# You can also import specific functions if preferred, e.g.: +# from .users import get_user_by_email, create_user, authenticate, update_user # Assuming users crud is also needed +# from .items import create_item # Assuming items crud is also needed + +# We need to import the existing user and item crud functions as well +# Check where user/item crud functions are defined (likely crud.py initially) +# Let's assume they are in a top-level crud.py that needs to be refactored or imported here +# For now, just importing the new modules: \ No newline at end of file diff --git a/backend/app/crud/base.py b/backend/app/crud/base.py new file mode 100644 index 0000000000..7afced6226 --- /dev/null +++ b/backend/app/crud/base.py @@ -0,0 +1,54 @@ +import uuid +from typing import Any + +from sqlmodel import Session, select + +from app.core.security import get_password_hash, verify_password +from app.models import Item, ItemCreate, User, UserCreate, UserUpdate + + +def create_user(*, session: Session, user_create: UserCreate) -> User: + db_obj = User.model_validate( + user_create, update={"hashed_password": get_password_hash(user_create.password)} + ) + session.add(db_obj) + session.commit() + session.refresh(db_obj) + return db_obj + + +def update_user(*, session: Session, db_user: User, user_in: UserUpdate) -> Any: + user_data = user_in.model_dump(exclude_unset=True) + extra_data = {} + if "password" in user_data: + password = user_data["password"] + hashed_password = get_password_hash(password) + extra_data["hashed_password"] = hashed_password + db_user.sqlmodel_update(user_data, update=extra_data) + session.add(db_user) + session.commit() + session.refresh(db_user) + return db_user + + +def get_user_by_email(*, session: Session, email: str) -> User | None: + statement = select(User).where(User.email == email) + session_user = session.exec(statement).first() + return session_user + + +def authenticate(*, session: Session, email: str, password: str) -> User | None: + db_user = get_user_by_email(session=session, email=email) + if not db_user: + return None + if not verify_password(password, db_user.hashed_password): + return None + return db_user + + +def create_item(*, session: Session, item_in: ItemCreate, owner_id: uuid.UUID) -> Item: + db_item = Item.model_validate(item_in, update={"owner_id": owner_id}) + session.add(db_item) + session.commit() + session.refresh(db_item) + return db_item \ No newline at end of file diff --git a/backend/app/crud/characters.py b/backend/app/crud/characters.py new file mode 100644 index 0000000000..4ba2a644a7 --- /dev/null +++ b/backend/app/crud/characters.py @@ -0,0 +1,81 @@ +# Placeholder for character CRUD operations + +import uuid +from typing import Sequence + +from sqlmodel import Session, select, col, func + +from app.models import Character, CharacterCreate, CharacterUpdate, CharacterStatus, User + + +def create_character( + *, session: Session, character_create: CharacterCreate, creator_id: uuid.UUID +) -> Character: + """Creates a new character with status 'pending'.""" + db_obj = Character.model_validate( + character_create, update={"creator_id": creator_id, "status": CharacterStatus.PENDING} + ) + session.add(db_obj) + session.commit() + session.refresh(db_obj) + return db_obj + + +def get_character(*, session: Session, character_id: uuid.UUID) -> Character | None: + """Gets a single character by its ID.""" + return session.get(Character, character_id) + + +def get_characters( + *, + session: Session, + skip: int = 0, + limit: int = 100, + status: CharacterStatus | None = None, + creator_id: uuid.UUID | None = None, +) -> Sequence[Character]: + """Gets a list of characters with optional filters.""" + statement = select(Character).offset(skip).limit(limit) + if status is not None: + statement = statement.where(Character.status == status) + if creator_id is not None: + statement = statement.where(Character.creator_id == creator_id) + + characters = session.exec(statement).all() + return characters + + +def get_characters_count( + *, + session: Session, + status: CharacterStatus | None = None, + creator_id: uuid.UUID | None = None, +) -> int: + """Gets the total count of characters with optional filters.""" + statement = select(func.count(Character.id)) + if status is not None: + statement = statement.where(Character.status == status) + if creator_id is not None: + statement = statement.where(Character.creator_id == creator_id) + + count = session.exec(statement).one() + return count + + +def update_character( + *, session: Session, db_character: Character, character_in: CharacterUpdate +) -> Character: + """Updates an existing character.""" + update_data = character_in.model_dump(exclude_unset=True) + db_character.sqlmodel_update(update_data) + session.add(db_character) + session.commit() + session.refresh(db_character) + return db_character + + +def delete_character(*, session: Session, db_character: Character) -> None: + """Deletes a character.""" + # Add cascade delete for conversations/messages if needed, handled by relationship cascade + session.delete(db_character) + session.commit() \ No newline at end of file diff --git a/backend/app/crud/conversations.py b/backend/app/crud/conversations.py new file mode 100644 index 0000000000..e321c2aa49 --- /dev/null +++ b/backend/app/crud/conversations.py @@ -0,0 +1,108 @@ +# Placeholder for conversation and message CRUD operations + +import uuid +from typing import Sequence + +from sqlmodel import Session, select, func + +from app.models import ( + Conversation, ConversationCreate, User, Character, + Message, MessageCreate, MessageSender +) + + +# --- Conversation CRUD --- + +def create_conversation( + *, session: Session, conversation_create: ConversationCreate, user_id: uuid.UUID +) -> Conversation: + """Creates a new conversation between a user and a character.""" + # Validate if character exists (optional, could be done at API level too) + character = session.get(Character, conversation_create.character_id) + if not character: + raise ValueError("Character not found") # Or handle appropriately + + db_obj = Conversation.model_validate( + conversation_create, update={"user_id": user_id} + ) + session.add(db_obj) + session.commit() + session.refresh(db_obj) + return db_obj + + +def get_conversation(*, session: Session, conversation_id: uuid.UUID) -> Conversation | None: + """Gets a single conversation by its ID.""" + return session.get(Conversation, conversation_id) + + +def get_user_conversations( + *, session: Session, user_id: uuid.UUID, skip: int = 0, limit: int = 100 +) -> Sequence[Conversation]: + """Gets a list of conversations for a specific user.""" + statement = ( + select(Conversation) + .where(Conversation.user_id == user_id) + .offset(skip) + .limit(limit) + ) + conversations = session.exec(statement).all() + return conversations + + +def get_user_conversations_count(*, session: Session, user_id: uuid.UUID) -> int: + """Gets the total count of conversations for a specific user.""" + statement = select(func.count(Conversation.id)).where( + Conversation.user_id == user_id + ) + count = session.exec(statement).one() + return count + + +def delete_conversation(*, session: Session, db_conversation: Conversation) -> None: + """Deletes a conversation and its associated messages (via cascade).""" + session.delete(db_conversation) + session.commit() + + +# --- Message CRUD --- + +def create_message( + *, + session: Session, + message_create: MessageCreate, + conversation_id: uuid.UUID, + sender: MessageSender, +) -> Message: + """Adds a new message to a conversation.""" + db_obj = Message.model_validate( + message_create, update={"conversation_id": conversation_id, "sender": sender} + ) + session.add(db_obj) + session.commit() + session.refresh(db_obj) + return db_obj + + +def get_conversation_messages( + *, session: Session, conversation_id: uuid.UUID, skip: int = 0, limit: int = 1000 # Usually get more messages +) -> Sequence[Message]: + """Gets messages for a specific conversation, ordered by timestamp.""" + statement = ( + select(Message) + .where(Message.conversation_id == conversation_id) + .order_by(Message.timestamp) + .offset(skip) + .limit(limit) + ) + messages = session.exec(statement).all() + return messages + + +def get_conversation_messages_count(*, session: Session, conversation_id: uuid.UUID) -> int: + """Gets the total count of messages for a specific conversation.""" + statement = select(func.count(Message.id)).where( + Message.conversation_id == conversation_id + ) + count = session.exec(statement).one() + return count \ No newline at end of file diff --git a/backend/app/models.py b/backend/app/models.py index 2389b4a532..8454adf7b2 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -1,4 +1,6 @@ import uuid +import datetime +from enum import Enum from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel @@ -44,6 +46,12 @@ class User(UserBase, table=True): id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) hashed_password: str items: list["Item"] = Relationship(back_populates="owner", cascade_delete=True) + created_characters: list["Character"] = Relationship( + back_populates="creator", cascade_delete=True + ) + conversations: list["Conversation"] = Relationship( + back_populates="user", cascade_delete=True + ) # Properties to return via API, id is always required @@ -111,3 +119,156 @@ class TokenPayload(SQLModel): class NewPassword(SQLModel): token: str new_password: str = Field(min_length=8, max_length=40) + +# ------------------------Character Model---------------------------- + +# Shared properties +class CharacterBase(SQLModel): + name: str = Field(index=True, max_length=100) + description: str | None = Field(default=None, max_length=1000) + image_url: str | None = Field(default=None, max_length=255) + greeting_message: str | None = Field(default=None, max_length=1000) +# More field + scenario: str | None = Field(default=None, max_length=2000) + category: str | None = Field(default=None, max_length=255) + greeting: str | None = Field(default=None, max_length=1000) + + voice_id: str | None = Field(default=None, max_length=255) + language: str | None = Field(default=None, max_length=50) + tags: list[str] | None = Field(default=None) + popularity_score: float | None = Field(default=None) + is_featured: bool = Field(default=False) + + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class CharacterStatus(str, Enum): + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + + +# Properties to receive via API on creation (user submission) +class CharacterCreate(CharacterBase): + pass + + +# Properties to receive via API on update (admin only) +class CharacterUpdate(CharacterBase): + name: str | None = Field(default=None, max_length=100) + description: str | None = Field(default=None, max_length=1000) + image_url: str | None = Field(default=None, max_length=255) + greeting_message: str | None = Field(default=None, max_length=1000) +# More fields + scenario: str | None = Field(default=None, max_length=2000) + greeting: str | None = Field(default=None, max_length=1000) + category: str | None = Field(default=None, max_length=255) + voice_id: str | None = Field(default=None, max_length=255) + language: str | None = Field(default=None, max_length=50) + tags: list[str] | None = Field(default=None) + + is_featured: bool = Field(default=False) + created_at: datetime | None = None + status: CharacterStatus | None = None + + +# Database model +class Character(CharacterBase, table=True): + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + + status: CharacterStatus = Field(default=CharacterStatus.PENDING) + + total_messages: int = Field(default=0) + + + creator_id: uuid.UUID = Field(foreign_key="user.id", nullable=False) + creator: User = Relationship(back_populates="created_characters") + conversations: list["Conversation"] = Relationship(back_populates="character") + +# Properties to return via API +class CharacterPublic(CharacterBase): + id: uuid.UUID + status: CharacterStatus + creator_id: uuid.UUID + + + total_messages: int + created_at: datetime + + +class CharactersPublic(SQLModel): + data: list[CharacterPublic] + count: int + + +# ---------------- Conversation Models ---------------- + + +class ConversationBase(SQLModel): + pass # No shared fields initially, maybe add title later? + + +class ConversationCreate(SQLModel): + character_id: uuid.UUID + + +# Database model +class Conversation(ConversationBase, table=True): + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + user_id: uuid.UUID = Field(foreign_key="user.id", nullable=False) + character_id: uuid.UUID = Field(foreign_key="character.id", nullable=False) + created_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + + user: User = Relationship(back_populates="conversations") + character: Character = Relationship(back_populates="conversations") + messages: list["Message"] = Relationship(back_populates="conversation") + + +class ConversationPublic(ConversationBase): + id: uuid.UUID + user_id: uuid.UUID + character_id: uuid.UUID + created_at: datetime.datetime + + +class ConversationsPublic(SQLModel): + data: list[ConversationPublic] + count: int + + +# ---------------- Message Models ---------------- + + +class MessageSender(str, Enum): + USER = "user" + AI = "ai" + + +class MessageBase(SQLModel): + content: str = Field(max_length=5000) # Limit message length + + +class MessageCreate(MessageBase): + pass # Content is the main input + + +# Database model +class Message(MessageBase, table=True): + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + conversation_id: uuid.UUID = Field(foreign_key="conversation.id", nullable=False) + sender: MessageSender + timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + + conversation: Conversation = Relationship(back_populates="messages") + + +class MessagePublic(MessageBase): + id: uuid.UUID + conversation_id: uuid.UUID + sender: MessageSender + timestamp: datetime.datetime + + +class MessagesPublic(SQLModel): + data: list[MessagePublic] + count: int diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py new file mode 100644 index 0000000000..0519ecba6e --- /dev/null +++ b/backend/app/services/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/backend/app/services/ai_service.py b/backend/app/services/ai_service.py new file mode 100644 index 0000000000..773347b396 --- /dev/null +++ b/backend/app/services/ai_service.py @@ -0,0 +1,50 @@ +import os +from typing import Optional + +from gemini_provider import GeminiAIProvider +from openai_provider import OpenAIProvider +from base import AIProvider +def get_active_provider() -> Optional[AIProvider]: + """ + Trả về một AIProvider đã khởi tạo, dựa trên cấu hình môi trường. + Ưu tiên OpenAI nếu có, fallback về Gemini nếu không. + """ + gemini_key = os.getenv("GEMINI_API_KEY") + if gemini_key: + print("[AI Service] Using GeminiAIProvider.") + return GeminiAIProvider(api_key=gemini_key, model_name="gemini-2.5-flash") + + + openai_key = os.getenv("OPENAI_API_KEY") + if openai_key: + print("[AI Service] Using OpenAIProvider.") + return OpenAIProvider(model_name="gpt-4o") + print("[AI Service] No valid AI provider found.") + return None +def get_ai_response( + *, character: Character, history: list[Message] +) -> str: + """ + Gọi phản hồi AI dựa trên character và lịch sử hội thoại. + + Args: + character: Đối tượng nhân vật (Character) dùng để cung cấp system_prompt. + history: Danh sách các tin nhắn lịch sử, bao gồm user message cuối. + + Returns: + Chuỗi phản hồi từ AI. + """ + provider = get_active_provider() + + if not provider: + return "AI service is not available due to missing API configuration." + + print(f"\n[AI Service] Character: {character.name}") + print(f"[AI Service] Provider: {type(provider).__name__}") + + try: + response = provider.get_completion(character=character, history=history) + return response + except Exception as e: + print(f"[AI Service] Error during response generation: {e}") + return "An error occurred while generating the AI response." \ No newline at end of file diff --git a/backend/app/services/gemini_provider.py b/backend/app/services/gemini_provider.py new file mode 100644 index 0000000000..6b725694a5 --- /dev/null +++ b/backend/app/services/gemini_provider.py @@ -0,0 +1,71 @@ + +import os +import google.generativeai as genai +from typing import Sequence, Any + +from .ai_service import AIProvider, Character, Message + + +class GeminiAIProvider: + def __init__(self, api_key: str, model_name: str = 'gemini-2.5-flash'): + genai.configure(api_key=api_key) + self.model = genai.GenerativeModel(model_name) + self.model_name = model_name + + def _format_history_for_gemini(self, history: Sequence[Message]) -> list[dict[str, Any]]: + gemini_history = [] + for message in history: + role = 'user' if message.sender == 'user' else 'model' + gemini_history.append({'role': role, 'parts': [message.content]}) + return gemini_history + + def get_completion(self, character: Character, history: Sequence[Message]) -> str: + if not history: + return "Error: No conversation history." + + last_message = history[-1] + + if last_message.sender != 'user': + print(f"Warning: Gemini provider expects last message from user, but got {last_message.sender}.") + last_user_message_content = None + context_history_for_api = [] + found_last_user = False + + for msg in reversed(history): + if msg.sender == 'user' and not found_last_user: + last_user_message_content = msg.content + found_last_user = True + elif found_last_user: + role = 'user' if msg.sender == 'user' else 'model' + context_history_for_api.append({'role': role, 'parts': [msg.content]}) + + if not found_last_user or last_user_message_content is None: + return "Error: No user message found in history for Gemini." + + context_history_for_api.reverse() + + formatted_context_history = context_history_for_api + current_user_prompt = last_user_message_content + + else: + formatted_context_history = self._format_history_for_gemini(history[:-1]) + current_user_prompt = last_message.content + + + system_instruction = character.system_prompt + + try: + chat = self.model.start_chat(history=formatted_context_history, + system_instruction=system_instruction) + response = chat.send_message(current_user_prompt) + ai_response_text = response.text + + if not ai_response_text: + print("Warning: Gemini API returned no text.") + return "I cannot generate a response for that request." + + return ai_response_text + + except Exception as e: + print(f"Error calling Gemini API: {e}") + return "Sorry, I encountered an error with the AI service." \ No newline at end of file diff --git a/backend/app/services/openai_provider.py b/backend/app/services/openai_provider.py new file mode 100644 index 0000000000..7f611d8f3a --- /dev/null +++ b/backend/app/services/openai_provider.py @@ -0,0 +1,46 @@ + +import os +from openai import OpenAI +from typing import Sequence, Any + +from .ai_service import AIProvider, Character, Message + + +class OpenAIProvider: + def __init__(self, api_key: str, model_name: str = 'gpt-4.5-turbo'): + self.client = OpenAI(api_key=api_key) + self.model_name = model_name + + def _format_history_for_openai(self, character: Character, history: Sequence[Message]) -> list[dict[str, str]]: + openai_history = [] + if character.system_prompt: + openai_history.append({"role": "system", "content": character.system_prompt}) + + for message in history: + role = 'user' if message.sender == 'user' else 'assistant' + openai_history.append({"role": role, "content": message.content}) + + return openai_history + + def get_completion(self, character: Character, history: Sequence[Message]) -> str: + if not history: + return "Error: No conversation history provided." + messages = self._format_history_for_openai(character, history) + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages + ) + ai_response_text = response.choices[0].message.content + + if not ai_response_text: + print("Warning: OpenAI API returned no text.") + return "I cannot generate a response for that request." + + + return ai_response_text + + except Exception as e: + print(f"Error calling OpenAI API: {e}") + return "Sorry, I encountered an erro." \ No newline at end of file diff --git a/backend/docker-entrypoint.sh b/backend/docker-entrypoint.sh new file mode 100755 index 0000000000..c814eb0dab --- /dev/null +++ b/backend/docker-entrypoint.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +# Exit immediately if a command exits with a non-zero status. +set -e + +# Run database migrations +echo "Running database migrations..." +alembic upgrade head + +echo "Migrations finished. Starting server..." +# Execute the command passed as arguments to this script (which will be the Docker CMD) +# Or, directly execute the intended Uvicorn command if CMD is removed/changed +exec uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-8000} --workers 4 \ No newline at end of file diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 1c77b83ded..08b776c378 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "alembic<2.0.0,>=1.12.1", "httpx<1.0.0,>=0.25.1", "psycopg[binary]<4.0.0,>=3.1.13", + "psycopg2-binary", "sqlmodel<1.0.0,>=0.0.21", # Pin bcrypt until passlib supports the latest "bcrypt==4.0.1",