Skip to content
Binary file modified submit-api/.DS_Store
Binary file not shown.
22 changes: 11 additions & 11 deletions submit-api/migrations/versions/0eabfcf062e3_.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Create Date: 2024-08-22 12:20:35.863982

"""
from datetime import datetime
from datetime import datetime, UTC

import sqlalchemy as sa
from alembic import op
Expand All @@ -30,48 +30,48 @@ def upgrade():
sa.MetaData(),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('name', sa.String),
sa.Column('created_date', sa.DateTime, default=datetime.utcnow),
sa.Column('created_date', sa.DateTime, default=datetime.now(UTC)),
sa.Column('created_by', sa.String, default='system'),
)
item_types = sa.Table(
'item_types',
sa.MetaData(),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('name', sa.String),
sa.Column('created_date', sa.DateTime, default=datetime.utcnow),
sa.Column('created_date', sa.DateTime, default=datetime.now(UTC)),
sa.Column('created_by', sa.String, default='system'),
)
package_item_types = sa.Table(
'package_item_types',
sa.MetaData(),
sa.Column('package_type_id', sa.Integer),
sa.Column('item_type_id', sa.Integer),
sa.Column('created_date', sa.DateTime, default=datetime.utcnow),
sa.Column('created_date', sa.DateTime, default=datetime.now(UTC)),
sa.Column('created_by', sa.String, default='system'),
)

# Insert package types and retrieve IDs
op.bulk_insert(package_types, [
{'name': 'Management Plan', 'created_date': datetime.utcnow()},
{'name': 'Management Plan', 'created_date': datetime.now(UTC)},
])
conn = op.get_bind()

package_type_id = conn.execute(package_types.select().where(package_types.c.name == management_plan)).fetchone()[0]

# Insert item types and retrieve IDs
op.bulk_insert(item_types, [
{'name': management_plan_form, 'created_date': datetime.utcnow()},
{'name': consultation_records, 'created_date': datetime.utcnow()},
{'name': management_plan_submission, 'created_date': datetime.utcnow()},
{'name': management_plan_form, 'created_date': datetime.now(UTC)},
{'name': consultation_records, 'created_date': datetime.now(UTC)},
{'name': management_plan_submission, 'created_date': datetime.now(UTC)},
])

item_type_ids = conn.execute(item_types.select().where(item_types.c.name.in_([management_plan_form, consultation_records, management_plan_submission]))).fetchall()

# Insert package item types using retrieved IDs
op.bulk_insert(package_item_types, [
{'package_type_id': package_type_id, 'item_type_id': item_type_ids[0][0], 'created_date': datetime.utcnow()},
{'package_type_id': package_type_id, 'item_type_id': item_type_ids[1][0], 'created_date': datetime.utcnow()},
{'package_type_id': package_type_id, 'item_type_id': item_type_ids[2][0], 'created_date': datetime.utcnow()},
{'package_type_id': package_type_id, 'item_type_id': item_type_ids[0][0], 'created_date': datetime.now(UTC)},
{'package_type_id': package_type_id, 'item_type_id': item_type_ids[1][0], 'created_date': datetime.now(UTC)},
{'package_type_id': package_type_id, 'item_type_id': item_type_ids[2][0], 'created_date': datetime.now(UTC)},
])
# ### end Alembic commands ###

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import os
from alembic import op
from datetime import datetime
from datetime import datetime, UTC
import sqlalchemy as sa
from sqlalchemy import text
from sqlalchemy.dialects import postgresql
Expand Down Expand Up @@ -58,7 +58,7 @@ def upgrade():
version=1,
content=html_content,
active=True,
created_date=datetime.utcnow()
created_date=datetime.now(UTC)
)
)
# ### end Alembic commands ###
Expand Down
3 changes: 1 addition & 2 deletions submit-api/migrations/versions/7cadbb980cf7_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, declarative_base

from submit_api.models import Package, PackageVersion, UserRole

Expand Down
6 changes: 3 additions & 3 deletions submit-api/migrations/versions/97de805275ec_.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Create Date: 2024-07-30 10:49:57.735159

"""
from datetime import datetime
from datetime import datetime, UTC

import sqlalchemy as sa
from alembic import op
Expand All @@ -22,7 +22,7 @@
sa.Column('id', sa.Integer, primary_key=True, autoincrement=True),
sa.Column('role_name', sa.String(50), nullable=False),
sa.Column('description', sa.Text, nullable=False),
sa.Column('created_date', sa.DateTime, nullable=False, default=datetime.utcnow),
sa.Column('created_date', sa.DateTime, nullable=False, default=datetime.now(UTC)),
sa.Column('updated_date', sa.DateTime, nullable=True),
sa.Column('created_by', sa.String(50), nullable=True),
sa.Column('updated_by', sa.String(50), nullable=True)
Expand All @@ -37,7 +37,7 @@ def upgrade():
{
'role_name': 'ACCOUNT_PRIMARY_ADMIN',
'description': 'Administrator role',
'created_date': datetime.utcnow(),
'created_date': datetime.now(UTC),
'created_by': 'system' # Or replace with the actual creator
}
]
Expand Down
18 changes: 9 additions & 9 deletions submit-api/migrations/versions/df4613866744_add_IEM_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Create Date: 2025-04-30 12:41:23.463283

"""
from datetime import datetime
from datetime import datetime, UTC
from sqlalchemy.sql import bindparam

from alembic import op
Expand All @@ -32,15 +32,15 @@ def upgrade():
sa.MetaData(),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('name', sa.String),
sa.Column('created_date', sa.DateTime, default=datetime.utcnow),
sa.Column('created_date', sa.DateTime, default=datetime.now(UTC)),
sa.Column('created_by', sa.String, default='system'),
)
item_types = sa.Table(
'item_types',
sa.MetaData(),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('name', sa.String),
sa.Column('created_date', sa.DateTime, default=datetime.utcnow),
sa.Column('created_date', sa.DateTime, default=datetime.now(UTC)),
sa.Column('submission_method', sa.Enum('FORM_SUBMISSION', 'DOCUMENT_UPLOAD', name='submissionmethod')),
sa.Column('created_by', sa.String, default='system'),
)
Expand All @@ -49,14 +49,14 @@ def upgrade():
sa.MetaData(),
sa.Column('package_type_id', sa.Integer),
sa.Column('item_type_id', sa.Integer),
sa.Column('created_date', sa.DateTime, default=datetime.utcnow),
sa.Column('created_date', sa.DateTime, default=datetime.now(UTC)),
sa.Column('created_by', sa.String, default='system'),
sa.Column("sort_order", sa.Integer(), nullable=True, server_default="0"),
)

# Insert new package type
op.bulk_insert(package_types, [
{'name': PACKAGE_TYPE_IEM, 'created_date': datetime.utcnow()},
{'name': PACKAGE_TYPE_IEM, 'created_date': datetime.now(UTC)},
])
conn = op.get_bind()

Expand All @@ -68,7 +68,7 @@ def upgrade():

# Insert new item types
op.bulk_insert(item_types, [
{'name': ITEM_TYPE_IEM_TERMS, 'created_date': datetime.utcnow(), 'submission_method': DOCUMENT_UPLOAD},
{'name': ITEM_TYPE_IEM_TERMS, 'created_date': datetime.now(UTC), 'submission_method': DOCUMENT_UPLOAD},
])

# Retrieve item type IDs using parameterized query
Expand All @@ -83,11 +83,11 @@ def upgrade():
# Insert package item types
op.bulk_insert(package_item_types, [
{'package_type_id': iem_package_type_id, 'item_type_id': item_type_ids[ITEM_TYPE_IEM_TERMS], 'sort_order': 2,
'created_date': datetime.utcnow()},
'created_date': datetime.now(UTC)},
{'package_type_id': iem_package_type_id, 'item_type_id': item_type_ids[ITEM_TYPE_CONSULTATION_RECORDS], 'sort_order': 1,
'created_date': datetime.utcnow()},
'created_date': datetime.now(UTC)},
{'package_type_id': iem_package_type_id, 'item_type_id': item_type_ids[ITEM_TYPE_CONTACT_INFORMATION_FORM], 'sort_order': 0,
'created_date': datetime.utcnow()},
'created_date': datetime.now(UTC)},
])


Expand Down
10 changes: 5 additions & 5 deletions submit-api/migrations/versions/dfd81c02e843_add_user_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

"""
from alembic import op
from datetime import datetime
from datetime import datetime, UTC
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

Expand Down Expand Up @@ -38,10 +38,10 @@ def upgrade():
sa.column('updated_date', sa.DateTime))

op.bulk_insert(user_status, [
{'id': 1, 'status_name': 'ACTIVE', 'description': 'Active User', 'created_date': datetime.utcnow(),
'updated_date': datetime.utcnow()},
{'id': 2, 'status_name': 'INACTIVE', 'description': 'Inactive User', 'created_date': datetime.utcnow(),
'updated_date': datetime.utcnow()}
{'id': 1, 'status_name': 'ACTIVE', 'description': 'Active User', 'created_date': datetime.now(UTC),
'updated_date': datetime.now(UTC)},
{'id': 2, 'status_name': 'INACTIVE', 'description': 'Inactive User', 'created_date': datetime.now(UTC),
'updated_date': datetime.now(UTC)}
])

with op.batch_alter_table('user_roles', schema=None) as batch_op:
Expand Down
12 changes: 7 additions & 5 deletions submit-api/src/submit_api/models/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def create_account(cls, account_data, session=None) -> Account:
account = Account(
proponent_id=account_data.get('proponent_id', None),
)
if session:
account.flush()
else:
account.save()
return account
return account.persist(session)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use this pattern quite frequently: if session - flush, else - save. I moved this into a common method called persist that does the same thing, but avoids duplicate code.


@classmethod
def get_ids_by_proponent_id(cls, proponent_id: int) -> list[int]:
"""Get account ids for a given proponent id."""
results = cls.query.with_entities(cls.id).filter_by(proponent_id=proponent_id).all()
return [account_id for (account_id,) in results]
23 changes: 16 additions & 7 deletions submit-api/src/submit_api/models/account_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def get_all_in_project_ids(cls, ids):
"""Get all projects in the given IDs."""
return cls.query.filter(cls.project_id.in_(ids)).all()

@classmethod
def get_all_in_account_ids(cls, account_ids: list[int]):
"""Get all projects for the given account ids."""
return cls.query.filter(cls.account_id.in_(account_ids)).all()

@classmethod
def get_by_account_id(cls, account_id: int) -> AccountProject | None:
"""Return the AccountProject object for the given account_id."""
Expand All @@ -84,8 +89,8 @@ def get_by_project_id(cls, project_id: int) -> AccountProject | None:
return cls.query.filter_by(project_id=project_id).first()

@classmethod
def create_account_project(cls, account_id, project_id, session=None) -> AccountProject:
"""Create account project."""
def get_or_create(cls, account_id, project_id, session=None) -> AccountProject:
"""Get or create account project."""
existing_account_project = cls.query.filter_by(
account_id=account_id,
project_id=project_id
Expand All @@ -96,8 +101,12 @@ def create_account_project(cls, account_id, project_id, session=None) -> Account
account_id=account_id,
project_id=project_id
)
if session:
session.add(account_project)
else:
account_project.save()
return account_project
return account_project.persist(session)

@classmethod
def get_project_ids_by_ids(cls, account_project_ids: list) -> list[int]:
"""Get project ids for the given account project ids."""
results = cls.query.filter(
cls.id.in_(account_project_ids)
).with_entities(cls.project_id).all()
return [pid for (pid,) in results]
14 changes: 3 additions & 11 deletions submit-api/src/submit_api/models/account_project_work.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def find_by_work_id(cls, work_id: int):
return cls.query.filter_by(work_id=work_id, is_active=True).all()

@classmethod
def create_or_get(cls, account_project_id: int, work_id: int, session=None):
def get_or_create(cls, account_project_id: int, work_id: int, session=None):
"""Create or get account project work.

Args:
Expand All @@ -86,10 +86,7 @@ def create_or_get(cls, account_project_id: int, work_id: int, session=None):
if existing:
if not existing.is_active:
existing.is_active = True
if session:
session.add(existing)
else:
existing.save()
existing.persist(session)
return existing

new_instance = cls(
Expand All @@ -98,9 +95,4 @@ def create_or_get(cls, account_project_id: int, work_id: int, session=None):
is_active=True
)

if session:
session.add(new_instance)
else:
new_instance.save()

return new_instance
return new_instance.persist(session)
18 changes: 6 additions & 12 deletions submit-api/src/submit_api/models/account_terms_of_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ class TermsOfService(BaseModel):
@classmethod
def create_terms_of_service(cls, data, session=None) -> TermsOfService:
"""Create a new Terms of service record."""
# Deactivate all existing records
if session:
session.query(cls).filter_by(active=True).update({"active": False})
else:
cls.query.filter_by(active=True).update({"active": False})
db.session.flush()
_session = session or db.session

# Deactivate all existing active records
_session.query(cls).filter_by(active=True).update({"active": False})
_session.flush()

active = data.get("active")
terms_of_service = TermsOfService(
Expand All @@ -35,12 +34,7 @@ def create_terms_of_service(cls, data, session=None) -> TermsOfService:
rich_content=data.get("rich_content"),
active=True if active is None else active,
)
if session:
session.add(terms_of_service)
session.flush()
else:
terms_of_service.save()
return terms_of_service
return terms_of_service.persist(session)

@classmethod
def get_active_terms_of_service(cls) -> TermsOfService | None:
Expand Down
27 changes: 19 additions & 8 deletions submit-api/src/submit_api/models/account_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
from __future__ import annotations

from datetime import datetime
from datetime import datetime, UTC

from sqlalchemy import Column, ForeignKey
from sqlalchemy.ext.hybrid import hybrid_property
Expand All @@ -13,6 +13,7 @@
from .base_model import BaseModel
from .db import db
from .user import User as UserModel
from .user_role import UserRole


class AccountUser(BaseModel):
Expand All @@ -39,7 +40,7 @@ def role(self):
"""Return the first role for backward compatibility."""
return self.roles[0] if self.roles else None
terms_of_service_version_id = Column(db.Integer, db.ForeignKey('account_terms_of_service.version'), nullable=True)
terms_of_service_accepted_date = db.Column(db.DateTime, default=datetime.utcnow, nullable=True)
terms_of_service_accepted_date = db.Column(db.DateTime, default=datetime.now(UTC), nullable=True)
company_name = Column(db.String(255), nullable=True)

terms_of_service = db.relationship(
Expand Down Expand Up @@ -86,12 +87,7 @@ def create_account_user(cls, data, session=None) -> AccountUser:
extension_number=data.get('extension_number', None),
terms_of_service_version_id=data.get('terms_of_service_version_id')
)
if session:
session.add(account_user)
session.flush()
else:
account_user.save()
return account_user
return account_user.persist(session)

@classmethod
def get_by_guid(cls, _guid):
Expand All @@ -104,7 +100,22 @@ def get_users_by_account_id(cls, account_id):
"""Get all users for a given account."""
return cls.query.filter(cls.account_id == account_id).all()

@classmethod
def get_all_in_account_ids(cls, account_ids: list[int]):
"""Get all users for the given account ids."""
return cls.query.filter(cls.account_id.in_(account_ids)).all()

@classmethod
def get_users_by_account_user_id(cls, account_user_id):
"""Get the user for a given account."""
return cls.query.filter(cls.id == account_user_id).first()

@classmethod
def get_filtered_by_account_id(cls, account_id: int, account_project_ids: list = None):
"""Get account users by account id, optionally filtered by project ids."""
query = cls.query.filter(cls.account_id == account_id)
if account_project_ids:
query = query.join(UserRole).filter(
UserRole.account_project_id.in_(account_project_ids)
)
return query.all()
Loading
Loading