diff --git a/.flake8 b/.flake8
index 715f6c8..4ee027a 100644
--- a/.flake8
+++ b/.flake8
@@ -1,4 +1,4 @@
[flake8]
-extend-ignore = E203, E501
+extend-ignore = E203, E501, E731
exclude = .github,__pycache__,docs/source/conf.py,old,build,dist,venv,
max-complexity = 10
\ No newline at end of file
diff --git a/README.md b/README.md
index afe672d..faff506 100644
--- a/README.md
+++ b/README.md
@@ -10,29 +10,56 @@ Ststeroids was designed to supercharge the development of complex multi-page app
The main concepts of Ststeroids are:
- Reusable Components
-- Logics Flows
+- Logic Flows
- Declarative Layouts
-- Routers
+- A Store
In addition, StSteroids provides an easy way to load style sheets into your Streamlit application and offers a wrapper around `st.session_state` to separate states into stores. This wrapper is also used within components to store the component and its state in the session state.
#### Components
Components are at the core of StSteroids. A component represents a specific visual element of your application along with its rendering logic. Examples include a login dialog or a person details component.
-Each component contains only the logic necessary for its functionality, such as basic input validation or button interactions that trigger a [flow](#flows). Components and their state are stored in the ComponentStore.
+Each component contains only the logic necessary for its functionality, such as basic input validation or button interactions that trigger a [flow](#flows). Components and their attributes are stored in the ComponentStore which is a special instance of a Store.
+
+Component concepts:
+
+- components never decide on domain logic, so there is no domain error handling for example
+- a component contains interaction elements, unless the component is still meaningful and usable without the interaction element → split the element out
+- components don't navigate pages
+- should have methods for updating its attributes (explicit state changes so that the flow doesn't need to all the attributes)
+
+For example, a metric component that can be reused for multiple purposes.
#### Flows
-Flows contain the business logic of the application, handling its core functionality and, in some cases, linking components to backend services.
-For example, a login flow might call an authentication service, validate the response, extract the access token, and store it in the session store.
+Flows encapsulate the application’s interaction and orchestration logic.
+They handle user-initiated actions, coordinate state changes across components, and invoke domain services to perform business operations.
+
+Flow concepts:
+
+- flows act as handlers for user and system interactions (e.g. button clicks, page entry, form submission)
+- flows orchestrate application behavior, calling services and updating component state
+- flows coordinate multiple components and stores as part of a single interaction
+- flows determine navigation and control flow between layouts or pages
+- flows own error handling and recovery logic for the interactions they manage
+- flows may contain light business rules, but core domain logic should live in services
+
+For example, a login flow might call an authentication service, evaluate the result, store relevant session data, and update one or more components to reflect the outcome.
+
+When multiple flows share orchestration resources—such as access to the same components, stores, or helper logic—it is recommended to introduce a shared base flow to centralize this responsibility and avoid duplication.
#### Layouts
-Layouts bring components together to create a multi-page application. Each layout functions as a page, rendering one or more components and defining their arrangement.
-For example, a layout might define multiple Streamlit columns and place components within them.
+Layouts bring components together to create a multi-page application. Each layout functions as a page, rendering one or more components and defining their arrangement and rendering.
+
+Layout concepts:
-#### Routers
-Routers enable multi-page applications by defining routes and linking them to layouts. These routes are internal, meaning they cannot be accessed directly via a URL (due to current Streamlit limitations) and should be triggered through user interactions.
+- layouts are responsible for initializing and wiring components
+- layouts are responsible for the visual arrangement of components
+- layouts are responsible for conditional rendering based on application state or context (for example, authorisation)
+- layout shouldn't handle domain errors
+
+For example, a layout might define multiple Streamlit columns and place components within them.
### Installation
@@ -40,9 +67,9 @@ Routers enable multi-page applications by defining routes and linking them to la
pip install ststeroids
```
-### Usage
+### Getting started
-StSteroids allows you to define components, layouts, and flows, then connect everything in `app.py` using a router. See the `example` folder in this repository.
+StSteroids allows you to define components, layouts, and flows, then connect everything in a `main.py` by creating a StSteroids app. See the `example` folder in this repository.
To run the example app, execute the following commands from the project root:
@@ -59,72 +86,130 @@ pip install -r requirements-dev.txt
pytest
```
+#### The basics
+
+To create an application using StSteroids, follow these steps:
+
+1. Create components – Define the individual UI elements of your application, such as dialogs, tables, or metrics, using the Component base class.
+2. Create flows – Implement the business or orchestration logic that interacts with components, services, and session state.
+3. Create layouts – Group and initialize components, arrange them visually. Layouts define how your pages are structured.
+4. Register event handlers.
+5. Create the StSteroids app – Instantiate the app, register routes for each layout, and define a default route if needed.
+
+This sequence ensures a clear separation of concerns and keeps your app modular, testable, and easy to maintain.
+
+#### StSteroids App and routes
+
+Example of creating a StSteroids application.
+
+```python
+app = StSteroids()
+
+# Register a layout as a route
+app.route("dashboard").to(DashboardLayout).register()
+
+# Set a default route (optional)
+app.default_route(DashboardLayout)
+
+# Run the app (optionally specify an entry route)
+app.run()
+```
+
+##### API reference
+
+`app.route(name).to(layout).register()`
+
+Registers a route that maps the route name to a layout class.
+The layout is rendered when the route becomes active.
+
+The full route builder API is as follows.
+
+`app.route(name).to(layout).when(callable).on_enter(flow).register()`
+
+- `when` sets up a condition by specifying a callable. The route is only registered if the callable evaluates to True
+- `on_enter` registers a flow for the on enter event. The flow is dispatched once when the route becomes active, before the layout is rendered. Note! that an `on_enter` event flow should not switch page as it will break the routing concept
+
+`app.on_app_run_once(flow)`
+
+Registers an `on_app_run_once` event handler flow. You can use this to have an initial flow that runs once at the start of the application. Note! that an `on_app_run_once`.
+
+`app.default_route(layout)`
+
+Sets a default layout to display if no route is specified.
+
+`app.run(entry_route)`
+
+Starts the app and navigates to entry_route if provided; otherwise, uses the default route.
+
+
+
#### Components
-Defining a new component.
+Example of defining a new component.
+
```python
from ststeroids import Component
-class YourXComponent(Component):
- def __init__(self, component_id: str):
- super().__init__(component_id) # This line is important to initialize the base class.
+class MetricComponent(Component):
+ def __init__(
+ self,
+ header: str,
+ ):
+ self.header = header
+ self.value = 0
- def render(self):
- # Your render logic
+ def display(self):
+ st.metric(self.header, self.value)
+
+ def set_value(self, value: int):
+ self.value = value
```
-Additionaly an initial state (dict) can be passed as a second paramters while initing the base class.
+The header attribute and set_value method are specific to this example. They illustrate how components can have instance-bound attributes and provide an explicit API for updating their state. Components should own their state and expose such methods rather than allowing external code to directly mutate their attributes.
##### API Reference
`id`
-Holds the component id
+Holds the component id, is automatically added from the base component.
-`state`
+`visible`
-Manages the component state. Although technically an instance of the StSteroids `State` class, it functions like a dictionary, allowing properties to be accessed using getters and setters.
+Controls if the component is visible. Defaults to `True` Control using the `show` and `hide` methods.
-When outside the component:
-```python
-myvalue = yourcomponent.state.yourproperty
-yourcomponent.state.yourproperty = "yourvalue"
-```
-
-When inside the component:
-```python
-myvalue = self.state.yourproperty
-self.state.yourproperty = "yourvalue"
-```
-
-`render()`
+`show()`
-This method needs to be implemented by the subclass. To call it in a layout, use `execute_render()`
+Sets the `visible` property of the component to `True`
-`execute_render(render_as: Literal["normal", "dialog", "fragment"]="normal", options:dict={})`
+`hide()`
-Executes the render method of an instance of a component. Additionaly provide the `render_as` parameter with the `options` parameter.
+Sets the `visible` property of the component to `False`
-Dialog options:
+`create(cls, component_id: str, *args, **kwargs)`
+`create(cls, component_id: str, title:str ,*args, **kwargs)` (Dialog only)
+`create(cls, component_id: str, refresh_interval:str ,*args, **kwargs)` (Fragment only)
-**title**
+Creates a new component instance with the given `component_id` and stores it in the `ComponentStore`.
+This is typically called in layouts to initialize components. Additional arguments are passed to the component's constructor.
-The dialog title.
-Fragment options:
+`get(cls, component_id: str)`
-**refresh_flow**
+Retrieves an existing component instance from the `ComponentStore` by its `component_id`.
+`create()` must have been called first; otherwise, an error will be raised.
+This is typically used in flows that need to interact with a component after it has been initialized.
-A refresh flow that should be called post rendering the component, you can use this to refresh the applications state for the next view.
+`display()`
-**refresh_interval**
+This method needs to be implemented by the subclass. To call it in a layout, use `render()`
-The refresh interval, for example: `2s`.
+`render()`
+Executes the display method of an instance of a component.
`register_element(element_name: str)`
-Registers a Streamlit element onto the component by generating component bound key. Use this function when setting a key for an element within the component.
+Registers a Streamlit element onto the component by generating component-bound key. Use this method when setting a key for an element within the component. For more information about using keys, please refer to the official Streamlit documentation.
Usage:
@@ -150,99 +235,157 @@ Usage:
Sets the value of a registered element.
+`on(event_name: str, callback: Flow)`
+
+Registers a flow as an event handler for the given event name on the component. The flow will be dispatched when the event is triggered.
+
+`on_refresh(self, callback: Flow)`
+
+Registers a flow as an event handler for the refresh event of a Fragment (Fragment only)
+
+`trigger(event_name: str)`
+
+Triggers the specified event and dispatches the flow registered for it.
+
#### Flows
-Defining a new flow.
+Example of defining a new flow:
+
```python
from ststeroids import Flow
-class YourXFlow(Flow):
- def __init__(self):
- super().__init__() # This line is important to initialize the base class.
+class AddDocumentFlow(Flow):
+ def __init__(self, session_store: Store):
+ self.session_store = session_store
+
+ @property
+ def cp_document_table(self):
+ return TableComponent.get(ComponentIDs.documents)
+
+ def run(self, ctx: FlowContext) -> None:
+ # Flow logic for adding a document
+```
- def run(self):
- # Your flow logic
+Now imagine your application supports multiple document-related actions (for example: add, delete, or update documents).
+These actions often share the same orchestration context, such as access to the session store or a document table component.
+
+To avoid duplicating this setup in every action flow, it is recommended to introduce a base flow that provides shared orchestration resources.
+
+First, rename the flow above to a base flow:
+
+```python
+class DocumentActionBaseFlow(Flow):
+ def __init__(self, session_store: Store):
+ self.session_store = session_store
+
+ @property
+ def cp_document_table(self):
+ return TableComponent.get(ComponentIDs.documents)
+```
+
+Then, create a dedicated flow for each document action:
+
+```python
+class AddDocumentFlow(DocumentActionBaseFlow):
+ def run(self, ctx: FlowContext):
+ # Flow logic for adding a documentd
```
+In this example, AddDocumentFlow represents a single user action, while DocumentActionBaseFlow provides shared orchestration context.
+This keeps flows focused, avoids duplication, and clearly separates reusable setup from action-specific logic.
+
##### API Reference
-`run()`
+`run(ctx: FlowContext)`
-This method needs to be implemented by the subclass. To call it, use `execute_run()`
+This method must be implemented by subclasses. It contains the logic that should run when the flow is triggered.
-`execute_run()`
+To execute a flow, register it with an event handler. When the event occurs, the framework calls `run()`.
-Executes the run method implemented in the subclass.
+The `FlowContext` object provides information about the event that triggered the flow and utilities for scheduling follow-up actions.
+
+**Attributes**
-`component_store`
+`identifier`
+ Identifier of the event that triggered the flow.
-The component store containing the instances of components and their states.
+`type`
+ Type of event that triggered the flow.
-Use `component_store.get_component(component_id: str)` to retrieve an instance of a component.
+**Methods**
+
+`schedule(function_to_schedule, args=None, kwargs=None)`
+Schedules a function to run **after the next rerun**.
+This is typically used when component state must be updated before executing additional logic.
+
+`schedule_and_rerun(function_to_schedule, args=None, kwargs=None)`
+Schedules a function to run **after the next rerun** and immediately triggers a rerun. Use schedule in combination with user interactions to avoid the `calling st.rerun() within a callback is a no-op` warning.
```python
-from components import YourXComponent
+class ApproveLabelsFlow(Flow):
+
+ def run(self, ctx: FlowContext):
+ # mark selected rows as approved
+ self.table.update_rows(approved=True)
-your_x_component_instance: YourXComponent = self.component_store.get_component("your_x_component_id")
+ # perform backend call after rerun
+ ctx.schedule_and_rerun(self.store_labels)
+
+ def store_labels(self):
+ self.backend.store(self.table.selected_rows)
```
-Notice the `: YourXComponent` this tells your IDE what type of component you are getting and helps the autocomplete.
+`dispatch()`
+
+Executes the run method implemented in the subclass.
#### Layouts
-Defining a new layout.
+Example of defining a new layout.
+
```python
from ststeroids import Layout
-class YourXLayout(Layout):
+class ManageDataLayout(Layout):
def __init__(self):
+ self.data_viewer = DataViewerComponent.create(
+ ComponentIDs.data_viewer, "Movies"
+ )
def render(self):
- # Your layout render logic
+ self.data_viewer.render()
```
-An instance of a layout can be rendered by calling either the `render()` function or by calling the instance of the layout.
-
-Calling the instance
-```python
-my_x_layout = YourXLayout()
-my_x_layout()
-```
-##### API Reference
+Layouts are responsible for creating and rendering components.
+They must not contain business logic, checks, or flow control.
- `render()`
+Component creation should always happen in the layout constructor using
+`Component.create(...)`.
-This method needs to be implemented by the subclass. To call it in the application, use `execute_render()`
+##### Rendering a layout
-`execute_render()`
-
-Executes the render method of an instance of a layout.
-
-#### Routers
-Intializing a router
+A layout instance can be rendered by calling its `render()` method.
```python
-from ststeroids import Router
-router = Router()
+my_x_layout = YourXLayout.create()
+my_x_layout.render()
```
-##### API Reference
-
-`run`
-
-Runs the currently active route
-
-`route(route_name: str)`
+Calling `render()` on a layout is restricted to the router.
-Changes the currently active to the given route name
+This ensures:
+- a single, predictable render entry point
+- consistent routing behavior
+- a clear separation of concerns
-`register_routes(routes: dict[str, Layout])`
+Layouts describe what is rendered.
+The router decides when it is rendered.
-Registers a dictionary of routes where keys are route names and values are layouts.
+##### API Reference
-`get_current_route`
+ `render()`
-Returns the currently active route. Useful for creating a navigation breadcrumbs.
+This method needs to be implemented by the subclass.
#### Store
@@ -251,7 +394,7 @@ A wrapper around `st.session_state` to separate states into stores.
Usage:
```python
-session_store = Store("yourstore")
+session_store = Store.create("yourstore")
```
##### API reference
@@ -287,9 +430,35 @@ app_style.apply_style()
### Release notes
+1.0.0
+
+Partially rewritten the framework to reduce its footprint and make object creation more intuitive. Editor and debugger support has been improved, making development smoother and more productive. The router system has also been greatly enhanced, now supporting conditional routes directly within the framework, giving you more control over navigation and layout rendering.
+
+**Note** this version is considered to be a breaking change. Make sure to adapt your code base so that it works with this new version. A small migration guide:
+
+- Update the `__init__` of your components to match the new style
+ - component_id is no longer needed
+ - the `super().__init__()` no longer needs to be called
+- accessing `.state` is no longer possible, you can directly access the attributes on a component instead
+- Rename `render` in your components to `display`
+- Remove any `show` and `hide` methods from your components as well as the `visible` property. They are now controlled by the framework
+- Use a component’s `on` method to register flows as event handlers for specific component events, typically when setting up the app. Call `trigger` on the component to emit events and dispatch the registered flows.
+- In Flows use `YourComponent.get(component_id)` instead of `self.component_store.get_component(component_id)
+- Remove `Router` from your Flows, use `st.switch_page` instead if you didn't already
+- Move the initialization of the sidebar to layouts instead of the `main` of the app
+- When rendering a component call `render` instead of `execute_render`
+- When creating instances of StSteroids classes use `create` instead of calling `ClassName()`. This does not apply to the `Style` class
+- The flow's `run` method can no longer take parameters. Access a components state instead to aquire the require parameters
+- When calling a flow, use `dispatch()` instead of `run()`
+- If you previously implemented your own logic for using the `router` class. Please consider using the new Steroids app style, by doing so you can also utilize
+ - The on app run once event, for initial set up
+ - The router on enter event, for initial route setup. For example refresh data before rendering the page
+- There are two new component types, `Fragment` and `Dialog`, they replace the `render_as` parameter. Please update your components and render calls accordingly
+- Added the `schedule` functionality that allows for scheduling a run after the first rerun.
+
0.1.17
-- Improved execute_render function by adding an error handler
+- Improved execute_render method by adding an error handler
- Default refresh_interval for a fragment is now `None` to avoid unintended refreshes
0.1.16
@@ -307,8 +476,8 @@ app_style.apply_style()
0.1.13
-- Adds a function to set a registered element's value.
-- Adds a function for rendering a component as a fragment.
+- Adds a method to set a registered element's value.
+- Adds a method for rendering a component as a fragment.
0.1.12
@@ -326,12 +495,4 @@ Beta releases
### Todo
-- Improve IDE/autocomplete for state managed variables
-- Ambition: directly link element values to component states
-- Describe component store
-- Layout and flow class singletons
-
-## Ideas
-
-- Something for RBAC
-- Something for running longtime requests
\ No newline at end of file
+* the default route can only be one of the registered routes
\ No newline at end of file
diff --git a/example/src/app.py b/example/src/app.py
index 208ed99..f0424c0 100644
--- a/example/src/app.py
+++ b/example/src/app.py
@@ -2,4 +2,4 @@
app = MainApp()
-app.run()
+app.app.run()
diff --git a/example/src/assets/style.css b/example/src/assets/style.css
index bdf13fa..6b79dbf 100644
--- a/example/src/assets/style.css
+++ b/example/src/assets/style.css
@@ -1,3 +1,47 @@
html, body, * {
font-style: italic
-}
\ No newline at end of file
+}
+
+.status-icon {
+ width: 20px;
+ height: 20px;
+ border-radius: 50%;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ font-size: 12px;
+ font-weight: bold;
+ color: white;
+}
+
+/* RUNNING (spinner) */
+.running {
+ border: 3px solid rgba(0,0,0,0.1);
+ border-left-color: #4CAF50;
+ animation: spin 1s linear infinite;
+}
+
+/* STATIC STATES */
+.success {
+ background-color: #4CAF50;
+}
+
+.error {
+ background-color: #F44336;
+}
+
+.info {
+ background-color: #2196F3;
+}
+
+/* Spin animation */
+@keyframes spin {
+ 0% { transform: rotate(0deg);}
+ 100% { transform: rotate(360deg);}
+}
+
+.status-row {
+ display: flex;
+ align-items: center;
+ gap: 10px;
+}
diff --git a/example/src/components/__init__.py b/example/src/components/__init__.py
index 28c4079..f7923f7 100644
--- a/example/src/components/__init__.py
+++ b/example/src/components/__init__.py
@@ -2,6 +2,16 @@
from .sidebar import SidebarComponent
from .data_viewer import DataViewerComponent
from .metric import MetricComponent
+from .toast import ToastComponent
+from .button import ButtonComponent
+from .status import StatusComponent
-
-__all__ = [LoginDialogComponent, SidebarComponent, DataViewerComponent, MetricComponent]
+__all__ = [
+ LoginDialogComponent,
+ SidebarComponent,
+ DataViewerComponent,
+ MetricComponent,
+ ToastComponent,
+ ButtonComponent,
+ StatusComponent
+]
diff --git a/example/src/components/button.py b/example/src/components/button.py
new file mode 100644
index 0000000..f385a30
--- /dev/null
+++ b/example/src/components/button.py
@@ -0,0 +1,22 @@
+import streamlit as st
+from ststeroids import Component, Flow
+
+
+class ButtonComponent(Component):
+
+ EVENT_ClICK = "click"
+
+ def __init__(self, button_text: str):
+ self.button_text = button_text
+
+ def _handle_click(self):
+ self.trigger(self.EVENT_ClICK)
+
+ def display(self):
+ st.button(self.button_text, on_click=self._handle_click)
+
+ def on_click(self, flow: Flow) -> None:
+ """
+ Register a flow to be executed when the user clicks the button.
+ """
+ self.on(self.EVENT_ClICK, flow)
diff --git a/example/src/components/data_viewer.py b/example/src/components/data_viewer.py
index b0f7cfe..06d122a 100644
--- a/example/src/components/data_viewer.py
+++ b/example/src/components/data_viewer.py
@@ -1,4 +1,3 @@
-import uuid
import streamlit as st
from ststeroids import Component
@@ -6,24 +5,23 @@
class DataViewerComponent(Component):
def __init__(
self,
- component_id: str,
header: str,
column_config: dict = {},
column_order: list = [],
):
- super().__init__(component_id, {"data": None, "dek": uuid.uuid4()})
self.header = header
self.column_config = column_config
self.column_order = column_order
+ self.data = None
- def render(self):
+ def display(self):
st.subheader(self.header)
st.dataframe(
- self.state.data,
+ self.data,
hide_index=True,
column_config=self.column_config,
column_order=self.column_order,
)
def set_data(self, data):
- self.state.data = data
+ self.data = data
diff --git a/example/src/components/login_dialog.py b/example/src/components/login_dialog.py
index 6f48ab0..4a28825 100644
--- a/example/src/components/login_dialog.py
+++ b/example/src/components/login_dialog.py
@@ -1,37 +1,31 @@
import streamlit as st
-from ststeroids import Component, Flow
+from ststeroids import Dialog, Flow
-class LoginDialogComponent(Component):
+class LoginDialogComponent(Dialog):
+
+ EVENT_LOGIN = "login"
+
def __init__(
self,
- component_id: str,
- login_flow: Flow,
- login_success_flow: Flow,
- header: str = "Enter username/password",
):
- super().__init__(component_id, {"visible": False})
- self.header = header
- self.login_flow = login_flow
- self.login_success_flow = login_success_flow
+ self.error_message = None
+ self.hide()
- def render(self):
- if self.state.visible:
- username = st.text_input("Username")
- password = st.text_input("Password", type="password")
- if st.button("Login", use_container_width=True):
- login_succes = self.login_flow.execute_run(username, password)
- if login_succes:
- self.login_success_flow.execute_run()
- else:
- st.error("Login failed, please check your username and password.")
+ def display(self):
+ self.username = st.text_input("Username")
+ self.password = st.text_input("Password", type="password")
+ if st.button("Login", use_container_width=True):
+ self.trigger(self.EVENT_LOGIN)
+ if self.error_message:
+ st.error(self.error_message)
+ self.error_message = None
- def show(self):
- if self.state.visible is False:
- self.state.visible = True
- st.rerun()
+ def on_login(self, flow: Flow) -> None:
+ """
+ Register a flow to be executed when the user clicks the login button.
+ """
+ self.on(self.EVENT_LOGIN, flow)
- def hide(self):
- if self.state.visible is True:
- self.state.visible = False
- st.rerun()
+ def set_error(self, message: str):
+ self.error_message = message
diff --git a/example/src/components/metric.py b/example/src/components/metric.py
index 87b992f..0dcf5b0 100644
--- a/example/src/components/metric.py
+++ b/example/src/components/metric.py
@@ -1,18 +1,17 @@
import streamlit as st
-from ststeroids import Component
+from ststeroids import Fragment
-class MetricComponent(Component):
+class MetricComponent(Fragment):
def __init__(
self,
- component_id: str,
header: str,
):
- super().__init__(component_id, {"value": None})
self.header = header
+ self.value = 0
- def render(self):
- st.metric(self.header, self.state.value)
+ def display(self):
+ st.metric(self.header, self.value)
def set_value(self, value: int):
- self.state.value = value
+ self.value = value
diff --git a/example/src/components/sidebar.py b/example/src/components/sidebar.py
index 51b0291..0a186a9 100644
--- a/example/src/components/sidebar.py
+++ b/example/src/components/sidebar.py
@@ -1,15 +1,14 @@
import streamlit as st
-from ststeroids import Component, Router
+from ststeroids import Component
class SidebarComponent(Component):
- def __init__(self, component_id: str, router: Router):
- super().__init__(component_id)
- self.router = router
-
- def render(self):
+ def display(self):
with st.sidebar:
- st.page_link("pages/dashboard.py", icon=":material/search:", label="Dashboard")
- st.page_link("pages/manage.py", icon=":material/bar_chart:", label="Manage data")
-
+ st.page_link(
+ "pages/dashboard.py", icon=":material/search:", label="Dashboard"
+ )
+ st.page_link(
+ "pages/manage.py", icon=":material/bar_chart:", label="Manage data"
+ )
diff --git a/example/src/components/status.py b/example/src/components/status.py
new file mode 100644
index 0000000..2e32703
--- /dev/null
+++ b/example/src/components/status.py
@@ -0,0 +1,32 @@
+from typing import Literal
+
+import streamlit as st
+from ststeroids import Component
+
+
+class StatusComponent(Component):
+
+ def __init__(
+ self,
+ message: str = None,
+ type: Literal["running", "info", "error", "success"] = "info",
+ ):
+ self.message = message
+ self.type = type
+
+ def display(self):
+ if self.message:
+ st.markdown(f"
{self._status_icon(self.type)}
{self.message}
", unsafe_allow_html=True)
+
+ def set_status(self, message: str, type: Literal["running", "info", "error", "success"] = "info"):
+ self.message = message
+ self.type = type
+
+ def clear(self):
+ self.message = None
+ self.type = None
+
+ def _status_icon(self, state: str):
+ icons = {"success": "✓", "error": "✕", "info": "i", "running": ""}
+
+ return f"{icons.get(state, '')}
"
diff --git a/example/src/components/toast.py b/example/src/components/toast.py
new file mode 100644
index 0000000..1910d85
--- /dev/null
+++ b/example/src/components/toast.py
@@ -0,0 +1,18 @@
+import streamlit as st
+from ststeroids import Component
+
+
+class ToastComponent(Component):
+ def __init__(
+ self,
+ ):
+ self.message = None
+ self.hide()
+
+ def display(self):
+ st.toast(self.message)
+ self.hide()
+
+ def set_message(self, message: str):
+ self.message = message
+ self.show()
diff --git a/example/src/flows/__init__.py b/example/src/flows/__init__.py
index e2b5e51..4b57a44 100644
--- a/example/src/flows/__init__.py
+++ b/example/src/flows/__init__.py
@@ -1,5 +1,7 @@
from .login import LoginFlow
-from .login_succes import LoginSuccessFlow
from .refresh import RefreshFlow
+from .app_setup import SetupFlow
+from .logout import LogoutFlow
+from .long_running import LongRunningFlow
-__all__ = [LoginFlow, LoginSuccessFlow, RefreshFlow]
+__all__ = [LoginFlow, RefreshFlow, SetupFlow, LogoutFlow, LongRunningFlow]
diff --git a/example/src/flows/app_setup.py b/example/src/flows/app_setup.py
new file mode 100644
index 0000000..adc440e
--- /dev/null
+++ b/example/src/flows/app_setup.py
@@ -0,0 +1,6 @@
+from ststeroids import Flow, FlowContext
+
+
+class SetupFlow(Flow):
+ def run(self, _ctx: FlowContext):
+ print("I'm a flow setting up the app per user")
diff --git a/example/src/flows/login.py b/example/src/flows/login.py
index 85af513..599a693 100644
--- a/example/src/flows/login.py
+++ b/example/src/flows/login.py
@@ -1,17 +1,62 @@
from service import MockBackendService
-from ststeroids import Flow, Store
+from ststeroids import Flow, Store, FlowContext
+from components import (
+ LoginDialogComponent,
+ DataViewerComponent,
+ MetricComponent,
+ ToastComponent,
+)
+from shared import ComponentIDs
+import streamlit as st
class LoginFlow(Flow):
def __init__(self, session_store: Store, backend_service: MockBackendService):
- super().__init__()
self.session_store = session_store
self.backend_service = backend_service
- def run(self, username: str, password: str):
- response = self.backend_service.authenticate(username, password)
+ @property
+ def cp_login_dialog(self):
+ return LoginDialogComponent.get(ComponentIDs.dialog_login)
+
+ @property
+ def cp_data_viewer(self):
+ return DataViewerComponent.get(ComponentIDs.data_viewer)
+
+ @property
+ def cp_total_movies(self):
+ return MetricComponent.get(ComponentIDs.total_movies)
+
+ @property
+ def cp_toast(self):
+ return ToastComponent.get(ComponentIDs.toast)
+
+ def run(self, _ctx: FlowContext):
+ response = self.backend_service.authenticate(
+ self.cp_login_dialog.username, self.cp_login_dialog.password
+ )
if response.ok:
- token_data = response.json()
- self.session_store.set_property("access_token", token_data["access_token"])
- return True
- return False
+ self._login_success(response)
+ else:
+ self._login_failed()
+
+ def _login_success(self, response):
+ token_data = response.json()
+ self.session_store.set_property("access_token", token_data["access_token"])
+ self.cp_login_dialog.hide()
+ response = self.backend_service.get_movies()
+ # enable the line below for example of an error scenario
+ # response.ok = False
+ if response.ok:
+ data = response.json()
+ self.session_store.set_property(
+ "data", data
+ ) # Store the data in the session_store for later use in more complex applications
+ self.cp_total_movies.set_value(len(data))
+ self.cp_data_viewer.set_data(data)
+ else:
+ self.cp_toast.set_message("error")
+ st.switch_page("pages/dashboard.py")
+
+ def _login_failed(self):
+ self.cp_login_dialog.set_error("Login failed, check your username and password")
diff --git a/example/src/flows/login_succes.py b/example/src/flows/login_succes.py
deleted file mode 100644
index c192528..0000000
--- a/example/src/flows/login_succes.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from service import MockBackendService
-from shared import ComponentIDs
-from ststeroids import Flow, Router, Store
-from components import LoginDialogComponent, DataViewerComponent, MetricComponent
-
-
-class LoginSuccessFlow(Flow):
- def __init__(
- self, router: Router, session_store: Store, backend_service: MockBackendService
- ):
- super().__init__()
- self.session_store = session_store
- self.backend_service = backend_service
- self.router = router
-
- def run(self):
- cp_login_dialog: LoginDialogComponent = self.component_store.get_component(
- ComponentIDs.dialog_login
- )
- cp_data_viewer: DataViewerComponent = self.component_store.get_component(
- ComponentIDs.data_viewer
- )
- cp_total_movies: MetricComponent = self.component_store.get_component(
- ComponentIDs.total_movies
- )
- response = self.backend_service.get_movies()
- if response.ok:
- data = response.json()
- self.session_store.set_property(
- "data", data
- ) # Store the data in the session_store for later use in more complex applications
- cp_total_movies.set_value(len(data))
- cp_data_viewer.set_data(data)
- self.router.route("dashboard")
- cp_login_dialog.hide()
diff --git a/example/src/flows/logout.py b/example/src/flows/logout.py
new file mode 100644
index 0000000..4e8964a
--- /dev/null
+++ b/example/src/flows/logout.py
@@ -0,0 +1,9 @@
+from ststeroids import Flow, Store, FlowContext
+
+
+class LogoutFlow(Flow):
+ def __init__(self, session_store: Store):
+ self.session_store = session_store
+
+ def run(self, _ctx: FlowContext):
+ self.session_store.del_property("access_token")
diff --git a/example/src/flows/long_running.py b/example/src/flows/long_running.py
new file mode 100644
index 0000000..f409ab1
--- /dev/null
+++ b/example/src/flows/long_running.py
@@ -0,0 +1,18 @@
+from ststeroids import Flow, FlowContext
+import time
+from components import StatusComponent
+from shared import ComponentIDs
+
+class LongRunningFlow(Flow):
+ @property
+ def cp_spinner(self):
+ return StatusComponent.get(ComponentIDs.spinner)
+
+ def run(self, ctx: FlowContext):
+ self.cp_spinner.set_status("Long running call", "running")
+ ctx.schedule(self._long_running_method)
+
+
+ def _long_running_method(self):
+ time.sleep(5)
+ self.cp_spinner.clear()
diff --git a/example/src/flows/refresh.py b/example/src/flows/refresh.py
index 4c93e95..23f3e32 100644
--- a/example/src/flows/refresh.py
+++ b/example/src/flows/refresh.py
@@ -1,6 +1,6 @@
from service import MockBackendService
from shared import ComponentIDs
-from ststeroids import Flow, Store
+from ststeroids import Flow, Store, FlowContext
from components import MetricComponent
@@ -10,23 +10,20 @@ def __init__(
session_store: Store,
backend_service: MockBackendService,
):
- super().__init__()
self.session_store = session_store
self.backend_service = backend_service
- def run(self):
- cp_avg_rating: MetricComponent = self.component_store.get_component(
- ComponentIDs.avg_rating
- )
+ def run(self, _ctx: FlowContext):
+ cp_avg_rating = MetricComponent.get(ComponentIDs.avg_rating)
response = self.backend_service.get_movies()
if response.ok:
data = response.json()
self.session_store.set_property(
"data", data
) # Store the data in the session_store for later use in more complex applications
- avg_rating = self.avg_rating(data, "rating")
+ avg_rating = self._avg_rating(data, "rating")
cp_avg_rating.set_value(avg_rating)
- def avg_rating(self, data, key):
+ def _avg_rating(self, data, key):
values = [d[key] for d in data if key in d and isinstance(d[key], (int, float))]
return round(sum(values) / len(values)) if values else 0
diff --git a/example/src/layouts/dashboard.py b/example/src/layouts/dashboard.py
index 0991354..4cff1c2 100644
--- a/example/src/layouts/dashboard.py
+++ b/example/src/layouts/dashboard.py
@@ -1,5 +1,11 @@
import streamlit as st
-from components import MetricComponent
+from components import (
+ MetricComponent,
+ SidebarComponent,
+ ToastComponent,
+ ButtonComponent,
+ StatusComponent
+)
from shared import ComponentIDs
from ststeroids import Layout, Flow
@@ -7,15 +13,27 @@
class DashboardLayout(Layout):
def __init__(self, refresh_flow: Flow):
self.refresh_flow = refresh_flow
- self.total_movies = MetricComponent(ComponentIDs.total_movies, "Total movies")
- self.avg_rating = MetricComponent(ComponentIDs.avg_rating, "Avg. Rating")
+ self.sidebar = SidebarComponent.create(ComponentIDs.sidebar)
+ self.status = StatusComponent.create(ComponentIDs.spinner)
+ self.toast = ToastComponent.create(ComponentIDs.toast)
+ self.total_movies = MetricComponent.create(
+ ComponentIDs.total_movies, None, "Total movies"
+ )
+ self.avg_rating = MetricComponent.create(
+ ComponentIDs.avg_rating, "2s", "Avg. Rating"
+ )
+ self.logout_button = ButtonComponent.create(ComponentIDs.logout, "Logout")
+ self.long_running_button = ButtonComponent.create(ComponentIDs.long_running, "Long running call")
def render(self):
+ self.sidebar.render()
+ self.toast.render()
left, right = st.columns([1, 1])
with left:
- self.total_movies.execute_render()
+ self.total_movies.render()
with right:
- self.avg_rating.execute_render(
- "fragment",
- {"refresh_flow": self.refresh_flow, "refresh_interval": "2s"},
- )
+ self.avg_rating.render()
+ self.status.render()
+ st.divider()
+ self.long_running_button.render()
+ self.logout_button.render()
diff --git a/example/src/layouts/login.py b/example/src/layouts/login.py
index 8f61ffa..02dea24 100644
--- a/example/src/layouts/login.py
+++ b/example/src/layouts/login.py
@@ -1,22 +1,25 @@
import streamlit as st
-from components import LoginDialogComponent
+from components import LoginDialogComponent, SidebarComponent
from shared import ComponentIDs
-from ststeroids import Flow, Layout
+from ststeroids import Layout, Store
class LoginLayout(Layout):
def __init__(
self,
+ session_store: Store,
login_header: str,
- login_flow: Flow,
- login_success_flow: Flow,
):
+ self.session_store = session_store
self.login_header = login_header
- self.login_dialog = LoginDialogComponent(
- ComponentIDs.dialog_login, login_flow, login_success_flow
+ self.sidebar = SidebarComponent.create(ComponentIDs.sidebar)
+ self.login_dialog = LoginDialogComponent.create(
+ ComponentIDs.dialog_login, self.login_header
)
def render(self):
- self.login_dialog.execute_render("dialog", {"title": self.login_header})
+ self.sidebar.render()
+ if not self.session_store.has_property("access_token"):
+ self.login_dialog.show()
+ self.login_dialog.render()
st.write("Not logged in. Please refresh or use the menu on the left.")
- self.login_dialog.show()
diff --git a/example/src/layouts/manage_data.py b/example/src/layouts/manage_data.py
index b62dbd5..06448c8 100644
--- a/example/src/layouts/manage_data.py
+++ b/example/src/layouts/manage_data.py
@@ -1,11 +1,15 @@
-from components import DataViewerComponent
+from components import DataViewerComponent, SidebarComponent
from shared import ComponentIDs
from ststeroids import Layout
class ManageDataLayout(Layout):
def __init__(self):
- self.data_viewer = DataViewerComponent(ComponentIDs.data_viewer, "Movies")
+ self.sidebar = SidebarComponent.create(ComponentIDs.sidebar)
+ self.data_viewer = DataViewerComponent.create(
+ ComponentIDs.data_viewer, "Movies"
+ )
def render(self):
+ self.sidebar.render()
self.data_viewer.render()
diff --git a/example/src/main.py b/example/src/main.py
index 7e6efa7..d3edcd7 100644
--- a/example/src/main.py
+++ b/example/src/main.py
@@ -1,51 +1,43 @@
-from collections import defaultdict
import streamlit as st
-from components import SidebarComponent
-from flows import LoginFlow, LoginSuccessFlow, RefreshFlow
+from flows import LoginFlow, RefreshFlow, SetupFlow, LogoutFlow, LongRunningFlow
from layouts import LoginLayout, DashboardLayout, ManageDataLayout
from service import MockBackendService
-from ststeroids import Router, Store, Style
+from ststeroids import Store, Style, StSteroids
+
class MainApp:
def __init__(self):
- self.session_store = Store("store")
- self.router = Router("login")
+ self.session_store = Store.create("store")
self.backend_service = MockBackendService("./example/test_data.json")
- self.login_flow = LoginFlow(self.session_store, self.backend_service)
- self.login_success_flow = LoginSuccessFlow(self.router, self.session_store, self.backend_service)
- self.refresh_flow = RefreshFlow(self.session_store, self.backend_service)
+ self.setup_flow = SetupFlow.create()
+ self.login_flow = LoginFlow.create(self.session_store, self.backend_service)
+ self.logout_flow = LogoutFlow.create(self.session_store)
+ self.refresh_flow = RefreshFlow.create(self.session_store, self.backend_service)
+ self.long_running_flow = LongRunningFlow.create()
st.set_page_config(page_title="StSteroids Example app", layout="wide")
app_style = Style("./example/src/assets/style.css")
app_style.apply_style()
- self.login_layout = LoginLayout("App login", self.login_flow, self.login_success_flow)
- self.dashboard_layout = DashboardLayout(self.refresh_flow)
- self.manage_data_layout = ManageDataLayout()
-
- self.sidebar = SidebarComponent("sidebar", self.router)
-
- def run(self, entry_route:str = None):
- self.sidebar.render()
-
- def get_routes():
- routes = defaultdict(lambda: self.login_layout)
- routes["login"] = self.login_layout
-
- if self.session_store.has_property("access_token"):
- routes.update(
- {
- "dashboard": self.dashboard_layout,
- "manage_data": self.manage_data_layout,
- },
- )
-
- return routes
-
- self.router.register_routes(get_routes())
- if entry_route:
- self.router.route(entry_route)
- self.router.run()
+ self.login_layout = LoginLayout.create(self.session_store, "App login")
+ self.dashboard_layout = DashboardLayout.create(self.refresh_flow)
+ self.manage_data_layout = ManageDataLayout.create()
+
+ # register event handlers
+ self.login_layout.login_dialog.on_login(self.login_flow)
+ self.dashboard_layout.long_running_button.on_click(self.long_running_flow)
+ self.dashboard_layout.logout_button.on_click(self.logout_flow)
+ self.dashboard_layout.avg_rating.on_refresh(self.refresh_flow)
+
+ self.app = StSteroids()
+
+ self.app.on_app_run_once(self.setup_flow)
+
+ self.app.default_route(self.login_layout)
+
+ self.app.route("login").to(self.login_layout).register()
+ self.app.route("dashboard").to(self.dashboard_layout).when(lambda: self.session_store.has_property("access_token")).register()
+ self.app.route("manage_data").to(self.manage_data_layout).when(lambda: self.session_store.has_property("access_token")).register()
diff --git a/example/src/pages/dashboard.py b/example/src/pages/dashboard.py
index 6d68026..ac7a751 100644
--- a/example/src/pages/dashboard.py
+++ b/example/src/pages/dashboard.py
@@ -2,4 +2,4 @@
app = MainApp()
-app.run("dashboard")
+app.app.run("dashboard")
diff --git a/example/src/pages/manage.py b/example/src/pages/manage.py
index 7ddadba..858c941 100644
--- a/example/src/pages/manage.py
+++ b/example/src/pages/manage.py
@@ -2,4 +2,4 @@
app = MainApp()
-app.run("manage_data")
+app.app.run("manage_data")
diff --git a/example/src/shared.py b/example/src/shared.py
index 3340653..7e202fe 100644
--- a/example/src/shared.py
+++ b/example/src/shared.py
@@ -4,3 +4,7 @@ class ComponentIDs:
data_viewer = "data_view"
total_movies = "total_movies"
avg_rating = "avg_rating"
+ toast = "toast"
+ logout = "logout"
+ spinner = "spinner"
+ long_running = "long_running"
diff --git a/pyproject.toml b/pyproject.toml
index 502c3e5..f8051f3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
[project]
name = "ststeroids"
-version = "0.1.17"
+version = "1.0.0"
description = "A framework supercharging Streamlit for building advanced multi-page applications"
readme = "README.md"
authors = [{ name = "ponsoc"}]
-license = { file = "LICENSE" }
+license = "MIT"
requires-python = ">=3.11.0"
dependencies = [
"streamlit>=1.41.0"
diff --git a/src/ststeroids/__init__.py b/src/ststeroids/__init__.py
index 10e7a93..edc8193 100644
--- a/src/ststeroids/__init__.py
+++ b/src/ststeroids/__init__.py
@@ -3,6 +3,19 @@
from .style import Style
from .store import Store
from .layout import Layout
-from .router import Router
+from .main import StSteroids
+from .flow_context import FlowContext
+from .fragment import Fragment
+from .dialog import Dialog
-__all__ = ["Component", "Layout", "Flow", "Style", "Store", "Router"]
+__all__ = [
+ "Component",
+ "Layout",
+ "Flow",
+ "Style",
+ "Store",
+ "StSteroids",
+ "FlowContext",
+ "Fragment",
+ "Dialog",
+]
diff --git a/src/ststeroids/component.py b/src/ststeroids/component.py
index dda12d5..2a4ce23 100644
--- a/src/ststeroids/component.py
+++ b/src/ststeroids/component.py
@@ -1,216 +1,139 @@
-from typing import Any, Literal
+from typing import Any
+from abc import ABC, abstractmethod
import streamlit as st
from .store import ComponentStore
from .flow import Flow
-from functools import wraps
+from .flow_context import FlowContext
# pylint: disable=too-few-public-methods
-class Component:
+class Component(ABC):
"""
- Base class for a component that interacts with the state and the store.
+ Base class for a component that interacts with the the store.
Attributes:
id (str): The unique identifier for the component.
- state (State): The state associated with the component.
+ visible (bool) Controls if the component is visible or not.
"""
- def __new__(cls, *args, **kwargs):
- """Creates an new instance of the component or returns it from the session."""
- component_id = kwargs.get("component_id") or (args[0] if args else None)
- if component_id is None:
- raise KeyError("component_id is required")
+ id: str
+ _events: dict[str, Flow]
- cls.__store = ComponentStore()
- component_instance_exists = cls.__store.has_property(component_id)
- if component_instance_exists:
- return cls.__store.get_component(component_id)
- return super().__new__(cls)
+ @classmethod
+ def create(cls, component_id: str, *args, **kwargs):
+ """
+ Create a new component instance or return it from the store.
+
+ :param component_id: A unique identifier for the instance of the component
- def __init_subclass__(cls, **kwargs):
- """Wrap subclass __init__ so it only runs once."""
- super().__init_subclass__(**kwargs)
- orig_init = cls.__init__
+ """
+ cls._store = ComponentStore.create("components")
- @wraps(orig_init)
- def wrapped_init(self, *args, **kwargs):
- if getattr(self, "_sub_initialized", False):
- return
- orig_init(self, *args, **kwargs)
- self._sub_initialized = True
+ if cls._store.has_property(component_id):
+ return cls._store.get_component(component_id)
+ try:
+ instance = cls(*args, **kwargs)
+ instance.id = component_id
+ if not hasattr(instance, "visible"):
+ instance.visible = True
+ instance._events = {}
- cls.__init__ = wrapped_init
+ cls._store.init_component(instance)
+ return instance
+ except TypeError as e:
+ raise TypeError(
+ f"{str(e)}. This usually happens when you are trying to get a component without creating it first."
+ )
- def __init__(self, component_id: str, initial_state: dict = None):
+ @classmethod
+ def get(cls, component_id: str):
"""
- Initializes the component with a unique ID and initial state.
+ Alias for create() — note that create has to be called first.
- :param component_id: The unique identifier for the component.
- :param initial_state: Initial state for the component. Defaults to an empty dictionary.
+ :param component_id: The unique identifier for the instance of the component to return.
"""
- self.id = component_id
- self.state = State(
- self.id, self.__store, initial_state if initial_state else {}
- )
- self.__store.init_component(self)
- def register_element(self, element_name: str):
+ return cls.create(component_id)
+
+ def register_element(self, element_name: str) -> str:
"""
Generates a unique key for an element based on the instance ID.
- Args:
- element_name (str): The name of the element to register.
+ param: element_name: The name of the element to register.
- Returns:
- str: A unique key for the element.
+ return: A unique key for the element.
"""
key = f"{self.id}_{element_name}"
return key
- def get_element(self, element_name: str):
+ def get_element(self, element_name: str) -> Any:
"""
Retrieves the value of a registered element from the session state.
- Args:
- element_name (str): The name of the element to retrieve.
-
- Returns:
- Any: The value of the element if it exists in the session state, otherwise None.
+ param: element_name: The name of the element to retrieve.
+ return: The value of the element if it exists in the session state, otherwise None.
"""
key = f"{self.id}_{element_name}"
if key not in st.session_state:
return None
return st.session_state[key]
- def set_element(self, element_name: str, element_value):
+ def set_element(self, element_name: str, element_value) -> None:
"""
Sets the value of a registered element in the session state.
- Args:
- element_name (str): The name of the element to set.
- element_value (Any): The value to assign to the element.
-
- Returns:
- None
+ param: element_name: The name of the element to set.
+ param: element_value: The value to assign to the element.
+ return: None
"""
key = f"{self.id}_{element_name}"
st.session_state[key] = element_value
- def _render_dialog(self, title: str):
+ def on(self, event_name: str, callback: Flow) -> None:
"""
- Internal method for rendering the component as a dialog.
-
- This wraps the component's core render logic in a Streamlit dialog with the given title.
+ Register a Flow callback for a named event on this component.
- :param title: The title to display at the top of the dialog.
+ :param event_name: The unique name of the event to bind the callback to.
+ Should ideally be a class-level constant to enable autocomplete.
+ :param callback: The Flow instance to execute when this event is triggered.
+ :return: None
"""
+ self._events[event_name] = callback
- @st.dialog(title)
- def _render():
- self.render()
-
- _render()
-
- def _render_fragment(self, refresh_interval: str = None, refresh_flow: Flow = None):
+ def trigger(self, event_name: str) -> None:
"""
- Internal method for rendering the component as a fragment.
+ Trigger a previously registered event callback.
- This sets up a Streamlit fragment that automatically re-runs at the given interval.
- It internally calls the __render_fragment method.
-
- This method is not meant to be overridden. Subclasses should implement the render()
- method to define the rendering behavior.
-
- :param refresh_interval: The interval at which the fragment should refresh (e.g., "5s").
- :param refresh_flow: Optional flow object to pass into the rendering logic.
+ :param event_name: The name of the event to trigger.
+ :return: None
"""
+ callback = self._events.get(event_name, None)
+ if callback:
+ callback.dispatch(FlowContext("component", self.id))
- @st.fragment(run_every=refresh_interval)
- def _render():
- self.__render_fragment(refresh_flow)
-
- _render()
-
- def __render_fragment(self, refresh_flow: Flow = None):
- self.render()
- if refresh_flow:
- refresh_flow.execute_run()
-
- def execute_render(
+ def render(
self,
- render_as: Literal["normal", "dialog", "fragment"] = "normal",
- options: dict = {},
- ):
- """
- Executes the render method implemented in the subclasses, additionaly providing extra configuration based on the `render_as` parameter
- """
- match render_as:
- case "normal":
- return self.render()
- case "dialog":
- return self._render_dialog(**options)
- case "fragment":
- return self._render_fragment(**options)
- raise ValueError(f"Unexpected render_as value: {render_as}")
-
- def render(self) -> None:
+ ) -> None:
"""
- Placeholder method for rendering the component.
-
- This method should be implemented by subclasses to define how the component is rendered.
-
- :raises NotImplementedError: If called directly without being implemented in a subclass.
+ Executes the render method implemented in the subclasses.
"""
- raise NotImplementedError("Subclasses should implement this method.")
+ if not self.visible:
+ return
+ self.display()
-class State:
- """
- Manages the state of a component, storing and retrieving properties
- through the associated store.
+ def show(self):
+ self.visible = True
- Attributes:
- __id (str): The unique identifier for the component.
- __store (ComponentStore): The store instance that holds the component's state.
- """
-
- def __init__(self, component_id: str, store: ComponentStore, initial_state: dict):
- """
- Initializes the state for a component, setting up the store and component ID.
+ def hide(self):
+ self.visible = False
- :param component_id: The unique identifier for the component.
- :param store: The store instance where the state is stored.
- :param initial_state: Initial state data for the component.
+ @abstractmethod
+ def display(self) -> None:
"""
- super().__setattr__(
- "_State__id", component_id
- ) # Directly set private attributes
- super().__setattr__("_State__store", store) # Avoid recursion
- store.init_component_state(component_id, initial_state)
+ Abstract method for displaying the component.
- def __getattr__(self, name) -> Any:
- """
- Retrieves a property of the component from the store.
-
- :param name: The name of the property to retrieve.
- :return: The value of the property from the store.
-
- :raises AttributeError: If the requested property is not found.
- """
- if not name.startswith("__"):
- return self.__store.get_property(self.__id, name)
-
- def __setattr__(self, name, value):
- """
- Sets a property of the component in the store.
-
- :param name: The name of the property to set.
- :param value: The value to set for the property.
-
- This method avoids recursion for special attributes and handles normal properties.
+ This method should be implemented by subclasses to define how the component is rendered.
"""
- if not name.startswith("__"):
- self.__store.set_property(self.__id, name, value)
- else:
- super().__setattr__(name, value) # Avoid recursion for special attributes
+ pass
diff --git a/src/ststeroids/dialog.py b/src/ststeroids/dialog.py
new file mode 100644
index 0000000..e6fecab
--- /dev/null
+++ b/src/ststeroids/dialog.py
@@ -0,0 +1,45 @@
+import streamlit as st
+from .component import Component
+
+
+class Dialog(Component):
+ """
+ Base class for dialog components.
+
+ Dialog components wrap their content inside a Streamlit dialog and provide
+ a dedicated method for rendering as a dialog.
+
+ Attributes:
+ title (str): The title of the dialog.
+ """
+
+ title: str
+
+ @classmethod
+ def create(cls, component_id: str, title: str = "title", *args, **kwargs):
+ """
+ Create a new Dialog instance or return it from the store.
+
+ :param component_id: Unique identifier for this dialog component.
+ :param title: Dialog title.
+ """
+ instance = super().create(component_id, *args, **kwargs)
+ if not hasattr(instance, "title"):
+ instance.title = title
+ return instance
+
+ def render(self) -> None:
+ """
+ Renders the component inside a Streamlit dialog context.
+ Calls the `display` method to render the contents.
+
+ :return: None
+ """
+ if not self.visible:
+ return
+
+ @st.dialog(self.title)
+ def _dialog():
+ self.display()
+
+ _dialog()
diff --git a/src/ststeroids/flow.py b/src/ststeroids/flow.py
index 29d3e05..5716900 100644
--- a/src/ststeroids/flow.py
+++ b/src/ststeroids/flow.py
@@ -1,33 +1,41 @@
-from .store import ComponentStore
+from abc import ABC, abstractmethod
+from .flow_context import FlowContext
# pylint: disable=too-few-public-methods
-class Flow:
+class Flow(ABC):
"""
- Base class for a flow that can interact with the component store
+ Base class for a flow
"""
- def __init__(self):
+ @classmethod
+ def create(cls, *args, **kwargs):
"""
- Initializes the Flow class and creates a ComponentStore instance.
+ Creates a new flow instance.
"""
- self.component_store = ComponentStore()
+ return cls(*args, **kwargs)
- def execute_run(self, *args, **kwargs):
+ def dispatch(self, ctx: FlowContext) -> None:
"""
- Executes the run method implemented in the subclasses.
+ Dispatches the flow execution.
+
+ This method triggers the flow and forwards the context of the
+ source that caused the execution.
+
+ :param ctx: The `context` provides contextual information about what triggered the flow.
+ :return: None
"""
- return self.run(*args, **kwargs)
+ self.run(ctx)
- def run(self, *args, **kwargs):
+ @abstractmethod
+ def run(self, ctx: FlowContext) -> None:
"""
Executes the flow logic.
- Each derived class should implement its own `run` method.
+ This method must be implemented by subclasses and contains the
+ orchestration and business logic for the flow.
- :param args: Positional arguments for the run method.
- :param kwargs: Keyword arguments for the run method.
+ :param ctx: The `context` provides contextual information about what triggered the flow. Can be useful when you want to reuse a flow for different instances of the same component.
:return: None
- :raises NotImplementedError: If the method is not implemented in a subclass.
"""
- raise NotImplementedError("Subclasses must implement the run method.")
+ pass
diff --git a/src/ststeroids/flow_context.py b/src/ststeroids/flow_context.py
new file mode 100644
index 0000000..fa78ecd
--- /dev/null
+++ b/src/ststeroids/flow_context.py
@@ -0,0 +1,30 @@
+import streamlit as st
+
+class FlowContext:
+ """
+ Encapsulates the context of why a flow is being executed.
+
+ Attributes:
+ type: The type of the trigger ("component", "route", "app").
+ identifier: Optional identifier, e.g., component id, route name.
+ """
+
+ type: str
+ identifier: str = None
+
+ def __init__(self, type: str, identifier: str):
+ self.type = type
+ self.identifier = identifier
+
+ def schedule_and_rerun(self, fn, *args, **kwargs):
+ self.schedule(fn, *args, **kwargs)
+ st.rerun()
+
+ def schedule(self, fn, *args, **kwargs):
+ st.session_state["_schedule_rerun"] = {
+ "fn": fn,
+ "args": args,
+ "kwargs": kwargs,
+ "type": self.type,
+ "identifier": self.identifier,
+ }
\ No newline at end of file
diff --git a/src/ststeroids/fragment.py b/src/ststeroids/fragment.py
new file mode 100644
index 0000000..acd9fd2
--- /dev/null
+++ b/src/ststeroids/fragment.py
@@ -0,0 +1,58 @@
+from .component import Component
+from .flow import Flow
+import streamlit as st
+
+
+class Fragment(Component):
+ """
+ Base class for components that render as Streamlit fragments and provide
+ a built-in `refresh` event for decoupled flows.
+
+ Attributes:
+ refresh_interval (str|None): The refresh interval
+ """
+
+ refresh_interval: str | None
+ EVENT_REFRESH = "_refresh" # class-level constant for autocomplete
+
+ @classmethod
+ def create(
+ cls, component_id: str, refresh_interval: str | None = None, *args, **kwargs
+ ):
+ """
+ Create a new Fragment instance or return it from the store.
+
+ :param component_id: Unique identifier for this dialog component.
+ :param refresh_interval: The interval for the on_refresh event.
+ """
+ instance = super().create(component_id, *args, **kwargs)
+ if not hasattr(instance, "refresh_interval"):
+ instance.refresh_interval = refresh_interval
+ return instance
+
+ def render(self) -> None:
+ """
+ Render the component as a fragment and trigger the `on_refresh` event
+ on each rerun/refresh.
+
+ :return: None
+ """
+ if not self.visible:
+ return
+
+ @st.fragment(run_every=self.refresh_interval)
+ def _fragment():
+ if self.EVENT_REFRESH in self._events:
+ self.trigger(self.EVENT_REFRESH)
+ self.display()
+
+ _fragment()
+
+ def on_refresh(self, callback: Flow) -> None:
+ """
+ Register a flow to be executed when the the fragment refreshes
+
+ :param: callback the flow to dispatch on the refresh event.
+ :return: None
+ """
+ self.on(self.EVENT_REFRESH, callback)
diff --git a/src/ststeroids/layout.py b/src/ststeroids/layout.py
index fb8d380..afac9b1 100644
--- a/src/ststeroids/layout.py
+++ b/src/ststeroids/layout.py
@@ -1,23 +1,23 @@
-class Layout:
+from abc import ABC, abstractmethod
+
+
+class Layout(ABC):
"""
Base class for a layout
"""
- def __call__(self):
- self.render()
-
- def execute_render(self):
+ @classmethod
+ def create(cls, *args, **kwargs):
"""
- Executes the render method implemented in the subclasses.
+ Creates a new layout instance.
"""
- self.render()
+ return cls(*args, **kwargs)
+ @abstractmethod
def render(self) -> None:
"""
- Placeholder method for rendering the layout.
+ Abstract method for rendering the layout.
This method should be implemented by subclasses to define how the layout is rendered.
-
- :raises NotImplementedError: If called directly without being implemented in a subclass.
"""
- raise NotImplementedError("Subclasses should implement this method.")
+ pass
diff --git a/src/ststeroids/main.py b/src/ststeroids/main.py
new file mode 100644
index 0000000..3e24596
--- /dev/null
+++ b/src/ststeroids/main.py
@@ -0,0 +1,115 @@
+from .route import Route
+from .route_builder import RouteBuilder
+from .router import Router
+from .layout import Layout
+from .flow import Flow
+import streamlit as st
+from .flow_context import FlowContext
+
+
+class StSteroids:
+ """
+ The main application class for managing routes and navigation.
+
+ StSteroids handles registration of routes, setting a default route,
+ and running the router to navigate to the appropriate page or layout.
+ """
+
+ def __init__(self):
+ """
+ Initializes the StSteroids application instance.
+ """
+ self._router = Router()
+ self._routes: dict[str, Route] = {}
+ self._default: Route | None = None
+ self._on_app_run_once = None
+ self._on_app_run = None
+
+ def _get_active_routes(self) -> dict[str, "Route"]:
+ """
+ Returns a dictionary of active routes filtered by their conditions.
+ Includes the default route if defined.
+ """
+ routes = {route.name: route for route in self._routes.values()
+ if not route.condition or route.condition()}
+
+ if self._default:
+ routes["__default__"] = self._default
+
+ return routes
+
+ def _trigger_run_once_event(self) -> None:
+ """
+ Executes the `_on_app_run_once` event once per session if defined.
+ """
+ if "_on_app_run_once_done" not in st.session_state and self._on_app_run_once:
+ st.session_state["_on_app_run_once_done"] = True
+ self._on_app_run_once.dispatch(FlowContext("app", "run_once"))
+
+ def _handle_scheduled_rerun(self) -> None:
+ """
+ Executes a scheduled rerun if present in session state.
+ """
+ scheduled = st.session_state.pop("_schedule_rerun", None)
+ if scheduled:
+ scheduled["fn"](*scheduled["args"], **scheduled["kwargs"])
+ st.rerun()
+
+ def route(self, name: str) -> RouteBuilder:
+ """
+ Creates a RouteBuilder for defining a new route.
+
+ Example usage:
+ app.route("home").to(HomeLayout).when(user_is_logged_in).register()
+
+ :param name: The unique name of the route.
+ :return: RouteBuilder instance to define target and condition before registering.
+ """
+ return RouteBuilder(self, name)
+
+ def default_route(self, target: Layout) -> None:
+ """
+ Sets the default route for the application.
+
+ The default route is used if no other route is specified when running the app.
+
+ :param target: The target layout for the default route.
+ :return: None
+ """
+ self._default = Route("__default__", target)
+
+ def register(self, route: Route) -> None:
+ """
+ Registers a route in the application.
+
+ :param route: The Route instance to register.
+ :return: None
+ """
+ self._routes[route.name] = route
+
+ def on_app_run_once(self, callback: Flow) -> None:
+ """
+ Registers a flow to be executed once when the application starts.
+
+ This flow will be triggered only the first time the app runs.
+ Subsequent reruns of the app will not re-execute this flow.
+
+ :param callback: The Flow instance to execute on the first app run.
+ :raises RuntimeError: If an on_app_run_once flow has already been registered.
+ :return: None
+ """
+ if self._on_app_run_once:
+ raise RuntimeError("on_app_run_once already registered.")
+ self._on_app_run_once = callback
+
+ def run(self, entry_route: str | None = None) -> None:
+ """Run the application router and handle scheduled tasks."""
+ self._trigger_run_once_event()
+ routes = self._get_active_routes()
+ self._router.register_routes(routes)
+
+ if entry_route:
+ self._router.route(entry_route)
+
+ self._router.run()
+ self._handle_scheduled_rerun()
\ No newline at end of file
diff --git a/src/ststeroids/route.py b/src/ststeroids/route.py
new file mode 100644
index 0000000..40b9870
--- /dev/null
+++ b/src/ststeroids/route.py
@@ -0,0 +1,40 @@
+from .layout import Layout
+from .flow import Flow
+
+
+class Route:
+ """
+ Represents a single route in the application.
+
+ A route defines:
+ - a name (unique identifier),
+ - a target (the layout or callable to navigate to),
+ - an optional flow to dispatch when the route is entered,
+ - an optional condition that determines if the route is active.
+
+ Attributes:
+ name (str): Unique name of the route.
+ target (layout): The target layout to render.
+ on_enter (flow): The flow to dispatch when the route is entered.
+ condition (callable, optional): If provided, the route is active only when this callable returns True.
+ """
+
+ def __init__(
+ self,
+ name: str,
+ target: Layout,
+ on_enter: Flow = None,
+ condition: callable = None,
+ ):
+ """
+ Initializes a Route instance.
+
+ :param name: Unique name of the route.
+ :param target: Layout to render when the route is triggered.
+ :param on_enter: Flow to dispatch when the route is entered.
+ :param condition: Optional callable returning a boolean. If provided, determines if the route is active.
+ """
+ self.name = name
+ self.target = target
+ self.on_enter = on_enter
+ self.condition = condition
diff --git a/src/ststeroids/route_builder.py b/src/ststeroids/route_builder.py
new file mode 100644
index 0000000..96fa9e9
--- /dev/null
+++ b/src/ststeroids/route_builder.py
@@ -0,0 +1,68 @@
+from .layout import Layout
+from .route import Route
+from .flow import Flow
+
+
+class RouteBuilder:
+ """
+ A builder class for defining and registering routes in the application.
+
+ Allows chaining of target and condition definitions before registering the route.
+
+ Example usage:
+ RouteBuilder(app, "home").to(HomeLayout).when(user_is_logged_in).register()
+ """
+
+ def __init__(self, app, name: str):
+ """
+ Initializes the RouteBuilder.
+
+ :param app: The application instance where the route will be registered.
+ :param name: Unique name of the route.
+ """
+ self.app = app
+ self._name = name
+ self._target = None
+ self._condition = None
+ self._on_enter = None
+
+ def to(self, target: Layout) -> "RouteBuilder":
+ """
+ Sets the target for this route.
+
+ :param target: Layout class or callable to execute when the route is triggered.
+ :return: Self, to allow method chaining.
+ """
+ self._target = target
+ return self
+
+ def when(self, condition: callable) -> "RouteBuilder":
+ """
+ Sets a condition for this route.
+
+ The route will only be active if the condition evaluates to True.
+
+ :param condition: Callable returning a boolean.
+ :return: Self, to allow method chaining.
+ """
+ self._condition = condition
+ return self
+
+ def on_enter(self, callback: Flow):
+ self._on_enter = callback
+ return self
+
+ def register(self) -> None:
+ """
+ Registers the route in the application with the specified target and condition.
+
+ Raises:
+ ValueError: If no target has been set for the route.
+ """
+ if self._target is None:
+ raise ValueError(
+ f"Route '{self._name}' cannot be registered without a target."
+ )
+ self.app.register(
+ Route(self._name, self._target, self._on_enter, self._condition)
+ )
diff --git a/src/ststeroids/router.py b/src/ststeroids/router.py
index e94e78b..e50269a 100644
--- a/src/ststeroids/router.py
+++ b/src/ststeroids/router.py
@@ -1,55 +1,70 @@
-import streamlit as st
-from .layout import Layout
+from .route import Route
+from .flow_context import FlowContext
class Router:
"""
- A routing system for Streamlit applications, allowing navigation between different pages.
+ Central routing system responsible for selecting and rendering layouts.
+
+ The Router maintains a set of registered routes, determines which route
+ is currently active, optionally dispatches route lifecycle flows, and
+ renders the corresponding layout.
"""
- def __init__(self, default: str = "home"):
+ def __init__(self, default: str = "__default__"):
"""
- Initializes the Router instance with a default page.
+ Initialize the Router.
- :param default: The default page to load when the app starts. Defaults to "home".
+ :param default: The name of the default route to use when no explicit
+ route has been selected.
"""
- self.routes = {}
- if "ststeroids_current_route" not in st.session_state:
- st.session_state["ststeroids_current_route"] = default
+ self._routes: dict[str, Route] = {}
+ self._current: str | None = None
+ self._default: str = default
- def run(self):
+ def register_routes(self, routes: dict[str, Route]) -> None:
"""
- Executes the function associated with the currently active route.
+ Register the available routes.
+
+ This replaces any previously registered routes.
+ :param routes: A mapping of route names to Route instances.
:return: None
"""
- try:
- route = self.routes[st.session_state["ststeroids_current_route"]]
- except KeyError as exc:
- raise KeyError(
- f"The current route '{st.session_state['ststeroids_current_route']}' is not a registered route."
- ) from exc
- route()
+ self._routes = routes
- def route(self, route_name: str):
+ def route(self, route_name: str) -> None:
"""
- Updates the current page in the session state.
+ Set the current route to navigate to.
+
+ The route will be resolved and rendered on the next call to `run()`.
- :param route_name: The name of the route to navigate to.
+ :param route_name: The name of the route to activate.
:return: None
"""
- st.session_state["ststeroids_current_route"] = route_name
+ self._current = route_name
- def register_routes(self, routes: dict[str, Layout]):
+ def run(self) -> None:
"""
- Registers a dictionary of routes where keys are route names and values are layouts.
+ Resolve and render the active route.
- :param routes: A dictionary mapping route names to layouts.
+ The router selects the current route if set, otherwise falls back
+ to the default route. If the route defines an `on_enter` flow, it
+ will be dispatched before rendering the target layout.
+
+ :raises RuntimeError: If no valid route can be resolved.
:return: None
"""
- self.routes = routes
+ if self._current in self._routes:
+ route = self._routes[self._current]
+ elif self._default in self._routes:
+ route = self._routes[self._default]
+ else:
+ raise RuntimeError(
+ "No current route selected and no default route registered."
+ )
+
+ if route.on_enter:
+ route.on_enter.dispatch(FlowContext("route", route.name))
- def get_current_route(self):
- if "ststeroids_current_route" in st.session_state:
- return st.session_state["ststeroids_current_route"]
- return None
+ route.target.render()
diff --git a/src/ststeroids/store.py b/src/ststeroids/store.py
index 817b746..f974c46 100644
--- a/src/ststeroids/store.py
+++ b/src/ststeroids/store.py
@@ -7,20 +7,29 @@ class Store:
Class for creating a session store.
This class manages storing and retrieving properties in Streamlit's session state.
- It initializes a store with a unique name and allows properties to be set and retrieved.
- :param store_name: The name of the store to create in session state.
+ Attributes:
+ name (str): Unique name of the store.
"""
def __init__(self, store_name: str):
"""
- Initializes the session store with the given name.
+ Initializes a store with a unique name and allows properties to be set and retrieved. Do not use directly, use create instead.
:param store_name: The name of the store to create in session state.
"""
self.name = store_name
if store_name not in st.session_state:
- st.session_state[self.name] = {}
+ st.session_state[store_name] = {}
+
+ @classmethod
+ def create(cls, store_name: str):
+ """
+ Creates a new instance of the store with a unique name and allows properties to be set and retrieved.
+
+ :param store_name: The name of the store to create in session state.
+ """
+ return cls(store_name)
def has_property(self, property_name: str) -> bool:
"""
@@ -65,75 +74,20 @@ def del_property(self, property_name: str) -> None:
class ComponentStore(Store):
"""
- Class that creates a component session store. This can be passed to component instances.
+ Class that creates a component session store.
- :param component_id: The unique identifier for the component.
- :param initial_state: The initial state of the component.
"""
- _instance = None
-
- def __new__(cls, *args, **kwargs):
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- return cls._instance
-
- def __init__(self):
- """
- Initializes the component store with the name 'components'.
-
- This store is used specifically for storing component-related state in the session.
- """
- super().__init__("components")
-
def init_component(self, component: object) -> None:
"""
Initializes a component in the session store with its ID
- :param component_id: The unique identifier for the component.
+ :param component: The component instance.
:return: None
"""
if not self.has_property(component.id):
super().set_property(component.id, component)
- def init_component_state(self, component_id: str, initial_state: dict) -> None:
- """
- Initializes a component state in the session store with its ID and initial state.
-
- :param component_id: The unique identifier for the component.
- :param initial_state: The initial state to set for the component.
- :return: None
- """
- if not self.has_property(f"{component_id}_state"):
- super().set_property(f"{component_id}_state", initial_state)
-
- def get_property( # pylint: disable=arguments-differ
- self, component_id: str, property_name: str
- ) -> Any:
- """
- Retrieves the value of a property from a component's state.
-
- :param component_id: The unique identifier for the component.
- :param property_name: The name of the property to retrieve.
- :return: The value of the property from the component's state.
- """
- return super().get_property(f"{component_id}_state")[property_name]
-
- def set_property( # pylint: disable=arguments-differ
- self, component_id: str, property_name: str, property_value: Any
- ) -> None:
- """
- Sets the value of a property in a component's state.
-
- :param component_id: The unique identifier for the component.
- :param property_name: The name of the property to set.
- :param property_value: The value to set for the property.
- :return: None
- """
- component_state = super().get_property(f"{component_id}_state")
- component_state[property_name] = property_value
- super().set_property(f"{component_id}_state", component_state)
-
def get_component(self, component_id: str):
"""
Retrieves the current state or properties of a component.
diff --git a/src/ststeroids/style.py b/src/ststeroids/style.py
index 935acf5..e794dc7 100644
--- a/src/ststeroids/style.py
+++ b/src/ststeroids/style.py
@@ -5,6 +5,9 @@
class Style:
"""
A class for applying custom CSS styles to a Streamlit app.
+
+ Attributes:
+ style_file (str): Path to CSS file for this instance.
"""
def __init__(self, style_file: str):
@@ -15,7 +18,7 @@ def __init__(self, style_file: str):
"""
self.style_file = style_file
- def apply_style(self):
+ def apply_style(self) -> None:
"""
Reads the CSS file and applies its styles to the Streamlit app.
diff --git a/tests/test_component.py b/tests/test_component.py
index df0afef..21caac4 100644
--- a/tests/test_component.py
+++ b/tests/test_component.py
@@ -1,149 +1,103 @@
-from unittest.mock import MagicMock, patch
import pytest
+from unittest.mock import MagicMock, patch
+
+from ststeroids.component import Component
+from ststeroids.flow import Flow
from ststeroids.store import ComponentStore
-from ststeroids.component import Component, State
+from ststeroids.flow_context import FlowContext
@pytest.fixture
def mock_session_state():
- with patch("streamlit.session_state", new={}) as mock_state:
- yield mock_state
+ with patch("streamlit.session_state", new={}) as state:
+ yield state
-@pytest.fixture(scope="session")
+@pytest.fixture
def mock_store():
- # Mocking the ComponentStore for testing purposes
- store = MagicMock(spec=ComponentStore)
+ store = MagicMock(spec_set=ComponentStore)
+ store.has_property.return_value = False
return store
-@pytest.fixture(scope="session")
+@pytest.fixture
def component(mock_store):
- # Creating a sample component for testing
- component = Component(component_id="test_component", initial_state={"key": "value"})
- component._Component__store = (
- mock_store # Injecting the mock store into the component
- )
- return component
-
-def test_component_creation_without_id():
- with pytest.raises(KeyError):
- component = Component(initial_state={"key": "value"})
-
-def test_component_singleton():
- first_instance = Component(component_id="test_component", initial_state={"key": "value"})
- second_instance = Component(component_id="test_component", initial_state={"key": "value"})
- assert first_instance is second_instance
-
-def test_subclass_init_runs_only_once():
- calls = {"count": 0}
-
- class Sub(Component):
- def __init__(self, value):
- calls["count"] += 1
- self.value = value
-
- obj = Sub(42)
- assert obj.value == 42
- assert calls["count"] == 1 # __init__ ran once
-
- # Call __init__ again explicitly
- obj.__init__(99)
- assert obj.value == 42 # value didn't change
- assert calls["count"] == 1 # __init__ not called again
-
-
-def test_component_initialization(component):
- # Test that the component is initialized correctly
- assert component.id == "test_component"
- assert isinstance(component.state, State)
-
-
-def test_state_initialization(mock_store):
- # Test that the state is initialized with the component ID and store
- state = State(
- component_id="test_component", store=mock_store, initial_state={"key": "value"}
- )
- mock_store.init_component_state.assert_called_once_with(
- "test_component", {"key": "value"}
- )
- assert state._State__id == "test_component"
- assert state._State__store == mock_store
-
-
-def test_getattr(component):
- # Test that attributes are retrieved correctly from the store
- assert component.state.key == "value"
-
-
-def test_setattr(component):
- # Test that attributes are set correctly in the store
- component.state.key = "new_value"
- assert component.state.key == "new_value"
-
+ with patch("ststeroids.store.ComponentStore.create", return_value=mock_store):
-def test_render_not_implemented(component):
- # Test that calling render raises NotImplementedError
- with pytest.raises(NotImplementedError):
- component.render()
+ class MyComponent(Component):
+ def display(self):
+ pass
+ return MyComponent.create("test_component")
-def test_register_element(component):
- element_name = "button"
- expected_key = "test_component_button"
- assert component.register_element(element_name) == expected_key
+def test_component_create_returns_same_instance(mock_store):
+ with patch("ststeroids.store.ComponentStore.create", return_value=mock_store):
-def test_get_element_not_set(component):
- element_name = "non_existent"
- assert component.get_element(element_name) is None
+ class MyComponent(Component):
+ def display(self):
+ pass
+ comp1 = MyComponent.create("comp")
+ MyComponent.create("comp")
+ # Same store call, simulating singleton
+ mock_store.has_property.return_value = True
+ mock_store.get_component.return_value = comp1
+ comp3 = MyComponent.create("comp")
+ assert comp1 is comp3
-def test_get_element_set(component, mock_session_state):
- element_name = "input"
- key = component.register_element(element_name)
- mock_session_state[key] = "Test Value"
- assert component.get_element(element_name) == "Test Value"
-
-def test_set_element(component, mock_session_state):
- element_name = "input"
- key = component.register_element(element_name)
- mock_session_state[key] = "nothing"
- component.set_element(element_name, "something")
- assert component.get_element(element_name) == "something"
-
-
-def test__render_fragment_with_flow(component):
- mock_flow = MagicMock()
- component.render = MagicMock()
-
- component._Component__render_fragment(refresh_flow=mock_flow)
-
- mock_flow.execute_run.assert_called_once()
- component.render.assert_called_once()
-
-
-def test_execute_render_normal(component):
- component.render = MagicMock(return_value="normal_rendered")
- result = component.execute_render(render_as="normal")
- component.render.assert_called_once()
- assert result == "normal_rendered"
-
-
-def test_execute_render_dialog(component):
- component._render_dialog = MagicMock(return_value="dialog_rendered")
- result = component.execute_render(render_as="dialog", options={"title": "bar"})
- component._render_dialog.assert_called_once_with(title="bar")
- assert result == "dialog_rendered"
-
-
-def test_execute_render_fragment(component):
- component._render_fragment = MagicMock(return_value="fragment_rendered")
- result = component.execute_render(render_as="fragment", options={"x": 1})
- component._render_fragment.assert_called_once_with(x=1)
- assert result == "fragment_rendered"
-
-def test_execute_render_raises_an_error_with_an_invalid_render_as(component):
- with pytest.raises(ValueError):
- component.execute_render(render_as="something", options={"x": 1})
\ No newline at end of file
+def test_component_attributes(component):
+ assert component.id == "test_component"
+ assert hasattr(component, "_events")
+ assert component.visible is True
+
+
+def test_register_element_returns_key(component):
+ key = component.register_element("button")
+ assert key == "test_component_button"
+
+
+def test_get_element_and_set_element(mock_session_state, component):
+ key = component.register_element("input")
+ # Initially None
+ assert component.get_element("input") is None
+ component.set_element("input", "value")
+ assert mock_session_state[key] == "value"
+ assert component.get_element("input") == "value"
+
+
+def test_on_and_trigger_calls_flow(component):
+ flow = MagicMock(spec_set=Flow)
+ component.on("click", flow)
+ component.trigger("click")
+ flow.dispatch.assert_called_once()
+ args, _ = flow.dispatch.call_args
+ ctx = args[0]
+ assert isinstance(ctx, FlowContext)
+ assert ctx.identifier == component.id
+ assert ctx.type == "component"
+
+def test_on_and_trigger_does_not_call_flow_when_not_registered(component):
+ flow = MagicMock(spec_set=Flow)
+ component.trigger("click")
+ flow.dispatch.assert_not_called()
+
+def test_render_calls_display(component):
+ component.display = MagicMock()
+ component.render()
+ component.display.assert_called_once()
+
+
+def test_render_skips_if_not_visible(component):
+ component.display = MagicMock()
+ component.hide()
+ component.render()
+ component.display.assert_not_called()
+
+
+def test_show_and_hide(component):
+ component.hide()
+ assert component.visible is False
+ component.show()
+ assert component.visible is True
diff --git a/tests/test_dialog.py b/tests/test_dialog.py
new file mode 100644
index 0000000..8a0a21b
--- /dev/null
+++ b/tests/test_dialog.py
@@ -0,0 +1,55 @@
+import pytest
+from unittest.mock import MagicMock, patch
+
+from ststeroids.dialog import Dialog
+
+
+class MyDialog(Dialog):
+ def display(self):
+ pass
+
+
+@pytest.fixture
+def mock_dialog():
+ with patch("streamlit.dialog") as mock:
+ yield mock
+
+
+@pytest.fixture
+def dialog_instance():
+ return MyDialog.create("my_dialog", title="My Title")
+
+
+def test_create_sets_title(dialog_instance):
+ assert dialog_instance.title == "My Title"
+
+
+def test_get_does_not_set_title(dialog_instance):
+ MyDialog.get("my_dialog")
+ assert dialog_instance.title == "My Title"
+
+def test_render_calls_display_inside_dialog(dialog_instance, mock_dialog):
+ # Mock display
+ dialog_instance.display = MagicMock()
+
+ # st.dialog returns a decorator that immediately calls the wrapped function
+ def fake_decorator(func):
+ def wrapper():
+ func()
+
+ return wrapper
+
+ mock_dialog.side_effect = lambda title: fake_decorator
+
+ dialog_instance.render()
+ dialog_instance.display.assert_called_once()
+ mock_dialog.assert_called_once_with("My Title")
+
+
+def test_render_skips_if_not_visible(dialog_instance, mock_dialog):
+ dialog_instance.display = MagicMock()
+ dialog_instance.hide()
+ dialog_instance.render()
+ # display should not be called
+ dialog_instance.display.assert_not_called()
+ mock_dialog.assert_not_called()
diff --git a/tests/test_flow.py b/tests/test_flow.py
index 170df45..af252cf 100644
--- a/tests/test_flow.py
+++ b/tests/test_flow.py
@@ -1,24 +1,31 @@
import pytest
+from unittest.mock import MagicMock
+
from ststeroids.flow import Flow
-from ststeroids.store import ComponentStore
-def test_flow_initializes_component_store():
- flow = Flow()
- assert isinstance(flow.component_store, ComponentStore)
+def test_flow_cannot_instantiate_directly():
+ with pytest.raises(TypeError):
+ Flow()
+
+def test_subclass_run_called_by_dispatch():
+ class MyFlow(Flow):
+ def run(self, ctx):
+ pass
-def test_flow_run_raises_not_implemented_error():
- flow = Flow()
- with pytest.raises(NotImplementedError):
- flow.execute_run()
+ flow = MyFlow.create()
+ flow.run = MagicMock()
+ flow.dispatch(None)
+ flow.run.assert_called_once_with(None)
-def test_subclass_run_called_by__run():
+def test_flow_create_classmethod():
class MyFlow(Flow):
- def run(self, x):
- return x * 2
+ def run(self, ctx):
+ pass
- flow = MyFlow()
- result = flow.execute_run(3)
- assert result == 6
+ flow = MyFlow.create()
+ assert isinstance(flow, MyFlow)
+ result = flow.run(None)
+ assert result is None
diff --git a/tests/test_fragement.py b/tests/test_fragement.py
new file mode 100644
index 0000000..9c4bc4b
--- /dev/null
+++ b/tests/test_fragement.py
@@ -0,0 +1,64 @@
+import pytest
+from unittest.mock import MagicMock, patch
+
+from ststeroids.fragment import Fragment
+from ststeroids.flow import Flow
+
+class MyFragment(Fragment):
+ def display(self):
+ pass
+
+@pytest.fixture
+def mock_fragment():
+ with patch("streamlit.fragment") as mock:
+ yield mock
+
+
+@pytest.fixture
+def fragment_instance():
+ return MyFragment.create("frag1", refresh_interval="5s")
+
+
+def test_create_sets_refresh_interval(fragment_instance):
+ assert fragment_instance.refresh_interval == "5s"
+
+def test_get_does_not_set_refresh_interval(fragment_instance):
+ MyFragment.get("my_dialog")
+ assert fragment_instance.refresh_interval == "5s"
+
+
+def test_on_refresh_registers_flow(fragment_instance):
+ flow = MagicMock(spec_set=Flow)
+ fragment_instance.on_refresh(flow)
+ assert fragment_instance._events[fragment_instance.EVENT_REFRESH] == flow
+
+
+def test_render_calls_display_inside_fragment(fragment_instance, mock_fragment):
+ # Mock display and trigger
+ fragment_instance.display = MagicMock()
+ fragment_instance.trigger = MagicMock()
+
+ # st.fragment returns a decorator that calls the wrapped function immediately
+ def fake_decorator(func):
+ def wrapper():
+ func()
+
+ return wrapper
+
+ mock_fragment.side_effect = lambda run_every=None: fake_decorator
+
+ fragment_instance.render()
+
+ mock_fragment.assert_called_once_with(run_every="5s")
+ fragment_instance.trigger.assert_called_once_with(fragment_instance.EVENT_REFRESH)
+ fragment_instance.display.assert_called_once()
+
+
+def test_render_skips_if_not_visible(fragment_instance, mock_fragment):
+ fragment_instance.display = MagicMock()
+ fragment_instance.trigger = MagicMock()
+ fragment_instance.hide()
+ fragment_instance.render()
+ fragment_instance.display.assert_not_called()
+ fragment_instance.trigger.assert_not_called()
+ mock_fragment.assert_not_called()
diff --git a/tests/test_layout.py b/tests/test_layout.py
index 94795ab..b73143e 100644
--- a/tests/test_layout.py
+++ b/tests/test_layout.py
@@ -3,18 +3,29 @@
from unittest.mock import MagicMock
-def test_layout_render_raises_not_implemented_error():
- layout = Layout()
- with pytest.raises(NotImplementedError):
- layout.render()
+def test_layout_cannot_instantiate_directly():
+ # Abstract classes cannot be instantiated
+ with pytest.raises(TypeError):
+ Layout()
-def test_subclass_run_called_by__run():
+def test_subclass_render_called():
class MyLayout(Layout):
def render(self):
- return ""
+ return "rendered"
layout = MyLayout()
- layout.render = MagicMock()
- layout.execute_render()
+ layout.render = MagicMock(return_value="rendered")
+ result = layout.render()
layout.render.assert_called_once()
+ assert result == "rendered"
+
+
+def test_layout_create_classmethod():
+ class MyLayout(Layout):
+ def render(self):
+ return "ok"
+
+ layout = MyLayout.create()
+ assert isinstance(layout, MyLayout)
+ assert layout.render() == "ok"
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 0000000..ebd0792
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,180 @@
+import pytest
+from unittest.mock import MagicMock, patch
+
+from ststeroids import StSteroids
+from ststeroids.route import Route
+from ststeroids.layout import Layout
+from ststeroids.flow import Flow
+
+
+@pytest.fixture
+def mock_session_state():
+ # Patch Streamlit session_state and return the dict for inspection
+ with patch("streamlit.session_state", new={}) as state:
+ yield state
+
+
+@pytest.fixture
+def app():
+ return StSteroids()
+
+
+# ----------------------------
+# Unit tests for helpers
+# ----------------------------
+
+def test_get_active_routes_filters_correctly(app):
+ class MyLayout(Layout):
+ def render(self): pass
+
+ route_true = Route("true_route", MyLayout(), condition=lambda: True)
+ route_false = Route("false_route", MyLayout(), condition=lambda: False)
+ app.register(route_true)
+ app.register(route_false)
+ app._default = Route("__default__", MyLayout())
+
+ active_routes = app._get_active_routes()
+ assert "true_route" in active_routes
+ assert "false_route" not in active_routes
+ assert "__default__" in active_routes
+
+
+def test_trigger_run_once_event_only_runs_once(app, mock_session_state):
+ flow = MagicMock(spec=Flow)
+ app._on_app_run_once = flow
+
+ # First call triggers dispatch
+ app._trigger_run_once_event()
+ flow.dispatch.assert_called_once()
+ assert mock_session_state["_on_app_run_once_done"] is True
+
+ # Second call does not trigger dispatch
+ flow.reset_mock()
+ app._trigger_run_once_event()
+ flow.dispatch.assert_not_called()
+
+
+def test_handle_scheduled_rerun_executes_and_reruns(app, mock_session_state):
+ fn = MagicMock()
+ mock_session_state["_schedule_rerun"] = {"fn": fn, "args": (1,), "kwargs": {"a": 2}}
+
+ # Patch st.rerun to avoid actually rerunning
+ with patch("streamlit.rerun") as mock_rerun:
+ app._handle_scheduled_rerun()
+
+ fn.assert_called_once_with(1, a=2)
+ mock_rerun.assert_called_once()
+ assert "_schedule_rerun" not in mock_session_state
+
+
+# ----------------------------
+# Integration tests for run
+# ----------------------------
+
+def test_run_triggers_on_app_run_once_only_once(app, mock_session_state):
+ flow = MagicMock(spec=Flow)
+ app.on_app_run_once(flow)
+
+ app._router.register_routes = MagicMock()
+ app._router.route = MagicMock()
+ app._router.run = MagicMock()
+
+ # First run
+ app.run()
+ flow.dispatch.assert_called_once()
+ assert "_on_app_run_once_done" in mock_session_state
+ assert mock_session_state["_on_app_run_once_done"] is True
+
+ # Second run
+ flow.reset_mock()
+ app.run()
+ flow.dispatch.assert_not_called()
+
+
+def test_run_calls_helpers_and_router_methods(app, mock_session_state):
+ flow = MagicMock(spec=Flow)
+ app.on_app_run_once(flow)
+
+ route = Route("home", MagicMock())
+ app.register(route)
+
+ app._router.register_routes = MagicMock()
+ app._router.route = MagicMock()
+ app._router.run = MagicMock()
+
+ # Patch the helpers to ensure they are called
+ with patch.object(app, "_trigger_run_once_event") as mock_trigger, \
+ patch.object(app, "_get_active_routes") as mock_routes, \
+ patch.object(app, "_handle_scheduled_rerun") as mock_rerun:
+
+ mock_routes.return_value = {"home": route}
+
+ app.run(entry_route="home")
+
+ mock_trigger.assert_called_once()
+ mock_routes.assert_called_once()
+ mock_rerun.assert_called_once()
+ app._router.register_routes.assert_called_once_with({"home": route})
+ app._router.route.assert_called_once_with("home")
+ app._router.run.assert_called_once()
+
+
+def test_run_filters_routes_by_condition(app):
+ class MyLayout(Layout):
+ def render(self): pass
+
+ route_true = Route("true_route", MyLayout(), condition=lambda: True)
+ route_false = Route("false_route", MyLayout(), condition=lambda: False)
+ app.register(route_true)
+ app.register(route_false)
+
+ app._router.register_routes = MagicMock()
+ app._router.route = MagicMock()
+ app._router.run = MagicMock()
+
+ app.run()
+ routes_passed = app._router.register_routes.call_args[0][0]
+
+ assert "true_route" in routes_passed
+ assert "false_route" not in routes_passed
+
+
+# ----------------------------
+# Other basic tests
+# ----------------------------
+
+def test_route_returns_routebuilder(app):
+ rb = app.route("home")
+ from ststeroids.route_builder import RouteBuilder
+ assert isinstance(rb, RouteBuilder)
+
+
+def test_default_route_sets_default(app):
+ class MyLayout(Layout):
+ def render(self): pass
+
+ layout = MyLayout()
+ app.default_route(layout)
+ assert app._default.name == "__default__"
+ assert app._default.target == layout
+
+
+def test_register_adds_route(app):
+ class MyLayout(Layout):
+ def render(self): pass
+
+ layout = MyLayout()
+ route = Route("home", layout)
+ app.register(route)
+
+ assert "home" in app._routes
+ assert app._routes["home"] == route
+
+
+def test_on_app_run_once_registers_flow(app):
+ flow = MagicMock(spec=Flow)
+ app.on_app_run_once(flow)
+ assert app._on_app_run_once == flow
+
+ with pytest.raises(RuntimeError):
+ app.on_app_run_once(flow)
\ No newline at end of file
diff --git a/tests/test_route.py b/tests/test_route.py
new file mode 100644
index 0000000..357eb37
--- /dev/null
+++ b/tests/test_route.py
@@ -0,0 +1,27 @@
+from ststeroids.route import Route
+from ststeroids.layout import Layout
+from ststeroids.flow import Flow
+from unittest.mock import MagicMock
+
+
+def test_route_initialization_defaults():
+ layout = MagicMock(spec_set=Layout)
+ route = Route(name="home", target=layout)
+
+ assert route.name == "home"
+ assert route.target == layout
+ assert route.on_enter is None
+ assert route.condition is None
+
+
+def test_route_initialization_with_all_arguments():
+ layout = MagicMock(spec_set=Layout)
+ flow = MagicMock(spec_set=Flow)
+ condition = lambda: True
+
+ route = Route(name="dashboard", target=layout, on_enter=flow, condition=condition)
+
+ assert route.name == "dashboard"
+ assert route.target == layout
+ assert route.on_enter == flow
+ assert route.condition == condition
diff --git a/tests/test_route_builder.py b/tests/test_route_builder.py
new file mode 100644
index 0000000..952d33f
--- /dev/null
+++ b/tests/test_route_builder.py
@@ -0,0 +1,55 @@
+import pytest
+from unittest.mock import MagicMock
+
+from ststeroids.route_builder import RouteBuilder
+from ststeroids.route import Route
+from ststeroids.layout import Layout
+from ststeroids.flow import Flow
+
+
+def test_to_when_on_enter_chain_returns_self():
+ app = MagicMock()
+ builder = RouteBuilder(app, "home")
+
+ class DummyLayout(Layout):
+ def render(self):
+ pass
+
+ flow = MagicMock(spec_set=Flow)
+ condition = lambda: True
+
+ # Each method should return self for chaining
+ assert builder.to(DummyLayout) is builder
+ assert builder.when(condition) is builder
+ assert builder.on_enter(flow) is builder
+
+
+def test_register_without_target_raises():
+ app = MagicMock()
+ builder = RouteBuilder(app, "home")
+
+ with pytest.raises(ValueError, match="cannot be registered without a target"):
+ builder.register()
+
+
+def test_register_calls_app_register_with_route():
+ app = MagicMock()
+ builder = RouteBuilder(app, "home")
+
+ class DummyLayout(Layout):
+ def render(self):
+ pass
+
+ flow = MagicMock(spec_set=Flow)
+ condition = lambda: True
+
+ builder.to(DummyLayout).when(condition).on_enter(flow).register()
+
+ # Ensure app.register was called once with a Route instance
+ assert app.register.call_count == 1
+ route_arg = app.register.call_args[0][0]
+ assert isinstance(route_arg, Route)
+ assert route_arg.name == "home"
+ assert route_arg.target == DummyLayout
+ assert route_arg.on_enter == flow
+ assert route_arg.condition == condition
diff --git a/tests/test_router.py b/tests/test_router.py
index 38e7dbe..f133b0a 100644
--- a/tests/test_router.py
+++ b/tests/test_router.py
@@ -1,8 +1,9 @@
-from collections import defaultdict
import pytest
-import streamlit as st
from unittest.mock import MagicMock
-from ststeroids import Router
+
+from ststeroids.router import Router
+from ststeroids.route import Route
+from ststeroids.flow_context import FlowContext
@pytest.fixture
@@ -10,79 +11,95 @@ def router():
return Router()
-@pytest.fixture
-def mock_session_state(mocker):
- mocker.patch.object(st, "session_state", {}, create=True)
+def make_route(name="home", on_enter=None):
+ target = MagicMock()
+ target.render = MagicMock()
+ return Route(
+ name=name,
+ target=target,
+ on_enter=on_enter,
+ )
-def test_router_initialization(mock_session_state, router):
- assert "ststeroids_current_route" in st.session_state
- assert st.session_state["ststeroids_current_route"] == "home"
+def test_router_initialization(router):
+ assert router._current is None
+ assert router._default == "__default__"
+ assert router._routes == {}
-def test_router_initialization_with_custom_default(mock_session_state):
- Router(default="dashboard")
- assert st.session_state["ststeroids_current_route"] == "dashboard"
+def test_router_initialization_with_custom_default():
+ router = Router(default="dashboard")
+ assert router._default == "dashboard"
-def test_register_routes(mock_session_state, router):
- mock_layout = MagicMock()
- routes = {"home": mock_layout, "dashboard": mock_layout}
+def test_register_routes(router):
+ route = make_route("home")
+ routes = {"home": route}
+
router.register_routes(routes)
- assert router.routes == routes
+
+ assert router._routes == routes
-def test_route_changes_current_route(mock_session_state, router):
+def test_route_sets_current_route(router):
router.route("dashboard")
- assert st.session_state["ststeroids_current_route"] == "dashboard"
+ assert router._current == "dashboard"
+
+def test_run_calls_current_route(router):
+ route = make_route("home")
-def test_run_calls_current_route(mock_session_state, router):
- mock_function = MagicMock()
- router.register_routes({"home": mock_function})
+ router.register_routes({"home": route})
+ router.route("home")
router.run()
- mock_function.assert_called_once()
+ route.target.render.assert_called_once()
-def test_run_calls_current_route_that_raises_an_exception(mock_session_state, router):
- mock_function = MagicMock(side_effect=KeyError("Missing key"))
- router.register_routes({"home": mock_function})
- with pytest.raises(KeyError, match="Missing key"):
- router.route("home")
- router.run()
+def test_run_calls_on_enter_if_present(router):
+ on_enter = MagicMock()
+ on_enter.dispatch = MagicMock()
-def test_run_calls_invalid_current_route(mock_session_state, router):
- mock_function = MagicMock()
- router.register_routes({"home": mock_function})
- with pytest.raises(
- KeyError, match="The current route 'invalid' is not a registered route."
- ):
- router.route("invalid")
- router.run()
+ route = make_route("home", on_enter=on_enter)
+ router.register_routes({"home": route})
+ router.route("home")
+ router.run()
-def test_run_calls_with_defaultdict(mock_session_state, router):
- mock_function = MagicMock()
- default_function = MagicMock()
+ on_enter.dispatch.assert_called_once()
+ args, kwargs = on_enter.dispatch.call_args
+ flow_context = args[0]
+ assert isinstance(flow_context, FlowContext)
+ assert flow_context.identifier == "home"
+ assert flow_context.type == "route"
- # Use defaultdict to return default_function for any missing keys
- router.register_routes(
- defaultdict(lambda: default_function, {"home": mock_function})
- )
+ route.target.render.assert_called_once()
- router.route(
- "invalid"
- ) # This will now return default_function instead of raising KeyError
+
+def test_run_falls_back_to_default_route(router):
+ default_route = make_route("__default__")
+
+ router.register_routes({"__default__": default_route})
router.run()
- # Ensure the default function is called
- default_function.assert_called_once()
- # Ensure the "home" function is not called
- mock_function.assert_not_called()
+ default_route.target.render.assert_called_once()
-def test_get_current_route(mock_session_state, router):
- assert router.get_current_route() == "home"
- router.route("dashboard")
- assert router.get_current_route() == "dashboard"
+def test_run_raises_if_no_current_and_no_default(router):
+ router.register_routes({})
+
+ with pytest.raises(
+ RuntimeError,
+ match="No current route selected and no default route registered.",
+ ):
+ router.run()
+
+
+def test_run_uses_default_when_current_is_invalid(router):
+ default_route = make_route("__default__")
+
+ router.register_routes({"__default__": default_route})
+ router.route("invalid")
+ router.run()
+
+ default_route.target.render.assert_called_once()
diff --git a/tests/test_store.py b/tests/test_store.py
index 5eac546..8fd0ddf 100644
--- a/tests/test_store.py
+++ b/tests/test_store.py
@@ -1,43 +1,56 @@
import pytest
from unittest.mock import patch
+
from ststeroids import Store
from ststeroids.store import ComponentStore
@pytest.fixture
def mock_session_state():
- with patch("streamlit.session_state", new={}) as mock_state:
- yield mock_state
+ # Patch Streamlit session_state and return the dict for inspection
+ with patch("streamlit.session_state", new={}) as state:
+ yield state
+
+
+# =========================
+# Store tests
+# =========================
def test_store_initialization(mock_session_state):
- Store("test_store")
+ store = Store.create("test_store")
+
assert "test_store" in mock_session_state
assert mock_session_state["test_store"] == {}
+ assert store.name == "test_store"
def test_store_set_property(mock_session_state):
- store = Store("test_store")
+ store = Store.create("test_store")
store.set_property("key", "value")
+
assert mock_session_state["test_store"]["key"] == "value"
def test_store_get_property(mock_session_state):
- store = Store("test_store")
+ store = Store.create("test_store")
store.set_property("key", "value")
+
assert store.get_property("key") == "value"
def test_store_del_property(mock_session_state):
- store = Store("test_store")
+ store = Store.create("test_store")
store.set_property("key", "value")
store.del_property("key")
- with pytest.raises(KeyError, match="'key' doesn't"):
+
+ with pytest.raises(KeyError, match="'key' doesn't exist in store 'test_store'."):
store.get_property("key")
def test_store_get_property_key_error(mock_session_state):
- store = Store("test_store")
+ store = Store.create("test_store")
+
with pytest.raises(
KeyError, match="'missing_key' doesn't exist in store 'test_store'."
):
@@ -45,40 +58,58 @@ def test_store_get_property_key_error(mock_session_state):
def test_store_has_property(mock_session_state):
- store = Store("test_store")
+ store = Store.create("test_store")
store.set_property("key", "value")
+
assert store.has_property("key") is True
assert store.has_property("missing_key") is False
-def test_component_store_singleton():
- first_instance = ComponentStore()
- second_instance = ComponentStore()
- assert first_instance is second_instance
+
+# =========================
+# ComponentStore tests
+# =========================
+
def test_component_store_initialization(mock_session_state):
- ComponentStore()
+ component_store = ComponentStore.create("components")
+
assert "components" in mock_session_state
assert mock_session_state["components"] == {}
+ assert component_store.name == "components"
def test_component_store_init_component(mock_session_state):
class MockComponent:
id = "comp1"
- component_store = ComponentStore()
+ component_store = ComponentStore.create("components")
component = MockComponent()
+
component_store.init_component(component)
- assert component_store.get_component("comp1") == component
+ assert component_store.get_component("comp1") is component
+
+
+def test_component_store_does_not_override_existing_component(mock_session_state):
+ class MockComponent:
+ id = "comp1"
+
+ component_store = ComponentStore.create("components")
-def test_component_store_init_component_state(mock_session_state):
- component_store = ComponentStore()
- component_store.init_component_state("comp1", {"state_key": "state_value"})
- assert component_store.get_property("comp1", "state_key") == "state_value"
+ first = MockComponent()
+ second = MockComponent()
+ component_store.init_component(first)
+ component_store.init_component(second)
-def test_component_store_set_get_property(mock_session_state):
- component_store = ComponentStore()
- component_store.init_component_state("comp1", {})
- component_store.set_property("comp1", "prop", "value")
- assert component_store.get_property("comp1", "prop") == "value"
+ # Should keep the first component
+ assert component_store.get_component("comp1") is first
+
+
+def test_component_store_get_missing_component_raises(mock_session_state):
+ component_store = ComponentStore.create("components")
+
+ with pytest.raises(
+ KeyError, match="'missing' doesn't exist in store 'components'."
+ ):
+ component_store.get_component("missing")