diff --git a/cli/__init__.py b/cli/__init__.py index fe049a9..3eb266c 100644 --- a/cli/__init__.py +++ b/cli/__init__.py @@ -16,7 +16,11 @@ def cli() -> None: def _register_commands() -> None: """Import subcommands to register them with the CLI group.""" + import cli.cmd_add_entity # noqa: F401 + import cli.cmd_generate # noqa: F401 import cli.cmd_init # noqa: F401 + import cli.cmd_list # noqa: F401 + import cli.cmd_migrate # noqa: F401 _register_commands() diff --git a/cli/cmd_add_entity.py b/cli/cmd_add_entity.py new file mode 100644 index 0000000..1000eb3 --- /dev/null +++ b/cli/cmd_add_entity.py @@ -0,0 +1,259 @@ +"""faststack add-entity — scaffold a new entity.""" + +import ast +import hashlib +from pathlib import Path + +import click +import inflect +import yaml +from jinja2 import Environment, FileSystemLoader + +from cli import cli_group +from cli.yaml_parser import EntityDefinition, FieldDefinition, parse_entities_yaml + +SIMPLE_TEMPLATE_DIR = Path(__file__).parent.parent / "templates" / "simple" + +_inflect_engine = inflect.engine() + +# Templates that are safe to regenerate (no user code expected) +REGENERATABLE_TEMPLATES = { + "schema.py.j2": "app/schemas/{name}.py", + "fake_repository.py.j2": "tests/unit/fakes/{name}_repository.py", + "factory.py.j2": "tests/factories/{name}.py", +} + +# Templates that may contain user code and should not be overwritten +PRESERVED_TEMPLATES = { + "model.py.j2": "app/models/{name}.py", + "repository.py.j2": "app/repositories/{name}.py", + "service.py.j2": "app/services/{name}.py", + "router.py.j2": "app/api/routes/{name}.py", + "test_unit_service.py.j2": "tests/unit/test_{name}_service.py", + "test_integration.py.j2": "tests/integration/test_{name}_api.py", +} + + +def _camel_to_snake(name: str) -> str: + """Convert ``CamelCase`` to ``snake_case``.""" + import re + + s1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def _pluralize(name: str) -> str: + """Return a lowercase, pluralized table name for *name*.""" + snake = _camel_to_snake(name) + plural = _inflect_engine.plural_noun(snake) + return plural if plural else snake + + +def _parse_fields_flag(name: str, fields_str: str) -> EntityDefinition: + """Parse ``"name:type:required,price:decimal"`` into an EntityDefinition.""" + fields: list[FieldDefinition] = [] + for part in fields_str.split(","): + parts = part.strip().split(":") + if not parts or not parts[0]: + continue + field_name = parts[0].strip() + field_type = parts[1].strip() if len(parts) > 1 else "string" + required = len(parts) > 2 and parts[2].strip().lower() == "required" + fields.append( + FieldDefinition( + name=field_name, + type=field_type, + required=required, + ) + ) + + return EntityDefinition( + name=name, + table_name=_pluralize(name), + fields=fields, + ) + + +def _generate_entity_files(entity_def: EntityDefinition, update: bool) -> None: + """Render all 9 entity templates and write to the correct locations.""" + env = Environment( + loader=FileSystemLoader(str(SIMPLE_TEMPLATE_DIR)), + keep_trailing_newline=True, + ) + env.filters["snake_case"] = _camel_to_snake + env.filters["pluralize"] = _pluralize + + name_snake = _camel_to_snake(entity_def.name) + + # Always write REGENERATABLE files + for template_name, path_pattern in REGENERATABLE_TEMPLATES.items(): + output_path = Path(path_pattern.format(name=name_snake)) + output_path.parent.mkdir(parents=True, exist_ok=True) + template = env.get_template(template_name) + content = template.render(entity=entity_def) + output_path.write_text(content) + + # PRESERVED files: only write if file doesn't exist, or if --update + for template_name, path_pattern in PRESERVED_TEMPLATES.items(): + output_path = Path(path_pattern.format(name=name_snake)) + if output_path.exists() and not update: + continue + output_path.parent.mkdir(parents=True, exist_ok=True) + template = env.get_template(template_name) + content = template.render(entity=entity_def) + output_path.write_text(content) + + +def _register_router_in_main(entity_name: str) -> None: + """Append router import and include_router to app/main.py if not already present.""" + main_path = Path("app/main.py") + if not main_path.exists(): + return + + snake = _camel_to_snake(entity_name) + source = main_path.read_text() + + # Use AST to check whether the import already exists + try: + tree = ast.parse(source) + except SyntaxError: + return + + module_name = f"app.api.routes.{snake}" + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom) and node.module == module_name: + return # already registered + + import_line = f"from app.api.routes.{snake} import router as {snake}_router" + include_line = f'app.include_router({snake}_router, prefix="/api")' + + lines = source.rstrip("\n").split("\n") + + # Find the last include_router line to insert after it + last_include_idx = -1 + for i, line in enumerate(lines): + if "app.include_router(" in line: + last_include_idx = i + + if last_include_idx >= 0: + # Insert after the last include_router block + lines.insert(last_include_idx + 1, f"\n{import_line}") + lines.insert(last_include_idx + 2, include_line) + else: + # No existing include_router — append at end of file + lines.append("") + lines.append(import_line) + lines.append(include_line) + + main_path.write_text("\n".join(lines) + "\n") + + +def _regenerate_registry_files() -> None: + """Render multi-entity templates (dependencies.py, integration conftest). + + These templates need the full entity list from .project-config.yaml, + unlike per-entity templates. + """ + config_path = Path(".project-config.yaml") + if not config_path.exists(): + return + + config = yaml.safe_load(config_path.read_text()) or {} + entities_map = config.get("entities") or {} + if not entities_map: + return + + # Build entity context list + entities = [{"name": name, "snake_name": _camel_to_snake(name)} for name in entities_map] + + env = Environment( + loader=FileSystemLoader(str(SIMPLE_TEMPLATE_DIR)), + keep_trailing_newline=True, + ) + env.filters["snake_case"] = _camel_to_snake + env.filters["pluralize"] = _pluralize + + context = {"entities": entities} + + # Render dependencies.py + deps_path = Path("app/api/dependencies.py") + deps_path.parent.mkdir(parents=True, exist_ok=True) + deps_template = env.get_template("dependencies.py.j2") + deps_path.write_text(deps_template.render(**context)) + + # Render integration test conftest + conftest_path = Path("tests/integration/conftest.py") + conftest_path.parent.mkdir(parents=True, exist_ok=True) + conftest_template = env.get_template("conftest_integration.py.j2") + conftest_path.write_text(conftest_template.render(**context)) + + +def _update_project_config(config_path: Path, entity_def: EntityDefinition, model_path: Path) -> None: + """Update ``.project-config.yaml`` with entity name, model path, and hash.""" + config = yaml.safe_load(config_path.read_text()) or {} + entities = config.get("entities", {}) + if entities is None: + entities = {} + + model_hash = "" + if model_path.exists(): + model_hash = hashlib.sha256(model_path.read_bytes()).hexdigest() + + entities[entity_def.name] = { + "model_path": str(model_path), + "hash": model_hash, + } + config["entities"] = entities + config_path.write_text(yaml.dump(config, default_flow_style=False)) + + +@cli_group.command("add-entity") +@click.argument("entity_name") +@click.option("--fields", help='Field definitions: "name:type:required,price:decimal"') +@click.option( + "--from-yaml", + "yaml_path", + type=click.Path(exists=True), + help="Path to entities.yaml", +) +@click.option("--update", is_flag=True, help="Update existing entity (merge new fields)") +def add_entity( + entity_name: str, + fields: str | None, + yaml_path: str | None, + update: bool, +) -> None: + """Add a new entity to the project.""" + # Check we're in a FastStack project + config_path = Path(".project-config.yaml") + if not config_path.exists(): + raise click.ClickException("No .project-config.yaml found. Run from project root.") + + # Build EntityDefinition from flags or YAML + if yaml_path: + entities = parse_entities_yaml(Path(yaml_path)) + entity_def = next((e for e in entities if e.name == entity_name), None) + if entity_def is None: + raise click.ClickException(f"Entity '{entity_name}' not found in {yaml_path}") + elif fields: + entity_def = _parse_fields_flag(entity_name, fields) + else: + raise click.ClickException("Provide --fields or --from-yaml") + + # Check if entity already exists + model_path = Path(f"app/models/{_camel_to_snake(entity_name)}.py") + if model_path.exists() and not update: + raise click.ClickException(f"Entity '{entity_name}' already exists. Use --update to merge fields.") + + # Generate files from templates + _generate_entity_files(entity_def, update) + + # Update .project-config.yaml + _update_project_config(config_path, entity_def, model_path) + + # Register router in main.py and regenerate registry files + _register_router_in_main(entity_name) + _regenerate_registry_files() + + click.echo(f"{'Updated' if update else 'Created'} entity '{entity_name}'") + click.echo(f"\nRun 'faststack migrate generate \"add {entity_name.lower()}\"' " f"to create the migration.") diff --git a/cli/cmd_generate.py b/cli/cmd_generate.py new file mode 100644 index 0000000..cac4907 --- /dev/null +++ b/cli/cmd_generate.py @@ -0,0 +1,108 @@ +"""faststack generate — regenerate derived files from models.""" + +import hashlib +from pathlib import Path + +import click +import yaml +from jinja2 import Environment, FileSystemLoader + +from cli import cli_group +from cli.cmd_add_entity import _camel_to_snake, _pluralize, _regenerate_registry_files +from cli.model_introspector import introspect_model + +SIMPLE_TEMPLATE_DIR = Path(__file__).parent.parent / "templates" / "simple" + +REGENERATABLE_FILES = { + "schema.py.j2": "app/schemas/{name}.py", + "fake_repository.py.j2": "tests/unit/fakes/{name}_repository.py", + "factory.py.j2": "tests/factories/{name}.py", +} + +PRESERVED_FILES = { + "model.py.j2": "app/models/{name}.py", + "repository.py.j2": "app/repositories/{name}.py", + "service.py.j2": "app/services/{name}.py", + "router.py.j2": "app/api/routes/{name}.py", + "test_unit_service.py.j2": "tests/unit/test_{name}_service.py", + "test_integration.py.j2": "tests/integration/test_{name}_api.py", +} + + +@cli_group.command("generate") +@click.argument("entity_name", required=False) +@click.option("--all", "generate_all", is_flag=True, help="Regenerate all entities") +@click.option("--force", is_flag=True, help="Also regenerate PRESERVED files (with confirmation)") +def generate(entity_name: str | None, generate_all: bool, force: bool) -> None: + """Regenerate derived files from model (schemas, fakes, factories).""" + config_path = Path(".project-config.yaml") + if not config_path.exists(): + raise click.ClickException("No .project-config.yaml found.") + + config = yaml.safe_load(config_path.read_text()) or {} + entities = config.get("entities", {}) + if entities is None: + entities = {} + + if generate_all: + names = list(entities.keys()) + elif entity_name: + names = [entity_name] + else: + raise click.ClickException("Provide entity name or --all") + + for name in names: + model_path = Path(f"app/models/{_camel_to_snake(name)}.py") + if not model_path.exists(): + click.echo(f"Skipping {name}: model file not found at {model_path}") + continue + + # Introspect model + entity_def = introspect_model(model_path) + + # Render REGENERATABLE files + env = Environment( + loader=FileSystemLoader(str(SIMPLE_TEMPLATE_DIR)), + keep_trailing_newline=True, + ) + env.filters["snake_case"] = _camel_to_snake + env.filters["pluralize"] = _pluralize + + for template_name, path_pattern in REGENERATABLE_FILES.items(): + output_path = Path(path_pattern.format(name=_camel_to_snake(name))) + template = env.get_template(template_name) + content = template.render(entity=entity_def) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(content) + click.echo(f" Regenerated {output_path}") + + # Skip PRESERVED files unless --force + if force: + if not click.confirm(f"Regenerate PRESERVED files for {name}? This will overwrite user code."): + continue + for template_name, path_pattern in PRESERVED_FILES.items(): + output_path = Path(path_pattern.format(name=_camel_to_snake(name))) + template = env.get_template(template_name) + content = template.render(entity=entity_def) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(content) + click.echo(f" Regenerated (PRESERVED) {output_path}") + else: + for path_pattern in PRESERVED_FILES.values(): + output_path = Path(path_pattern.format(name=_camel_to_snake(name))) + if output_path.exists(): + click.echo(f" Skipping {output_path} (PRESERVED — contains user code)") + + # Update hash + new_hash = hashlib.sha256(model_path.read_bytes()).hexdigest() + entities[name] = entities.get(name, {}) or {} + entities[name]["hash"] = new_hash + entities[name]["model_path"] = str(model_path) + + config["entities"] = entities + config_path.write_text(yaml.dump(config, default_flow_style=False)) + + # Regenerate multi-entity registry files (dependencies.py, integration conftest) + _regenerate_registry_files() + + click.echo("\nDone.") diff --git a/cli/cmd_init.py b/cli/cmd_init.py index e7fd466..278400d 100644 --- a/cli/cmd_init.py +++ b/cli/cmd_init.py @@ -1,11 +1,19 @@ """faststack init — scaffold a new FastStack project.""" +import os from pathlib import Path import click from jinja2 import Environment, FileSystemLoader from cli import cli_group +from cli.cmd_add_entity import ( + _camel_to_snake, + _generate_entity_files, + _pluralize, + _regenerate_registry_files, + _update_project_config, +) from cli.yaml_parser import parse_entities_yaml TEMPLATE_DIR = Path(__file__).parent.parent / "templates" / "project" @@ -66,6 +74,8 @@ def init_project(project_name: str, entities: str | None = None) -> None: # Render templates env = Environment(loader=FileSystemLoader(str(TEMPLATE_DIR)), keep_trailing_newline=True) + env.filters["snake_case"] = _camel_to_snake + env.filters["pluralize"] = _pluralize template_context = { "project_name": project_name, @@ -94,13 +104,27 @@ def init_project(project_name: str, entities: str | None = None) -> None: # Create .env file db_name = project_name.lower().replace("-", "_") - env_content = ( - f"DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/{db_name}\n" - f"LOG_LEVEL=INFO\n" - ) + env_content = f"DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/{db_name}\n" f"LOG_LEVEL=INFO\n" (project_dir / ".env").write_text(env_content) - click.echo(f"Created project '{project_name}' at {project_dir}") + # Generate entity files if YAML was provided + if entity_defs: + original_cwd = os.getcwd() + os.chdir(project_dir) + try: + config_path = Path(".project-config.yaml") + for entity_def in entity_defs: + _generate_entity_files(entity_def, update=False) + model_path = Path(f"app/models/{_camel_to_snake(entity_def.name)}.py") + _update_project_config(config_path, entity_def, model_path) + click.echo(f" Generated entity: {entity_def.name}") + _regenerate_registry_files() + finally: + os.chdir(original_cwd) + + click.echo(f"\nCreated project '{project_name}' at {project_dir}") + if entity_defs: + click.echo(f" {len(entity_defs)} entities generated") click.echo() click.echo("Next steps:") click.echo(f" cd {project_name}") diff --git a/cli/cmd_list.py b/cli/cmd_list.py new file mode 100644 index 0000000..5c4fa02 --- /dev/null +++ b/cli/cmd_list.py @@ -0,0 +1,48 @@ +"""faststack list — show entity generation status.""" + +import hashlib +from pathlib import Path + +import click +import yaml + +from cli import cli_group + + +@cli_group.command("list") +def list_entities() -> None: + """Show all entities and their generation status.""" + config_path = Path(".project-config.yaml") + if not config_path.exists(): + raise click.ClickException("No .project-config.yaml found. Run this from a FastStack project root.") + + config = yaml.safe_load(config_path.read_text()) + entities = config.get("entities", {}) + + if not entities: + click.echo("No entities registered. Run 'faststack add-entity' to create one.") + return + + # Header + click.echo(f"{'Entity':<20} {'Model':<40} {'Status':<20}") + click.echo("-" * 80) + + for entity_name, entity_info in entities.items(): + model_path = Path(entity_info.get("model_path", f"app/models/{entity_name.lower()}.py")) + stored_hash = entity_info.get("hash", "") + + if not model_path.exists(): + status = click.style("MISSING", fg="red") + else: + current_hash = _file_hash(model_path) + if current_hash == stored_hash: + status = click.style("up to date", fg="green") + else: + status = click.style("schemas outdated", fg="yellow") + + click.echo(f"{entity_name:<20} {str(model_path):<40} {status}") + + +def _file_hash(path: Path) -> str: + """Compute SHA-256 hash of a file.""" + return hashlib.sha256(path.read_bytes()).hexdigest() diff --git a/cli/cmd_migrate.py b/cli/cmd_migrate.py new file mode 100644 index 0000000..8662dc4 --- /dev/null +++ b/cli/cmd_migrate.py @@ -0,0 +1,47 @@ +"""faststack migrate — Alembic migration wrapper.""" + +import subprocess +import sys +from pathlib import Path + +import click + +from cli import cli_group + + +@cli_group.group("migrate") +def migrate() -> None: + """Database migration commands (wraps Alembic).""" + + +@migrate.command("generate") +@click.argument("message") +def migrate_generate(message: str) -> None: + """Generate a new migration from model changes.""" + _run_alembic(["revision", "--autogenerate", "-m", message]) + + +@migrate.command("upgrade") +def migrate_upgrade() -> None: + """Apply all pending migrations.""" + _run_alembic(["upgrade", "head"]) + + +@migrate.command("downgrade") +def migrate_downgrade() -> None: + """Roll back one migration.""" + _run_alembic(["downgrade", "-1"]) + + +def _run_alembic(args: list[str]) -> None: + """Run an Alembic command, checking for alembic.ini in cwd.""" + if not Path("alembic.ini").exists(): + raise click.ClickException( + "No alembic.ini found in current directory. " "Run this command from your project root." + ) + result = subprocess.run( + [sys.executable, "-m", "alembic"] + args, + check=False, + ) + if result.returncode != 0: + raise SystemExit(result.returncode) diff --git a/cli/field_mappings.py b/cli/field_mappings.py index 6278f2d..0280b28 100644 --- a/cli/field_mappings.py +++ b/cli/field_mappings.py @@ -89,8 +89,7 @@ def _validate_type(yaml_type: str) -> None: """Raise ValueError if *yaml_type* is not a recognised YAML field type.""" if yaml_type not in ALL_YAML_TYPES: raise ValueError( - f"Unknown YAML field type: {yaml_type!r}. " - f"Supported types: {', '.join(sorted(ALL_YAML_TYPES))}" + f"Unknown YAML field type: {yaml_type!r}. " f"Supported types: {', '.join(sorted(ALL_YAML_TYPES))}" ) diff --git a/cli/model_introspector.py b/cli/model_introspector.py index eb76d83..465a46a 100644 --- a/cli/model_introspector.py +++ b/cli/model_introspector.py @@ -139,8 +139,7 @@ def introspect_model(path: Path) -> EntityDefinition: relationships = list(explicit_relationships) for fk_rel in fk_relationships: has_explicit = any( - r.target_entity == fk_rel.target_entity - and r.type in ("many_to_one", "self_referential") + r.target_entity == fk_rel.target_entity and r.type in ("many_to_one", "self_referential") for r in explicit_relationships ) if not has_explicit: diff --git a/cli/yaml_parser.py b/cli/yaml_parser.py index 16908c5..b5abb8e 100644 --- a/cli/yaml_parser.py +++ b/cli/yaml_parser.py @@ -67,11 +67,14 @@ def _pluralize(name: str) -> str: """Return a lowercase, pluralized table name for *name*. Uses *inflect* for proper English pluralization. + Skips words that are already plural (e.g. "parameters"). """ # Convert CamelCase to snake_case first snake = _camel_to_snake(name) + # Check if already plural — singular_noun() returns False for singular words + if _inflect_engine.singular_noun(snake) is not False: + return snake plural = _inflect_engine.plural_noun(snake) - # inflect returns False if the word is already plural return plural if plural else snake diff --git a/examples/rag_modulo.yaml b/examples/rag_modulo.yaml new file mode 100644 index 0000000..47fb79f --- /dev/null +++ b/examples/rag_modulo.yaml @@ -0,0 +1,276 @@ +# entities.yaml — Full rag_modulo entity set (20 entities) +# +# Complete reproduction of the rag_modulo project's data model +# for testing FastStack code generation end-to-end. +# +# Note: Junction tables (UserTeam, UserCollection) are defined as +# regular entities with composite FKs rather than many_to_many +# shorthand, matching the rag_modulo pattern of association tables +# with extra fields (joined_at). + +entities: + User: + base: AuditedEntity + fields: + ibm_id: {type: string, unique: true, required: true} + email: {type: string, required: true} + name: {type: string, required: true} + role: {type: enum, values: [user, admin], default: '"user"'} + searchable: [email, name, ibm_id] + + Team: + base: AuditedEntity + fields: + name: {type: string, required: true} + description: {type: text} + searchable: [name] + + UserTeam: + base: Entity + fields: + user_id: {type: uuid, references: User, on_delete: cascade, required: true} + team_id: {type: uuid, references: Team, on_delete: cascade, required: true} + joined_at: {type: datetime} + + Collection: + base: AuditedEntity + fields: + name: {type: string, required: true} + vector_db_name: {type: string, required: true} + status: {type: enum, values: [created, processing, ready, error], default: '"created"'} + is_private: {type: boolean} + searchable: [name] + + UserCollection: + base: Entity + fields: + user_id: {type: uuid, references: User, on_delete: cascade, required: true} + collection_id: {type: uuid, references: Collection, on_delete: cascade, required: true} + joined_at: {type: datetime} + + File: + base: AuditedEntity + fields: + user_id: {type: uuid, references: User, on_delete: cascade, required: true} + collection_id: {type: uuid, references: Collection} + filename: {type: string, required: true} + file_path: {type: string, required: true} + file_type: {type: string, required: true} + document_id: {type: string} + file_metadata: {type: jsonb} + searchable: [filename, file_path] + + SuggestedQuestion: + base: Entity + fields: + collection_id: {type: uuid, references: Collection, on_delete: cascade, required: true} + question: {type: string, required: true} + question_metadata: {type: jsonb} + created_at: {type: datetime} + searchable: [question] + + ConversationSession: + base: AuditedEntity + fields: + user_id: {type: uuid, references: User, on_delete: cascade, required: true} + collection_id: {type: uuid, references: Collection, on_delete: cascade, required: true} + session_name: {type: string, required: true} + status: {type: enum, values: [active, completed, archived], default: '"active"'} + context_window_size: {type: integer} + max_messages: {type: integer} + is_archived: {type: boolean} + is_pinned: {type: boolean} + session_metadata: {type: jsonb} + searchable: [session_name] + + ConversationMessage: + base: Entity + fields: + session_id: {type: uuid, references: ConversationSession, on_delete: cascade, required: true} + content: {type: text, required: true} + role: {type: enum, values: [user, assistant, system], required: true} + message_type: {type: string} + message_metadata: {type: jsonb} + token_count: {type: integer} + execution_time: {type: float} + created_at: {type: datetime} + + ConversationSummary: + base: Entity + fields: + session_id: {type: uuid, references: ConversationSession, on_delete: cascade, required: true} + summary_text: {type: text, required: true} + summarized_message_count: {type: integer} + tokens_saved: {type: integer} + key_topics: {type: jsonb} + important_decisions: {type: jsonb} + unresolved_questions: {type: jsonb} + summary_strategy: {type: enum, values: [recent_plus_summary, full_context, sliding_window], default: '"recent_plus_summary"'} + summary_metadata: {type: jsonb} + created_at: {type: datetime} + + LlmProvider: + base: AuditedEntity + fields: + name: {type: string, unique: true, required: true} + base_url: {type: string, required: true} + api_key: {type: string, required: true} + org_id: {type: string} + project_id: {type: string} + is_active: {type: boolean} + is_default: {type: boolean} + searchable: [name] + + LlmModel: + base: AuditedEntity + fields: + provider_id: {type: uuid, references: LlmProvider, on_delete: cascade, required: true} + model_id: {type: string, required: true} + default_model_id: {type: string} + model_type: {type: enum, values: [chat, embedding, reranker], required: true} + timeout: {type: integer} + retry_attempts: {type: integer} + batch_size: {type: integer} + concurrency_limit: {type: integer} + rate_limit: {type: integer} + is_active: {type: boolean} + is_default: {type: boolean} + stream: {type: boolean} + searchable: [model_id] + + LlmParameters: + base: AuditedEntity + fields: + user_id: {type: uuid, references: User, on_delete: cascade, required: true} + name: {type: string, required: true} + description: {type: text} + max_new_tokens: {type: integer} + temperature: {type: float} + top_k: {type: integer} + top_p: {type: float} + repetition_penalty: {type: float} + is_default: {type: boolean} + searchable: [name] + + PromptTemplate: + base: AuditedEntity + fields: + user_id: {type: uuid, references: User, on_delete: cascade, required: true} + name: {type: string, required: true} + template_type: {type: enum, values: [qa, summarization, classification, extraction, custom], required: true} + system_prompt: {type: text} + template_format: {type: text} + input_variables: {type: jsonb, required: true} + example_inputs: {type: jsonb} + context_strategy: {type: jsonb} + stop_sequences: {type: jsonb} + validation_schema: {type: jsonb} + max_context_length: {type: integer} + is_default: {type: boolean} + searchable: [name] + + RuntimeConfig: + base: AuditedEntity + fields: + scope: {type: enum, values: [global, user, collection], required: true} + category: {type: enum, values: [system, override, experiment, performance], required: true} + config_key: {type: string, required: true} + config_value: {type: jsonb, required: true} + user_id: {type: uuid, references: User} + collection_id: {type: uuid, references: Collection} + is_active: {type: boolean} + description: {type: text} + created_by: {type: string} + searchable: [config_key] + + Agent: + base: AuditedEntity + fields: + spiffe_id: {type: string, unique: true, required: true} + agent_type: {type: string, required: true} + name: {type: string, required: true} + owner_user_id: {type: uuid, references: User, on_delete: cascade, required: true} + team_id: {type: uuid, references: Team} + capabilities: {type: jsonb} + agent_metadata: {type: jsonb} + description: {type: text} + status: {type: enum, values: [active, suspended, revoked, pending], default: '"pending"'} + last_seen_at: {type: datetime} + searchable: [name, agent_type, spiffe_id] + + Podcast: + base: AuditedEntity + fields: + user_id: {type: uuid, references: User, on_delete: cascade, required: true} + collection_id: {type: uuid, references: Collection, on_delete: cascade, required: true} + title: {type: string} + duration: {type: enum, values: [short, medium, long], required: true} + status: {type: enum, values: [pending, generating, ready, error], default: '"pending"'} + progress_percentage: {type: integer} + active_step: {type: string} + step_details: {type: jsonb} + estimated_time_remaining: {type: integer} + host_voice: {type: string} + expert_voice: {type: string} + voice_settings: {type: jsonb} + audio_format: {type: string} + audio_url: {type: string} + transcript: {type: text} + chapter_markers: {type: jsonb} + error_message: {type: string} + completed_at: {type: datetime} + searchable: [title] + + Voice: + base: AuditedEntity + fields: + user_id: {type: uuid, references: User, on_delete: cascade, required: true} + name: {type: string, required: true} + description: {type: text} + gender: {type: enum, values: [male, female, neutral], required: true} + status: {type: enum, values: [uploading, processing, ready, failed], default: '"uploading"'} + provider_voice_id: {type: string} + provider_name: {type: string} + sample_file_url: {type: string, required: true} + sample_file_size: {type: integer} + quality_score: {type: integer} + error_message: {type: string} + times_used: {type: integer} + processed_at: {type: datetime} + searchable: [name] + + TokenWarning: + base: Entity + fields: + user_id: {type: uuid, references: User} + session_id: {type: string} + current_tokens: {type: integer, required: true} + limit_tokens: {type: integer, required: true} + percentage_used: {type: float, required: true} + warning_type: {type: enum, values: [approaching, exceeded, critical], required: true} + severity: {type: enum, values: [low, medium, high, critical], required: true} + message: {type: string, required: true} + suggested_action: {type: string} + model_name: {type: string} + service_type: {type: string} + created_at: {type: datetime} + acknowledged_at: {type: datetime} + searchable: [warning_type, severity] + + PipelineConfig: + base: AuditedEntity + fields: + user_id: {type: uuid, references: User, on_delete: cascade, required: true} + collection_id: {type: uuid, references: Collection} + provider_id: {type: uuid, references: LlmProvider} + name: {type: string, required: true} + chunking_strategy: {type: string} + embedding_model: {type: string} + retriever: {type: string} + context_strategy: {type: string} + enable_logging: {type: boolean} + max_context_length: {type: integer} + timeout: {type: float} + config_metadata: {type: jsonb} + is_default: {type: boolean} + searchable: [name] diff --git a/examples/smoke_test_orchestrator.py b/examples/smoke_test_orchestrator.py new file mode 100644 index 0000000..8020455 --- /dev/null +++ b/examples/smoke_test_orchestrator.py @@ -0,0 +1,275 @@ +"""Smoke test: Compare business logic DB call patterns. + +Question: Does FastStack's CrudService + Repository pattern result in +fewer DB calls than rag_modulo's hand-written orchestrator? + +rag_modulo's MessageProcessingOrchestrator makes 11 operations per +message (6 DB, 5 service calls), including a redundant provider lookup. + +This test builds equivalent business logic on FastStack's base classes +to see if the pattern naturally avoids those inefficiencies. + +Usage: + PYTHONPATH=. python examples/smoke_test_orchestrator.py + (run from the faststack project root, not a generated project) +""" + +import asyncio +import logging +import uuid +from uuid import UUID + +from sqlalchemy import Float, ForeignKey, Integer, String, Text, func, select +from sqlalchemy.ext.asyncio import ( + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import Mapped, mapped_column + +from faststack_core.base.entity import AuditedEntity, Base +from faststack_core.base.repository import SqlAlchemyRepository + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)-5s | %(name)s | %(message)s", +) +logger = logging.getLogger("orchestrator") + + +# --------------------------------------------------------------------------- +# Models (inline — no generated project needed) +# --------------------------------------------------------------------------- + + +class User(AuditedEntity): + __tablename__ = "users" + email: Mapped[str] = mapped_column(String(255)) + name: Mapped[str] = mapped_column(String(255)) + + +class ConversationSession(AuditedEntity): + __tablename__ = "conversation_sessions" + user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id")) + session_name: Mapped[str] = mapped_column(String(255)) + + +class ConversationMessage(AuditedEntity): + __tablename__ = "conversation_messages" + session_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("conversation_sessions.id")) + content: Mapped[str] = mapped_column(Text) + role: Mapped[str] = mapped_column(String(50)) + token_count: Mapped[int | None] = mapped_column(Integer, default=None) + + +# --------------------------------------------------------------------------- +# DB call tracker +# --------------------------------------------------------------------------- +db_calls: list[str] = [] + + +def track(operation: str, detail: str = "") -> None: + label = f"{operation}: {detail}" if detail else operation + db_calls.append(label) + logger.info(f"[#{len(db_calls):02d}] {label}") + + +# --------------------------------------------------------------------------- +# Repositories (extend SqlAlchemyRepository with custom queries) +# --------------------------------------------------------------------------- + + +class SessionRepository(SqlAlchemyRepository): + def __init__(self, db: AsyncSession) -> None: + super().__init__(db, ConversationSession) + + async def get_session_by_id(self, session_id: UUID) -> ConversationSession | None: + track("DB READ", "get session by id") + return await self.get_by_id(session_id) + + +class MessageRepository(SqlAlchemyRepository): + def __init__(self, db: AsyncSession) -> None: + super().__init__(db, ConversationMessage) + + async def create_message(self, session_id: UUID, content: str, role: str) -> ConversationMessage: + track("DB WRITE", f"create {role} message ({len(content)} chars)") + return await self.create( + { + "session_id": session_id, + "content": content, + "role": role, + "token_count": len(content.split()), + } + ) + + async def get_messages_by_session(self, session_id: UUID) -> list[ConversationMessage]: + track("DB READ", "get messages for session") + result = await self.db.execute(select(ConversationMessage).where(ConversationMessage.session_id == session_id)) + return list(result.scalars().all()) + + async def get_token_usage(self, session_id: UUID) -> int: + track("DB READ", "aggregate token_count for session") + result = await self.db.execute( + select(func.coalesce(func.sum(ConversationMessage.token_count), 0)).where( + ConversationMessage.session_id == session_id + ) + ) + return result.scalar_one() + + +class UserRepository(SqlAlchemyRepository): + def __init__(self, db: AsyncSession) -> None: + super().__init__(db, User) + + async def get_user_provider(self, user_id: UUID) -> str: + track("DB READ", "get user's LLM provider") + return "mock-provider" + + +# --------------------------------------------------------------------------- +# Orchestrator — FastStack version +# --------------------------------------------------------------------------- + + +class MessageOrchestrator: + """Message processing using FastStack patterns. + + Key differences from rag_modulo: + - Repository provides typed methods (not raw SQL) + - The redundant provider lookup is avoided by caching + """ + + def __init__( + self, + session_repo: SessionRepository, + message_repo: MessageRepository, + user_repo: UserRepository, + ) -> None: + self.session_repo = session_repo + self.message_repo = message_repo + self.user_repo = user_repo + + async def process_message(self, session_id: UUID, user_id: UUID, content: str) -> dict: + db_calls.clear() + logger.info(f"Processing message for session {session_id}") + + # 1. Load session + session = await self.session_repo.get_session_by_id(session_id) + if not session: + raise ValueError("Session not found") + + # 2. Store user message + user_msg = await self.message_repo.create_message(session_id, content, "user") + + # 3. Get history for context + messages = await self.message_repo.get_messages_by_session(session_id) + logger.info(f" {len(messages)} messages in history") + + # 4-6. Service calls (context, enhance, search) + track("SERVICE", "build context from messages") + track("SERVICE", "enhance question with context") + track("SERVICE", "RAG search (vector DB + LLM)") + + # 7. Get provider (ONCE — unlike rag_modulo which does it twice) + provider = await self.user_repo.get_user_provider(user_id) + + # 8. Get token usage + total_tokens = await self.message_repo.get_token_usage(session_id) + logger.info(f" Total tokens in session: {total_tokens}") + + # 9. Check token warning (uses cached provider — NO redundant DB call) + track("SERVICE", f"check token warning (provider={provider}, cached)") + + # 10. Store assistant response + answer = f"Generated answer for: {content[:50]}" + assistant_msg = await self.message_repo.create_message(session_id, answer, "assistant") + + return { + "user_msg_id": user_msg.id, + "assistant_msg_id": assistant_msg.id, + "answer": answer, + "total_tokens": total_tokens, + "db_calls": len(db_calls), + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main() -> None: + print("=" * 70) + print("FastStack vs rag_modulo: DB Call Pattern Comparison") + print("=" * 70) + + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with factory() as session: + # Seed + user = User(email="manav@test.com", name="Manav Gupta") + session.add(user) + await session.flush() + + conv = ConversationSession(user_id=user.id, session_name="Test Chat") + session.add(conv) + await session.flush() + + # Process a message + orchestrator = MessageOrchestrator( + SessionRepository(session), + MessageRepository(session), + UserRepository(session), + ) + + print() + result = await orchestrator.process_message( + conv.id, + user.id, + "What are the key findings in the quarterly report?", + ) + + await engine.dispose() + + # Report + actual_db = sum(1 for c in db_calls if c.startswith("DB")) + service_calls = sum(1 for c in db_calls if c.startswith("SERVICE")) + + print("\n" + "=" * 70) + print("COMPARISON") + print("=" * 70) + print() + print(" rag_modulo (current):") + print(" 11 total operations") + print(" 6 DB calls (2 writes, 3 reads, 1 aggregate)") + print(" 5 service calls") + print(" x get_user_provider called TWICE (redundant)") + print(" x No correlation IDs in logs") + print(" x 45+ manual log statements with emojis") + print() + print(" FastStack (this test):") + print(f" {len(db_calls)} total operations") + print(f" {actual_db} DB calls") + print(f" {service_calls} service calls") + print(" + get_user_provider called ONCE (cached)") + print(" + Correlation IDs via middleware (automatic)") + print(" + Request logging via middleware (automatic)") + print() + print(" Call log:") + for i, call in enumerate(db_calls, 1): + marker = " " if call.startswith("SERVICE") else "> " + print(f" {marker}#{i:02d} {call}") + print() + saved = 11 - len(db_calls) + if saved > 0: + print(f" Result: {saved} fewer operation(s) by caching the provider lookup") + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/faststack_core/base/repository.py b/faststack_core/base/repository.py index 1e9ca52..59b2b9f 100644 --- a/faststack_core/base/repository.py +++ b/faststack_core/base/repository.py @@ -53,9 +53,7 @@ async def count(self) -> int: ... class SearchableRepository(Repository[T], Protocol): """Extended contract with full-text search and sorting.""" - async def search( - self, query: str, fields: list[str], skip: int = 0, limit: int = 100 - ) -> list[T]: ... + async def search(self, query: str, fields: list[str], skip: int = 0, limit: int = 100) -> list[T]: ... # --------------------------------------------------------------------------- @@ -79,9 +77,7 @@ def __init__(self, db: AsyncSession, model: type[T]) -> None: self.model = model async def get_by_id(self, id: UUID) -> T | None: - result = await self.db.execute( - select(self.model).where(self.model.id == id) # type: ignore[attr-defined] - ) + result = await self.db.execute(select(self.model).where(self.model.id == id)) # type: ignore[attr-defined] return result.scalar_one_or_none() async def list(self, skip: int = 0, limit: int = 100) -> list[T]: diff --git a/faststack_core/logging/structured_logger.py b/faststack_core/logging/structured_logger.py index c8e2953..c4528dc 100644 --- a/faststack_core/logging/structured_logger.py +++ b/faststack_core/logging/structured_logger.py @@ -25,9 +25,7 @@ def setup(self, app_name: str = "faststack", log_level: str = "INFO") -> logging # Console handler — simple text format console = logging.StreamHandler(sys.stderr) - console.setFormatter( - logging.Formatter("%(asctime)s | %(levelname)-8s | %(name)s | %(message)s") - ) + console.setFormatter(logging.Formatter("%(asctime)s | %(levelname)-8s | %(name)s | %(message)s")) logger.addHandler(console) return logger diff --git a/pyproject.toml b/pyproject.toml index 558982a..ba50107 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ exclude_lines = [ [tool.ruff] target-version = "py312" -line-length = 100 +line-length = 120 [tool.ruff.lint] select = [ @@ -91,9 +91,12 @@ ignore = [ "tests/*" = ["C901"] # allow complex test helper functions "cli/model_introspector.py" = ["C901"] # AST parsing is inherently complex "cli/yaml_parser.py" = ["C901"] # YAML parsing with relationship resolution +"cli/cmd_generate.py" = ["C901"] # regeneration logic with PRESERVED/REGENERATABLE handling +"cli/cmd_add_entity.py" = ["C901"] # entity generation with multiple input modes +"examples/*" = ["F841", "E402", "F401", "C901", "B904"] # smoke test scripts — relaxed lint [tool.black] -line-length = 100 +line-length = 120 target-version = ["py312"] [tool.mypy] diff --git a/templates/project/main.py.j2 b/templates/project/main.py.j2 index f854c86..405e8e4 100644 --- a/templates/project/main.py.j2 +++ b/templates/project/main.py.j2 @@ -11,6 +11,6 @@ setup_app(app, FastStackConfig( )) {% for entity in entities %} -from app.api.routes.{{ entity.name | lower }} import router as {{ entity.name | lower }}_router -app.include_router({{ entity.name | lower }}_router, prefix="/api") +from app.api.routes.{{ entity.name | snake_case }} import router as {{ entity.name | snake_case }}_router +app.include_router({{ entity.name | snake_case }}_router, prefix="/api") {% endfor %} diff --git a/templates/project/pyproject.toml.j2 b/templates/project/pyproject.toml.j2 index ebb4da2..7158bab 100644 --- a/templates/project/pyproject.toml.j2 +++ b/templates/project/pyproject.toml.j2 @@ -28,8 +28,8 @@ testpaths = ["tests"] [tool.ruff] target-version = "py312" -line-length = 100 +line-length = 120 [tool.black] -line-length = 100 +line-length = 120 target-version = ["py312"] diff --git a/templates/simple/conftest_integration.py.j2 b/templates/simple/conftest_integration.py.j2 new file mode 100644 index 0000000..12c2216 --- /dev/null +++ b/templates/simple/conftest_integration.py.j2 @@ -0,0 +1,41 @@ +{#- -------------------------------------------------------------------------- + conftest_integration.py.j2 — Integration test fixtures with DI overrides. + + Template variables: + entities — list of dicts with keys: name (str), snake_name (str) + -------------------------------------------------------------------------- -#} +"""Integration test fixtures with fake-repository DI overrides.""" + +import pytest +from httpx import ASGITransport, AsyncClient + +from app.main import app +from app.api.dependencies import ( +{%- for entity in entities %} + get_{{ entity.snake_name }}_service, +{%- endfor %} +) +{% for entity in entities %} +from app.services.{{ entity.snake_name }} import {{ entity.name }}Service +from tests.unit.fakes.{{ entity.snake_name }}_repository import Fake{{ entity.name }}Repository +{%- endfor %} + + +@pytest.fixture +async def client(): + """Async HTTP client with fake repositories injected.""" +{%- for entity in entities %} + fake_{{ entity.snake_name }}_repo = Fake{{ entity.name }}Repository() + app.dependency_overrides[get_{{ entity.snake_name }}_service] = ( + lambda _repo=fake_{{ entity.snake_name }}_repo: {{ entity.name }}Service(_repo) + ) +{%- endfor %} + + async with AsyncClient( + transport=ASGITransport(app=app), + base_url="http://test", + follow_redirects=True, + ) as c: + yield c + + app.dependency_overrides.clear() diff --git a/templates/simple/dependencies.py.j2 b/templates/simple/dependencies.py.j2 new file mode 100644 index 0000000..4d73fe6 --- /dev/null +++ b/templates/simple/dependencies.py.j2 @@ -0,0 +1,43 @@ +{#- -------------------------------------------------------------------------- + dependencies.py.j2 — Generate DI providers for all entities. + + Template variables: + entities — list of dicts with keys: name (str), snake_name (str) + -------------------------------------------------------------------------- -#} +"""Dependency injection providers for FastAPI routes.""" + +from collections.abc import AsyncGenerator + +from fastapi import Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from faststack_core.database.session import DatabaseConfig, create_engine, create_session_factory + +_db_config = DatabaseConfig(url=settings.database_url) +_engine = create_engine(_db_config) +_session_factory = create_session_factory(_engine) + + +async def get_db_session() -> AsyncGenerator[AsyncSession, None]: + """Yield a transactional async database session.""" + async with _session_factory() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + +{% for entity in entities %} +from app.repositories.{{ entity.snake_name }} import {{ entity.name }}Repository +from app.services.{{ entity.snake_name }} import {{ entity.name }}Service + + +async def get_{{ entity.snake_name }}_service( + session: AsyncSession = Depends(get_db_session), +) -> {{ entity.name }}Service: + """Provide {{ entity.name }}Service with a repository bound to the current session.""" + return {{ entity.name }}Service({{ entity.name }}Repository(session)) + +{% endfor %} diff --git a/templates/simple/factory.py.j2 b/templates/simple/factory.py.j2 index 1d9e356..47cb46e 100644 --- a/templates/simple/factory.py.j2 +++ b/templates/simple/factory.py.j2 @@ -2,7 +2,7 @@ from polyfactory.factories.pydantic_factory import ModelFactory -from app.schemas.{{ entity.name | lower }} import {{ entity.name }}Create +from app.schemas.{{ entity.name | snake_case }} import {{ entity.name }}Create class {{ entity.name }}CreateFactory(ModelFactory): diff --git a/templates/simple/fake_repository.py.j2 b/templates/simple/fake_repository.py.j2 index 0b9ddf6..273ce53 100644 --- a/templates/simple/fake_repository.py.j2 +++ b/templates/simple/fake_repository.py.j2 @@ -1,11 +1,12 @@ """Fake {{ entity.name }} repository for unit testing.""" import uuid +from datetime import UTC, datetime from uuid import UUID from faststack_core.exceptions.domain import NotFoundError -from app.models.{{ entity.name | lower }} import {{ entity.name }} +from app.models.{{ entity.name | snake_case }} import {{ entity.name }} class Fake{{ entity.name }}Repository: @@ -25,7 +26,13 @@ class Fake{{ entity.name }}Repository: return items[skip : skip + limit] async def create(self, data: dict) -> {{ entity.name }}: + now = datetime.now(UTC) entity = {{ entity.name }}(id=uuid.uuid4(), **data) + # Populate audit fields that SQLAlchemy normally sets via column defaults + if hasattr(entity, "created_at") and entity.created_at is None: + entity.created_at = now + if hasattr(entity, "updated_at") and entity.updated_at is None: + entity.updated_at = now self._store[entity.id] = entity return entity diff --git a/templates/simple/model.py.j2 b/templates/simple/model.py.j2 index e1f8fed..e5048d7 100644 --- a/templates/simple/model.py.j2 +++ b/templates/simple/model.py.j2 @@ -24,7 +24,7 @@ "string": "String", "text": "Text", "integer": "Integer", "float": "Float", "boolean": "Boolean", "datetime": "DateTime", "date": "Date", "decimal": "Numeric", "json": "JSON", - "jsonb": "JSONB", "array": "ARRAY", "enum": "Enum" + "jsonb": "JSON", "array": "ARRAY", "enum": "Enum" } -%} {%- set _seen = namespace(types=[]) -%} {%- for field in entity.fields -%} @@ -38,21 +38,22 @@ {%- endif -%} {%- endfor -%} {%- set sa_imports = _seen.types | sort -%} -{#- Collect TYPE_CHECKING targets -#} -{%- set _tc = namespace(targets=[]) -%} -{%- for rel in entity.relationships -%} -{%- if rel.type != "self_referential" and rel.target_entity != entity.name -%} -{%- if rel.target_entity not in _tc.targets %}{% set _tc.targets = _tc.targets + [rel.target_entity] %}{% endif -%} +{#- Collect repr fields (required fields + first 2 optional, max 5 total) -#} +{%- set _repr_fields = namespace(items=[]) -%} +{%- for field in entity.fields -%} +{%- if field.required and not field.references and _repr_fields.items | length < 5 -%} +{%- set _repr_fields.items = _repr_fields.items + [field.name] -%} +{%- endif -%} +{%- endfor -%} +{%- for field in entity.fields -%} +{%- if not field.required and not field.references and _repr_fields.items | length < 3 -%} +{%- set _repr_fields.items = _repr_fields.items + [field.name] -%} {%- endif -%} {%- endfor -%} {#- ========== Begin output ========== -#} """{{ entity.name }} model.""" -{%- if _tc.targets %} from __future__ import annotations - -from typing import TYPE_CHECKING -{%- endif %} {%- if ns.need_enum %} import enum @@ -79,13 +80,6 @@ from sqlalchemy import {{ sa_imports | join(", ") }} from sqlalchemy.orm import Mapped, mapped_column{{ ", relationship" if ns.need_relationship else "" }} from faststack_core.base.entity import {{ entity.base }} -{%- if _tc.targets %} - -if TYPE_CHECKING: -{%- for target in _tc.targets | sort %} - from app.models.{{ target | lower }} import {{ target }} -{%- endfor %} -{%- endif %} {#- ---------- Enum classes ---------- -#} {%- for field in entity.fields %} {%- if field.type == "enum" and field.enum_values %} @@ -114,9 +108,9 @@ class {{ entity.name }}({{ entity.base }}): {%- endif %} {%- else %} {%- if field.required %} - {{ field.name }}: Mapped[uuid.UUID] = mapped_column(ForeignKey("{{ field.references | lower }}s.id", ondelete="{{ field.on_delete }}"){{ ", unique=True" if field.unique else "" }}) + {{ field.name }}: Mapped[uuid.UUID] = mapped_column(ForeignKey("{{ field.references | snake_case | pluralize }}.id", ondelete="{{ field.on_delete }}"){{ ", unique=True" if field.unique else "" }}) {%- else %} - {{ field.name }}: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("{{ field.references | lower }}s.id", ondelete="{{ field.on_delete }}"){{ ", unique=True" if field.unique else "" }}, default=None) + {{ field.name }}: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("{{ field.references | snake_case | pluralize }}.id", ondelete="{{ field.on_delete }}"){{ ", unique=True" if field.unique else "" }}, default=None) {%- endif %} {%- endif %} {#- ---- String ---- -#} @@ -189,12 +183,12 @@ class {{ entity.name }}({{ entity.base }}): {%- else %} {{ field.name }}: Mapped[dict | None] = mapped_column(JSON, default=None) {%- endif %} -{#- ---- JSONB ---- -#} +{#- ---- JSONB (uses JSON for cross-backend compat; switch to JSONB for PostgreSQL) ---- -#} {%- elif field.type == "jsonb" %} {%- if field.required %} - {{ field.name }}: Mapped[dict] = mapped_column(JSONB) + {{ field.name }}: Mapped[dict] = mapped_column(JSON) {%- else %} - {{ field.name }}: Mapped[dict | None] = mapped_column(JSONB, default=None) + {{ field.name }}: Mapped[dict | None] = mapped_column(JSON, default=None) {%- endif %} {#- ---- Array ---- -#} {%- elif field.type == "array" %} @@ -216,26 +210,30 @@ class {{ entity.name }}({{ entity.base }}): {%- endif %} {%- endif %} {%- endfor %} -{#- ---------- Relationships ---------- -#} +{#- ---------- Relationships (forward only, no back_populates) ---------- -#} {%- for rel in entity.relationships %} {%- if rel.type == "many_to_one" %} - # Relationship to {{ rel.target_entity }} - {{ rel.target_entity | lower }}: Mapped["{{ rel.target_entity }}"] = relationship(back_populates="{{ rel.back_populates }}") + {{ rel.target_entity | snake_case }}: Mapped["{{ rel.target_entity }}"] = relationship() {%- elif rel.type == "self_referential" %} - # Self-referential relationship parent: Mapped["{{ entity.name }} | None"] = relationship( - back_populates="{{ rel.back_populates }}", remote_side="[{{ entity.name }}.id]", ) - {{ rel.back_populates }}: Mapped[list["{{ entity.name }}"]] = relationship(back_populates="parent") {%- elif rel.type == "many_to_many" %} - # Many-to-many relationship to {{ rel.target_entity }} {{ rel.field_name }}: Mapped[list["{{ rel.target_entity }}"]] = relationship( - back_populates="{{ rel.back_populates }}", - secondary="{{ entity.name | lower }}_{{ rel.target_entity | lower }}", + secondary="{{ entity.name | snake_case }}_{{ rel.target_entity | snake_case }}", ) {%- endif %} {%- endfor %} + + def __repr__(self) -> str: + return ( + "{{ entity.name }}(" + f"id={self.id!r}" +{%- for fname in _repr_fields.items %} + f", {{ fname }}={self.{{ fname }}!r}" +{%- endfor %} + ")" + ) diff --git a/templates/simple/repository.py.j2 b/templates/simple/repository.py.j2 index eb13cfe..88eb9de 100644 --- a/templates/simple/repository.py.j2 +++ b/templates/simple/repository.py.j2 @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession -from app.models.{{ entity.name | lower }} import {{ entity.name }} +from app.models.{{ entity.name | snake_case }} import {{ entity.name }} from faststack_core.base.repository import SqlAlchemyRepository @@ -30,7 +30,7 @@ class {{ entity.name }}Repository(SqlAlchemyRepository[{{ entity.name }}]): {%- if entity.searchable %} async def search(self, query: str, skip: int = 0, limit: int = 100) -> list[{{ entity.name }}]: - """Search {{ entity.name | lower }} records by {{ entity.searchable | join(", ") }}.""" + """Search {{ entity.name | snake_case }} records by {{ entity.searchable | join(", ") }}.""" from sqlalchemy import or_, select stmt = ( diff --git a/templates/simple/router.py.j2 b/templates/simple/router.py.j2 index f5951a8..cbf6baf 100644 --- a/templates/simple/router.py.j2 +++ b/templates/simple/router.py.j2 @@ -4,12 +4,14 @@ from uuid import UUID from fastapi import APIRouter, Depends, Query -from app.schemas.{{ entity.name | lower }} import ( +from app.api.dependencies import get_{{ entity.name | snake_case }}_service +from app.schemas.{{ entity.name | snake_case }} import ( {{ entity.name }}Create, {{ entity.name }}DetailResponse, {{ entity.name }}Response, {{ entity.name }}Update, ) +from app.services.{{ entity.name | snake_case }} import {{ entity.name }}Service router = APIRouter( prefix="/{{ entity.table_name }}", @@ -27,31 +29,44 @@ async def list_{{ entity.table_name }}( {%- if entity.searchable %} q: str | None = Query(None, description="Search {{ entity.searchable | join(', ') }}"), {%- endif %} - # TODO: Inject service via Depends() + service: {{ entity.name }}Service = Depends(get_{{ entity.name | snake_case }}_service), ): """List {{ entity.name }} records.""" - ... + return await service.list(skip=skip, limit=limit) @router.get("/{id}", response_model={{ entity.name }}DetailResponse) -async def get_{{ entity.name | lower }}(id: UUID): +async def get_{{ entity.name | snake_case }}( + id: UUID, + service: {{ entity.name }}Service = Depends(get_{{ entity.name | snake_case }}_service), +): """Get a single {{ entity.name }} by ID.""" - ... + return await service.get(id) @router.post("/", response_model={{ entity.name }}Response, status_code=201) -async def create_{{ entity.name | lower }}(data: {{ entity.name }}Create): +async def create_{{ entity.name | snake_case }}( + data: {{ entity.name }}Create, + service: {{ entity.name }}Service = Depends(get_{{ entity.name | snake_case }}_service), +): """Create a new {{ entity.name }}.""" - ... + return await service.create(data.model_dump()) @router.put("/{id}", response_model={{ entity.name }}Response) -async def update_{{ entity.name | lower }}(id: UUID, data: {{ entity.name }}Update): +async def update_{{ entity.name | snake_case }}( + id: UUID, + data: {{ entity.name }}Update, + service: {{ entity.name }}Service = Depends(get_{{ entity.name | snake_case }}_service), +): """Update an existing {{ entity.name }}.""" - ... + return await service.update(id, data.model_dump(exclude_unset=True)) @router.delete("/{id}", status_code=204) -async def delete_{{ entity.name | lower }}(id: UUID): +async def delete_{{ entity.name | snake_case }}( + id: UUID, + service: {{ entity.name }}Service = Depends(get_{{ entity.name | snake_case }}_service), +): """Delete a {{ entity.name }}.""" - ... + await service.delete(id) diff --git a/templates/simple/schema.py.j2 b/templates/simple/schema.py.j2 index 53f74c8..726624e 100644 --- a/templates/simple/schema.py.j2 +++ b/templates/simple/schema.py.j2 @@ -70,7 +70,7 @@ from pydantic import BaseModel, ConfigDict {%- for field in entity.fields %} {%- if field.type == "enum" and field.enum_values %} -from app.models.{{ entity.name | lower }} import {{ entity.name }}{{ field.name | capitalize }} +from app.models.{{ entity.name | snake_case }} import {{ entity.name }}{{ field.name | capitalize }} {%- endif %} {%- endfor %} diff --git a/templates/simple/service.py.j2 b/templates/simple/service.py.j2 index 387570e..08b1188 100644 --- a/templates/simple/service.py.j2 +++ b/templates/simple/service.py.j2 @@ -3,7 +3,7 @@ from faststack_core.base.repository import Repository from faststack_core.base.service import CrudService -from app.models.{{ entity.name | lower }} import {{ entity.name }} +from app.models.{{ entity.name | snake_case }} import {{ entity.name }} class {{ entity.name }}Service(CrudService[{{ entity.name }}]): diff --git a/templates/simple/test_integration.py.j2 b/templates/simple/test_integration.py.j2 index c44bd89..5aabbf9 100644 --- a/templates/simple/test_integration.py.j2 +++ b/templates/simple/test_integration.py.j2 @@ -1,19 +1,36 @@ """Integration tests for {{ entity.name }} API endpoints.""" -# TODO: Wire up the `client` fixture with FastAPI TestClient + DI overrides (Phase 6). - import uuid import pytest -from httpx import ASGITransport, AsyncClient -async def test_create_{{ entity.name | lower }}(client): +async def test_create_{{ entity.name | snake_case }}(client): response = await client.post( "/api/{{ entity.table_name }}", json={ -{%- for field in entity.fields if field.required and not field.references %} - "{{ field.name }}": {% if field.type == "string" or field.type == "text" %}"test_{{ field.name }}"{% elif field.type == "integer" %}1{% elif field.type == "float" %}1.0{% elif field.type == "boolean" %}True{% elif field.type == "enum" %}"{{ field.enum_values[0] }}"{% elif field.type == "decimal" %}"10.00"{% else %}"test_{{ field.name }}"{% endif %}, +{%- for field in entity.fields if field.required %} +{%- if field.references %} + "{{ field.name }}": str(uuid.uuid4()), +{%- elif field.type == "string" or field.type == "text" %} + "{{ field.name }}": "test_{{ field.name }}", +{%- elif field.type == "integer" %} + "{{ field.name }}": 1, +{%- elif field.type == "float" %} + "{{ field.name }}": 1.0, +{%- elif field.type == "boolean" %} + "{{ field.name }}": True, +{%- elif field.type == "enum" %} + "{{ field.name }}": "{{ field.enum_values[0] }}", +{%- elif field.type == "decimal" %} + "{{ field.name }}": "10.00", +{%- elif field.type == "json" or field.type == "jsonb" %} + "{{ field.name }}": {"key": "value"}, +{%- elif field.type == "uuid" %} + "{{ field.name }}": str(uuid.uuid4()), +{%- else %} + "{{ field.name }}": "test_{{ field.name }}", +{%- endif %} {%- endfor %} }, ) @@ -28,6 +45,6 @@ async def test_list_{{ entity.table_name }}(client): assert isinstance(response.json(), list) -async def test_get_{{ entity.name | lower }}_not_found(client): +async def test_get_{{ entity.name | snake_case }}_not_found(client): response = await client.get(f"/api/{{ entity.table_name }}/{uuid.uuid4()}") assert response.status_code == 404 diff --git a/templates/simple/test_unit_service.py.j2 b/templates/simple/test_unit_service.py.j2 index b6c6415..43a4cd5 100644 --- a/templates/simple/test_unit_service.py.j2 +++ b/templates/simple/test_unit_service.py.j2 @@ -6,9 +6,9 @@ from decimal import Decimal import pytest -from app.services.{{ entity.name | lower }} import {{ entity.name }}Service +from app.services.{{ entity.name | snake_case }} import {{ entity.name }}Service from faststack_core.exceptions.domain import NotFoundError -from tests.unit.fakes.{{ entity.name | lower }}_repository import Fake{{ entity.name }}Repository +from tests.unit.fakes.{{ entity.name | snake_case }}_repository import Fake{{ entity.name }}Repository @pytest.fixture @@ -24,7 +24,7 @@ def service(repo): async def test_create(service): entity = await service.create({ {%- for field in entity.fields if field.required and not field.references %} - "{{ field.name }}": {% if field.type == "string" or field.type == "text" %}"test_{{ field.name }}"{% elif field.type == "integer" %}1{% elif field.type == "float" %}1.0{% elif field.type == "boolean" %}True{% elif field.type == "uuid" %}uuid.uuid4(){% elif field.type == "datetime" %}datetime.now(UTC){% elif field.type == "enum" %}"{{ field.enum_values[0] }}"{% elif field.type == "decimal" %}Decimal("10.00"){% else %}"test_{{ field.name }}"{% endif %}, + "{{ field.name }}": {% if field.type == "string" or field.type == "text" %}"test_{{ field.name }}"{% elif field.type == "integer" %}1{% elif field.type == "float" %}1.0{% elif field.type == "boolean" %}True{% elif field.type == "uuid" %}uuid.uuid4(){% elif field.type == "datetime" %}datetime.now(UTC){% elif field.type == "enum" %}"{{ field.enum_values[0] }}"{% elif field.type == "decimal" %}Decimal("10.00"){% elif field.type == "json" or field.type == "jsonb" %}{"key": "value"}{% else %}"test_{{ field.name }}"{% endif %}, {%- endfor %} }) assert entity.id is not None @@ -36,7 +36,7 @@ async def test_create(service): async def test_get(service): entity = await service.create({ {%- for field in entity.fields if field.required and not field.references %} - "{{ field.name }}": {% if field.type == "string" or field.type == "text" %}"test_{{ field.name }}"{% elif field.type == "integer" %}1{% elif field.type == "float" %}1.0{% elif field.type == "boolean" %}True{% elif field.type == "uuid" %}uuid.uuid4(){% elif field.type == "datetime" %}datetime.now(UTC){% elif field.type == "enum" %}"{{ field.enum_values[0] }}"{% elif field.type == "decimal" %}Decimal("10.00"){% else %}"test_{{ field.name }}"{% endif %}, + "{{ field.name }}": {% if field.type == "string" or field.type == "text" %}"test_{{ field.name }}"{% elif field.type == "integer" %}1{% elif field.type == "float" %}1.0{% elif field.type == "boolean" %}True{% elif field.type == "uuid" %}uuid.uuid4(){% elif field.type == "datetime" %}datetime.now(UTC){% elif field.type == "enum" %}"{{ field.enum_values[0] }}"{% elif field.type == "decimal" %}Decimal("10.00"){% elif field.type == "json" or field.type == "jsonb" %}{"key": "value"}{% else %}"test_{{ field.name }}"{% endif %}, {%- endfor %} }) found = await service.get(entity.id) @@ -51,7 +51,7 @@ async def test_get_not_found(service): async def test_list(service): await service.create({ {%- for field in entity.fields if field.required and not field.references %} - "{{ field.name }}": {% if field.type == "string" or field.type == "text" %}"test_{{ field.name }}"{% elif field.type == "integer" %}1{% elif field.type == "float" %}1.0{% elif field.type == "boolean" %}True{% elif field.type == "uuid" %}uuid.uuid4(){% elif field.type == "datetime" %}datetime.now(UTC){% elif field.type == "enum" %}"{{ field.enum_values[0] }}"{% elif field.type == "decimal" %}Decimal("10.00"){% else %}"test_{{ field.name }}"{% endif %}, + "{{ field.name }}": {% if field.type == "string" or field.type == "text" %}"test_{{ field.name }}"{% elif field.type == "integer" %}1{% elif field.type == "float" %}1.0{% elif field.type == "boolean" %}True{% elif field.type == "uuid" %}uuid.uuid4(){% elif field.type == "datetime" %}datetime.now(UTC){% elif field.type == "enum" %}"{{ field.enum_values[0] }}"{% elif field.type == "decimal" %}Decimal("10.00"){% elif field.type == "json" or field.type == "jsonb" %}{"key": "value"}{% else %}"test_{{ field.name }}"{% endif %}, {%- endfor %} }) result = await service.list() @@ -61,7 +61,7 @@ async def test_list(service): async def test_delete(service): entity = await service.create({ {%- for field in entity.fields if field.required and not field.references %} - "{{ field.name }}": {% if field.type == "string" or field.type == "text" %}"test_{{ field.name }}"{% elif field.type == "integer" %}1{% elif field.type == "float" %}1.0{% elif field.type == "boolean" %}True{% elif field.type == "uuid" %}uuid.uuid4(){% elif field.type == "datetime" %}datetime.now(UTC){% elif field.type == "enum" %}"{{ field.enum_values[0] }}"{% elif field.type == "decimal" %}Decimal("10.00"){% else %}"test_{{ field.name }}"{% endif %}, + "{{ field.name }}": {% if field.type == "string" or field.type == "text" %}"test_{{ field.name }}"{% elif field.type == "integer" %}1{% elif field.type == "float" %}1.0{% elif field.type == "boolean" %}True{% elif field.type == "uuid" %}uuid.uuid4(){% elif field.type == "datetime" %}datetime.now(UTC){% elif field.type == "enum" %}"{{ field.enum_values[0] }}"{% elif field.type == "decimal" %}Decimal("10.00"){% elif field.type == "json" or field.type == "jsonb" %}{"key": "value"}{% else %}"test_{{ field.name }}"{% endif %}, {%- endfor %} }) await service.delete(entity.id) diff --git a/tests/test_cli/test_add_entity.py b/tests/test_cli/test_add_entity.py new file mode 100644 index 0000000..a47941b --- /dev/null +++ b/tests/test_cli/test_add_entity.py @@ -0,0 +1,357 @@ +"""Tests for ``faststack add-entity`` CLI command.""" + +from __future__ import annotations + +import ast +from pathlib import Path + +import pytest +import yaml +from click.testing import CliRunner + +from cli import cli_group + + +@pytest.fixture +def runner(): + return CliRunner() + + +@pytest.fixture +def project_dir(runner: CliRunner, tmp_path: Path, monkeypatch) -> Path: + """Scaffold a FastStack project and chdir into it.""" + monkeypatch.chdir(tmp_path) + result = runner.invoke(cli_group, ["init", "test-project"], catch_exceptions=False) + assert result.exit_code == 0 + + project = tmp_path / "test-project" + monkeypatch.chdir(project) + return project + + +@pytest.fixture +def sample_entities_yaml(tmp_path: Path) -> Path: + """Write a minimal entities.yaml and return its path.""" + yaml_content = """\ +entities: + Product: + base: FullAuditedEntity + fields: + name: + type: string + required: true + price: + type: decimal + searchable: + - name +""" + yaml_file = tmp_path / "entities.yaml" + yaml_file.write_text(yaml_content) + return yaml_file + + +class TestAddEntityCreatesFiles: + """Test that ``faststack add-entity`` creates all 9 entity files.""" + + def test_creates_all_entity_files(self, runner: CliRunner, project_dir: Path) -> None: + result = runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required,price:decimal"], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert "Created entity 'Product'" in result.output + + expected_files = [ + "app/models/product.py", + "app/schemas/product.py", + "app/repositories/product.py", + "app/services/product.py", + "app/api/routes/product.py", + "tests/factories/product.py", + "tests/unit/fakes/product_repository.py", + "tests/unit/test_product_service.py", + "tests/integration/test_product_api.py", + ] + for f in expected_files: + assert (project_dir / f).is_file(), f"Missing file: {f}" + + def test_updates_project_config(self, runner: CliRunner, project_dir: Path) -> None: + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required,price:decimal"], + catch_exceptions=False, + ) + + config = yaml.safe_load((project_dir / ".project-config.yaml").read_text()) + assert "Product" in config["entities"] + assert "hash" in config["entities"]["Product"] + assert "model_path" in config["entities"]["Product"] + assert config["entities"]["Product"]["hash"] != "" + + def test_generated_model_is_valid_python(self, runner: CliRunner, project_dir: Path) -> None: + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required,price:decimal"], + catch_exceptions=False, + ) + + source = (project_dir / "app/models/product.py").read_text() + ast.parse(source) # raises SyntaxError if invalid + + def test_generated_schema_is_valid_python(self, runner: CliRunner, project_dir: Path) -> None: + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required,price:decimal"], + catch_exceptions=False, + ) + + source = (project_dir / "app/schemas/product.py").read_text() + ast.parse(source) + + def test_generated_service_is_valid_python(self, runner: CliRunner, project_dir: Path) -> None: + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required,price:decimal"], + catch_exceptions=False, + ) + + source = (project_dir / "app/services/product.py").read_text() + ast.parse(source) + + +class TestAddEntityFromYaml: + """Test ``faststack add-entity --from-yaml``.""" + + def test_from_yaml_creates_entity( + self, + runner: CliRunner, + project_dir: Path, + sample_entities_yaml: Path, + ) -> None: + result = runner.invoke( + cli_group, + ["add-entity", "Product", "--from-yaml", str(sample_entities_yaml)], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert (project_dir / "app/models/product.py").is_file() + + # Check model file references the entity name + model_content = (project_dir / "app/models/product.py").read_text() + assert "class Product" in model_content + + def test_from_yaml_entity_not_found( + self, + runner: CliRunner, + project_dir: Path, + sample_entities_yaml: Path, + ) -> None: + result = runner.invoke( + cli_group, + ["add-entity", "NonExistent", "--from-yaml", str(sample_entities_yaml)], + ) + + assert result.exit_code != 0 + assert "not found" in result.output + + +class TestAddEntityErrorCases: + """Test error handling in ``faststack add-entity``.""" + + def test_error_entity_already_exists(self, runner: CliRunner, project_dir: Path) -> None: + # First creation succeeds + result1 = runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required"], + catch_exceptions=False, + ) + assert result1.exit_code == 0 + + # Second creation without --update fails + result2 = runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required"], + ) + assert result2.exit_code != 0 + assert "already exists" in result2.output + + def test_update_flag_succeeds_when_entity_exists(self, runner: CliRunner, project_dir: Path) -> None: + # First creation + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required"], + catch_exceptions=False, + ) + + # Update with new fields + result = runner.invoke( + cli_group, + [ + "add-entity", + "Product", + "--fields", + "name:string:required,price:decimal", + "--update", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert "Updated entity 'Product'" in result.output + + # Model file should now contain price field + model_content = (project_dir / "app/models/product.py").read_text() + assert "price" in model_content + + def test_error_no_project_config(self, runner: CliRunner, tmp_path: Path, monkeypatch) -> None: + monkeypatch.chdir(tmp_path) + + result = runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required"], + ) + + assert result.exit_code != 0 + assert ".project-config.yaml" in result.output + + def test_error_no_fields_or_yaml(self, runner: CliRunner, project_dir: Path) -> None: + result = runner.invoke( + cli_group, + ["add-entity", "Product"], + ) + + assert result.exit_code != 0 + assert "Provide --fields or --from-yaml" in result.output + + +class TestAddEntityRegistration: + """Test that ``faststack add-entity`` registers the router in main.py and generates registry files.""" + + def test_registers_router_in_main_py(self, runner: CliRunner, project_dir: Path) -> None: + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required,price:decimal"], + catch_exceptions=False, + ) + + main_content = (project_dir / "app/main.py").read_text() + assert "from app.api.routes.product import router as product_router" in main_content + assert "app.include_router(product_router" in main_content + + def test_no_duplicate_router_registration(self, runner: CliRunner, project_dir: Path) -> None: + # First creation + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required"], + catch_exceptions=False, + ) + + # Update with --update (should not duplicate) + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required,price:decimal", "--update"], + catch_exceptions=False, + ) + + main_content = (project_dir / "app/main.py").read_text() + assert main_content.count("product_router") == 2 # one import, one include_router + + def test_generates_dependencies_py(self, runner: CliRunner, project_dir: Path) -> None: + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required"], + catch_exceptions=False, + ) + + deps_path = project_dir / "app/api/dependencies.py" + assert deps_path.is_file() + content = deps_path.read_text() + assert "get_product_service" in content + assert "get_db_session" in content + ast.parse(content) + + def test_generates_integration_conftest(self, runner: CliRunner, project_dir: Path) -> None: + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required"], + catch_exceptions=False, + ) + + conftest_path = project_dir / "tests/integration/conftest.py" + assert conftest_path.is_file() + content = conftest_path.read_text() + assert "async def client" in content + assert "FakeProductRepository" in content + ast.parse(content) + + def test_multiple_entities_in_registry_files(self, runner: CliRunner, project_dir: Path) -> None: + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required"], + catch_exceptions=False, + ) + runner.invoke( + cli_group, + ["add-entity", "Order", "--fields", "total:decimal:required"], + catch_exceptions=False, + ) + + deps_content = (project_dir / "app/api/dependencies.py").read_text() + assert "get_product_service" in deps_content + assert "get_order_service" in deps_content + ast.parse(deps_content) + + conftest_content = (project_dir / "tests/integration/conftest.py").read_text() + assert "FakeProductRepository" in conftest_content + assert "FakeOrderRepository" in conftest_content + ast.parse(conftest_content) + + +class TestAddEntityFileContents: + """Test content of generated entity files.""" + + def test_model_contains_class_and_tablename(self, runner: CliRunner, project_dir: Path) -> None: + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required,price:decimal"], + catch_exceptions=False, + ) + + content = (project_dir / "app/models/product.py").read_text() + assert "class Product(" in content + assert '__tablename__ = "products"' in content + + def test_schema_contains_create_and_response(self, runner: CliRunner, project_dir: Path) -> None: + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required,price:decimal"], + catch_exceptions=False, + ) + + content = (project_dir / "app/schemas/product.py").read_text() + assert "class ProductCreate(" in content + assert "class ProductUpdate(" in content + assert "class ProductResponse(" in content + assert "class ProductDetailResponse(" in content + + def test_repository_references_entity(self, runner: CliRunner, project_dir: Path) -> None: + runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required"], + catch_exceptions=False, + ) + + content = (project_dir / "app/repositories/product.py").read_text() + assert "class ProductRepository(" in content + assert "from app.models.product import Product" in content + + def test_next_steps_message(self, runner: CliRunner, project_dir: Path) -> None: + result = runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required"], + catch_exceptions=False, + ) + + assert "faststack migrate generate" in result.output diff --git a/tests/test_cli/test_generate.py b/tests/test_cli/test_generate.py new file mode 100644 index 0000000..731e318 --- /dev/null +++ b/tests/test_cli/test_generate.py @@ -0,0 +1,250 @@ +"""Tests for ``faststack generate`` CLI command.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +import yaml +from click.testing import CliRunner + +from cli import cli_group + + +@pytest.fixture +def runner(): + return CliRunner() + + +@pytest.fixture +def project_with_entity(runner: CliRunner, tmp_path: Path, monkeypatch) -> Path: + """Scaffold a project and add a Product entity, returning the project dir.""" + monkeypatch.chdir(tmp_path) + result = runner.invoke(cli_group, ["init", "test-project"], catch_exceptions=False) + assert result.exit_code == 0 + + project = tmp_path / "test-project" + monkeypatch.chdir(project) + + result = runner.invoke( + cli_group, + ["add-entity", "Product", "--fields", "name:string:required,price:decimal"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + return project + + +class TestGenerateRegenerates: + """Test that ``faststack generate`` regenerates derived files.""" + + def test_regenerates_schema(self, runner: CliRunner, project_with_entity: Path) -> None: + project = project_with_entity + schema_path = project / "app/schemas/product.py" + + # Record original content + original_content = schema_path.read_text() + assert "ProductCreate" in original_content + + # Modify the schema file to simulate drift + schema_path.write_text("# corrupted\n") + + # Regenerate + result = runner.invoke( + cli_group, + ["generate", "Product"], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert "Regenerated" in result.output + + # Schema should be restored + restored_content = schema_path.read_text() + assert "ProductCreate" in restored_content + + def test_skips_preserved_files(self, runner: CliRunner, project_with_entity: Path) -> None: + project = project_with_entity + + # Add a custom comment to the service file (PRESERVED) + service_path = project / "app/services/product.py" + original = service_path.read_text() + service_path.write_text("# MY CUSTOM CODE\n" + original) + + result = runner.invoke( + cli_group, + ["generate", "Product"], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert "Skipping" in result.output + assert "PRESERVED" in result.output + + # Custom comment should still be there + content = service_path.read_text() + assert "MY CUSTOM CODE" in content + + def test_updates_hash_in_config(self, runner: CliRunner, project_with_entity: Path) -> None: + project = project_with_entity + + # Read original hash + config = yaml.safe_load((project / ".project-config.yaml").read_text()) + original_hash = config["entities"]["Product"]["hash"] + + # Modify model to change hash + model_path = project / "app/models/product.py" + model_content = model_path.read_text() + model_path.write_text(model_content + "\n# modified\n") + + # Regenerate + runner.invoke( + cli_group, + ["generate", "Product"], + catch_exceptions=False, + ) + + # Hash should be updated + config = yaml.safe_load((project / ".project-config.yaml").read_text()) + new_hash = config["entities"]["Product"]["hash"] + assert new_hash != original_hash + + def test_regenerates_all_regeneratable_files(self, runner: CliRunner, project_with_entity: Path) -> None: + project = project_with_entity + + regeneratable = [ + "app/schemas/product.py", + "tests/unit/fakes/product_repository.py", + "tests/factories/product.py", + ] + + # Corrupt all regeneratable files + for f in regeneratable: + (project / f).write_text("# corrupted\n") + + # Regenerate + result = runner.invoke( + cli_group, + ["generate", "Product"], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + + # All should be restored + for f in regeneratable: + content = (project / f).read_text() + assert content != "# corrupted\n", f"File {f} was not regenerated" + + +class TestGenerateRegistryFiles: + """Test that ``faststack generate`` regenerates registry files.""" + + def test_regenerates_dependencies_py(self, runner: CliRunner, project_with_entity: Path) -> None: + project = project_with_entity + deps_path = project / "app/api/dependencies.py" + + # Corrupt dependencies.py + deps_path.write_text("# corrupted\n") + + # Regenerate + result = runner.invoke( + cli_group, + ["generate", "Product"], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + content = deps_path.read_text() + assert "get_product_service" in content + assert "get_db_session" in content + + def test_regenerates_integration_conftest(self, runner: CliRunner, project_with_entity: Path) -> None: + project = project_with_entity + conftest_path = project / "tests/integration/conftest.py" + + # Corrupt conftest + conftest_path.write_text("# corrupted\n") + + # Regenerate + result = runner.invoke( + cli_group, + ["generate", "Product"], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + content = conftest_path.read_text() + assert "async def client" in content + assert "FakeProductRepository" in content + + +class TestGenerateAllFlag: + """Test ``faststack generate --all``.""" + + def test_generate_all_regenerates_all_entities(self, runner: CliRunner, project_with_entity: Path) -> None: + project = project_with_entity + + # Add a second entity + runner.invoke( + cli_group, + ["add-entity", "Order", "--fields", "total:decimal:required"], + catch_exceptions=False, + ) + + # Corrupt both schema files + (project / "app/schemas/product.py").write_text("# corrupted\n") + (project / "app/schemas/order.py").write_text("# corrupted\n") + + # Regenerate all + result = runner.invoke( + cli_group, + ["generate", "--all"], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + + # Both should be restored + assert "ProductCreate" in (project / "app/schemas/product.py").read_text() + assert "OrderCreate" in (project / "app/schemas/order.py").read_text() + + +class TestGenerateErrorCases: + """Test error handling in ``faststack generate``.""" + + def test_error_no_project_config(self, runner: CliRunner, tmp_path: Path, monkeypatch) -> None: + monkeypatch.chdir(tmp_path) + + result = runner.invoke( + cli_group, + ["generate", "Product"], + ) + + assert result.exit_code != 0 + assert ".project-config.yaml" in result.output + + def test_error_no_entity_name_or_all(self, runner: CliRunner, project_with_entity: Path) -> None: + result = runner.invoke( + cli_group, + ["generate"], + ) + + assert result.exit_code != 0 + assert "Provide entity name or --all" in result.output + + def test_skips_when_model_file_missing(self, runner: CliRunner, project_with_entity: Path) -> None: + project = project_with_entity + + # Remove model file + (project / "app/models/product.py").unlink() + + result = runner.invoke( + cli_group, + ["generate", "Product"], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + assert "Skipping Product" in result.output + assert "model file not found" in result.output diff --git a/tests/test_cli/test_init.py b/tests/test_cli/test_init.py index 44f4ff5..1506edc 100644 --- a/tests/test_cli/test_init.py +++ b/tests/test_cli/test_init.py @@ -197,9 +197,7 @@ def test_project_config_yaml_exists(self, runner: CliRunner, tmp_path: Path, mon content = (tmp_path / "cool-app" / ".project-config.yaml").read_text() assert "project_name: cool-app" in content - def test_docker_compose_contains_project_name( - self, runner: CliRunner, tmp_path: Path, monkeypatch - ): + def test_docker_compose_contains_project_name(self, runner: CliRunner, tmp_path: Path, monkeypatch): monkeypatch.chdir(tmp_path) runner.invoke(cli_group, ["init", "cool-app"], catch_exceptions=False) @@ -251,9 +249,7 @@ def test_next_steps_include_migrate( class TestErrorCases: """Test error handling in ``faststack init``.""" - def test_error_when_directory_already_exists( - self, runner: CliRunner, tmp_path: Path, monkeypatch - ): + def test_error_when_directory_already_exists(self, runner: CliRunner, tmp_path: Path, monkeypatch): monkeypatch.chdir(tmp_path) (tmp_path / "existing-project").mkdir() diff --git a/tests/test_cli/test_list.py b/tests/test_cli/test_list.py new file mode 100644 index 0000000..05e5c4e --- /dev/null +++ b/tests/test_cli/test_list.py @@ -0,0 +1,106 @@ +"""Tests for faststack list command.""" + +import hashlib + +import yaml +from click.testing import CliRunner + +from cli import cli_group + + +def _file_hash(path) -> str: + return hashlib.sha256(path.read_bytes()).hexdigest() + + +class TestListCommand: + """Tests for the 'list' command.""" + + def test_no_config_file(self, tmp_path, monkeypatch) -> None: + """Should fail when .project-config.yaml is missing.""" + monkeypatch.chdir(tmp_path) + runner = CliRunner() + result = runner.invoke(cli_group, ["list"]) + assert result.exit_code != 0 + assert "No .project-config.yaml found" in result.output + + def test_no_entities(self, tmp_path, monkeypatch) -> None: + """Should inform user when no entities are registered.""" + monkeypatch.chdir(tmp_path) + config_path = tmp_path / ".project-config.yaml" + config_path.write_text(yaml.dump({"entities": {}})) + + runner = CliRunner() + result = runner.invoke(cli_group, ["list"]) + assert result.exit_code == 0 + assert "No entities registered" in result.output + + def test_entity_up_to_date(self, tmp_path, monkeypatch) -> None: + """Entity whose model hash matches stored hash shows 'up to date'.""" + monkeypatch.chdir(tmp_path) + + # Create model file + model_dir = tmp_path / "app" / "models" + model_dir.mkdir(parents=True) + model_file = model_dir / "product.py" + model_file.write_text("class Product: pass\n") + + current_hash = _file_hash(model_file) + + config = { + "entities": { + "Product": { + "model_path": "app/models/product.py", + "hash": current_hash, + } + } + } + (tmp_path / ".project-config.yaml").write_text(yaml.dump(config)) + + runner = CliRunner() + result = runner.invoke(cli_group, ["list"]) + assert result.exit_code == 0 + assert "up to date" in result.output + + def test_entity_schemas_outdated(self, tmp_path, monkeypatch) -> None: + """Entity whose model changed since last generation shows 'schemas outdated'.""" + monkeypatch.chdir(tmp_path) + + model_dir = tmp_path / "app" / "models" + model_dir.mkdir(parents=True) + model_file = model_dir / "product.py" + model_file.write_text("class Product: pass\n") + + # Store a stale hash + config = { + "entities": { + "Product": { + "model_path": "app/models/product.py", + "hash": "stale_hash_value", + } + } + } + (tmp_path / ".project-config.yaml").write_text(yaml.dump(config)) + + runner = CliRunner() + result = runner.invoke(cli_group, ["list"]) + assert result.exit_code == 0 + assert "schemas outdated" in result.output + + def test_entity_missing_model(self, tmp_path, monkeypatch) -> None: + """Entity whose model file does not exist shows 'MISSING'.""" + monkeypatch.chdir(tmp_path) + + config = { + "entities": { + "Product": { + "model_path": "app/models/product.py", + "hash": "some_hash", + } + } + } + (tmp_path / ".project-config.yaml").write_text(yaml.dump(config)) + + runner = CliRunner() + result = runner.invoke(cli_group, ["list"]) + assert result.exit_code == 0 + assert "MISSING" in result.output diff --git a/tests/test_cli/test_migrate.py b/tests/test_cli/test_migrate.py new file mode 100644 index 0000000..29bc5ca --- /dev/null +++ b/tests/test_cli/test_migrate.py @@ -0,0 +1,56 @@ +"""Tests for faststack migrate command.""" + +from click.testing import CliRunner + +from cli import cli_group +from cli.cmd_migrate import migrate + + +class TestMigrateGroup: + """Tests for the migrate command group.""" + + def test_migrate_group_exists(self) -> None: + """The migrate group is registered on the CLI.""" + command = cli_group.commands.get("migrate") + assert command is not None + + def test_migrate_has_three_subcommands(self) -> None: + """The migrate group exposes generate, upgrade, and downgrade.""" + subcommands = set(migrate.commands.keys()) + assert subcommands == {"generate", "upgrade", "downgrade"} + + +class TestMigrateGenerate: + """Tests for 'migrate generate'.""" + + def test_generate_no_alembic_ini(self, tmp_path, monkeypatch) -> None: + """Should fail when alembic.ini is missing.""" + monkeypatch.chdir(tmp_path) + runner = CliRunner() + result = runner.invoke(cli_group, ["migrate", "generate", "add users table"]) + assert result.exit_code != 0 + assert "No alembic.ini found" in result.output + + +class TestMigrateUpgrade: + """Tests for 'migrate upgrade'.""" + + def test_upgrade_no_alembic_ini(self, tmp_path, monkeypatch) -> None: + """Should fail when alembic.ini is missing.""" + monkeypatch.chdir(tmp_path) + runner = CliRunner() + result = runner.invoke(cli_group, ["migrate", "upgrade"]) + assert result.exit_code != 0 + assert "No alembic.ini found" in result.output + + +class TestMigrateDowngrade: + """Tests for 'migrate downgrade'.""" + + def test_downgrade_no_alembic_ini(self, tmp_path, monkeypatch) -> None: + """Should fail when alembic.ini is missing.""" + monkeypatch.chdir(tmp_path) + runner = CliRunner() + result = runner.invoke(cli_group, ["migrate", "downgrade"]) + assert result.exit_code != 0 + assert "No alembic.ini found" in result.output diff --git a/tests/test_cli/test_yaml_parser.py b/tests/test_cli/test_yaml_parser.py index fac8f72..350d979 100644 --- a/tests/test_cli/test_yaml_parser.py +++ b/tests/test_cli/test_yaml_parser.py @@ -132,6 +132,7 @@ def test_category_has_self_referential(self, parsed_entities): assert rel.back_populates == "children" def test_user_has_no_relationships(self, parsed_entities): + """User has no FK fields, so no relationships (reverse side is user-added).""" user = parsed_entities[0] assert len(user.relationships) == 0 diff --git a/tests/test_e2e/test_smoke.py b/tests/test_e2e/test_smoke.py index 638d4b1..d8fe086 100644 --- a/tests/test_e2e/test_smoke.py +++ b/tests/test_e2e/test_smoke.py @@ -80,9 +80,7 @@ def test_init_succeeds(self, runner: CliRunner, tmp_path: Path, monkeypatch) -> assert result.exit_code == 0 assert "Created project" in result.output - def test_all_expected_directories_exist( - self, runner: CliRunner, tmp_path: Path, monkeypatch - ) -> None: + def test_all_expected_directories_exist(self, runner: CliRunner, tmp_path: Path, monkeypatch) -> None: monkeypatch.chdir(tmp_path) runner.invoke(cli_group, ["init", "blog"], catch_exceptions=False) project = tmp_path / "blog" @@ -121,9 +119,7 @@ def test_project_files_exist(self, runner: CliRunner, tmp_path: Path, monkeypatc for f in expected_files: assert (project / f).is_file(), f"Missing: {f}" - def test_all_generated_python_is_valid( - self, runner: CliRunner, tmp_path: Path, monkeypatch - ) -> None: + def test_all_generated_python_is_valid(self, runner: CliRunner, tmp_path: Path, monkeypatch) -> None: monkeypatch.chdir(tmp_path) runner.invoke(cli_group, ["init", "blog"], catch_exceptions=False) project = tmp_path / "blog" @@ -141,9 +137,7 @@ def test_all_generated_python_is_valid( rel = py_file.relative_to(project) pytest.fail(f"Invalid Python in {rel}: {e}") - def test_project_config_has_structure( - self, runner: CliRunner, tmp_path: Path, monkeypatch - ) -> None: + def test_project_config_has_structure(self, runner: CliRunner, tmp_path: Path, monkeypatch) -> None: monkeypatch.chdir(tmp_path) runner.invoke(cli_group, ["init", "blog"], catch_exceptions=False) diff --git a/tests/test_templates/test_project_templates.py b/tests/test_templates/test_project_templates.py index 00228fd..aef522b 100644 --- a/tests/test_templates/test_project_templates.py +++ b/tests/test_templates/test_project_templates.py @@ -14,7 +14,12 @@ @pytest.fixture def jinja_env(): - return Environment(loader=FileSystemLoader(str(TEMPLATE_DIR)), keep_trailing_newline=True) + from cli.cmd_add_entity import _camel_to_snake, _pluralize + + env = Environment(loader=FileSystemLoader(str(TEMPLATE_DIR)), keep_trailing_newline=True) + env.filters["snake_case"] = _camel_to_snake + env.filters["pluralize"] = _pluralize + return env @pytest.fixture diff --git a/tests/test_templates/test_simple_mode.py b/tests/test_templates/test_simple_mode.py index ab9bfad..05fde8f 100644 --- a/tests/test_templates/test_simple_mode.py +++ b/tests/test_templates/test_simple_mode.py @@ -115,10 +115,15 @@ @pytest.fixture def jinja_env(): - return Environment( + from cli.cmd_add_entity import _camel_to_snake, _pluralize + + env = Environment( loader=FileSystemLoader(str(TEMPLATE_DIR)), keep_trailing_newline=True, ) + env.filters["snake_case"] = _camel_to_snake + env.filters["pluralize"] = _pluralize + return env def _render(jinja_env, template_name: str, entity: EntityDefinition) -> str: @@ -144,8 +149,7 @@ def test_template_renders_valid_python(jinja_env, template_name, entity): ast.parse(output) except SyntaxError as e: pytest.fail( - f"{template_name} for {entity.name} produced invalid Python:\n" - f" {e}\n\nGenerated code:\n{output}" + f"{template_name} for {entity.name} produced invalid Python:\n" f" {e}\n\nGenerated code:\n{output}" ) @@ -299,3 +303,99 @@ def test_integration_test_has_endpoints(jinja_env): output = _render(jinja_env, "test_integration.py.j2", USER_ENTITY) assert "test_create" in output assert "test_list" in output + + +def test_integration_test_has_no_todo(jinja_env): + output = _render(jinja_env, "test_integration.py.j2", USER_ENTITY) + assert "TODO" not in output + + +# --------------------------------------------------------------------------- +# Router template — wired endpoints +# --------------------------------------------------------------------------- + + +def test_router_has_service_depends(jinja_env): + output = _render(jinja_env, "router.py.j2", USER_ENTITY) + assert "Depends(get_user_service)" in output + assert "service: UserService" in output + + +def test_router_has_no_stubs(jinja_env): + output = _render(jinja_env, "router.py.j2", USER_ENTITY) + # Endpoints should have real bodies, not `...` + lines = output.split("\n") + for i, line in enumerate(lines): + stripped = line.strip() + if stripped == "...": + pytest.fail(f"Found stub '...' at line {i + 1} in rendered router") + + +def test_router_imports_dependencies(jinja_env): + output = _render(jinja_env, "router.py.j2", USER_ENTITY) + assert "from app.api.dependencies import get_user_service" in output + assert "from app.services.user import UserService" in output + + +# --------------------------------------------------------------------------- +# Multi-entity templates: dependencies.py.j2, conftest_integration.py.j2 +# --------------------------------------------------------------------------- + + +def _render_multi_entity(jinja_env, template_name: str, entities: list) -> str: + """Render a multi-entity template with a list of entity dicts.""" + template = jinja_env.get_template(template_name) + return template.render(entities=entities) + + +MULTI_ENTITY_CONTEXT = [ + {"name": "User", "snake_name": "user"}, + {"name": "Post", "snake_name": "post"}, + {"name": "Category", "snake_name": "category"}, +] + + +def test_dependencies_template_valid_python(jinja_env): + output = _render_multi_entity(jinja_env, "dependencies.py.j2", MULTI_ENTITY_CONTEXT) + try: + ast.parse(output) + except SyntaxError as e: + pytest.fail(f"dependencies.py.j2 produced invalid Python: {e}\n\n{output}") + + +def test_dependencies_has_all_entity_providers(jinja_env): + output = _render_multi_entity(jinja_env, "dependencies.py.j2", MULTI_ENTITY_CONTEXT) + assert "get_user_service" in output + assert "get_post_service" in output + assert "get_category_service" in output + assert "get_db_session" in output + + +def test_dependencies_imports_repos_and_services(jinja_env): + output = _render_multi_entity(jinja_env, "dependencies.py.j2", MULTI_ENTITY_CONTEXT) + assert "from app.repositories.user import UserRepository" in output + assert "from app.services.user import UserService" in output + + +def test_conftest_integration_valid_python(jinja_env): + output = _render_multi_entity(jinja_env, "conftest_integration.py.j2", MULTI_ENTITY_CONTEXT) + try: + ast.parse(output) + except SyntaxError as e: + pytest.fail(f"conftest_integration.py.j2 produced invalid Python: {e}\n\n{output}") + + +def test_conftest_integration_has_client_fixture(jinja_env): + output = _render_multi_entity(jinja_env, "conftest_integration.py.j2", MULTI_ENTITY_CONTEXT) + assert "async def client" in output + assert "FakeUserRepository" in output + assert "FakePostRepository" in output + assert "FakeCategoryRepository" in output + + +def test_conftest_integration_overrides_all_services(jinja_env): + output = _render_multi_entity(jinja_env, "conftest_integration.py.j2", MULTI_ENTITY_CONTEXT) + assert "get_user_service" in output + assert "get_post_service" in output + assert "get_category_service" in output + assert "dependency_overrides.clear()" in output