2026-01-26 11:09:40 -05:00

328 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import calendar
import datetime
from typing import List, cast
from uuid import UUID
import strawberry
from strawberry.types import Info
from asgiref.sync import sync_to_async
from channels.db import database_sync_to_async
from django.db import transaction
from core.graphql.inputs.service import ServiceInput, ServiceUpdateInput, ServiceGenerationInput
from core.graphql.pubsub import pubsub
from core.graphql.types.service import ServiceType
from core.graphql.utils import create_object, update_object, delete_object, _is_holiday
from core.models.account import AccountAddress
from core.models.profile import TeamProfile
from core.models.schedule import Schedule
from core.models.service import Service
from core.services.events import (
publish_service_created, publish_service_deleted,
publish_service_status_changed, publish_service_completed, publish_service_cancelled,
publish_services_bulk_generated, publish_service_dispatched,
)
# Helper to get admin profile
async def _get_admin_profile():
return await sync_to_async(
lambda: TeamProfile.objects.filter(role='ADMIN').first()
)()
# Helper to check if admin is in team member IDs (handles GlobalID objects)
def _admin_in_team_members(admin_id, team_member_ids):
if not team_member_ids or not admin_id:
return False
# team_member_ids may be GlobalID objects with .node_id attribute
member_uuids = []
for mid in team_member_ids:
if hasattr(mid, 'node_id'):
member_uuids.append(str(mid.node_id))
else:
member_uuids.append(str(mid))
return str(admin_id) in member_uuids
# Helper to get old team member IDs from instance
async def _get_old_team_member_ids(instance):
return await sync_to_async(
lambda: set(str(m.id) for m in instance.team_members.all())
)()
@strawberry.type
class Mutation:
@strawberry.mutation(description="Create a new service visit")
async def create_service(self, input: ServiceInput, info: Info) -> ServiceType:
# Exclude m2m id fields from model constructor
payload = {k: v for k, v in input.__dict__.items() if k not in {"team_member_ids"}}
m2m_data = {"team_members": input.team_member_ids}
instance = await create_object(payload, Service, m2m_data)
await pubsub.publish("service_created", instance.id)
# Publish event for notifications
profile = getattr(info.context.request, 'profile', None)
# Get account_id safely via account_address
account_id = None
if instance.account_address_id:
account_address = await sync_to_async(
lambda: AccountAddress.objects.select_related('account').get(id=instance.account_address_id)
)()
account_id = str(account_address.account_id) if account_address.account_id else None
await publish_service_created(
service_id=str(instance.id),
triggered_by=profile,
metadata={
'account_id': account_id,
'date': str(instance.date),
'status': instance.status
}
)
# Check if service was dispatched (admin in team members)
admin = await _get_admin_profile()
if admin and _admin_in_team_members(admin.id, input.team_member_ids):
# Build metadata
account_name = None
account_address_id = None
if instance.account_address_id:
account_address_id = str(instance.account_address_id)
account_address = await sync_to_async(
lambda: AccountAddress.objects.select_related('account').get(id=instance.account_address_id)
)()
account_name = account_address.account.name if account_address.account else None
await publish_service_dispatched(
service_id=str(instance.id),
triggered_by=profile,
metadata={
'service_id': str(instance.id),
'account_address_id': account_address_id,
'account_name': account_name,
'date': str(instance.date),
'status': instance.status
}
)
return cast(ServiceType, instance)
@strawberry.mutation(description="Update an existing service visit")
async def update_service(self, input: ServiceUpdateInput, info: Info) -> ServiceType:
# Get old service data for comparison
old_service = await database_sync_to_async(Service.objects.get)(pk=input.id.node_id)
old_status = old_service.status
# Get old team member IDs before update (for dispatched detection)
old_team_member_ids = await _get_old_team_member_ids(old_service)
# Keep id and non-m2m fields; drop m2m *_ids from the update payload
payload = {k: v for k, v in input.__dict__.items() if k not in {"team_member_ids"}}
m2m_data = {"team_members": getattr(input, "team_member_ids", None)}
instance = await update_object(payload, Service, m2m_data)
await pubsub.publish("service_updated", instance.id)
# Publish events for notifications
profile = getattr(info.context.request, 'profile', None)
# Check for status change
if hasattr(input, 'status') and input.status and input.status != old_status:
# Get account name for notifications
account_name = None
if instance.account_address_id:
account_address = await sync_to_async(
lambda: AccountAddress.objects.select_related('account').get(id=instance.account_address_id)
)()
account_name = account_address.account.name if account_address.account else None
if instance.status == 'COMPLETED':
await publish_service_completed(
service_id=str(instance.id),
triggered_by=profile,
metadata={
'date': str(instance.date),
'account_name': account_name
}
)
elif instance.status == 'CANCELLED':
await publish_service_cancelled(
service_id=str(instance.id),
triggered_by=profile,
metadata={
'date': str(instance.date),
'account_name': account_name
}
)
else:
await publish_service_status_changed(
service_id=str(instance.id),
old_status=old_status,
new_status=instance.status,
triggered_by=profile
)
# Check if admin was newly added (dispatched)
if input.team_member_ids is not None:
admin = await _get_admin_profile()
if admin:
admin_was_in_old = str(admin.id) in old_team_member_ids
admin_in_new = _admin_in_team_members(admin.id, input.team_member_ids)
if not admin_was_in_old and admin_in_new:
# Admin was just added - service was dispatched
account_name = None
account_address_id = None
# Use explicit select_related to safely traverse FK chain
if instance.account_address_id:
account_address_id = str(instance.account_address_id)
account_address = await sync_to_async(
lambda: AccountAddress.objects.select_related('account').get(id=instance.account_address_id)
)()
account_name = account_address.account.name if account_address.account else None
await publish_service_dispatched(
service_id=str(instance.id),
triggered_by=profile,
metadata={
'service_id': str(instance.id),
'account_address_id': account_address_id,
'account_name': account_name,
'date': str(instance.date),
'status': instance.status
}
)
return cast(ServiceType, instance)
@strawberry.mutation(description="Delete an existing service visit")
async def delete_service(self, id: strawberry.ID, info: Info) -> strawberry.ID:
instance = await delete_object(id, Service)
if instance:
await pubsub.publish("service_deleted", id)
# Publish event for notifications
profile = getattr(info.context.request, 'profile', None)
await publish_service_deleted(
service_id=str(id),
triggered_by=profile,
metadata={'date': str(instance.date)}
)
return id
raise ValueError(f"Service with ID {id} does not exist")
@strawberry.mutation(description="Generate service visits for a given month (all-or-nothing)")
async def generate_services_by_month(self, input: ServiceGenerationInput, info: Info) -> List[ServiceType]:
if input.month < 1 or input.month > 12:
raise ValueError("month must be in range 1..12")
year = input.year
month_num = input.month
# Fetch the AccountAddress and Schedule by their IDs
address = await AccountAddress.objects.aget(id=input.account_address_id.node_id)
schedule = await Schedule.objects.aget(id=input.schedule_id.node_id)
# Optional but recommended: ensure the schedule belongs to this address
if getattr(schedule, "account_address_id", None) != address.id:
raise ValueError("Schedule does not belong to the provided account address")
cal = calendar.Calendar(firstweekday=calendar.MONDAY)
days_in_month = [d for d in cal.itermonthdates(year, month_num) if d.month == month_num]
def is_within_schedule(dt: datetime.date) -> bool:
if dt < schedule.start_date:
return False
if schedule.end_date and dt > schedule.end_date:
return False
return True
def day_flag(weekday: int) -> bool:
return [
schedule.monday_service,
schedule.tuesday_service,
schedule.wednesday_service,
schedule.thursday_service,
schedule.friday_service,
schedule.saturday_service,
schedule.sunday_service,
][weekday]
targets: list[tuple[datetime.date, str | None]] = []
for day in days_in_month:
if not is_within_schedule(day):
continue
if _is_holiday(day):
continue
wd = day.weekday() # Mon=0...Sun=6
schedule_today = False
note: str | None = None
if 0 <= wd <= 3:
schedule_today = day_flag(wd)
elif wd == 4:
# Friday
if schedule.weekend_service:
schedule_today = True
note = "Weekend service window (FriSun)"
else:
schedule_today = day_flag(wd)
else:
# Sat-Sun
if schedule.weekend_service:
schedule_today = False
else:
schedule_today = day_flag(wd)
if schedule_today:
targets.append((day, note))
if not targets:
return cast(List[ServiceType], [])
# Run the transactional DB work in a sync thread
def _create_services_sync(
account_address_id: UUID,
targets_local: list[tuple[datetime.date, str | None]]
) -> List[Service]:
with transaction.atomic():
if Service.objects.filter(
account_address_id=account_address_id,
date__in=[svc_day for (svc_day, _) in targets_local]
).exists():
raise ValueError(
"One or more services already exist for the selected month; nothing was created."
)
to_create = [
Service(
account_address_id=account_address_id,
date=svc_day,
notes=(svc_note or None),
)
for (svc_day, svc_note) in targets_local
]
return Service.objects.bulk_create(to_create)
created_instances: List[Service] = await sync_to_async(
_create_services_sync,
thread_sensitive=True,
)(address.id, targets)
for obj in created_instances:
await pubsub.publish("service_created", obj.id)
# Publish bulk generation event for notifications
if created_instances:
profile = getattr(info.context.request, 'profile', None)
month_name = datetime.date(year, month_num, 1).strftime('%B %Y')
await publish_services_bulk_generated(
account_id=str(address.account_id),
count=len(created_instances),
month=month_name,
triggered_by=profile
)
return cast(List[ServiceType], created_instances)