262 lines
9.6 KiB
Python
262 lines
9.6 KiB
Python
"""
|
|
WebSocket consumer for AI chat.
|
|
"""
|
|
import json
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
|
from channels.db import database_sync_to_async
|
|
|
|
from core.models import TeamProfile
|
|
from core.models.chat import ChatConversation
|
|
from core.chat.service import ChatService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatConsumer(AsyncWebsocketConsumer):
|
|
"""
|
|
WebSocket consumer for AI chat with Claude.
|
|
|
|
Handles:
|
|
- Connection authentication (via OryWebSocketAuthMiddleware)
|
|
- Message streaming
|
|
- Conversation history
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.profile: Optional[TeamProfile] = None
|
|
self.chat_service: Optional[ChatService] = None
|
|
|
|
async def connect(self):
|
|
"""Handle WebSocket connection."""
|
|
# Get profile from scope (set by OryWebSocketAuthMiddleware)
|
|
self.profile = self.scope.get('profile')
|
|
|
|
if not self.profile:
|
|
logger.warning("Chat connection rejected - no profile")
|
|
await self.close(code=4401)
|
|
return
|
|
|
|
# Only allow team profiles
|
|
if not isinstance(self.profile, TeamProfile):
|
|
logger.warning("Chat connection rejected - not a team profile")
|
|
await self.close(code=4403)
|
|
return
|
|
|
|
# Initialize chat service
|
|
self.chat_service = ChatService(self.profile)
|
|
|
|
await self.accept()
|
|
|
|
# Send welcome message
|
|
await self.send_json({
|
|
"type": "connected",
|
|
"user": {
|
|
"id": str(self.profile.id),
|
|
"name": f"{self.profile.first_name} {self.profile.last_name}".strip(),
|
|
"email": self.profile.email,
|
|
}
|
|
})
|
|
|
|
# Send role-based intro message
|
|
await self.send_json({
|
|
"type": "intro",
|
|
"content": self._get_intro_message()
|
|
})
|
|
|
|
def _get_intro_message(self) -> str:
|
|
"""Get intro message based on user role."""
|
|
first_name = self.profile.first_name or "there"
|
|
role = getattr(self.profile, 'role', None)
|
|
|
|
if role == 'ADMIN':
|
|
return (
|
|
f"Hey {first_name}! I'm your Nexus assistant. As an admin, I can help you with:\n\n"
|
|
"• **View & manage** all services, projects, and team assignments\n"
|
|
"• **Create & schedule** new services and projects\n"
|
|
"• **Access reports** and system statistics\n"
|
|
"• **Manage notifications** and team settings\n\n"
|
|
"What would you like to do today?"
|
|
)
|
|
elif role == 'TEAM_LEADER':
|
|
return (
|
|
f"Hey {first_name}! I'm your Nexus assistant. As a team leader, I can help you with:\n\n"
|
|
"• **View schedules** for you and your team\n"
|
|
"• **Check service & project details** across accounts\n"
|
|
"• **Track work sessions** and task completion\n"
|
|
"• **Access customer and account information**\n\n"
|
|
"What can I help you with?"
|
|
)
|
|
else: # TEAM_MEMBER
|
|
return (
|
|
f"Hey {first_name}! I'm your Nexus assistant. I can help you with:\n\n"
|
|
"• **View your schedule** and assigned work\n"
|
|
"• **Check service & project details** for your assignments\n"
|
|
"• **Manage work sessions** and mark tasks complete\n"
|
|
"• **Track your notifications**\n\n"
|
|
"What do you need help with?"
|
|
)
|
|
|
|
async def disconnect(self, close_code):
|
|
"""Handle WebSocket disconnection."""
|
|
logger.info(f"Chat disconnected: {close_code}")
|
|
|
|
async def receive(self, text_data):
|
|
"""Handle incoming WebSocket messages."""
|
|
try:
|
|
data = json.loads(text_data)
|
|
except json.JSONDecodeError:
|
|
await self.send_json({"type": "error", "error": "Invalid JSON"})
|
|
return
|
|
|
|
message_type = data.get("type")
|
|
|
|
if message_type == "chat":
|
|
await self.handle_chat(data)
|
|
elif message_type == "history":
|
|
await self.handle_history(data)
|
|
elif message_type == "conversations":
|
|
await self.handle_list_conversations()
|
|
elif message_type == "new_conversation":
|
|
await self.handle_new_conversation()
|
|
else:
|
|
await self.send_json({"type": "error", "error": f"Unknown message type: {message_type}"})
|
|
|
|
async def handle_chat(self, data):
|
|
"""Handle a chat message."""
|
|
content = data.get("content", "").strip()
|
|
conversation_id = data.get("conversation_id")
|
|
|
|
if not content:
|
|
await self.send_json({"type": "error", "error": "Message content is required"})
|
|
return
|
|
|
|
try:
|
|
# Get or create conversation
|
|
conversation = await self.chat_service.get_or_create_conversation(conversation_id)
|
|
|
|
# If new conversation, send conversation_created event
|
|
if not conversation_id:
|
|
await self.send_json({
|
|
"type": "conversation_created",
|
|
"conversation": {
|
|
"id": str(conversation.id),
|
|
"title": conversation.title or "New Conversation",
|
|
"created_at": conversation.created_at.isoformat(),
|
|
}
|
|
})
|
|
|
|
# Stream response
|
|
async for event in self.chat_service.stream_response(conversation, content):
|
|
await self.send_json(event)
|
|
|
|
except Exception as e:
|
|
logger.exception("Error handling chat message")
|
|
await self.send_json({"type": "error", "error": str(e)})
|
|
|
|
async def handle_history(self, data):
|
|
"""Handle request for conversation history."""
|
|
conversation_id = data.get("conversation_id")
|
|
|
|
if not conversation_id:
|
|
await self.send_json({"type": "error", "error": "conversation_id is required"})
|
|
return
|
|
|
|
try:
|
|
@database_sync_to_async
|
|
def get_conversation_with_messages():
|
|
try:
|
|
conv = ChatConversation.objects.prefetch_related('messages').get(
|
|
id=conversation_id,
|
|
team_profile=self.profile,
|
|
is_active=True
|
|
)
|
|
return {
|
|
"id": str(conv.id),
|
|
"title": conv.title or "New Conversation",
|
|
"created_at": conv.created_at.isoformat(),
|
|
"messages": [
|
|
{
|
|
"id": str(msg.id),
|
|
"role": msg.role,
|
|
"content": msg.content,
|
|
"tool_calls": msg.tool_calls,
|
|
"tool_results": msg.tool_results,
|
|
"created_at": msg.created_at.isoformat(),
|
|
}
|
|
for msg in conv.messages.all().order_by('created_at')
|
|
]
|
|
}
|
|
except ChatConversation.DoesNotExist:
|
|
return None
|
|
|
|
conversation = await get_conversation_with_messages()
|
|
|
|
if conversation:
|
|
await self.send_json({
|
|
"type": "history",
|
|
"conversation": conversation
|
|
})
|
|
else:
|
|
await self.send_json({"type": "error", "error": "Conversation not found"})
|
|
|
|
except Exception as e:
|
|
logger.exception("Error fetching history")
|
|
await self.send_json({"type": "error", "error": str(e)})
|
|
|
|
async def handle_list_conversations(self):
|
|
"""Handle request to list all conversations."""
|
|
try:
|
|
@database_sync_to_async
|
|
def get_conversations():
|
|
convs = ChatConversation.objects.filter(
|
|
team_profile=self.profile,
|
|
is_active=True
|
|
).order_by('-updated_at')[:50]
|
|
|
|
return [
|
|
{
|
|
"id": str(conv.id),
|
|
"title": conv.title or "New Conversation",
|
|
"created_at": conv.created_at.isoformat(),
|
|
"updated_at": conv.updated_at.isoformat(),
|
|
}
|
|
for conv in convs
|
|
]
|
|
|
|
conversations = await get_conversations()
|
|
|
|
await self.send_json({
|
|
"type": "conversations",
|
|
"conversations": conversations
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.exception("Error listing conversations")
|
|
await self.send_json({"type": "error", "error": str(e)})
|
|
|
|
async def handle_new_conversation(self):
|
|
"""Handle request to create a new conversation."""
|
|
try:
|
|
conversation = await self.chat_service.get_or_create_conversation()
|
|
|
|
await self.send_json({
|
|
"type": "conversation_created",
|
|
"conversation": {
|
|
"id": str(conversation.id),
|
|
"title": conversation.title or "New Conversation",
|
|
"created_at": conversation.created_at.isoformat(),
|
|
}
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.exception("Error creating conversation")
|
|
await self.send_json({"type": "error", "error": str(e)})
|
|
|
|
async def send_json(self, data):
|
|
"""Send JSON data to the WebSocket."""
|
|
await self.send(text_data=json.dumps(data))
|