Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
"password": "postgres"
}
],
"python.analysis.typeCheckingMode": "basic",
"python.analysis.typeCheckingMode": "strict",
"mypy-type-checker.args": [
"--config-file=mypy.ini"
],
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
Expand Down
70 changes: 70 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
[mypy]
python_version = 3.12
warn_return_any = True
warn_unused_configs = True
no_implicit_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_no_return = True
check_untyped_defs = True
strict_equality = True

# Strict mode for already fixed modules
[mypy-powonline.model]
strict = True

[mypy-powonline.config]
strict = True

[mypy-powonline.exc]
strict = True

[mypy-powonline.util]
strict = True

[mypy-powonline.web]
disallow_untyped_defs = True
disallow_incomplete_defs = True

# Per-module options for third-party libraries
[mypy-flask.*]
ignore_missing_imports = True

[mypy-flask_restful.*]
ignore_missing_imports = True

[mypy-flask_sqlalchemy.*]
ignore_missing_imports = True

[mypy-flask_testing.*]
ignore_missing_imports = True

[mypy-werkzeug.*]
ignore_missing_imports = True

[mypy-bcrypt.*]
ignore_missing_imports = True

[mypy-gouge.*]
ignore_missing_imports = True

[mypy-imapclient.*]
ignore_missing_imports = True

[mypy-pusher.*]
ignore_missing_imports = True

[mypy-jwt.*]
ignore_missing_imports = True

[mypy-config_resolver.*]
ignore_missing_imports = True

[mypy-requests_oauthlib.*]
ignore_missing_imports = True

[mypy-oauthlib.*]
ignore_missing_imports = True

[mypy-PIL.*]
ignore_missing_imports = True
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dev = [
"alembic",
"blessings",
"gouge",
"types-Flask-SQLAlchemy",
"mypy>=1.4.1",
"types-Pillow",
"types-python-dateutil",
]
Expand Down
3 changes: 2 additions & 1 deletion src/powonline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ def default() -> ConfigParser:
"filename": "app.ini",
},
)
return lookup.config
config: ConfigParser = lookup.config
return config
2 changes: 1 addition & 1 deletion src/powonline/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ class PowonlineException(Exception):


class NoQuestionnaireForStation(PowonlineException):
def __init__(self, station, msg=""):
def __init__(self, station: str, msg: str = "") -> None:
super().__init__(msg or f"No questionnaire for station {station}")
self.station = station

Expand Down
100 changes: 65 additions & 35 deletions src/powonline/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from os import environ, urandom
from typing import Any
from uuid import UUID as PyUUID
from urllib.parse import urlparse, urlunparse

import sqlalchemy.types as types
Expand All @@ -22,13 +23,13 @@
func,
)
from sqlalchemy.dialects.postgresql import BYTEA, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship, scoped_session
from sqlalchemy.orm import Mapped, mapped_column, relationship, scoped_session, Session

LOG = logging.getLogger(__name__)
DB = SQLAlchemy()


def get_dsn():
def get_dsn() -> str:
dsn = environ.get("POWONLINE_DSN", "").strip()
if dsn:
parsed = urlparse(dsn)
Expand All @@ -51,38 +52,55 @@ class TeamState(Enum):


class TimestampMixin:
inserted = Column(
"""Base mixin for models with optional updated timestamp."""
inserted: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
FetchedValue(),
nullable=False,
server_default=func.now(),
)
updated = Column(DateTime(timezone=True), nullable=True)
updated: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, default=None
)


# Note: TeamStation, Questionnaire, and TeamQuestionnaire don't use TimestampMixin
# because they require `updated` to be non-nullable with a default, which conflicts
# with the mixin's nullable type. They define their timestamp fields directly.

class TeamStateType(types.TypeDecorator):

class TeamStateType(types.TypeDecorator[TeamState]):
impl = types.Unicode
cache_ok = True

def process_bind_param(self, value, dialect):
def process_bind_param(
self, value: TeamState | None, dialect: Any
) -> str | None:
if value is None:
return None
return value.value

def process_result_value(self, value, dialect):
def process_result_value(
self, value: str | None, dialect: Any
) -> TeamState | None:
if value is None:
return None
return TeamState(value)


class Setting(DB.Model): # type: ignore
__tablename__ = "setting"

key = Column(Unicode, primary_key=True, nullable=False)
value = Column(Unicode)
description = Column(Unicode)
key: Mapped[str] = mapped_column(Unicode, primary_key=True, nullable=False)
value: Mapped[str | None] = mapped_column(Unicode)
description: Mapped[str | None] = mapped_column(Unicode)


class Message(DB.Model, TimestampMixin): # type: ignore
__tablename__ = "message"
id = Column(Integer, primary_key=True)
content = Column(Unicode)
user = Column(
id: Mapped[int] = mapped_column(Integer, primary_key=True)
content: Mapped[str | None] = mapped_column(Unicode)
user: Mapped[str | None] = mapped_column(
Unicode,
ForeignKey(
"user.name",
Expand All @@ -91,7 +109,7 @@ class Message(DB.Model, TimestampMixin): # type: ignore
ondelete="CASCADE",
),
)
team = Column(
team: Mapped[str | None] = mapped_column(
Unicode,
ForeignKey(
"team.name",
Expand Down Expand Up @@ -127,7 +145,7 @@ class Team(DB.Model, TimestampMixin): # type: ignore
route_name: Mapped[str | None] = mapped_column(
ForeignKey("route.name", onupdate="CASCADE", ondelete="SET NULL")
)
owner = Column(
owner: Mapped[str | None] = mapped_column(
Unicode,
ForeignKey(
"user.name",
Expand Down Expand Up @@ -275,12 +293,6 @@ class User(DB.Model, TimestampMixin): # type: ignore
oauth_connection: Mapped[list["OauthConnection"]] = relationship(
"OauthConnection", back_populates="user"
)
stations: Mapped[set["Station"]] = relationship(
"User",
secondary="user_station",
back_populates="users",
collection_class=set,
)
files: Mapped[list["Upload"]] = relationship(
"Upload", back_populates="user"
)
Expand All @@ -303,7 +315,7 @@ def avatar_url(self) -> str:
return ""

@staticmethod
def get_or_create(session: scoped_session, username: str) -> "User":
def get_or_create(session: "scoped_session[Session]", username: str) -> "User":
"""
Returns a user instance by name. Creates it if missing.
"""
Expand Down Expand Up @@ -360,7 +372,7 @@ def __init__(self) -> None:
self.name = "Example Station"

@staticmethod
def get_or_create(session: scoped_session, name: str) -> "Role":
def get_or_create(session: "scoped_session[Session]", name: str) -> "Role":
"""
Retrieves a role with name *name*.

Expand All @@ -369,15 +381,15 @@ def get_or_create(session: scoped_session, name: str) -> "Role":
query = session.query(Role).filter_by(name=name)
existing = query.one_or_none()
if not existing:
output = Role() # type: ignore
output = Role()
output.name = name
session.add(output)
else:
output = existing
return output # type: ignore
return output


class TeamStation(DB.Model, TimestampMixin): # type: ignore
class TeamStation(DB.Model): # type: ignore
__tablename__ = "team_station_state"

team_name: Mapped[str] = mapped_column(
Expand All @@ -388,10 +400,16 @@ class TeamStation(DB.Model, TimestampMixin): # type: ignore
ForeignKey("station.name", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True,
)
state: Mapped[TeamStateType | None] = mapped_column(
state: Mapped[TeamState | None] = mapped_column(
TeamStateType, default=TeamState.UNKNOWN
)
score: Mapped[int | None] = mapped_column(nullable=True, default=None)
inserted: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
FetchedValue(),
nullable=False,
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
Expand All @@ -415,19 +433,25 @@ def __init__(
self.state = state


class Questionnaire(DB.Model, TimestampMixin): # type: ignore
class Questionnaire(DB.Model): # type: ignore
__tablename__ = "questionnaire"

name: Mapped[str] = mapped_column(nullable=False, primary_key=True)
max_score: Mapped[int | None] = mapped_column()
order: Mapped[int | None] = mapped_column(server_default="0")
inserted: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
FetchedValue(),
nullable=False,
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=datetime.now(),
server_default=func.now(),
)
station_name: Mapped[str] = mapped_column(
station_name: Mapped[str | None] = mapped_column(
Unicode,
ForeignKey(
"station.name",
Expand Down Expand Up @@ -462,7 +486,7 @@ def __init__(
LOG.debug("Ignoring 'inserted' timestamp (%s)", inserted)


class TeamQuestionnaire(DB.Model, TimestampMixin): # type: ignore
class TeamQuestionnaire(DB.Model): # type: ignore
__tablename__ = "questionnaire_score"

team_name: Mapped[str] = mapped_column(
Expand All @@ -480,6 +504,12 @@ class TeamQuestionnaire(DB.Model, TimestampMixin): # type: ignore
score: Mapped[int | None] = mapped_column(
Integer, nullable=True, default=None
)
inserted: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
FetchedValue(),
nullable=False,
server_default=func.now(),
)
updated: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
Expand All @@ -506,8 +536,8 @@ class Upload(DB.Model): # type: ignore
ForeignKey("user.name", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True,
)
uuid: Mapped[UUID] = mapped_column(
UUID,
uuid: Mapped[PyUUID] = mapped_column(
UUID(as_uuid=True),
unique=True,
nullable=False,
name="id",
Expand All @@ -522,7 +552,7 @@ def __init__(self, relname: str, username: str) -> None:

@staticmethod
def get_or_create(
session: scoped_session, relname: str, username: str
session: "scoped_session[Session]", relname: str, username: str
) -> "Upload":
"""
Returns an upload entity. Create it if it is missing
Expand Down Expand Up @@ -633,6 +663,6 @@ def __init__(


class Job:
def __init__(self):
def __init__(self) -> None:
self.action = "example_action"
self.args = {}
self.args: dict[str, Any] = {}
Loading