2026-01-26 11:09:40 -05:00

305 lines
10 KiB
Python

import base64
import binascii
import calendar
import datetime
from functools import lru_cache
from typing import Union, Dict, Any, Iterable, List, Optional, Type, TypeVar, Awaitable
from channels.db import database_sync_to_async
from django.db.models import Model
from strawberry.relay import GlobalID
ModelType = TypeVar('ModelType', bound=Model)
def _decode_global_id(gid: Union[str, GlobalID, None]) -> Optional[str]:
"""
Decode a Global ID to extract the node ID.
Handles both GlobalID objects and base64 encoded strings.
"""
if gid is None:
return None
if isinstance(gid, GlobalID):
return gid.node_id
try:
decoded = base64.b64decode(gid).decode("utf-8")
if ":" in decoded:
return decoded.split(":", 1)[1]
except (binascii.Error, UnicodeDecodeError):
pass
return gid
def _decode_global_ids(ids: Optional[Iterable[str]]) -> Optional[List[str]]:
"""Decode a list of Global IDs."""
if ids is None:
return None
return [_decode_global_id(x) for x in ids]
def _to_dict(input_data: Union[Dict[str, Any], object]) -> Dict[str, Any]:
"""
Convert input data to a dictionary.
Handles both dict objects and objects with attributes.
"""
if isinstance(input_data, dict):
return dict(input_data)
try:
return {k: v for k, v in vars(input_data).items() if not k.startswith("_")}
except TypeError:
return {}
def _decode_scalar_ids_inplace(data: Dict[str, Any]) -> None:
"""
Decode Global IDs in-place for scalar ID fields.
Handles 'id' field and any field ending with '_id'.
"""
if "id" in data and data["id"] is not None:
data["id"] = _decode_global_id(data["id"])
for k, v in list(data.items()):
if k.endswith("_id") and v is not None:
data[k] = _decode_global_id(v)
def _filter_write_fields(raw: Dict[str, Any], m2m_data: Optional[dict] = None) -> Dict[str, Any]:
"""
Remove fields that shouldn't be written directly to the model:
- many-to-many fields handled separately (m2m_data keys)
- *_ids convenience arrays that are processed elsewhere
- id (primary key)
"""
m2m_fields = list(m2m_data.keys()) if m2m_data else []
exclude_keys = set(m2m_fields) | {k for k in raw.keys() if k.endswith("_ids")} | {"id"}
return {k: v for k, v in raw.items() if k not in exclude_keys}
def _observed(date: datetime.date) -> datetime.date:
"""
Calculate the observed date for holidays.
If holiday is Saturday -> observe Friday; if Sunday -> observe Monday
"""
wd = date.weekday()
if wd == 5: # Saturday
return date - datetime.timedelta(days=1)
if wd == 6: # Sunday
return date + datetime.timedelta(days=1)
return date
def _nth_weekday_of_month(year: int, month: int, weekday: int, n: int) -> datetime.date:
"""
Find the nth occurrence of a weekday in a month.
weekday: Mon=0...Sun=6, n: 1..5 (e.g., 4th Thursday)
"""
count = 0
for day in range(1, 32):
try:
d = datetime.date(year, month, day)
except ValueError:
break
if d.weekday() == weekday:
count += 1
if count == n:
return d
raise ValueError("Invalid nth weekday request")
def _last_weekday_of_month(year: int, month: int, weekday: int) -> datetime.date:
"""Find the last occurrence of a weekday in a month (e.g., last Monday in May)."""
last_day = calendar.monthrange(year, month)[1]
for day in range(last_day, 0, -1):
d = datetime.date(year, month, day)
if d.weekday() == weekday:
return d
raise ValueError("Invalid last weekday request")
@lru_cache(maxsize=64)
def _holiday_set(year: int) -> set[datetime.date]:
"""Generate a set of federal holidays for the given year."""
holidays: set[datetime.date] = set()
# New Year's Day (observed)
holidays.add(_observed(datetime.date(year, 1, 1)))
# Memorial Day (last Monday in May)
holidays.add(_last_weekday_of_month(year, 5, calendar.MONDAY))
# Independence Day (observed)
holidays.add(_observed(datetime.date(year, 7, 4)))
# Labor Day (first Monday in September)
holidays.add(_nth_weekday_of_month(year, 9, calendar.MONDAY, 1))
# Thanksgiving Day (4th Thursday in November)
holidays.add(_nth_weekday_of_month(year, 11, calendar.THURSDAY, 4))
# Christmas Day (observed)
holidays.add(_observed(datetime.date(year, 12, 25)))
return holidays
def _is_holiday(date: datetime.date) -> bool:
"""Check if a date is a federal holiday."""
return date in _holiday_set(date.year)
def _extract_id(payload: Union[dict, str, int]) -> str:
"""Extract ID from various payload formats."""
return str(payload.get("id")) if isinstance(payload, dict) else str(payload)
# Internal synchronous implementations
def _create_object_sync(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> ModelType:
"""Synchronous implementation of create_object."""
raw = _to_dict(input_data)
_decode_scalar_ids_inplace(raw)
data = _filter_write_fields(raw, m2m_data)
instance = model_class.objects.create(**data)
# Handle many-to-many relationships
if m2m_data:
for field, values in m2m_data.items():
if values is not None:
getattr(instance, field).set(_decode_global_ids(values))
return instance
def _update_object_sync(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> ModelType:
"""Synchronous implementation of update_object."""
raw = _to_dict(input_data)
_decode_scalar_ids_inplace(raw)
try:
instance = model_class.objects.get(pk=raw.get("id"))
data = _filter_write_fields(raw, m2m_data)
# Update only provided fields
update_fields = []
for field, value in data.items():
if value is not None:
setattr(instance, field, value)
update_fields.append(field)
if update_fields:
instance.save(update_fields=update_fields)
else:
instance.save()
# Handle many-to-many relationships (only update if explicitly provided)
if m2m_data:
for field, values in m2m_data.items():
if values is not None:
getattr(instance, field).set(_decode_global_ids(values))
# None means "not provided" - leave unchanged
# To clear a relationship, pass an empty array []
return instance
except model_class.DoesNotExist:
raise ValueError(f"{model_class.__name__} with ID {raw.get('id')} does not exist")
def _delete_object_sync(object_id, model_class: Type[ModelType]) -> Optional[ModelType]:
"""Synchronous implementation of delete_object."""
pk = _decode_global_id(object_id)
try:
instance = model_class.objects.get(pk=pk)
instance.delete()
return instance
except model_class.DoesNotExist:
return None
# Public async functions with explicit typing for IDE support
def create_object(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> Awaitable[ModelType]:
"""
Create a new model instance asynchronously.
Args:
input_data: Input data (dict or object with attributes)
model_class: Django model class
m2m_data: Optional dictionary of many-to-many field data
Returns:
Awaitable that resolves to created model instance
"""
return database_sync_to_async(_create_object_sync)(input_data, model_class, m2m_data)
def update_object(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> Awaitable[ModelType]:
"""
Update an existing model instance asynchronously.
Args:
input_data: Input data (dict or object with attributes) - must include 'id'
model_class: Django model class
m2m_data: Optional dictionary of many-to-many field data
Returns:
Awaitable that resolves to updated model instance
Raises:
ValueError: If an object with the given ID doesn't exist
"""
return database_sync_to_async(_update_object_sync)(input_data, model_class, m2m_data)
def delete_object(object_id, model_class: Type[ModelType]) -> Awaitable[Optional[ModelType]]:
"""
Delete a model instance asynchronously.
Args:
object_id: Global ID or primary key of the object to delete
model_class: Django model class
Returns:
Awaitable that resolves to deleted model instance if found, None if not found
"""
return database_sync_to_async(_delete_object_sync)(object_id, model_class)
def _get_conversations_for_entity_sync(entity_instance: Model) -> List:
"""
Synchronous implementation to get conversations linked to an entity via GenericForeignKey.
Args:
entity_instance: The model instance (e.g., Service, Project, Account, Customer)
Returns:
List of Conversation objects linked to this entity
"""
from django.contrib.contenttypes.models import ContentType
from core.models.messaging import Conversation
content_type = ContentType.objects.get_for_model(type(entity_instance))
return list(Conversation.objects.filter(
entity_content_type=content_type,
entity_object_id=entity_instance.id
))
def get_conversations_for_entity(entity_instance: Model) -> Awaitable[List]:
"""
Get all conversations linked to an entity asynchronously.
This helper handles the GenericForeignKey relationship pattern for conversations.
Use this in your GraphQL types to easily add a conversations field.
Args:
entity_instance: The model instance (e.g., Service, Project, Account, Customer)
Returns:
Awaitable that resolves to list of Conversation objects
Example usage in GraphQL type:
@strawberry.field
async def conversations(self) -> List[ConversationType]:
return await get_conversations_for_entity(self)
"""
return database_sync_to_async(_get_conversations_for_entity_sync)(entity_instance)