import base64 import binascii import calendar import datetime from functools import lru_cache from typing import Union, Dict, Any, Iterable, List, Optional, Type, TypeVar, Awaitable from channels.db import database_sync_to_async from django.db.models import Model from strawberry.relay import GlobalID ModelType = TypeVar('ModelType', bound=Model) def _decode_global_id(gid: Union[str, GlobalID, None]) -> Optional[str]: """ Decode a Global ID to extract the node ID. Handles both GlobalID objects and base64 encoded strings. """ if gid is None: return None if isinstance(gid, GlobalID): return gid.node_id try: decoded = base64.b64decode(gid).decode("utf-8") if ":" in decoded: return decoded.split(":", 1)[1] except (binascii.Error, UnicodeDecodeError): pass return gid def _decode_global_ids(ids: Optional[Iterable[str]]) -> Optional[List[str]]: """Decode a list of Global IDs.""" if ids is None: return None return [_decode_global_id(x) for x in ids] def _to_dict(input_data: Union[Dict[str, Any], object]) -> Dict[str, Any]: """ Convert input data to a dictionary. Handles both dict objects and objects with attributes. """ if isinstance(input_data, dict): return dict(input_data) try: return {k: v for k, v in vars(input_data).items() if not k.startswith("_")} except TypeError: return {} def _decode_scalar_ids_inplace(data: Dict[str, Any]) -> None: """ Decode Global IDs in-place for scalar ID fields. Handles 'id' field and any field ending with '_id'. """ if "id" in data and data["id"] is not None: data["id"] = _decode_global_id(data["id"]) for k, v in list(data.items()): if k.endswith("_id") and v is not None: data[k] = _decode_global_id(v) def _filter_write_fields(raw: Dict[str, Any], m2m_data: Optional[dict] = None) -> Dict[str, Any]: """ Remove fields that shouldn't be written directly to the model: - many-to-many fields handled separately (m2m_data keys) - *_ids convenience arrays that are processed elsewhere - id (primary key) """ m2m_fields = list(m2m_data.keys()) if m2m_data else [] exclude_keys = set(m2m_fields) | {k for k in raw.keys() if k.endswith("_ids")} | {"id"} return {k: v for k, v in raw.items() if k not in exclude_keys} def _observed(date: datetime.date) -> datetime.date: """ Calculate the observed date for holidays. If holiday is Saturday -> observe Friday; if Sunday -> observe Monday """ wd = date.weekday() if wd == 5: # Saturday return date - datetime.timedelta(days=1) if wd == 6: # Sunday return date + datetime.timedelta(days=1) return date def _nth_weekday_of_month(year: int, month: int, weekday: int, n: int) -> datetime.date: """ Find the nth occurrence of a weekday in a month. weekday: Mon=0...Sun=6, n: 1..5 (e.g., 4th Thursday) """ count = 0 for day in range(1, 32): try: d = datetime.date(year, month, day) except ValueError: break if d.weekday() == weekday: count += 1 if count == n: return d raise ValueError("Invalid nth weekday request") def _last_weekday_of_month(year: int, month: int, weekday: int) -> datetime.date: """Find the last occurrence of a weekday in a month (e.g., last Monday in May).""" last_day = calendar.monthrange(year, month)[1] for day in range(last_day, 0, -1): d = datetime.date(year, month, day) if d.weekday() == weekday: return d raise ValueError("Invalid last weekday request") @lru_cache(maxsize=64) def _holiday_set(year: int) -> set[datetime.date]: """Generate a set of federal holidays for the given year.""" holidays: set[datetime.date] = set() # New Year's Day (observed) holidays.add(_observed(datetime.date(year, 1, 1))) # Memorial Day (last Monday in May) holidays.add(_last_weekday_of_month(year, 5, calendar.MONDAY)) # Independence Day (observed) holidays.add(_observed(datetime.date(year, 7, 4))) # Labor Day (first Monday in September) holidays.add(_nth_weekday_of_month(year, 9, calendar.MONDAY, 1)) # Thanksgiving Day (4th Thursday in November) holidays.add(_nth_weekday_of_month(year, 11, calendar.THURSDAY, 4)) # Christmas Day (observed) holidays.add(_observed(datetime.date(year, 12, 25))) return holidays def _is_holiday(date: datetime.date) -> bool: """Check if a date is a federal holiday.""" return date in _holiday_set(date.year) def _extract_id(payload: Union[dict, str, int]) -> str: """Extract ID from various payload formats.""" return str(payload.get("id")) if isinstance(payload, dict) else str(payload) # Internal synchronous implementations def _create_object_sync(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> ModelType: """Synchronous implementation of create_object.""" raw = _to_dict(input_data) _decode_scalar_ids_inplace(raw) data = _filter_write_fields(raw, m2m_data) instance = model_class.objects.create(**data) # Handle many-to-many relationships if m2m_data: for field, values in m2m_data.items(): if values is not None: getattr(instance, field).set(_decode_global_ids(values)) return instance def _update_object_sync(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> ModelType: """Synchronous implementation of update_object.""" raw = _to_dict(input_data) _decode_scalar_ids_inplace(raw) try: instance = model_class.objects.get(pk=raw.get("id")) data = _filter_write_fields(raw, m2m_data) # Update only provided fields update_fields = [] for field, value in data.items(): if value is not None: setattr(instance, field, value) update_fields.append(field) if update_fields: instance.save(update_fields=update_fields) else: instance.save() # Handle many-to-many relationships (only update if explicitly provided) if m2m_data: for field, values in m2m_data.items(): if values is not None: getattr(instance, field).set(_decode_global_ids(values)) # None means "not provided" - leave unchanged # To clear a relationship, pass an empty array [] return instance except model_class.DoesNotExist: raise ValueError(f"{model_class.__name__} with ID {raw.get('id')} does not exist") def _delete_object_sync(object_id, model_class: Type[ModelType]) -> Optional[ModelType]: """Synchronous implementation of delete_object.""" pk = _decode_global_id(object_id) try: instance = model_class.objects.get(pk=pk) instance.delete() return instance except model_class.DoesNotExist: return None # Public async functions with explicit typing for IDE support def create_object(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> Awaitable[ModelType]: """ Create a new model instance asynchronously. Args: input_data: Input data (dict or object with attributes) model_class: Django model class m2m_data: Optional dictionary of many-to-many field data Returns: Awaitable that resolves to created model instance """ return database_sync_to_async(_create_object_sync)(input_data, model_class, m2m_data) def update_object(input_data, model_class: Type[ModelType], m2m_data: dict = None) -> Awaitable[ModelType]: """ Update an existing model instance asynchronously. Args: input_data: Input data (dict or object with attributes) - must include 'id' model_class: Django model class m2m_data: Optional dictionary of many-to-many field data Returns: Awaitable that resolves to updated model instance Raises: ValueError: If an object with the given ID doesn't exist """ return database_sync_to_async(_update_object_sync)(input_data, model_class, m2m_data) def delete_object(object_id, model_class: Type[ModelType]) -> Awaitable[Optional[ModelType]]: """ Delete a model instance asynchronously. Args: object_id: Global ID or primary key of the object to delete model_class: Django model class Returns: Awaitable that resolves to deleted model instance if found, None if not found """ return database_sync_to_async(_delete_object_sync)(object_id, model_class) def _get_conversations_for_entity_sync(entity_instance: Model) -> List: """ Synchronous implementation to get conversations linked to an entity via GenericForeignKey. Args: entity_instance: The model instance (e.g., Service, Project, Account, Customer) Returns: List of Conversation objects linked to this entity """ from django.contrib.contenttypes.models import ContentType from core.models.messaging import Conversation content_type = ContentType.objects.get_for_model(type(entity_instance)) return list(Conversation.objects.filter( entity_content_type=content_type, entity_object_id=entity_instance.id )) def get_conversations_for_entity(entity_instance: Model) -> Awaitable[List]: """ Get all conversations linked to an entity asynchronously. This helper handles the GenericForeignKey relationship pattern for conversations. Use this in your GraphQL types to easily add a conversations field. Args: entity_instance: The model instance (e.g., Service, Project, Account, Customer) Returns: Awaitable that resolves to list of Conversation objects Example usage in GraphQL type: @strawberry.field async def conversations(self) -> List[ConversationType]: return await get_conversations_for_entity(self) """ return database_sync_to_async(_get_conversations_for_entity_sync)(entity_instance)