diff --git a/apps/announcements/models.py b/apps/announcements/models.py index f271e5de..ba4cf784 100644 --- a/apps/announcements/models.py +++ b/apps/announcements/models.py @@ -78,16 +78,5 @@ def get_related_project(self) -> Optional["Project"]: """Return the project related to this model.""" return self.project - def duplicate(self, project: "Project") -> "Announcement": - return Announcement.objects.create( - project=project, - description=self.description, - title=self.title, - type=self.type, - status=self.status, - deadline=self.deadline, - is_remunerated=self.is_remunerated, - ) - def __str__(self): return str(self.title) diff --git a/apps/commons/mixins.py b/apps/commons/mixins.py index 84ef04d5..93f94635 100644 --- a/apps/commons/mixins.py +++ b/apps/commons/mixins.py @@ -1,4 +1,7 @@ -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from collections.abc import Iterable +from contextlib import suppress +from copy import copy +from typing import TYPE_CHECKING, Any, Optional, Self from django.contrib.auth.models import Group, Permission from django.contrib.contenttypes.models import ContentType @@ -35,7 +38,7 @@ def organization_query(cls, key: str, value: Any) -> Q: return Q(**{cls.organization_query_string: value}) return Q(**{key: value}) - def get_related_organizations(self) -> List["Organization"]: + def get_related_organizations(self) -> list["Organization"]: """Return the organizations related to this model.""" raise NotImplementedError() @@ -91,7 +94,7 @@ def get_related_project(self) -> Optional["Project"]: """Return the projects related to this model.""" raise NotImplementedError() - def get_related_organizations(self) -> List["Organization"]: + def get_related_organizations(self) -> list["Organization"]: """Return the organizations related to this model.""" raise NotImplementedError() @@ -184,7 +187,7 @@ def setup_permissions(self, user: Optional["ProjectUser"] = None): @classmethod def batch_reassign_permissions( - cls, roles_permissions: Tuple[str, Iterable[Permission]] + cls, roles_permissions: tuple[str, Iterable[Permission]] ): """ Reassign permissions for all instances of the model. @@ -268,8 +271,24 @@ class DuplicableModel: A model that can be duplicated. """ - def duplicate(self, *args, **kwargs) -> "DuplicableModel": - raise NotImplementedError() + def duplicate(self, **fields) -> type[Self]: + """duplicate models elements, set new fields + + :return: new models + """ + + instance_copy = copy(self) + instance_copy.pk = None + + for name, value in fields.items(): + setattr(instance_copy, name, value) + + # remove prefetch m2m + with suppress(AttributeError): + del instance_copy._prefetched_objects_cache + + instance_copy.save() + return instance_copy class HasMultipleIDs: @@ -320,9 +339,9 @@ def get_id_field_name(cls, object_id: Any) -> str: The outdated slugs of the object. They are kept for url retro-compatibility. """ - _original_slug_fields_value: Dict[str, str] = {} - slugified_fields: List[str] = [] - reserved_slugs: List[str] = [] + _original_slug_fields_value: dict[str, str] = {} + slugified_fields: list[str] = [] + reserved_slugs: list[str] = [] slug_prefix: str = "" def __init__(self, *args, **kwargs): @@ -371,8 +390,8 @@ def get_main_id(cls, object_id: Any, returned_field: str = "id") -> Any: @classmethod def get_main_ids( - cls, objects_ids: List[Any], returned_field: str = "id" - ) -> List[Any]: + cls, objects_ids: list[Any], returned_field: str = "id" + ) -> list[Any]: """Get the main IDs from a list of secondary IDs.""" return [cls.get_main_id(object_id, returned_field) for object_id in objects_ids] diff --git a/apps/files/models.py b/apps/files/models.py index b67010f8..adbf4963 100644 --- a/apps/files/models.py +++ b/apps/files/models.py @@ -1,7 +1,7 @@ import datetime import uuid from contextlib import suppress -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Self from azure.core.exceptions import ResourceNotFoundError from django.apps import apps @@ -127,18 +127,6 @@ def get_related_project(self) -> Optional["Project"]: """Return the project related to this model.""" return self.project - def duplicate(self, project: "Project") -> "AttachmentLink": - return AttachmentLink.objects.create( - project=project, - attachment_type=self.attachment_type, - category=self.category, - description=self.description, - preview_image_url=self.preview_image_url, - site_name=self.site_name, - site_url=self.site_url, - title=self.title, - ) - class OrganizationAttachmentFile( HasAutoTranslatedFields, OrganizationRelated, models.Model @@ -226,15 +214,7 @@ def duplicate(self, project: "Project") -> Optional["AttachmentFile"]: content=self.file.read(), content_type=f"application/{file_extension}", ) - return AttachmentFile.objects.create( - project=project, - attachment_type=self.attachment_type, - file=new_file, - mime=self.mime, - title=self.title, - description=self.description, - hashcode=self.hashcode, - ) + return super().duplicate(project=project, file=new_file) return None @@ -400,9 +380,7 @@ def get_related_project(self) -> Optional["Project"]: return queryset.first() return None - def duplicate( - self, owner: Optional["ProjectUser"] = None, upload_to: str = "" - ) -> Optional["Image"]: + def duplicate(self, upload_to: str = "", **fields) -> None | type[Self]: with suppress(ResourceNotFoundError): file_path = self.file.name.split("/") file_name = file_path.pop() @@ -416,21 +394,8 @@ def duplicate( content=self.file.read(), content_type=f"image/{file_extension}", ) - image = Image( - name=self.name, - file=new_file, - height=self.height, - width=self.width, - natural_ratio=self.natural_ratio, - scale_x=self.scale_x, - scale_y=self.scale_y, - left=self.left, - top=self.top, - owner=owner or self.owner, - ) - image._upload_to = lambda instance, filename: upload_to - image.save() - return image + _upload_to = lambda instance, filename: upload_to # noqa: E731 + return super().duplicate(_upload_to=_upload_to, file=new_file, **fields) return None diff --git a/apps/projects/models.py b/apps/projects/models.py index 46053853..b4cf08c9 100644 --- a/apps/projects/models.py +++ b/apps/projects/models.py @@ -553,48 +553,49 @@ def calculate_score(self) -> "ProjectScore": @transaction.atomic def duplicate(self, owner: Optional["ProjectUser"] = None) -> "Project": - header = self.header_image.duplicate(owner) if self.header_image else None - project = Project.objects.create( - title=self.title, + header = self.header_image.duplicate(owner=owner) if self.header_image else None + project = super().duplicate( + slug=None, + outdated_slugs=[], header_image=header, - description=self.description, - purpose=self.purpose, - is_locked=self.is_locked, - is_shareable=self.is_shareable, publication_status=Project.PublicationStatus.PRIVATE, - life_status=self.life_status, - language=self.language, - sdgs=self.sdgs, - template=self.template, + # TODO(remi): add this id (or fk) directly in DuplicateMixins duplicated_from=self.id, ) - project.setup_permissions(user=owner) + project.categories.set(self.categories.all()) project.organizations.set(self.organizations.all()) project.tags.set(self.tags.all()) + project.setup_permissions(user=owner) + + images_to_set = [] for image in self.images.all(): - new_image = image.duplicate(owner) + new_image = image.duplicate(owner=owner) if new_image is not None: - project.images.add(new_image) + images_to_set.append(new_image) for identifier in [self.pk, self.slug]: project.description = project.description.replace( f"/v1/project/{identifier}/image/{image.pk}/", f"/v1/project/{project.pk}/image/{new_image.pk}/", ) - project.save() + project.images.set(images_to_set) + for blog_entry in self.blog_entries.all(): - blog_entry.duplicate(project, self, owner) + blog_entry.duplicate(project=project, initial_project=self, owner=owner) for announcement in self.announcements.all(): - announcement.duplicate(project) + announcement.duplicate(project=project) for location in self.locations.all(): - location.duplicate(project) + location.duplicate(project=project) for goal in self.goals.all(): - goal.duplicate(project) + goal.duplicate(project=project) for link in self.links.all(): - link.duplicate(project) + link.duplicate(project=project) for file in self.files.all(): - file.duplicate(project) + file.duplicate(project=project) + Stat.objects.create(project=project) + + project.save() return project @@ -768,21 +769,18 @@ def duplicate( initial_project: Optional["Project"] = None, owner: Optional["ProjectUser"] = None, ) -> "BlogEntry": - blog_entry = BlogEntry.objects.create( - project=project, - title=self.title, - content=self.content, - ) + blog_entry = super().duplicate(project=project) + images_to_set = [] for image in self.images.all(): - new_image = image.duplicate(owner) + new_image = image.duplicate(owner=owner) if new_image is not None: - blog_entry.images.add(new_image) + images_to_set.append(new_image) for identifier in [initial_project.pk, initial_project.slug]: blog_entry.content = blog_entry.content.replace( f"/v1/project/{identifier}/blog-entry-image/{image.pk}/", f"/v1/project/{project.pk}/blog-entry-image/{new_image.pk}/", ) - blog_entry.created_at = self.created_at + blog_entry.images.set(images_to_set) blog_entry.save() return blog_entry @@ -851,15 +849,6 @@ def get_related_project(self) -> Optional["Project"]: """Return the project related to this model.""" return self.project - def duplicate(self, project: "Project") -> "Goal": - return Goal.objects.create( - project=project, - title=self.title, - description=self.description, - deadline_at=self.deadline_at, - status=self.status, - ) - class Location( HasAutoTranslatedFields, @@ -916,16 +905,6 @@ def get_related_organizations(self) -> List["Organization"]: """Return the organizations related to this model.""" return self.project.get_related_organizations() - def duplicate(self, project: "Project") -> "Location": - return Location.objects.create( - project=project, - title=self.title, - description=self.description, - lat=self.lat, - lng=self.lng, - type=self.type, - ) - class ProjectMessage( HasAutoTranslatedFields, diff --git a/apps/projects/tests/views/test_project.py b/apps/projects/tests/views/test_project.py index b9b0eb36..a458ae9e 100644 --- a/apps/projects/tests/views/test_project.py +++ b/apps/projects/tests/views/test_project.py @@ -497,6 +497,12 @@ def check_duplicated_project(self, duplicated_project: Dict, initial_project: Di self.assertEqual( duplicated_project["publication_status"], Project.PublicationStatus.PRIVATE ) + self.assertNotEqual( + duplicated_project["created_at"], initial_project["created_at"] + ) + self.assertNotEqual( + duplicated_project["updated_at"], initial_project["updated_at"] + ) for field in fields: self.assertEqual(duplicated_project[field], initial_project[field])