328 lines
13 KiB
Python
328 lines
13 KiB
Python
import calendar
|
||
import datetime
|
||
from typing import List, cast
|
||
from uuid import UUID
|
||
import strawberry
|
||
from strawberry.types import Info
|
||
from asgiref.sync import sync_to_async
|
||
from channels.db import database_sync_to_async
|
||
from django.db import transaction
|
||
from core.graphql.inputs.service import ServiceInput, ServiceUpdateInput, ServiceGenerationInput
|
||
from core.graphql.pubsub import pubsub
|
||
from core.graphql.types.service import ServiceType
|
||
from core.graphql.utils import create_object, update_object, delete_object, _is_holiday
|
||
from core.models.account import AccountAddress
|
||
from core.models.profile import TeamProfile
|
||
from core.models.schedule import Schedule
|
||
from core.models.service import Service
|
||
from core.services.events import (
|
||
publish_service_created, publish_service_deleted,
|
||
publish_service_status_changed, publish_service_completed, publish_service_cancelled,
|
||
publish_services_bulk_generated, publish_service_dispatched,
|
||
)
|
||
|
||
|
||
# Helper to get admin profile
|
||
async def _get_admin_profile():
|
||
return await sync_to_async(
|
||
lambda: TeamProfile.objects.filter(role='ADMIN').first()
|
||
)()
|
||
|
||
|
||
# Helper to check if admin is in team member IDs (handles GlobalID objects)
|
||
def _admin_in_team_members(admin_id, team_member_ids):
|
||
if not team_member_ids or not admin_id:
|
||
return False
|
||
# team_member_ids may be GlobalID objects with .node_id attribute
|
||
member_uuids = []
|
||
for mid in team_member_ids:
|
||
if hasattr(mid, 'node_id'):
|
||
member_uuids.append(str(mid.node_id))
|
||
else:
|
||
member_uuids.append(str(mid))
|
||
return str(admin_id) in member_uuids
|
||
|
||
|
||
# Helper to get old team member IDs from instance
|
||
async def _get_old_team_member_ids(instance):
|
||
return await sync_to_async(
|
||
lambda: set(str(m.id) for m in instance.team_members.all())
|
||
)()
|
||
|
||
|
||
@strawberry.type
|
||
class Mutation:
|
||
@strawberry.mutation(description="Create a new service visit")
|
||
async def create_service(self, input: ServiceInput, info: Info) -> ServiceType:
|
||
# Exclude m2m id fields from model constructor
|
||
payload = {k: v for k, v in input.__dict__.items() if k not in {"team_member_ids"}}
|
||
m2m_data = {"team_members": input.team_member_ids}
|
||
instance = await create_object(payload, Service, m2m_data)
|
||
await pubsub.publish("service_created", instance.id)
|
||
|
||
# Publish event for notifications
|
||
profile = getattr(info.context.request, 'profile', None)
|
||
# Get account_id safely via account_address
|
||
account_id = None
|
||
if instance.account_address_id:
|
||
account_address = await sync_to_async(
|
||
lambda: AccountAddress.objects.select_related('account').get(id=instance.account_address_id)
|
||
)()
|
||
account_id = str(account_address.account_id) if account_address.account_id else None
|
||
|
||
await publish_service_created(
|
||
service_id=str(instance.id),
|
||
triggered_by=profile,
|
||
metadata={
|
||
'account_id': account_id,
|
||
'date': str(instance.date),
|
||
'status': instance.status
|
||
}
|
||
)
|
||
|
||
# Check if service was dispatched (admin in team members)
|
||
admin = await _get_admin_profile()
|
||
if admin and _admin_in_team_members(admin.id, input.team_member_ids):
|
||
# Build metadata
|
||
account_name = None
|
||
account_address_id = None
|
||
if instance.account_address_id:
|
||
account_address_id = str(instance.account_address_id)
|
||
account_address = await sync_to_async(
|
||
lambda: AccountAddress.objects.select_related('account').get(id=instance.account_address_id)
|
||
)()
|
||
account_name = account_address.account.name if account_address.account else None
|
||
|
||
await publish_service_dispatched(
|
||
service_id=str(instance.id),
|
||
triggered_by=profile,
|
||
metadata={
|
||
'service_id': str(instance.id),
|
||
'account_address_id': account_address_id,
|
||
'account_name': account_name,
|
||
'date': str(instance.date),
|
||
'status': instance.status
|
||
}
|
||
)
|
||
|
||
return cast(ServiceType, instance)
|
||
|
||
@strawberry.mutation(description="Update an existing service visit")
|
||
async def update_service(self, input: ServiceUpdateInput, info: Info) -> ServiceType:
|
||
# Get old service data for comparison
|
||
old_service = await database_sync_to_async(Service.objects.get)(pk=input.id.node_id)
|
||
old_status = old_service.status
|
||
|
||
# Get old team member IDs before update (for dispatched detection)
|
||
old_team_member_ids = await _get_old_team_member_ids(old_service)
|
||
|
||
# Keep id and non-m2m fields; drop m2m *_ids from the update payload
|
||
payload = {k: v for k, v in input.__dict__.items() if k not in {"team_member_ids"}}
|
||
m2m_data = {"team_members": getattr(input, "team_member_ids", None)}
|
||
instance = await update_object(payload, Service, m2m_data)
|
||
await pubsub.publish("service_updated", instance.id)
|
||
|
||
# Publish events for notifications
|
||
profile = getattr(info.context.request, 'profile', None)
|
||
|
||
# Check for status change
|
||
if hasattr(input, 'status') and input.status and input.status != old_status:
|
||
# Get account name for notifications
|
||
account_name = None
|
||
if instance.account_address_id:
|
||
account_address = await sync_to_async(
|
||
lambda: AccountAddress.objects.select_related('account').get(id=instance.account_address_id)
|
||
)()
|
||
account_name = account_address.account.name if account_address.account else None
|
||
|
||
if instance.status == 'COMPLETED':
|
||
await publish_service_completed(
|
||
service_id=str(instance.id),
|
||
triggered_by=profile,
|
||
metadata={
|
||
'date': str(instance.date),
|
||
'account_name': account_name
|
||
}
|
||
)
|
||
elif instance.status == 'CANCELLED':
|
||
await publish_service_cancelled(
|
||
service_id=str(instance.id),
|
||
triggered_by=profile,
|
||
metadata={
|
||
'date': str(instance.date),
|
||
'account_name': account_name
|
||
}
|
||
)
|
||
else:
|
||
await publish_service_status_changed(
|
||
service_id=str(instance.id),
|
||
old_status=old_status,
|
||
new_status=instance.status,
|
||
triggered_by=profile
|
||
)
|
||
|
||
# Check if admin was newly added (dispatched)
|
||
if input.team_member_ids is not None:
|
||
admin = await _get_admin_profile()
|
||
if admin:
|
||
admin_was_in_old = str(admin.id) in old_team_member_ids
|
||
admin_in_new = _admin_in_team_members(admin.id, input.team_member_ids)
|
||
|
||
if not admin_was_in_old and admin_in_new:
|
||
# Admin was just added - service was dispatched
|
||
account_name = None
|
||
account_address_id = None
|
||
# Use explicit select_related to safely traverse FK chain
|
||
if instance.account_address_id:
|
||
account_address_id = str(instance.account_address_id)
|
||
account_address = await sync_to_async(
|
||
lambda: AccountAddress.objects.select_related('account').get(id=instance.account_address_id)
|
||
)()
|
||
account_name = account_address.account.name if account_address.account else None
|
||
|
||
await publish_service_dispatched(
|
||
service_id=str(instance.id),
|
||
triggered_by=profile,
|
||
metadata={
|
||
'service_id': str(instance.id),
|
||
'account_address_id': account_address_id,
|
||
'account_name': account_name,
|
||
'date': str(instance.date),
|
||
'status': instance.status
|
||
}
|
||
)
|
||
|
||
return cast(ServiceType, instance)
|
||
|
||
@strawberry.mutation(description="Delete an existing service visit")
|
||
async def delete_service(self, id: strawberry.ID, info: Info) -> strawberry.ID:
|
||
instance = await delete_object(id, Service)
|
||
if instance:
|
||
await pubsub.publish("service_deleted", id)
|
||
|
||
# Publish event for notifications
|
||
profile = getattr(info.context.request, 'profile', None)
|
||
await publish_service_deleted(
|
||
service_id=str(id),
|
||
triggered_by=profile,
|
||
metadata={'date': str(instance.date)}
|
||
)
|
||
|
||
return id
|
||
raise ValueError(f"Service with ID {id} does not exist")
|
||
|
||
@strawberry.mutation(description="Generate service visits for a given month (all-or-nothing)")
|
||
async def generate_services_by_month(self, input: ServiceGenerationInput, info: Info) -> List[ServiceType]:
|
||
if input.month < 1 or input.month > 12:
|
||
raise ValueError("month must be in range 1..12")
|
||
|
||
year = input.year
|
||
month_num = input.month
|
||
|
||
# Fetch the AccountAddress and Schedule by their IDs
|
||
address = await AccountAddress.objects.aget(id=input.account_address_id.node_id)
|
||
schedule = await Schedule.objects.aget(id=input.schedule_id.node_id)
|
||
|
||
# Optional but recommended: ensure the schedule belongs to this address
|
||
if getattr(schedule, "account_address_id", None) != address.id:
|
||
raise ValueError("Schedule does not belong to the provided account address")
|
||
|
||
cal = calendar.Calendar(firstweekday=calendar.MONDAY)
|
||
days_in_month = [d for d in cal.itermonthdates(year, month_num) if d.month == month_num]
|
||
|
||
def is_within_schedule(dt: datetime.date) -> bool:
|
||
if dt < schedule.start_date:
|
||
return False
|
||
if schedule.end_date and dt > schedule.end_date:
|
||
return False
|
||
return True
|
||
|
||
def day_flag(weekday: int) -> bool:
|
||
return [
|
||
schedule.monday_service,
|
||
schedule.tuesday_service,
|
||
schedule.wednesday_service,
|
||
schedule.thursday_service,
|
||
schedule.friday_service,
|
||
schedule.saturday_service,
|
||
schedule.sunday_service,
|
||
][weekday]
|
||
|
||
targets: list[tuple[datetime.date, str | None]] = []
|
||
for day in days_in_month:
|
||
if not is_within_schedule(day):
|
||
continue
|
||
if _is_holiday(day):
|
||
continue
|
||
|
||
wd = day.weekday() # Mon=0...Sun=6
|
||
schedule_today = False
|
||
note: str | None = None
|
||
|
||
if 0 <= wd <= 3:
|
||
schedule_today = day_flag(wd)
|
||
elif wd == 4:
|
||
# Friday
|
||
if schedule.weekend_service:
|
||
schedule_today = True
|
||
note = "Weekend service window (Fri–Sun)"
|
||
else:
|
||
schedule_today = day_flag(wd)
|
||
else:
|
||
# Sat-Sun
|
||
if schedule.weekend_service:
|
||
schedule_today = False
|
||
else:
|
||
schedule_today = day_flag(wd)
|
||
|
||
if schedule_today:
|
||
targets.append((day, note))
|
||
|
||
if not targets:
|
||
return cast(List[ServiceType], [])
|
||
|
||
# Run the transactional DB work in a sync thread
|
||
def _create_services_sync(
|
||
account_address_id: UUID,
|
||
targets_local: list[tuple[datetime.date, str | None]]
|
||
) -> List[Service]:
|
||
with transaction.atomic():
|
||
if Service.objects.filter(
|
||
account_address_id=account_address_id,
|
||
date__in=[svc_day for (svc_day, _) in targets_local]
|
||
).exists():
|
||
raise ValueError(
|
||
"One or more services already exist for the selected month; nothing was created."
|
||
)
|
||
|
||
to_create = [
|
||
Service(
|
||
account_address_id=account_address_id,
|
||
date=svc_day,
|
||
notes=(svc_note or None),
|
||
)
|
||
for (svc_day, svc_note) in targets_local
|
||
]
|
||
return Service.objects.bulk_create(to_create)
|
||
|
||
created_instances: List[Service] = await sync_to_async(
|
||
_create_services_sync,
|
||
thread_sensitive=True,
|
||
)(address.id, targets)
|
||
|
||
for obj in created_instances:
|
||
await pubsub.publish("service_created", obj.id)
|
||
|
||
# Publish bulk generation event for notifications
|
||
if created_instances:
|
||
profile = getattr(info.context.request, 'profile', None)
|
||
month_name = datetime.date(year, month_num, 1).strftime('%B %Y')
|
||
await publish_services_bulk_generated(
|
||
account_id=str(address.account_id),
|
||
count=len(created_instances),
|
||
month=month_name,
|
||
triggered_by=profile
|
||
)
|
||
|
||
return cast(List[ServiceType], created_instances)
|