305 lines
10 KiB
Python
305 lines
10 KiB
Python
import base64
|
|
import binascii
|
|
import calendar
|
|
import datetime
|
|
from functools import lru_cache
|
|
from typing import Union, Dict, Any, Iterable, List, Optional, Type, TypeVar, Awaitable
|
|
from channels.db import database_sync_to_async
|
|
from django.db.models import Model
|
|
from strawberry.relay import GlobalID
|
|
|
|
ModelType = TypeVar('ModelType', bound=Model)
|
|
|
|
|
|
def _decode_global_id(gid: Union[str, GlobalID, None]) -> Optional[str]:
|
|
"""
|
|
Decode a Global ID to extract the node ID.
|
|
Handles both GlobalID objects and base64 encoded strings.
|
|
"""
|
|
if gid is None:
|
|
return None
|
|
if isinstance(gid, GlobalID):
|
|
return gid.node_id
|
|
try:
|
|
decoded = base64.b64decode(gid).decode("utf-8")
|
|
if ":" in decoded:
|
|
return decoded.split(":", 1)[1]
|
|
except (binascii.Error, UnicodeDecodeError):
|
|
pass
|
|
return gid
|
|
|
|
|
|
def _decode_global_ids(ids: Optional[Iterable[str]]) -> Optional[List[str]]:
|
|
"""Decode a list of Global IDs."""
|
|
if ids is None:
|
|
return None
|
|
return [_decode_global_id(x) for x in ids]
|
|
|
|
|
|
def _to_dict(input_data: Union[Dict[str, Any], object]) -> Dict[str, Any]:
|
|
"""
|
|
Convert input data to a dictionary.
|
|
Handles both dict objects and objects with attributes.
|
|
"""
|
|
if isinstance(input_data, dict):
|
|
return dict(input_data)
|
|
try:
|
|
return {k: v for k, v in vars(input_data).items() if not k.startswith("_")}
|
|
except TypeError:
|
|
return {}
|
|
|
|
|
|
def _decode_scalar_ids_inplace(data: Dict[str, Any]) -> None:
|
|
"""
|
|
Decode Global IDs in-place for scalar ID fields.
|
|
Handles 'id' field and any field ending with '_id'.
|
|
"""
|
|
if "id" in data and data["id"] is not None:
|
|
data["id"] = _decode_global_id(data["id"])
|
|
for k, v in list(data.items()):
|
|
if k.endswith("_id") and v is not None:
|
|
data[k] = _decode_global_id(v)
|
|
|
|
|
|
def _filter_write_fields(raw: Dict[str, Any], m2m_data: Optional[dict] = None) -> Dict[str, Any]:
|
|
"""
|
|
Remove fields that shouldn't be written directly to the model:
|
|
- many-to-many fields handled separately (m2m_data keys)
|
|
- *_ids convenience arrays that are processed elsewhere
|
|
- id (primary key)
|
|
"""
|
|
m2m_fields = list(m2m_data.keys()) if m2m_data else []
|
|
exclude_keys = set(m2m_fields) | {k for k in raw.keys() if k.endswith("_ids")} | {"id"}
|
|
return {k: v for k, v in raw.items() if k not in exclude_keys}
|
|
|
|
|
|
def _observed(date: datetime.date) -> datetime.date:
|
|
"""
|
|
Calculate the observed date for holidays.
|
|
If holiday is Saturday -> observe Friday; if Sunday -> observe Monday
|
|
"""
|
|
wd = date.weekday()
|
|
if wd == 5: # Saturday
|
|
return date - datetime.timedelta(days=1)
|
|
if wd == 6: # Sunday
|
|
return date + datetime.timedelta(days=1)
|
|
return date
|
|
|
|
|
|
def _nth_weekday_of_month(year: int, month: int, weekday: int, n: int) -> datetime.date:
|
|
"""
|
|
Find the nth occurrence of a weekday in a month.
|
|
weekday: Mon=0...Sun=6, n: 1..5 (e.g., 4th Thursday)
|
|
"""
|
|
count = 0
|
|
for day in range(1, 32):
|
|
try:
|
|
d = datetime.date(year, month, day)
|
|
except ValueError:
|
|
break
|
|
if d.weekday() == weekday:
|
|
count += 1
|
|
if count == n:
|
|
return d
|
|
raise ValueError("Invalid nth weekday request")
|
|
|
|
|
|
def _last_weekday_of_month(year: int, month: int, weekday: int) -> datetime.date:
|
|
"""Find the last occurrence of a weekday in a month (e.g., last Monday in May)."""
|
|
last_day = calendar.monthrange(year, month)[1]
|
|
for day in range(last_day, 0, -1):
|
|
d = datetime.date(year, month, day)
|
|
if d.weekday() == weekday:
|
|
return d
|
|
raise ValueError("Invalid last weekday request")
|
|
|
|
|
|
@lru_cache(maxsize=64)
|
|
def _holiday_set(year: int) -> set[datetime.date]:
|
|
"""Generate a set of federal holidays for the given year."""
|
|
holidays: set[datetime.date] = set()
|
|
|
|
# New Year's Day (observed)
|
|
holidays.add(_observed(datetime.date(year, 1, 1)))
|
|
|
|
# Memorial Day (last Monday in May)
|
|
holidays.add(_last_weekday_of_month(year, 5, calendar.MONDAY))
|
|
|
|
# Independence Day (observed)
|
|
holidays.add(_observed(datetime.date(year, 7, 4)))
|
|
|
|
# Labor Day (first Monday in September)
|
|
holidays.add(_nth_weekday_of_month(year, 9, calendar.MONDAY, 1))
|
|
|
|
# Thanksgiving Day (4th Thursday in November)
|
|
holidays.add(_nth_weekday_of_month(year, 11, calendar.THURSDAY, 4))
|
|
|
|
# Christmas Day (observed)
|
|
holidays.add(_observed(datetime.date(year, 12, 25)))
|
|
|
|
return holidays
|
|
|
|
|
|
def _is_holiday(date: datetime.date) -> bool:
|
|
"""Check if a date is a federal holiday."""
|
|
return date in _holiday_set(date.year)
|
|
|
|
|
|
def _extract_id(payload: Union[dict, str, int]) -> str:
|
|
"""Extract ID from various payload formats."""
|
|
return str(payload.get("id")) if isinstance(payload, dict) else str(payload)
|
|
|
|
|
|
# Internal synchronous implementations
|
|
def _create_object_sync(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> ModelType:
|
|
"""Synchronous implementation of create_object."""
|
|
raw = _to_dict(input_data)
|
|
_decode_scalar_ids_inplace(raw)
|
|
|
|
data = _filter_write_fields(raw, m2m_data)
|
|
instance = model_class.objects.create(**data)
|
|
|
|
# Handle many-to-many relationships
|
|
if m2m_data:
|
|
for field, values in m2m_data.items():
|
|
if values is not None:
|
|
getattr(instance, field).set(_decode_global_ids(values))
|
|
|
|
return instance
|
|
|
|
|
|
def _update_object_sync(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> ModelType:
|
|
"""Synchronous implementation of update_object."""
|
|
raw = _to_dict(input_data)
|
|
_decode_scalar_ids_inplace(raw)
|
|
|
|
try:
|
|
instance = model_class.objects.get(pk=raw.get("id"))
|
|
|
|
data = _filter_write_fields(raw, m2m_data)
|
|
|
|
# Update only provided fields
|
|
update_fields = []
|
|
for field, value in data.items():
|
|
if value is not None:
|
|
setattr(instance, field, value)
|
|
update_fields.append(field)
|
|
|
|
if update_fields:
|
|
instance.save(update_fields=update_fields)
|
|
else:
|
|
instance.save()
|
|
|
|
# Handle many-to-many relationships (only update if explicitly provided)
|
|
if m2m_data:
|
|
for field, values in m2m_data.items():
|
|
if values is not None:
|
|
getattr(instance, field).set(_decode_global_ids(values))
|
|
# None means "not provided" - leave unchanged
|
|
# To clear a relationship, pass an empty array []
|
|
|
|
return instance
|
|
|
|
except model_class.DoesNotExist:
|
|
raise ValueError(f"{model_class.__name__} with ID {raw.get('id')} does not exist")
|
|
|
|
|
|
def _delete_object_sync(object_id, model_class: Type[ModelType]) -> Optional[ModelType]:
|
|
"""Synchronous implementation of delete_object."""
|
|
pk = _decode_global_id(object_id)
|
|
try:
|
|
instance = model_class.objects.get(pk=pk)
|
|
instance.delete()
|
|
return instance
|
|
except model_class.DoesNotExist:
|
|
return None
|
|
|
|
|
|
# Public async functions with explicit typing for IDE support
|
|
def create_object(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> Awaitable[ModelType]:
|
|
"""
|
|
Create a new model instance asynchronously.
|
|
|
|
Args:
|
|
input_data: Input data (dict or object with attributes)
|
|
model_class: Django model class
|
|
m2m_data: Optional dictionary of many-to-many field data
|
|
|
|
Returns:
|
|
Awaitable that resolves to created model instance
|
|
"""
|
|
return database_sync_to_async(_create_object_sync)(input_data, model_class, m2m_data)
|
|
|
|
|
|
def update_object(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> Awaitable[ModelType]:
|
|
"""
|
|
Update an existing model instance asynchronously.
|
|
|
|
Args:
|
|
input_data: Input data (dict or object with attributes) - must include 'id'
|
|
model_class: Django model class
|
|
m2m_data: Optional dictionary of many-to-many field data
|
|
|
|
Returns:
|
|
Awaitable that resolves to updated model instance
|
|
|
|
Raises:
|
|
ValueError: If an object with the given ID doesn't exist
|
|
"""
|
|
return database_sync_to_async(_update_object_sync)(input_data, model_class, m2m_data)
|
|
|
|
|
|
def delete_object(object_id, model_class: Type[ModelType]) -> Awaitable[Optional[ModelType]]:
|
|
"""
|
|
Delete a model instance asynchronously.
|
|
|
|
Args:
|
|
object_id: Global ID or primary key of the object to delete
|
|
model_class: Django model class
|
|
|
|
Returns:
|
|
Awaitable that resolves to deleted model instance if found, None if not found
|
|
"""
|
|
return database_sync_to_async(_delete_object_sync)(object_id, model_class)
|
|
|
|
|
|
def _get_conversations_for_entity_sync(entity_instance: Model) -> List:
|
|
"""
|
|
Synchronous implementation to get conversations linked to an entity via GenericForeignKey.
|
|
|
|
Args:
|
|
entity_instance: The model instance (e.g., Service, Project, Account, Customer)
|
|
|
|
Returns:
|
|
List of Conversation objects linked to this entity
|
|
"""
|
|
from django.contrib.contenttypes.models import ContentType
|
|
from core.models.messaging import Conversation
|
|
|
|
content_type = ContentType.objects.get_for_model(type(entity_instance))
|
|
return list(Conversation.objects.filter(
|
|
entity_content_type=content_type,
|
|
entity_object_id=entity_instance.id
|
|
))
|
|
|
|
|
|
def get_conversations_for_entity(entity_instance: Model) -> Awaitable[List]:
|
|
"""
|
|
Get all conversations linked to an entity asynchronously.
|
|
|
|
This helper handles the GenericForeignKey relationship pattern for conversations.
|
|
Use this in your GraphQL types to easily add a conversations field.
|
|
|
|
Args:
|
|
entity_instance: The model instance (e.g., Service, Project, Account, Customer)
|
|
|
|
Returns:
|
|
Awaitable that resolves to list of Conversation objects
|
|
|
|
Example usage in GraphQL type:
|
|
@strawberry.field
|
|
async def conversations(self) -> List[ConversationType]:
|
|
return await get_conversations_for_entity(self)
|
|
"""
|
|
return database_sync_to_async(_get_conversations_for_entity_sync)(entity_instance)
|